Skip to content

[mypyc] Generate smaller code for casts #12839

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
May 23, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
[mypyc] Simplify C code generated for casts
  • Loading branch information
JukkaL committed May 20, 2022
commit 5e9938b919f3a41d6a960f731f48325f62d7691c
120 changes: 82 additions & 38 deletions mypyc/codegen/emit.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

from mypy.backports import OrderedDict
from typing import List, Set, Dict, Optional, Callable, Union, Tuple
from typing_extensions import Final

import sys

from mypyc.common import (
Expand All @@ -23,6 +25,10 @@
from mypyc.sametype import is_same_type
from mypyc.codegen.literals import Literals

# Whether to insert debug asserts for all error handling, to quickly
# catch errors propagating without exceptions set.
DEBUG_ERRORS: Final = False


class HeaderDeclaration:
"""A representation of a declaration in C.
Expand Down Expand Up @@ -104,6 +110,20 @@ def __init__(self, label: str) -> None:
self.label = label


class TracebackAndGotoHandler(ErrorHandler):
"""Add traceback item and goto label on error."""

def __init__(self,
label: str,
source_path: str,
module_name: str,
traceback_entry: Tuple[str, int]) -> None:
self.label = label
self.source_path = source_path
self.module_name = module_name
self.traceback_entry = traceback_entry


class ReturnHandler(ErrorHandler):
"""Return a constant value on error."""

Expand Down Expand Up @@ -439,18 +459,6 @@ def emit_cast(self,
likely: If the cast is likely to succeed (can be False for unions)
"""
error = error or AssignHandler()
if isinstance(error, AssignHandler):
handle_error = '%s = NULL;' % dest
elif isinstance(error, GotoHandler):
handle_error = 'goto %s;' % error.label
else:
assert isinstance(error, ReturnHandler)
handle_error = 'return %s;' % error.value
if raise_exception:
raise_exc = f'CPy_TypeError("{self.pretty_name(typ)}", {src}); '
err = raise_exc + handle_error
else:
err = handle_error

# Special case casting *from* optional
if src_type and is_optional_type(src_type) and not is_object_rprimitive(typ):
Expand All @@ -465,9 +473,9 @@ def emit_cast(self,
self.emit_arg_check(src, dest, typ, check.format(src), optional)
self.emit_lines(
f' {dest} = {src};',
'else {',
err,
'}')
'else {')
self.emit_cast_error_handler(error, src, dest, typ, raise_exception)
self.emit_line('}')
return

# TODO: Verify refcount handling.
Expand Down Expand Up @@ -500,9 +508,9 @@ def emit_cast(self,
self.emit_arg_check(src, dest, typ, check.format(prefix, src), optional)
self.emit_lines(
f' {dest} = {src};',
'else {',
err,
'}')
'else {')
self.emit_cast_error_handler(error, src, dest, typ, raise_exception)
self.emit_line('}')
elif is_bytes_rprimitive(typ):
if declare_dest:
self.emit_line(f'PyObject *{dest};')
Expand All @@ -512,9 +520,9 @@ def emit_cast(self,
self.emit_arg_check(src, dest, typ, check.format(src, src), optional)
self.emit_lines(
f' {dest} = {src};',
'else {',
err,
'}')
'else {')
self.emit_cast_error_handler(error, src, dest, typ, raise_exception)
self.emit_line('}')
elif is_tuple_rprimitive(typ):
if declare_dest:
self.emit_line(f'{self.ctype(typ)} {dest};')
Expand All @@ -525,9 +533,9 @@ def emit_cast(self,
check.format(src), optional)
self.emit_lines(
f' {dest} = {src};',
'else {',
err,
'}')
'else {')
self.emit_cast_error_handler(error, src, dest, typ, raise_exception)
self.emit_line('}')
elif isinstance(typ, RInstance):
if declare_dest:
self.emit_line(f'PyObject *{dest};')
Expand All @@ -551,10 +559,10 @@ def emit_cast(self,
check = f'(likely{check})'
self.emit_arg_check(src, dest, typ, check, optional)
self.emit_lines(
f' {dest} = {src};',
'else {',
err,
'}')
f' {dest} = {src};'.format(dest, src),
'else {')
self.emit_cast_error_handler(error, src, dest, typ, raise_exception)
self.emit_line('}')
elif is_none_rprimitive(typ):
if declare_dest:
self.emit_line(f'PyObject *{dest};')
Expand All @@ -565,9 +573,9 @@ def emit_cast(self,
check.format(src), optional)
self.emit_lines(
f' {dest} = {src};',
'else {',
err,
'}')
'else {')
self.emit_cast_error_handler(error, src, dest, typ, raise_exception)
self.emit_line('}')
elif is_object_rprimitive(typ):
if declare_dest:
self.emit_line(f'PyObject *{dest};')
Expand All @@ -576,21 +584,43 @@ def emit_cast(self,
if optional:
self.emit_line('}')
elif isinstance(typ, RUnion):
self.emit_union_cast(src, dest, typ, declare_dest, err, optional, src_type)
self.emit_union_cast(src, dest, typ, declare_dest, error, optional, src_type,
raise_exception)
elif isinstance(typ, RTuple):
assert not optional
self.emit_tuple_cast(src, dest, typ, declare_dest, err, src_type)
self.emit_tuple_cast(src, dest, typ, declare_dest, error, src_type)
else:
assert False, 'Cast not implemented: %s' % typ

def emit_cast_error_handler(self,
error: ErrorHandler,
src: str,
dest: str,
typ: RType,
raise_exception: bool) -> None:
if raise_exception:
self.emit_line('CPy_TypeError("{}", {}); '.format(self.pretty_name(typ), src))
if isinstance(error, AssignHandler):
self.emit_line('%s = NULL;' % dest)
elif isinstance(error, GotoHandler):
self.emit_line('goto %s;' % error.label)
elif isinstance(error, TracebackAndGotoHandler):
self.emit_line('%s = NULL;' % dest)
self.emit_traceback(error.source_path, error.module_name, error.traceback_entry)
self.emit_line('goto %s;' % error.label)
else:
assert isinstance(error, ReturnHandler)
self.emit_line('return %s;' % error.value)

def emit_union_cast(self,
src: str,
dest: str,
typ: RUnion,
declare_dest: bool,
err: str,
error: ErrorHandler,
optional: bool,
src_type: Optional[RType]) -> None:
src_type: Optional[RType],
raise_exception: bool) -> None:
"""Emit cast to a union type.

The arguments are similar to emit_cast.
Expand All @@ -613,11 +643,11 @@ def emit_union_cast(self,
likely=False)
self.emit_line(f'if ({dest} != NULL) goto {good_label};')
# Handle cast failure.
self.emit_line(err)
self.emit_cast_error_handler(error, src, dest, typ, raise_exception)
self.emit_label(good_label)

def emit_tuple_cast(self, src: str, dest: str, typ: RTuple, declare_dest: bool,
err: str, src_type: Optional[RType]) -> None:
error: ErrorHandler, src_type: Optional[RType]) -> None:
"""Emit cast to a tuple type.

The arguments are similar to emit_cast.
Expand Down Expand Up @@ -740,7 +770,8 @@ def emit_unbox(self,
self.emit_line('} else {')

cast_temp = self.temp_name()
self.emit_tuple_cast(src, cast_temp, typ, declare_dest=True, err='', src_type=None)
self.emit_tuple_cast(src, cast_temp, typ, declare_dest=True, error=error,
src_type=None)
self.emit_line(f'if (unlikely({cast_temp} == NULL)) {{')

# self.emit_arg_check(src, dest, typ,
Expand Down Expand Up @@ -886,3 +917,16 @@ def emit_gc_clear(self, target: str, rtype: RType) -> None:
self.emit_line(f'Py_CLEAR({target});')
else:
assert False, 'emit_gc_clear() not implemented for %s' % repr(rtype)

def emit_traceback(self,
source_path: str,
module_name: str,
traceback_entry: Tuple[str, int]) -> None:
globals_static = self.static_name('globals', module_name)
self.emit_line('CPy_AddTraceback("%s", "%s", %d, %s);' % (
source_path.replace("\\", "\\\\"),
traceback_entry[0],
traceback_entry[1],
globals_static))
if DEBUG_ERRORS:
self.emit_line('assert(PyErr_Occurred() != NULL && "failure w/o err!");')
35 changes: 20 additions & 15 deletions mypyc/codegen/emitfunc.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from mypyc.common import (
REG_PREFIX, NATIVE_PREFIX, STATIC_PREFIX, TYPE_PREFIX, MODULE_PREFIX,
)
from mypyc.codegen.emit import Emitter
from mypyc.codegen.emit import Emitter, TracebackAndGotoHandler, DEBUG_ERRORS
from mypyc.ir.ops import (
Op, OpVisitor, Goto, Branch, Return, Assign, Integer, LoadErrorValue, GetAttr, SetAttr,
LoadStatic, InitStatic, TupleGet, TupleSet, Call, IncRef, DecRef, Box, Cast, Unbox,
Expand All @@ -23,10 +23,6 @@
from mypyc.ir.pprint import generate_names_for_ir
from mypyc.analysis.blockfreq import frequently_executed_blocks

# Whether to insert debug asserts for all error handling, to quickly
# catch errors propagating without exceptions set.
DEBUG_ERRORS = False


def native_function_type(fn: FuncIR, emitter: Emitter) -> str:
args = ', '.join(emitter.ctype(arg.type) for arg in fn.args) or 'void'
Expand Down Expand Up @@ -322,7 +318,7 @@ def visit_get_attr(self, op: GetAttr) -> None:
and branch.traceback_entry is not None
and not branch.negated):
# Generate code for the following branch here to avoid
# redundant branches in the generate code.
# redundant branches in the generated code.
self.emit_attribute_error(branch, cl.name, op.attr)
self.emit_line('goto %s;' % self.label(branch.true))
merged_branch = branch
Expand Down Expand Up @@ -485,8 +481,24 @@ def visit_box(self, op: Box) -> None:
self.emitter.emit_box(self.reg(op.src), self.reg(op), op.src.type, can_borrow=True)

def visit_cast(self, op: Cast) -> None:
branch = self.next_branch()
handler = None
if branch is not None:
if (branch.value is op
and branch.op == Branch.IS_ERROR
and branch.traceback_entry is not None
and not branch.negated
and branch.false is self.next_block):
# Generate code also for the following branch here to avoid
# redundant branches in the generated code.
handler = TracebackAndGotoHandler(self.label(branch.true),
self.source_path,
self.module_name,
branch.traceback_entry)
self.op_index += 1

self.emitter.emit_cast(self.reg(op.src), self.reg(op), op.type,
src_type=op.src.type)
src_type=op.src.type, error=handler)

def visit_unbox(self, op: Unbox) -> None:
self.emitter.emit_unbox(self.reg(op.src), self.reg(op), op.type)
Expand Down Expand Up @@ -647,14 +659,7 @@ def emit_declaration(self, line: str) -> None:

def emit_traceback(self, op: Branch) -> None:
if op.traceback_entry is not None:
globals_static = self.emitter.static_name('globals', self.module_name)
self.emit_line('CPy_AddTraceback("%s", "%s", %d, %s);' % (
self.source_path.replace("\\", "\\\\"),
op.traceback_entry[0],
op.traceback_entry[1],
globals_static))
if DEBUG_ERRORS:
self.emit_line('assert(PyErr_Occurred() != NULL && "failure w/o err!");')
self.emitter.emit_traceback(self.source_path, self.module_name, op.traceback_entry)

def emit_attribute_error(self, op: Branch, class_name: str, attr: str) -> None:
assert op.traceback_entry is not None
Expand Down
15 changes: 15 additions & 0 deletions mypyc/test-data/run-functions.test
Original file line number Diff line number Diff line change
Expand Up @@ -1220,3 +1220,18 @@ def sub(s: str, f: Callable[[str], str]) -> str: ...
def sub(s: bytes, f: Callable[[bytes], bytes]) -> bytes: ...
def sub(s, f):
return f(s)

[case testContextManagerSpecialCase]
from typing import Generator, Callable, Iterator
from contextlib import contextmanager

@contextmanager
def f() -> Iterator[None]:
yield

def g() -> None:
a = ['']
with f():
a.pop()

g()
3 changes: 2 additions & 1 deletion mypyc/test/test_emitwrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ def test_check_list(self) -> None:
'if (likely(PyList_Check(obj_x)))',
' arg_x = obj_x;',
'else {',
' CPy_TypeError("list", obj_x); return NULL;',
' CPy_TypeError("list", obj_x);',
' return NULL;',
'}',
], lines)

Expand Down