Skip to content

[mypyc] Speed up and improve multiple assignment #9800

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Dec 29, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 35 additions & 5 deletions mypyc/irbuild/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,15 +39,15 @@
from mypyc.ir.rtypes import (
RType, RTuple, RInstance, int_rprimitive, dict_rprimitive,
none_rprimitive, is_none_rprimitive, object_rprimitive, is_object_rprimitive,
str_rprimitive, is_tagged
str_rprimitive, is_tagged, is_list_rprimitive, is_tuple_rprimitive, c_pyssize_t_rprimitive
)
from mypyc.ir.func_ir import FuncIR, INVALID_FUNC_DEF
from mypyc.ir.class_ir import ClassIR, NonExtClassInfo
from mypyc.primitives.registry import CFunctionDescription, function_ops
from mypyc.primitives.list_ops import to_list, list_pop_last
from mypyc.primitives.list_ops import to_list, list_pop_last, list_get_item_unsafe_op
from mypyc.primitives.dict_ops import dict_get_item_op, dict_set_item_op
from mypyc.primitives.generic_ops import py_setattr_op, iter_op, next_op
from mypyc.primitives.misc_ops import import_op
from mypyc.primitives.misc_ops import import_op, check_unpack_count_op
from mypyc.crash import catch_errors
from mypyc.options import CompilerOptions
from mypyc.errors import Errors
Expand Down Expand Up @@ -465,8 +465,10 @@ def read(self, target: Union[Value, AssignmentTarget], line: int = -1) -> Value:

assert False, 'Unsupported lvalue: %r' % target

def assign(self, target: Union[Register, AssignmentTarget],
rvalue_reg: Value, line: int) -> None:
def assign(self,
target: Union[Register, AssignmentTarget],
rvalue_reg: Value,
line: int) -> None:
if isinstance(target, Register):
self.add(Assign(target, rvalue_reg))
elif isinstance(target, AssignmentTargetRegister):
Expand All @@ -491,11 +493,39 @@ def assign(self, target: Union[Register, AssignmentTarget],
for i in range(len(rtypes)):
item_value = self.add(TupleGet(rvalue_reg, i, line))
self.assign(target.items[i], item_value, line)
elif ((is_list_rprimitive(rvalue_reg.type) or is_tuple_rprimitive(rvalue_reg.type))
and target.star_idx is None):
self.process_sequence_assignment(target, rvalue_reg, line)
else:
self.process_iterator_tuple_assignment(target, rvalue_reg, line)
else:
assert False, 'Unsupported assignment target'

def process_sequence_assignment(self,
target: AssignmentTargetTuple,
rvalue: Value,
line: int) -> None:
"""Process assignment like 'x, y = s', where s is a variable-length list or tuple."""
# Check the length of sequence.
expected_len = self.add(LoadInt(len(target.items), rtype=c_pyssize_t_rprimitive))
self.builder.call_c(check_unpack_count_op, [rvalue, expected_len], line)

# Read sequence items.
values = []
for i in range(len(target.items)):
item = target.items[i]
index = self.builder.load_static_int(i)
if is_list_rprimitive(rvalue.type):
item_value = self.call_c(list_get_item_unsafe_op, [rvalue, index], line)
else:
item_value = self.builder.gen_method_call(
rvalue, '__getitem__', [index], item.type, line)
values.append(item_value)

# Assign sequence items to the target lvalues.
for lvalue, value in zip(target.items, values):
self.assign(lvalue, value, line)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The iterator code below assigns as it goes. Do we know which Python does? It's all pretty marginal, and if we wanted to say we aren't aiming to exactly replicate execution order in some marginal cases, that's probably fine, but we should do it with eyes open. (In general I've tried to match things but I know we mismatch in at least one case with method calls.)

I guess in this case we wouldn't be able to do unsafe get items if we were assigning as we went (because it could get modified.)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CPython reads all the values before performing the assignments, so this brings the semantics a bit closer to Python. I'll create an issue about updating the generic iterable case to also match Python semantics.


def process_iterator_tuple_assignment_helper(self,
litem: AssignmentTarget,
ritem: Value, line: int) -> None:
Expand Down
31 changes: 20 additions & 11 deletions mypyc/irbuild/statement.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
from mypy.nodes import (
Block, ExpressionStmt, ReturnStmt, AssignmentStmt, OperatorAssignmentStmt, IfStmt, WhileStmt,
ForStmt, BreakStmt, ContinueStmt, RaiseStmt, TryStmt, WithStmt, AssertStmt, DelStmt,
Expression, StrExpr, TempNode, Lvalue, Import, ImportFrom, ImportAll, TupleExpr
Expression, StrExpr, TempNode, Lvalue, Import, ImportFrom, ImportAll, TupleExpr, ListExpr,
StarExpr
)

from mypyc.ir.ops import (
Expand Down Expand Up @@ -69,39 +70,47 @@ def transform_return_stmt(builder: IRBuilder, stmt: ReturnStmt) -> None:


def transform_assignment_stmt(builder: IRBuilder, stmt: AssignmentStmt) -> None:
assert len(stmt.lvalues) >= 1
builder.disallow_class_assignments(stmt.lvalues, stmt.line)
lvalue = stmt.lvalues[0]
lvalues = stmt.lvalues
assert len(lvalues) >= 1
builder.disallow_class_assignments(lvalues, stmt.line)
first_lvalue = lvalues[0]
if stmt.type and isinstance(stmt.rvalue, TempNode):
# This is actually a variable annotation without initializer. Don't generate
# an assignment but we need to call get_assignment_target since it adds a
# name binding as a side effect.
builder.get_assignment_target(lvalue, stmt.line)
builder.get_assignment_target(first_lvalue, stmt.line)
return

# multiple assignment
if (isinstance(lvalue, TupleExpr) and isinstance(stmt.rvalue, TupleExpr)
and len(lvalue.items) == len(stmt.rvalue.items)):
# Special case multiple assignments like 'x, y = e1, e2'.
if (isinstance(first_lvalue, (TupleExpr, ListExpr))
and isinstance(stmt.rvalue, (TupleExpr, ListExpr))
and len(first_lvalue.items) == len(stmt.rvalue.items)
and all(is_simple_lvalue(item) for item in first_lvalue.items)
and len(lvalues) == 1):
temps = []
for right in stmt.rvalue.items:
rvalue_reg = builder.accept(right)
temp = builder.alloc_temp(rvalue_reg.type)
builder.assign(temp, rvalue_reg, stmt.line)
temps.append(temp)
for (left, temp) in zip(lvalue.items, temps):
for (left, temp) in zip(first_lvalue.items, temps):
assignment_target = builder.get_assignment_target(left)
builder.assign(assignment_target, temp, stmt.line)
return

line = stmt.rvalue.line
rvalue_reg = builder.accept(stmt.rvalue)
if builder.non_function_scope() and stmt.is_final_def:
builder.init_final_static(lvalue, rvalue_reg)
for lvalue in stmt.lvalues:
builder.init_final_static(first_lvalue, rvalue_reg)
for lvalue in lvalues:
target = builder.get_assignment_target(lvalue)
builder.assign(target, rvalue_reg, line)


def is_simple_lvalue(expr: Expression) -> bool:
return not isinstance(expr, (StarExpr, ListExpr, TupleExpr))


def transform_operator_assignment_stmt(builder: IRBuilder, stmt: OperatorAssignmentStmt) -> None:
"""Operator assignment statement such as x += 1"""
builder.disallow_class_assignments([stmt.lvalue], stmt.line)
Expand Down
1 change: 1 addition & 0 deletions mypyc/lib-rt/CPy.h
Original file line number Diff line number Diff line change
Expand Up @@ -489,6 +489,7 @@ void CPyDebug_Print(const char *msg);
void CPy_Init(void);
int CPyArg_ParseTupleAndKeywords(PyObject *, PyObject *,
const char *, char **, ...);
int CPySequence_CheckUnpackCount(PyObject *sequence, Py_ssize_t expected);


#ifdef __cplusplus
Expand Down
14 changes: 14 additions & 0 deletions mypyc/lib-rt/misc_ops.c
Original file line number Diff line number Diff line change
Expand Up @@ -495,3 +495,17 @@ void CPyDebug_Print(const char *msg) {
printf("%s\n", msg);
fflush(stdout);
}

int CPySequence_CheckUnpackCount(PyObject *sequence, Py_ssize_t expected) {
Py_ssize_t actual = Py_SIZE(sequence);
if (unlikely(actual != expected)) {
if (actual < expected) {
PyErr_Format(PyExc_ValueError, "not enough values to unpack (expected %zd, got %zd)",
expected, actual);
} else {
PyErr_Format(PyExc_ValueError, "too many values to unpack (expected %zd)", expected);
}
return -1;
}
return 0;
}
10 changes: 9 additions & 1 deletion mypyc/primitives/misc_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from mypyc.ir.ops import ERR_NEVER, ERR_MAGIC, ERR_FALSE
from mypyc.ir.rtypes import (
bool_rprimitive, object_rprimitive, str_rprimitive, object_pointer_rprimitive,
int_rprimitive, dict_rprimitive, c_int_rprimitive, bit_rprimitive
int_rprimitive, dict_rprimitive, c_int_rprimitive, bit_rprimitive, c_pyssize_t_rprimitive
)
from mypyc.primitives.registry import (
function_op, custom_op, load_address_op, ERR_NEG_INT
Expand Down Expand Up @@ -176,3 +176,11 @@
return_type=bit_rprimitive,
c_function_name='CPyDataclass_SleightOfHand',
error_kind=ERR_FALSE)

# Raise ValueError if length of first argument is not equal to the second argument.
# The first argument must be a list or a variable-length tuple.
check_unpack_count_op = custom_op(
arg_types=[object_rprimitive, c_pyssize_t_rprimitive],
return_type=c_int_rprimitive,
c_function_name='CPySequence_CheckUnpackCount',
error_kind=ERR_NEG_INT)
42 changes: 0 additions & 42 deletions mypyc/test-data/irbuild-basic.test
Original file line number Diff line number Diff line change
Expand Up @@ -3596,48 +3596,6 @@ L0:
r2 = truncate r0: int32 to builtins.bool
return r2

[case testMultipleAssignment]
from typing import Tuple

def f(x: int, y: int) -> Tuple[int, int]:
x, y = y, x
return (x, y)

def f2(x: int, y: str, z: float) -> Tuple[float, str, int]:
a, b, c = x, y, z
return (c, b, a)
[out]
def f(x, y):
x, y, r0, r1 :: int
r2 :: tuple[int, int]
L0:
r0 = y
r1 = x
x = r0
y = r1
r2 = (x, y)
return r2
def f2(x, y, z):
x :: int
y :: str
z :: float
r0 :: int
r1 :: str
r2 :: float
a :: int
b :: str
c :: float
r3 :: tuple[float, str, int]
L0:
r0 = x
r1 = y
r2 = z
a = r0
b = r1
c = r2
r3 = (c, b, a)
return r3

[case testLocalImportSubmodule]
def f() -> int:
import p.m
Expand Down
97 changes: 96 additions & 1 deletion mypyc/test-data/irbuild-statements.test
Original file line number Diff line number Diff line change
Expand Up @@ -454,7 +454,63 @@ L9:
L10:
return s

[case testMultipleAssignment]
[case testMultipleAssignmentWithNoUnpacking]
from typing import Tuple

def f(x: int, y: int) -> Tuple[int, int]:
x, y = y, x
return (x, y)

def f2(x: int, y: str, z: float) -> Tuple[float, str, int]:
a, b, c = x, y, z
return (c, b, a)

def f3(x: int, y: int) -> Tuple[int, int]:
[x, y] = [y, x]
return (x, y)
[out]
def f(x, y):
x, y, r0, r1 :: int
r2 :: tuple[int, int]
L0:
r0 = y
r1 = x
x = r0
y = r1
r2 = (x, y)
return r2
def f2(x, y, z):
x :: int
y :: str
z :: float
r0 :: int
r1 :: str
r2 :: float
a :: int
b :: str
c :: float
r3 :: tuple[float, str, int]
L0:
r0 = x
r1 = y
r2 = z
a = r0
b = r1
c = r2
r3 = (c, b, a)
return r3
def f3(x, y):
x, y, r0, r1 :: int
r2 :: tuple[int, int]
L0:
r0 = y
r1 = x
x = r0
y = r1
r2 = (x, y)
return r2

[case testMultipleAssignmentBasicUnpacking]
from typing import Tuple, Any

def from_tuple(t: Tuple[int, str]) -> None:
Expand Down Expand Up @@ -599,6 +655,45 @@ L0:
z = r6
return 1

[case testMultipleAssignmentUnpackFromSequence]
from typing import List, Tuple

def f(l: List[int], t: Tuple[int, ...]) -> None:
x: object
y: int
x, y = l
x, y = t
[out]
def f(l, t):
l :: list
t :: tuple
x :: object
y :: int
r0 :: int32
r1 :: bit
r2, r3 :: object
r4 :: int
r5 :: int32
r6 :: bit
r7, r8 :: object
r9 :: int
L0:
r0 = CPySequence_CheckUnpackCount(l, 2)
r1 = r0 >= 0 :: signed
r2 = CPyList_GetItemUnsafe(l, 0)
r3 = CPyList_GetItemUnsafe(l, 2)
x = r2
r4 = unbox(int, r3)
y = r4
r5 = CPySequence_CheckUnpackCount(t, 2)
r6 = r5 >= 0 :: signed
r7 = CPySequenceTuple_GetItem(t, 0)
r8 = CPySequenceTuple_GetItem(t, 2)
r9 = unbox(int, r8)
x = r7
y = r9
return 1

[case testAssert]
from typing import Optional

Expand Down
Loading