Skip to content

Commit a370550

Browse files
Michael0x2amsullivan
authored andcommitted
Make joins of callables respect positional parameter names (#4920)
For example, consider the following program: def f(x: int) -> int: ... def g(y: int) -> int: ... lst = [f, g] Previously mypy would treat the final line as an error since 'f' and 'g' have different types due to the different parameter names. Now, mypy infers that `lst` has type `list[def (int) -> int]`, effectively erasing the parameter name from the inferred type. This commit does not attempt to handle keyword-only arguments. Fixes #2777.
1 parent 5a683cf commit a370550

File tree

2 files changed

+80
-2
lines changed

2 files changed

+80
-2
lines changed

mypy/join.py

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
is_subtype, is_equivalent, is_subtype_ignoring_tvars, is_proper_subtype,
1414
is_protocol_implementation
1515
)
16+
from mypy.nodes import ARG_NAMED, ARG_NAMED_OPT
1617

1718
from mypy import experiments
1819

@@ -348,7 +349,6 @@ def is_similar_callables(t: CallableType, s: CallableType) -> bool:
348349
"""Return True if t and s have identical numbers of
349350
arguments, default arguments and varargs.
350351
"""
351-
352352
return (len(t.arg_types) == len(s.arg_types) and t.min_args == s.min_args and
353353
t.is_var_arg == s.is_var_arg)
354354

@@ -366,6 +366,7 @@ def join_similar_callables(t: CallableType, s: CallableType) -> CallableType:
366366
else:
367367
fallback = s.fallback
368368
return t.copy_modified(arg_types=arg_types,
369+
arg_names=combine_arg_names(t, s),
369370
ret_type=join_types(t.ret_type, s.ret_type),
370371
fallback=fallback,
371372
name=None)
@@ -383,11 +384,42 @@ def combine_similar_callables(t: CallableType, s: CallableType) -> CallableType:
383384
else:
384385
fallback = s.fallback
385386
return t.copy_modified(arg_types=arg_types,
387+
arg_names=combine_arg_names(t, s),
386388
ret_type=join_types(t.ret_type, s.ret_type),
387389
fallback=fallback,
388390
name=None)
389391

390392

393+
def combine_arg_names(t: CallableType, s: CallableType) -> List[Optional[str]]:
394+
"""Produces a list of argument names compatible with both callables.
395+
396+
For example, suppose 't' and 's' have the following signatures:
397+
398+
- t: (a: int, b: str, X: str) -> None
399+
- s: (a: int, b: str, Y: str) -> None
400+
401+
This function would return ["a", "b", None]. This information
402+
is then used above to compute the join of t and s, which results
403+
in a signature of (a: int, b: str, str) -> None.
404+
405+
Note that the third argument's name is omitted and 't' and 's'
406+
are both valid subtypes of this inferred signature.
407+
408+
Precondition: is_similar_types(t, s) is true.
409+
"""
410+
num_args = len(t.arg_types)
411+
new_names = []
412+
named = (ARG_NAMED, ARG_NAMED_OPT)
413+
for i in range(num_args):
414+
t_name = t.arg_names[i]
415+
s_name = s.arg_names[i]
416+
if t_name == s_name or t.arg_kinds[i] in named or s.arg_kinds[i] in named:
417+
new_names.append(t_name)
418+
else:
419+
new_names.append(None)
420+
return new_names
421+
422+
391423
def object_from_instance(instance: Instance) -> Instance:
392424
"""Construct the type 'builtins.object' from an instance type."""
393425
# Use the fact that 'object' is always the last class in the mro.

test-data/unit/check-inference.test

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -856,13 +856,59 @@ s = s_s() # E: Incompatible types in assignment (expression has type "Set[str]",
856856
[builtins fixtures/set.pyi]
857857

858858
[case testSetWithStarExpr]
859-
860859
s = {1, 2, *(3, 4)}
861860
t = {1, 2, *s}
862861
reveal_type(s) # E: Revealed type is 'builtins.set[builtins.int*]'
863862
reveal_type(t) # E: Revealed type is 'builtins.set[builtins.int*]'
864863
[builtins fixtures/set.pyi]
865864

865+
[case testListLiteralWithFunctionsErasesNames]
866+
def f1(x: int) -> int: ...
867+
def g1(y: int) -> int: ...
868+
def h1(x: int) -> int: ...
869+
870+
list_1 = [f1, g1]
871+
list_2 = [f1, h1]
872+
reveal_type(list_1) # E: Revealed type is 'builtins.list[def (builtins.int) -> builtins.int]'
873+
reveal_type(list_2) # E: Revealed type is 'builtins.list[def (x: builtins.int) -> builtins.int]'
874+
875+
def f2(x: int, z: str) -> int: ...
876+
def g2(y: int, z: str) -> int: ...
877+
def h2(x: int, z: str) -> int: ...
878+
879+
list_3 = [f2, g2]
880+
list_4 = [f2, h2]
881+
reveal_type(list_3) # E: Revealed type is 'builtins.list[def (builtins.int, z: builtins.str) -> builtins.int]'
882+
reveal_type(list_4) # E: Revealed type is 'builtins.list[def (x: builtins.int, z: builtins.str) -> builtins.int]'
883+
[builtins fixtures/list.pyi]
884+
885+
[case testListLiteralWithSimilarFunctionsErasesName]
886+
from typing import Union
887+
888+
class A: ...
889+
class B(A): ...
890+
class C: ...
891+
class D: ...
892+
893+
def f(x: Union[A, C], y: B) -> A: ...
894+
def g(z: Union[B, D], y: A) -> B: ...
895+
def h(x: Union[B, D], y: A) -> B: ...
896+
897+
list_1 = [f, g]
898+
list_2 = [f, h]
899+
reveal_type(list_1) # E: Revealed type is 'builtins.list[def (__main__.B, y: __main__.B) -> __main__.A]'
900+
reveal_type(list_2) # E: Revealed type is 'builtins.list[def (x: __main__.B, y: __main__.B) -> __main__.A]'
901+
[builtins fixtures/list.pyi]
902+
903+
[case testListLiteralWithNameOnlyArgsDoesNotEraseNames]
904+
def f(*, x: int) -> int: ...
905+
def g(*, y: int) -> int: ...
906+
def h(*, x: int) -> int: ...
907+
908+
list_1 = [f, g] # E: List item 0 has incompatible type "Callable[[NamedArg(int, 'x')], int]"; expected "Callable[[NamedArg(int, 'y')], int]"
909+
list_2 = [f, h]
910+
[builtins fixtures/list.pyi]
911+
866912

867913
-- For statements
868914
-- --------------

0 commit comments

Comments
 (0)