Skip to content

Make coroutine function return type more specific #5052

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
May 17, 2018
18 changes: 16 additions & 2 deletions mypy/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -578,6 +578,13 @@ def get_generator_receive_type(self, return_type: Type, is_coroutine: bool) -> T
# values. IOW, tc is None.
return NoneTyp()

def get_coroutine_return_type(self, return_type: Type) -> Type:
if isinstance(return_type, AnyType):
return AnyType(TypeOfAny.from_another_any, source_any=return_type)
assert isinstance(return_type, Instance), "Should only be called on coroutine functions."
# Note: return type is the 3rd type parameter of Coroutine.
return return_type.args[2]

def get_generator_return_type(self, return_type: Type, is_coroutine: bool) -> Type:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the is_coroutine argument here still useful, or can we remove it now?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm pretty sure we still need it to handle asynchronous generators. In that case, they have both is_generator=True and is_coroutine=True.

"""Given the declared return type of a generator (t), return the type it returns (tr)."""
if isinstance(return_type, AnyType):
Expand Down Expand Up @@ -756,7 +763,10 @@ def check_func_def(self, defn: FuncItem, typ: CallableType, name: Optional[str])
c = defn.is_coroutine
ty = self.get_generator_yield_type(t, c)
tc = self.get_generator_receive_type(t, c)
tr = self.get_generator_return_type(t, c)
if c:
tr = self.get_coroutine_return_type(t)
else:
tr = self.get_generator_return_type(t, c)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like the symmetry between the branches here. 👍

ret_type = self.named_generic_type('typing.AwaitableGenerator',
[ty, tc, tr, t])
typ = typ.copy_modified(ret_type=ret_type)
Expand Down Expand Up @@ -841,6 +851,8 @@ def check_func_def(self, defn: FuncItem, typ: CallableType, name: Optional[str])
is_named_instance(self.return_types[-1], 'typing.AwaitableGenerator')):
return_type = self.get_generator_return_type(self.return_types[-1],
defn.is_coroutine)
elif defn.is_coroutine:
return_type = self.get_coroutine_return_type(self.return_types[-1])
else:
return_type = self.return_types[-1]

Expand Down Expand Up @@ -878,7 +890,7 @@ def is_unannotated_any(t: Type) -> bool:
if is_unannotated_any(ret_type):
self.fail(messages.RETURN_TYPE_EXPECTED, fdef)
elif (fdef.is_coroutine and isinstance(ret_type, Instance) and
is_unannotated_any(ret_type.args[0])):
is_unannotated_any(self.get_coroutine_return_type(ret_type))):
self.fail(messages.RETURN_TYPE_EXPECTED, fdef)
if any(is_unannotated_any(t) for t in fdef.type.arg_types):
self.fail(messages.ARGUMENT_TYPE_EXPECTED, fdef)
Expand Down Expand Up @@ -2211,6 +2223,8 @@ def check_return_stmt(self, s: ReturnStmt) -> None:
if defn.is_generator:
return_type = self.get_generator_return_type(self.return_types[-1],
defn.is_coroutine)
elif defn.is_coroutine:
return_type = self.get_coroutine_return_type(self.return_types[-1])
else:
return_type = self.return_types[-1]

Expand Down
2 changes: 1 addition & 1 deletion mypy/checkexpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -2667,7 +2667,7 @@ def is_async_def(t: Type) -> bool:
and t.type.fullname() == 'typing.AwaitableGenerator'
and len(t.args) >= 4):
t = t.args[3]
return isinstance(t, Instance) and t.type.fullname() == 'typing.Awaitable'
return isinstance(t, Instance) and t.type.fullname() == 'typing.Coroutine'


def map_actuals_to_formals(caller_kinds: List[int],
Expand Down
3 changes: 1 addition & 2 deletions mypy/fastparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,8 +398,7 @@ def do_func_def(self, n: Union[ast3.FunctionDef, ast3.AsyncFunctionDef],
self.as_required_block(n.body, n.lineno),
func_type)
if is_coroutine:
# A coroutine is also a generator, mostly for internal reasons.
func_def.is_generator = func_def.is_coroutine = True
func_def.is_coroutine = True
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is good you also cleaned this up.

if func_type is not None:
func_type.definition = func_def
func_type.line = n.lineno
Expand Down
8 changes: 5 additions & 3 deletions mypy/semanal.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,9 +431,11 @@ def _visit_func_def(self, defn: FuncDef) -> None:
pass
else:
# A coroutine defined as `async def foo(...) -> T: ...`
# has external return type `Awaitable[T]`.
ret_type = self.named_type_or_none('typing.Awaitable', [defn.type.ret_type])
assert ret_type is not None, "Internal error: typing.Awaitable not found"
# has external return type `Coroutine[Any, Any, T]`.
any_type = AnyType(TypeOfAny.special_form)
ret_type = self.named_type_or_none('typing.Coroutine',
[any_type, any_type, defn.type.ret_type])
assert ret_type is not None, "Internal error: typing.Coroutine not found"
defn.type = defn.type.copy_modified(ret_type=ret_type)

def prepare_method_signature(self, func: FuncDef, info: TypeInfo) -> None:
Expand Down
8 changes: 4 additions & 4 deletions test-data/unit/check-async-await.test
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ async def f() -> int:

async def f() -> int:
return 0
reveal_type(f()) # E: Revealed type is 'typing.Awaitable[builtins.int]'
reveal_type(f()) # E: Revealed type is 'typing.Coroutine[Any, Any, builtins.int]'
[builtins fixtures/async_await.pyi]
[typing fixtures/typing-full.pyi]

Expand Down Expand Up @@ -378,7 +378,7 @@ def g() -> Generator[Any, None, str]:
[builtins fixtures/async_await.pyi]
[typing fixtures/typing-full.pyi]
[out]
main:6: error: "yield from" can't be applied to "Awaitable[str]"
main:6: error: "yield from" can't be applied to "Coroutine[Any, Any, str]"

[case testAwaitableSubclass]

Expand Down Expand Up @@ -630,9 +630,9 @@ def plain_host_generator() -> Generator[str, None, None]:
yield 'a'
x = 0
x = yield from plain_generator()
x = yield from plain_coroutine() # E: "yield from" can't be applied to "Awaitable[int]"
x = yield from plain_coroutine() # E: "yield from" can't be applied to "Coroutine[Any, Any, int]"
x = yield from decorated_generator()
x = yield from decorated_coroutine() # E: "yield from" can't be applied to "AwaitableGenerator[Any, Any, int, Awaitable[int]]"
x = yield from decorated_coroutine() # E: "yield from" can't be applied to "AwaitableGenerator[Any, Any, int, Coroutine[Any, Any, int]]"
x = yield from other_iterator()
x = yield from other_coroutine() # E: "yield from" can't be applied to "Aw"

Expand Down
2 changes: 1 addition & 1 deletion test-data/unit/check-class-namedtuple.test
Original file line number Diff line number Diff line change
Expand Up @@ -503,7 +503,7 @@ class XRepr(NamedTuple):
return 0

reveal_type(XMeth(1).double()) # E: Revealed type is 'builtins.int'
reveal_type(XMeth(1).asyncdouble()) # E: Revealed type is 'typing.Awaitable[builtins.int]'
reveal_type(XMeth(1).asyncdouble()) # E: Revealed type is 'typing.Coroutine[Any, Any, builtins.int]'
reveal_type(XMeth(42).x) # E: Revealed type is 'builtins.int'
reveal_type(XRepr(42).__str__()) # E: Revealed type is 'builtins.str'
reveal_type(XRepr(1, 2).__add__(XRepr(3))) # E: Revealed type is 'builtins.int'
Expand Down
10 changes: 10 additions & 0 deletions test-data/unit/fixtures/typing-full.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,16 @@ class Awaitable(Protocol[T]):
class AwaitableGenerator(Generator[T, U, V], Awaitable[V], Generic[T, U, V, S]):
pass

class Coroutine(Awaitable[V], Generic[T, U, V]):
@abstractmethod
def send(self, value: U) -> T: pass

@abstractmethod
def throw(self, typ: Any, val: Any=None, tb: Any=None) -> None: pass

@abstractmethod
def close(self) -> None: pass

@runtime
class AsyncIterable(Protocol[T]):
@abstractmethod
Expand Down