Skip to content

[mypyc] Support yields while values are live #16305

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 15 commits into from
Apr 22, 2025
Prev Previous commit
Next Next commit
WIP: start on spilling
  • Loading branch information
msullivan committed Oct 21, 2023
commit 7efb8dbc2dab63226fe1449328bccfaabf659fc8
20 changes: 15 additions & 5 deletions mypyc/analysis/dataflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
Cast,
ComparisonOp,
ControlOp,
DecRef,
Extend,
Float,
FloatComparisonOp,
Expand All @@ -25,6 +26,7 @@
GetAttr,
GetElementPtr,
Goto,
IncRef,
InitStatic,
Integer,
IntOp,
Expand Down Expand Up @@ -79,12 +81,11 @@ def __str__(self) -> str:
return "\n".join(lines)


def get_cfg(blocks: list[BasicBlock]) -> CFG:
def get_cfg(blocks: list[BasicBlock], *, use_yields: bool = False) -> CFG:
"""Calculate basic block control-flow graph.

The result is a dictionary like this:

basic block index -> (successors blocks, predecesssor blocks)
If use_yields is set, then we treat returns inserted by yields as gotos
instead of exits.
"""
succ_map = {}
pred_map: dict[BasicBlock, list[BasicBlock]] = {}
Expand All @@ -94,7 +95,10 @@ def get_cfg(blocks: list[BasicBlock]) -> CFG:
isinstance(op, ControlOp) for op in block.ops[:-1]
), "Control-flow ops must be at the end of blocks"

succ = list(block.terminator.targets())
if use_yields and isinstance(block.terminator, Return) and block.terminator.yield_target:
succ = [block.terminator.yield_target]
else:
succ = list(block.terminator.targets())
if not succ:
exits.add(block)

Expand Down Expand Up @@ -494,6 +498,12 @@ def visit_assign_multi(self, op: AssignMulti) -> GenAndKill[Value]:
def visit_set_mem(self, op: SetMem) -> GenAndKill[Value]:
return non_trivial_sources(op), set()

def visit_inc_ref(self, op: IncRef) -> GenAndKill[Value]:
return set(), set()

def visit_dec_ref(self, op: DecRef) -> GenAndKill[Value]:
return set(), set()


def analyze_live_regs(blocks: list[BasicBlock], cfg: CFG) -> AnalysisResult[Value]:
"""Calculate live registers at each CFG location.
Expand Down
9 changes: 9 additions & 0 deletions mypyc/codegen/emitmodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
from mypyc.options import CompilerOptions
from mypyc.transform.exceptions import insert_exception_handling
from mypyc.transform.refcount import insert_ref_count_opcodes
from mypyc.transform.spill import insert_spills
from mypyc.transform.uninit import insert_uninit_checks

# All of the modules being compiled are divided into "groups". A group
Expand Down Expand Up @@ -225,6 +226,10 @@ def compile_scc_to_ir(
if errors.num_errors > 0:
return modules

# XXX: HOW WILL WE DEAL WITH REFCOUNTING ON THE SPILLAGE
# DO WE DO IT... LAST? MAYBE MAYBE MAYBE YES
# ONLY DO UNINIT.... YEAH OK

# Insert uninit checks.
for module in modules.values():
for fn in module.functions:
Expand All @@ -237,6 +242,10 @@ def compile_scc_to_ir(
for module in modules.values():
for fn in module.functions:
insert_ref_count_opcodes(fn)
for module in modules.values():
for cls in module.classes:
if cls.env_user_function:
insert_spills(cls.env_user_function, cls)

return modules

Expand Down
7 changes: 7 additions & 0 deletions mypyc/ir/class_ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,9 @@ def __init__(
# value of an attribute is the same as the error value.
self.bitmap_attrs: list[str] = []

# If this is a generator environment class, what is the actual method for it
self.env_user_function: FuncIR | None = None

def __repr__(self) -> str:
return (
"ClassIR("
Expand Down Expand Up @@ -391,6 +394,7 @@ def serialize(self) -> JsonDict:
"_always_initialized_attrs": sorted(self._always_initialized_attrs),
"_sometimes_initialized_attrs": sorted(self._sometimes_initialized_attrs),
"init_self_leak": self.init_self_leak,
"env_user_function": self.env_user_function.id if self.env_user_function else None,
}

@classmethod
Expand Down Expand Up @@ -442,6 +446,9 @@ def deserialize(cls, data: JsonDict, ctx: DeserMaps) -> ClassIR:
ir._always_initialized_attrs = set(data["_always_initialized_attrs"])
ir._sometimes_initialized_attrs = set(data["_sometimes_initialized_attrs"])
ir.init_self_leak = data["init_self_leak"]
ir.env_user_function = (
ctx.functions[data["env_user_function"]] if data["env_user_function"] else None
)

return ir

Expand Down
9 changes: 8 additions & 1 deletion mypyc/ir/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -430,9 +430,16 @@ class Return(ControlOp):

error_kind = ERR_NEVER

def __init__(self, value: Value, line: int = -1) -> None:
def __init__(
self, value: Value, line: int = -1, *, yield_target: BasicBlock | None = None
) -> None:
super().__init__(line)
self.value = value
# If this return is created by a yield, keep track of the next
# basic block. This doesn't affect the code we generate but
# can feed into analysis that need to understand the
# *original* CFG.
self.yield_target = yield_target

def sources(self) -> list[Value]:
return [self.value]
Expand Down
1 change: 1 addition & 0 deletions mypyc/irbuild/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,7 @@ def c() -> None:
# Re-enter the FuncItem and visit the body of the function this time.
builder.enter(fn_info)
setup_env_for_generator_class(builder)

load_outer_envs(builder, builder.fn_info.generator_class)
top_level = builder.top_level_fn_info()
if (
Expand Down
2 changes: 2 additions & 0 deletions mypyc/irbuild/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,8 @@ def add_helper_to_generator_class(
)
fn_info.generator_class.ir.methods["__mypyc_generator_helper__"] = helper_fn_ir
builder.functions.append(helper_fn_ir)
fn_info.env_class.env_user_function = helper_fn_ir

return helper_fn_decl


Expand Down
2 changes: 1 addition & 1 deletion mypyc/irbuild/statement.py
Original file line number Diff line number Diff line change
Expand Up @@ -894,7 +894,7 @@ def emit_yield(builder: IRBuilder, val: Value, line: int) -> Value:
next_label = len(cls.continuation_blocks)
cls.continuation_blocks.append(next_block)
builder.assign(cls.next_label_target, Integer(next_label), line)
builder.add(Return(retval))
builder.add(Return(retval, yield_target=next_block))
builder.activate_block(next_block)

add_raise_exception_blocks_to_generator_class(builder, line)
Expand Down
18 changes: 18 additions & 0 deletions mypyc/test-data/run-generators.test
Original file line number Diff line number Diff line change
Expand Up @@ -679,3 +679,21 @@ def test_basic() -> None:
with context:
assert context.x == 1
assert context.x == 0


[case testYieldSpill]
from typing import Generator

def f() -> int:
return 1

def yield_spill() -> Generator[str, int, int]:
return f() + (yield "foo")

[file driver.py]
from native import yield_spill
from testutil import run_generator

yields, val = run_generator(yield_spill(), [2])
assert yields == ('foo',)
assert val == 3, val
108 changes: 108 additions & 0 deletions mypyc/transform/spill.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
"""Insert spills for values that are live across yields."""

from __future__ import annotations

from mypyc.analysis.dataflow import AnalysisResult, analyze_live_regs, get_cfg
from mypyc.common import TEMP_ATTR_NAME
from mypyc.ir.class_ir import ClassIR
from mypyc.ir.func_ir import FuncIR
from mypyc.ir.ops import (
BasicBlock,
Branch,
DecRef,
GetAttr,
IncRef,
LoadErrorValue,
Register,
SetAttr,
Value,
)


def insert_spills(ir: FuncIR, env: ClassIR) -> None:
cfg = get_cfg(ir.blocks, use_yields=True)
live = analyze_live_regs(ir.blocks, cfg)
entry_live = live.before[ir.blocks[0], 0]

# from mypyc.ir.pprint import format_func

# print('\n'.join(format_func(ir)))

entry_live = {op for op in entry_live if not (isinstance(op, Register) and op.is_arg)}
# XXX: Actually for now, no Registers at all -- we keep the manual spills
entry_live = {op for op in entry_live if not isinstance(op, Register)}

ir.blocks = spill_regs(ir.blocks, env, entry_live, live)
# print("\n".join(format_func(ir)))
# print("\n\n\n=========")


def spill_regs(
blocks: list[BasicBlock], env: ClassIR, to_spill: set[Value], live: AnalysisResult[Value]
) -> list[BasicBlock]:
for op in blocks[0].ops:
if isinstance(op, GetAttr) and op.attr == "__mypyc_env__":
env_reg = op
break
else:
raise AssertionError("could not find __mypyc_env__")

spill_locs = {}
for i, val in enumerate(to_spill):
name = f"{TEMP_ATTR_NAME}2_{i}"
env.attributes[name] = val.type
spill_locs[val] = name

for block in blocks:
ops = block.ops
block.ops = []

for i, op in enumerate(ops):
to_decref = []

if isinstance(op, IncRef) and op.src in spill_locs:
raise AssertionError("not sure what to do with an incref of a spill...")
if isinstance(op, DecRef) and op.src in spill_locs:
# When we decref a spilled value, we turn that into
# NULLing out the attribute, but only if the spilled
# value is not live *when we include yields in the
# CFG*. (The original decrefs are computed without that.)
#
# We also skip a decref is the env register is not
# live. That should only happen when an exception is
# being raised, so everything should be handled there.
if op.src not in live.after[block, i] and env_reg in live.after[block, i]:
# Skip the DecRef but null out the spilled location
null = LoadErrorValue(op.src.type)
block.ops.extend([null, SetAttr(env_reg, spill_locs[op.src], null, op.line)])
continue

if (
any(src in spill_locs for src in op.sources())
# N.B: IS_ERROR should be before a spill happens
# XXX: but could we have a regular branch?
and not (isinstance(op, Branch) and op.op == Branch.IS_ERROR)
):
new_sources: list[Value] = []
for src in op.sources():
if src in spill_locs:
read = GetAttr(env_reg, spill_locs[src], op.line)
block.ops.append(read)
new_sources.append(read)
if src.type.is_refcounted:
to_decref.append(read)
else:
new_sources.append(src)

op.set_sources(new_sources)

block.ops.append(op)

for dec in to_decref:
block.ops.append(DecRef(dec))

if op in spill_locs:
# XXX: could we set uninit?
block.ops.append(SetAttr(env_reg, spill_locs[op], op, op.line))

return blocks