diff --git a/library/_compile_test.py b/library/_compile_test.py index c48488974..a112a150f 100644 --- a/library/_compile_test.py +++ b/library/_compile_test.py @@ -811,7 +811,6 @@ def foo(x): """\ LOAD_FAST_REVERSE_UNCHECKED x STORE_FAST_REVERSE y -DELETE_FAST x LOAD_FAST_REVERSE_UNCHECKED y RETURN_VALUE """, @@ -1485,6 +1484,250 @@ async def foo(): # TODO(emacs): Test with (multiple context managers) + def test_dead_local_store_is_removed(self): + source = """ +def foo(): + x = 123 +""" + func = compile_function(source, "foo") + self.assertEqual( + dis(func.__code__), + """\ +LOAD_CONST 123 +POP_TOP +LOAD_CONST None +RETURN_VALUE +""", + ) + self.assertEqual(func(), None) + + def test_dead_local_store_in_one_branch_is_removed(self): + source = """ +def foo(cond): + if cond: + x = 123 +""" + func = compile_function(source, "foo") + self.assertEqual( + dis(func.__code__), + """\ +LOAD_FAST_REVERSE_UNCHECKED cond +POP_JUMP_IF_FALSE 8 +LOAD_CONST 123 +POP_TOP +LOAD_CONST None +RETURN_VALUE +""", + ) + self.assertEqual(func(True), None) + + def test_dead_local_store_in_all_branches_is_removed(self): + source = """ +def foo(cond): + if cond: + x = 123 + else: + x = 456 +""" + func = compile_function(source, "foo") + self.assertEqual( + dis(func.__code__), + """\ +LOAD_FAST_REVERSE_UNCHECKED cond +POP_JUMP_IF_FALSE 10 +LOAD_CONST 123 +POP_TOP +JUMP_FORWARD 4 +LOAD_CONST 456 +POP_TOP +LOAD_CONST None +RETURN_VALUE +""", + ) + self.assertEqual(func(True), None) + + def test_local_store_used_in_same_branch_removed_in_other_branch(self): + source = """ +def foo(cond): + if cond: + x = 123 + f(x) + else: + x = 456 +""" + func = compile_function(source, "foo") + self.assertEqual( + dis(func.__code__), + """\ +LOAD_FAST_REVERSE_UNCHECKED cond +POP_JUMP_IF_FALSE 18 +LOAD_CONST 123 +STORE_FAST_REVERSE x +LOAD_GLOBAL f +LOAD_FAST_REVERSE_UNCHECKED x +CALL_FUNCTION 1 +POP_TOP +JUMP_FORWARD 4 +LOAD_CONST 456 +POP_TOP +LOAD_CONST None +RETURN_VALUE +""", + ) + self.assertEqual(func(False), None) + + def test_store_before_store_removed(self): + source = """ +def foo(): + x = 2 + x = 3 + return x +""" + func = compile_function(source, "foo") + self.assertEqual( + dis(func.__code__), + """\ +LOAD_CONST 2 +POP_TOP +LOAD_CONST 3 +STORE_FAST_REVERSE x +LOAD_FAST_REVERSE_UNCHECKED x +RETURN_VALUE +""", + ) + self.assertEqual(func(), 3) + + def test_store_before_del_removed(self): + source = """ +def foo(): + x = 2 + del x +""" + func = compile_function(source, "foo") + self.assertEqual( + dis(func.__code__), + """\ +LOAD_CONST 2 +POP_TOP +LOAD_CONST None +RETURN_VALUE +""", + ) + self.assertEqual(func(), None) + + @unittest.skip("TODO: Figure out how to leave one DELETE_FAST") + def test_del_before_del_leaves_one_del(self): + source = """ +def foo(): + x = 2 + del x + del x +""" + func = compile_function(source, "foo") + self.assertEqual( + dis(func.__code__), + """\ +LOAD_CONST 2 +POP_TOP +DELETE_FAST x +LOAD_CONST None +RETURN_VALUE +""", + ) + self.assertEqual(func(), None) + + def test_store_before_del_and_use_removed(self): + # TODO(max): Remove the unused STORE_FAST and DELETE_FAST + source = """ +def foo(): + x = 2 + del x + return x +""" + func = compile_function(source, "foo") + self.assertEqual( + dis(func.__code__), + """\ +DELETE_FAST_REVERSE_UNCHECKED x +LOAD_CONST 2 +STORE_FAST_REVERSE x +DELETE_FAST x +LOAD_FAST x +RETURN_VALUE +""", + ) + with self.assertRaises(UnboundLocalError): + func() + + @unittest.skip("TODO(max): Figure out why this test is failing") + def test_dead_local_store_used_in_other_branch_is_removed(self): + source = """ +def foo(cond): + if cond: + x = 123 + else: + f(x) +""" + func = compile_function(source, "foo") + self.assertEqual( + dis(func.__code__), + """\ +DELETE_FAST_REVERSE_UNCHECKED x +LOAD_FAST_REVERSE_UNCHECKED cond +POP_JUMP_IF_FALSE 12 +LOAD_CONST 123 +POP_TOP +JUMP_FORWARD 8 +LOAD_GLOBAL f +LOAD_FAST x +CALL_FUNCTION 1 +POP_TOP +LOAD_CONST None +RETURN_VALUE +""", + ) + self.assertEqual(func(True), None) + + def test_dead_store_not_removed_if_calling_locals_function(self): + source = """ +def foo(): + x = 123 + return locals() +""" + func = compile_function(source, "foo") + self.assertEqual( + dis(func.__code__), + """\ +LOAD_CONST 123 +STORE_FAST_REVERSE x +LOAD_GLOBAL locals +CALL_FUNCTION 0 +RETURN_VALUE +""") + self.assertEqual(func(), {"x": 123}) + + def test_store_self_removes_last_store(self): + # TODO(max): See if we can remove the first store too + source = """ +def foo(): + x = 123 + x = x +""" + func = compile_function(source, "foo") + self.assertEqual( + dis(func.__code__), + """\ +LOAD_CONST 123 +STORE_FAST_REVERSE x +LOAD_FAST_REVERSE_UNCHECKED x +POP_TOP +LOAD_CONST None +RETURN_VALUE +""") + self.assertEqual(func(), None) + + # TODO(max): Test loops + if __name__ == "__main__": unittest.main() diff --git a/library/_compiler.py b/library/_compiler.py index 6dfc1324b..c1853dc79 100644 --- a/library/_compiler.py +++ b/library/_compiler.py @@ -212,6 +212,10 @@ class PyroFlowGraph(PyFlowGraph38): opcode = opcodepyro.opcode def optimizeLoadFast(self): + # TODO(max): Bail out early if gen/coro/asyncgen/itercoro? + # TODO(max): Bail out early if exception handling opcodes + # TODO(max): Make edges between all opcodes, not just basic blocks + # TODO(max): Profile number of iterations until fixpoint blocks = self.getBlocksInOrder() preds = tuple(set() for i in range(self.block_count)) for block in blocks: @@ -277,10 +281,12 @@ def process_one_block(block, modify=False): return True changed = True + num_iterations = 0 while changed: changed = False for block in blocks: changed |= process_one_block(block) + num_iterations += 1 for block in blocks: process_one_block(block, modify=True) @@ -296,8 +302,104 @@ def process_one_block(block, modify=False): ] self.entry.insts = deletes + self.entry.insts + def getInstructions(self): + for block in self.getBlocksInOrder(): + for instr in block.getInstructions(): + yield instr + + def optimizeDeadStores(self): + all_instrs = self.getInstructions() + if any( + instr.opname + in ( + # Exception handling opcodes + "POP_BLOCK", + "SETUP_ASYNC_WITH", + "SETUP_FINALLY", + "SETUP_WITH", + "WITH_CLEANUP_START", + "YIELD_FROM", + "YIELD_VALUE", + "END_ASYNC_FOR", + ) + for instr in all_instrs + ): + return + if "locals" in self.names: + # This is hack to avoid optimizing away dead locals in the presence + # of one particular kind of call to locals(). + return + # TODO(max): Bail out early if gen/coro/asyncgen/itercoro? + # TODO(max): Make edges between all opcodes, not just basic blocks + # TODO(max): Profile number of iterations until fixpoint + blocks = self.getBlocksInOrder() + preds = tuple(set() for i in range(self.block_count)) + succs = tuple(set() for i in range(self.block_count)) + for block in blocks: + for child in block.get_children(): + if child is not None: + preds[child.bid].add(block.bid) + succs[block.bid].add(child.bid) + + num_locals = len(self.varnames) + Top = 0 + live_out = [Top] * self.block_count + total_locals = num_locals + len(self.cellvars) + len(self.freevars) + + def reverse_local_idx(idx): + return total_locals - idx - 1 + + def meet(args): + result = Top + for arg in args: + result |= arg + return result + + def process_one_block(block, modify=False): + bid = block.bid + if len(succs[bid]) == 0: + live = Top + else: + live = meet(live_out[succ] for succ in succs[bid]) + for instr in reversed(block.getInstructions()): + if ( + instr.opname == "DELETE_FAST" + and modify + and not (live & (1 << instr.ioparg)) + ): + instr.opname = "NOP" + instr.oparg = None + instr.ioparg = 0 + live &= ~(1 << instr.ioparg) + if instr.opname == "LOAD_FAST" or instr.opname == "DELETE_FAST": + live |= 1 << instr.ioparg + elif instr.opname == "STORE_FAST": + if modify and not (live & (1 << instr.ioparg)): + instr.opname = "POP_TOP" + instr.oparg = None + instr.ioparg = 0 + live &= ~(1 << instr.ioparg) + if live == live_out[bid]: + return False + live_out[bid] = live + return True + + changed = True + num_iterations = 0 + while changed: + changed = False + for block in blocks: + changed |= process_one_block(block) + num_iterations += 1 + + for block in blocks: + process_one_block(block, modify=True) + def getCode(self): - self.optimizeLoadFast() + # Do this first; it can't (yet?) handle LOAD_FAST_REVERSE_UNCHECKED et + # al. + # self.optimizeDeadStores() + # self.optimizeLoadFast() return super().getCode() diff --git a/runtime/bytecode-test.cpp b/runtime/bytecode-test.cpp index ac9fe498c..b19cf766e 100644 --- a/runtime/bytecode-test.cpp +++ b/runtime/bytecode-test.cpp @@ -765,8 +765,9 @@ TEST_F( word argcount = 1; word nlocals = 3; const byte bytecode[] = { - LOAD_FAST, 2, LOAD_FAST, 1, LOAD_FAST, 0, STORE_FAST, 2, - STORE_FAST, 1, STORE_FAST, 0, DELETE_FAST, 0, + LOAD_FAST, 2, LOAD_FAST, 1, LOAD_FAST, 0, + STORE_FAST, 2, STORE_FAST, 1, STORE_FAST, 0, + DELETE_FAST, 0, RETURN_VALUE, 0, }; Bytes code_code(&scope, runtime_->newBytesWithAll(bytecode)); Object empty_tuple(&scope, runtime_->emptyTuple()); @@ -791,7 +792,7 @@ TEST_F( LOAD_FAST_REVERSE, 2, 0, 0, LOAD_FAST_REVERSE, 3, 0, 0, LOAD_FAST_REVERSE, 4, 0, 0, STORE_FAST_REVERSE, 2, 0, 0, STORE_FAST_REVERSE, 3, 0, 0, STORE_FAST_REVERSE, 4, 0, 0, - DELETE_FAST, 0, 0, 0, + DELETE_FAST, 0, 0, 0, RETURN_VALUE, 0, 0, 0, }; Object rewritten_bytecode(&scope, function.rewrittenBytecode()); EXPECT_TRUE(isMutableBytesEqualsBytes(rewritten_bytecode, expected));