diff --git a/library/_compile_test.py b/library/_compile_test.py index c48488974..2f10ccedb 100644 --- a/library/_compile_test.py +++ b/library/_compile_test.py @@ -1486,5 +1486,89 @@ async def foo(): # TODO(emacs): Test with (multiple context managers) +@pyro_only +class OptStoreFastTests(unittest.TestCase): + def test_store_with_load_not_replaced(self): + source = """ +def foo(): + _ = 123 + return _ +""" + func = compile_function(source, "foo") + self.assertEqual( + dis(func.__code__), + """\ +LOAD_CONST 123 +STORE_FAST_REVERSE _ +LOAD_FAST_REVERSE_UNCHECKED _ +RETURN_VALUE +""", + ) + self.assertEqual(func(), 123) + + def test_store_with_no_load_replaced_with_pop_top(self): + source = """ +def foo(): + _ = 123 + return 456 +""" + func = compile_function(source, "foo") + self.assertEqual( + dis(func.__code__), + """\ +LOAD_CONST 123 +POP_TOP +LOAD_CONST 456 +RETURN_VALUE +""", + ) + self.assertEqual(func(), 456) + + def test_store_in_loop_replaced_with_pop_top(self): + source = """ +def foo(x): + for _ in x: + pass +""" + func = compile_function(source, "foo") + self.assertEqual( + dis(func.__code__), + """\ +LOAD_FAST_REVERSE_UNCHECKED x +GET_ITER +FOR_ITER 4 +POP_TOP +JUMP_ABSOLUTE 4 +LOAD_CONST None +RETURN_VALUE +""", + ) + self.assertEqual(func(()), None) + + def test_multiple_store_no_read_replaced_with_pop_top(self): + source = """ +def foo(): + x = 123 + y = 456 + z = 789 + return x +""" + func = compile_function(source, "foo") + self.assertEqual( + dis(func.__code__), + """\ +LOAD_CONST 123 +STORE_FAST_REVERSE x +LOAD_CONST 456 +POP_TOP +LOAD_CONST 789 +POP_TOP +LOAD_FAST_REVERSE_UNCHECKED x +RETURN_VALUE +""", + ) + self.assertEqual(func(), 123) + + if __name__ == "__main__": unittest.main() diff --git a/library/_compiler.py b/library/_compiler.py index 6dfc1324b..17db4baf8 100644 --- a/library/_compiler.py +++ b/library/_compiler.py @@ -211,6 +211,25 @@ def visitBinOp(self, node: ast.BinOp) -> ast.expr: class PyroFlowGraph(PyFlowGraph38): opcode = opcodepyro.opcode + def optimizeStoreFast(self): + if "locals" in self.varnames or "locals" in self.names: + # A bit of a hack: if someone is using locals(), we shouldn't mess + # with them. + return + used = set() + for block in self.getBlocksInOrder(): + for instr in block.getInstructions(): + if instr.opname == "LOAD_FAST" or instr.opname == "DELETE_FAST": + used.add(instr.oparg) + # We never read from or delete the local, so we can replace all stores + # to it with POP_TOP. + for block in self.getBlocksInOrder(): + for instr in block.getInstructions(): + if instr.opname == "STORE_FAST" and instr.oparg not in used: + instr.opname = "POP_TOP" + instr.oparg = 0 + instr.ioparg = 0 + def optimizeLoadFast(self): blocks = self.getBlocksInOrder() preds = tuple(set() for i in range(self.block_count)) @@ -297,6 +316,7 @@ def process_one_block(block, modify=False): self.entry.insts = deletes + self.entry.insts def getCode(self): + self.optimizeStoreFast() self.optimizeLoadFast() return super().getCode()