Skip to content

[mypyc] Generate smaller code for casts #12839

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
May 23, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
156 changes: 118 additions & 38 deletions mypyc/codegen/emit.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

from mypy.backports import OrderedDict
from typing import List, Set, Dict, Optional, Callable, Union, Tuple
from typing_extensions import Final

import sys

from mypyc.common import (
Expand All @@ -23,6 +25,10 @@
from mypyc.sametype import is_same_type
from mypyc.codegen.literals import Literals

# Whether to insert debug asserts for all error handling, to quickly
# catch errors propagating without exceptions set.
DEBUG_ERRORS: Final = False


class HeaderDeclaration:
"""A representation of a declaration in C.
Expand Down Expand Up @@ -104,6 +110,20 @@ def __init__(self, label: str) -> None:
self.label = label


class TracebackAndGotoHandler(ErrorHandler):
"""Add traceback item and goto label on error."""

def __init__(self,
label: str,
source_path: str,
module_name: str,
traceback_entry: Tuple[str, int]) -> None:
self.label = label
self.source_path = source_path
self.module_name = module_name
self.traceback_entry = traceback_entry


class ReturnHandler(ErrorHandler):
"""Return a constant value on error."""

Expand Down Expand Up @@ -439,18 +459,6 @@ def emit_cast(self,
likely: If the cast is likely to succeed (can be False for unions)
"""
error = error or AssignHandler()
if isinstance(error, AssignHandler):
handle_error = '%s = NULL;' % dest
elif isinstance(error, GotoHandler):
handle_error = 'goto %s;' % error.label
else:
assert isinstance(error, ReturnHandler)
handle_error = 'return %s;' % error.value
if raise_exception:
raise_exc = f'CPy_TypeError("{self.pretty_name(typ)}", {src}); '
err = raise_exc + handle_error
else:
err = handle_error

# Special case casting *from* optional
if src_type and is_optional_type(src_type) and not is_object_rprimitive(typ):
Expand All @@ -465,9 +473,9 @@ def emit_cast(self,
self.emit_arg_check(src, dest, typ, check.format(src), optional)
self.emit_lines(
f' {dest} = {src};',
'else {',
err,
'}')
'else {')
self.emit_cast_error_handler(error, src, dest, typ, raise_exception)
self.emit_line('}')
return

# TODO: Verify refcount handling.
Expand Down Expand Up @@ -500,9 +508,9 @@ def emit_cast(self,
self.emit_arg_check(src, dest, typ, check.format(prefix, src), optional)
self.emit_lines(
f' {dest} = {src};',
'else {',
err,
'}')
'else {')
self.emit_cast_error_handler(error, src, dest, typ, raise_exception)
self.emit_line('}')
elif is_bytes_rprimitive(typ):
if declare_dest:
self.emit_line(f'PyObject *{dest};')
Expand All @@ -512,9 +520,9 @@ def emit_cast(self,
self.emit_arg_check(src, dest, typ, check.format(src, src), optional)
self.emit_lines(
f' {dest} = {src};',
'else {',
err,
'}')
'else {')
self.emit_cast_error_handler(error, src, dest, typ, raise_exception)
self.emit_line('}')
elif is_tuple_rprimitive(typ):
if declare_dest:
self.emit_line(f'{self.ctype(typ)} {dest};')
Expand All @@ -525,9 +533,9 @@ def emit_cast(self,
check.format(src), optional)
self.emit_lines(
f' {dest} = {src};',
'else {',
err,
'}')
'else {')
self.emit_cast_error_handler(error, src, dest, typ, raise_exception)
self.emit_line('}')
elif isinstance(typ, RInstance):
if declare_dest:
self.emit_line(f'PyObject *{dest};')
Expand All @@ -551,10 +559,10 @@ def emit_cast(self,
check = f'(likely{check})'
self.emit_arg_check(src, dest, typ, check, optional)
self.emit_lines(
f' {dest} = {src};',
'else {',
err,
'}')
f' {dest} = {src};'.format(dest, src),
'else {')
self.emit_cast_error_handler(error, src, dest, typ, raise_exception)
self.emit_line('}')
elif is_none_rprimitive(typ):
if declare_dest:
self.emit_line(f'PyObject *{dest};')
Expand All @@ -565,9 +573,9 @@ def emit_cast(self,
check.format(src), optional)
self.emit_lines(
f' {dest} = {src};',
'else {',
err,
'}')
'else {')
self.emit_cast_error_handler(error, src, dest, typ, raise_exception)
self.emit_line('}')
elif is_object_rprimitive(typ):
if declare_dest:
self.emit_line(f'PyObject *{dest};')
Expand All @@ -576,21 +584,51 @@ def emit_cast(self,
if optional:
self.emit_line('}')
elif isinstance(typ, RUnion):
self.emit_union_cast(src, dest, typ, declare_dest, err, optional, src_type)
self.emit_union_cast(src, dest, typ, declare_dest, error, optional, src_type,
raise_exception)
elif isinstance(typ, RTuple):
assert not optional
self.emit_tuple_cast(src, dest, typ, declare_dest, err, src_type)
self.emit_tuple_cast(src, dest, typ, declare_dest, error, src_type)
else:
assert False, 'Cast not implemented: %s' % typ

def emit_cast_error_handler(self,
error: ErrorHandler,
src: str,
dest: str,
typ: RType,
raise_exception: bool) -> None:
if raise_exception:
if isinstance(error, TracebackAndGotoHandler):
# Merge raising and emitting traceback entry into a single call.
self.emit_type_error_traceback(
error.source_path, error.module_name, error.traceback_entry,
typ=typ,
src=src)
self.emit_line('goto %s;' % error.label)
return
self.emit_line('CPy_TypeError("{}", {}); '.format(self.pretty_name(typ), src))
if isinstance(error, AssignHandler):
self.emit_line('%s = NULL;' % dest)
elif isinstance(error, GotoHandler):
self.emit_line('goto %s;' % error.label)
elif isinstance(error, TracebackAndGotoHandler):
self.emit_line('%s = NULL;' % dest)
self.emit_traceback(error.source_path, error.module_name, error.traceback_entry)
self.emit_line('goto %s;' % error.label)
else:
assert isinstance(error, ReturnHandler)
self.emit_line('return %s;' % error.value)

def emit_union_cast(self,
src: str,
dest: str,
typ: RUnion,
declare_dest: bool,
err: str,
error: ErrorHandler,
optional: bool,
src_type: Optional[RType]) -> None:
src_type: Optional[RType],
raise_exception: bool) -> None:
"""Emit cast to a union type.

The arguments are similar to emit_cast.
Expand All @@ -613,11 +651,11 @@ def emit_union_cast(self,
likely=False)
self.emit_line(f'if ({dest} != NULL) goto {good_label};')
# Handle cast failure.
self.emit_line(err)
self.emit_cast_error_handler(error, src, dest, typ, raise_exception)
self.emit_label(good_label)

def emit_tuple_cast(self, src: str, dest: str, typ: RTuple, declare_dest: bool,
err: str, src_type: Optional[RType]) -> None:
error: ErrorHandler, src_type: Optional[RType]) -> None:
"""Emit cast to a tuple type.

The arguments are similar to emit_cast.
Expand Down Expand Up @@ -740,7 +778,8 @@ def emit_unbox(self,
self.emit_line('} else {')

cast_temp = self.temp_name()
self.emit_tuple_cast(src, cast_temp, typ, declare_dest=True, err='', src_type=None)
self.emit_tuple_cast(src, cast_temp, typ, declare_dest=True, error=error,
src_type=None)
self.emit_line(f'if (unlikely({cast_temp} == NULL)) {{')

# self.emit_arg_check(src, dest, typ,
Expand Down Expand Up @@ -886,3 +925,44 @@ def emit_gc_clear(self, target: str, rtype: RType) -> None:
self.emit_line(f'Py_CLEAR({target});')
else:
assert False, 'emit_gc_clear() not implemented for %s' % repr(rtype)

def emit_traceback(self,
source_path: str,
module_name: str,
traceback_entry: Tuple[str, int]) -> None:
return self._emit_traceback('CPy_AddTraceback', source_path, module_name, traceback_entry)

def emit_type_error_traceback(
self,
source_path: str,
module_name: str,
traceback_entry: Tuple[str, int],
*,
typ: RType,
src: str) -> None:
func = 'CPy_TypeErrorTraceback'
type_str = f'"{self.pretty_name(typ)}"'
return self._emit_traceback(
func, source_path, module_name, traceback_entry, type_str=type_str, src=src)

def _emit_traceback(self,
func: str,
source_path: str,
module_name: str,
traceback_entry: Tuple[str, int],
type_str: str = '',
src: str = '') -> None:
globals_static = self.static_name('globals', module_name)
line = '%s("%s", "%s", %d, %s' % (
func,
source_path.replace("\\", "\\\\"),
traceback_entry[0],
traceback_entry[1],
globals_static)
if type_str:
assert src
line += f', {type_str}, {src}'
line += ');'
self.emit_line(line)
if DEBUG_ERRORS:
self.emit_line('assert(PyErr_Occurred() != NULL && "failure w/o err!");')
35 changes: 20 additions & 15 deletions mypyc/codegen/emitfunc.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from mypyc.common import (
REG_PREFIX, NATIVE_PREFIX, STATIC_PREFIX, TYPE_PREFIX, MODULE_PREFIX,
)
from mypyc.codegen.emit import Emitter
from mypyc.codegen.emit import Emitter, TracebackAndGotoHandler, DEBUG_ERRORS
from mypyc.ir.ops import (
Op, OpVisitor, Goto, Branch, Return, Assign, Integer, LoadErrorValue, GetAttr, SetAttr,
LoadStatic, InitStatic, TupleGet, TupleSet, Call, IncRef, DecRef, Box, Cast, Unbox,
Expand All @@ -23,10 +23,6 @@
from mypyc.ir.pprint import generate_names_for_ir
from mypyc.analysis.blockfreq import frequently_executed_blocks

# Whether to insert debug asserts for all error handling, to quickly
# catch errors propagating without exceptions set.
DEBUG_ERRORS = False


def native_function_type(fn: FuncIR, emitter: Emitter) -> str:
args = ', '.join(emitter.ctype(arg.type) for arg in fn.args) or 'void'
Expand Down Expand Up @@ -322,7 +318,7 @@ def visit_get_attr(self, op: GetAttr) -> None:
and branch.traceback_entry is not None
and not branch.negated):
# Generate code for the following branch here to avoid
# redundant branches in the generate code.
# redundant branches in the generated code.
self.emit_attribute_error(branch, cl.name, op.attr)
self.emit_line('goto %s;' % self.label(branch.true))
merged_branch = branch
Expand Down Expand Up @@ -485,8 +481,24 @@ def visit_box(self, op: Box) -> None:
self.emitter.emit_box(self.reg(op.src), self.reg(op), op.src.type, can_borrow=True)

def visit_cast(self, op: Cast) -> None:
branch = self.next_branch()
handler = None
if branch is not None:
if (branch.value is op
and branch.op == Branch.IS_ERROR
and branch.traceback_entry is not None
and not branch.negated
and branch.false is self.next_block):
# Generate code also for the following branch here to avoid
# redundant branches in the generated code.
handler = TracebackAndGotoHandler(self.label(branch.true),
self.source_path,
self.module_name,
branch.traceback_entry)
self.op_index += 1

self.emitter.emit_cast(self.reg(op.src), self.reg(op), op.type,
src_type=op.src.type)
src_type=op.src.type, error=handler)

def visit_unbox(self, op: Unbox) -> None:
self.emitter.emit_unbox(self.reg(op.src), self.reg(op), op.type)
Expand Down Expand Up @@ -647,14 +659,7 @@ def emit_declaration(self, line: str) -> None:

def emit_traceback(self, op: Branch) -> None:
if op.traceback_entry is not None:
globals_static = self.emitter.static_name('globals', self.module_name)
self.emit_line('CPy_AddTraceback("%s", "%s", %d, %s);' % (
self.source_path.replace("\\", "\\\\"),
op.traceback_entry[0],
op.traceback_entry[1],
globals_static))
if DEBUG_ERRORS:
self.emit_line('assert(PyErr_Occurred() != NULL && "failure w/o err!");')
self.emitter.emit_traceback(self.source_path, self.module_name, op.traceback_entry)

def emit_attribute_error(self, op: Branch, class_name: str, attr: str) -> None:
assert op.traceback_entry is not None
Expand Down
2 changes: 2 additions & 0 deletions mypyc/lib-rt/CPy.h
Original file line number Diff line number Diff line change
Expand Up @@ -498,6 +498,8 @@ void _CPy_GetExcInfo(PyObject **p_type, PyObject **p_value, PyObject **p_traceba
void CPyError_OutOfMemory(void);
void CPy_TypeError(const char *expected, PyObject *value);
void CPy_AddTraceback(const char *filename, const char *funcname, int line, PyObject *globals);
void CPy_TypeErrorTraceback(const char *filename, const char *funcname, int line,
PyObject *globals, const char *expected, PyObject *value);
void CPy_AttributeError(const char *filename, const char *funcname, const char *classname,
const char *attrname, int line, PyObject *globals);

Expand Down
7 changes: 7 additions & 0 deletions mypyc/lib-rt/exc_ops.c
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,13 @@ void CPy_AddTraceback(const char *filename, const char *funcname, int line, PyOb
_PyErr_ChainExceptions(exc, val, tb);
}

CPy_NOINLINE
void CPy_TypeErrorTraceback(const char *filename, const char *funcname, int line,
PyObject *globals, const char *expected, PyObject *value) {
CPy_TypeError(expected, value);
CPy_AddTraceback(filename, funcname, line, globals);
}

void CPy_AttributeError(const char *filename, const char *funcname, const char *classname,
const char *attrname, int line, PyObject *globals) {
char buf[500];
Expand Down
15 changes: 15 additions & 0 deletions mypyc/test-data/run-functions.test
Original file line number Diff line number Diff line change
Expand Up @@ -1220,3 +1220,18 @@ def sub(s: str, f: Callable[[str], str]) -> str: ...
def sub(s: bytes, f: Callable[[bytes], bytes]) -> bytes: ...
def sub(s, f):
return f(s)

[case testContextManagerSpecialCase]
from typing import Generator, Callable, Iterator
from contextlib import contextmanager

@contextmanager
def f() -> Iterator[None]:
yield

def g() -> None:
a = ['']
with f():
a.pop()

g()
Loading