Skip to content

Commit 83d54ff

Browse files
Use polymorphic inference in unification (#17348)
Moving towards #15907 Fixes #17206 This PR enables polymorphic inference during unification. This will allow us to handle even more tricky situations involving generic higher-order functions (see a random example I added in tests). Implementation is mostly straightforward, few notes: * This uncovered another issue with unions in solver, unfortunately current constraint inference algorithm can sometimes infer weird constraints like `T <: Union[T, int]`, that later confuse the solver. * This uncovered another possible type variable clash scenario that was not handled properly. In overloaded generic function, each overload should have a different namespace for type variables (currently they all just get function name). I use `module.some_func#0` etc. for overloads namespaces instead. * Another thing with overloads is that the switch caused unsafe overlap check to change: after some back and forth I am keeping it mostly the same to avoid possible regressions (unfortunately this requires some extra refreshing of type variables). * This makes another `ParamSpec` crash to happen more often so I fix it in this same PR. * Finally this uncovered a bug in handling of overloaded `__init__()` that I am fixing here as well. --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 5ae9e69 commit 83d54ff

12 files changed

+253
-65
lines changed

mypy/checker.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -791,9 +791,21 @@ def check_overlapping_overloads(self, defn: OverloadedFuncDef) -> None:
791791
if impl_type is not None:
792792
assert defn.impl is not None
793793

794+
# This is what we want from implementation, it should accept all arguments
795+
# of an overload, but the return types should go the opposite way.
796+
if is_callable_compatible(
797+
impl_type,
798+
sig1,
799+
is_compat=is_subtype,
800+
is_proper_subtype=False,
801+
is_compat_return=lambda l, r: is_subtype(r, l),
802+
):
803+
continue
804+
# If the above check didn't work, we repeat some key steps in
805+
# is_callable_compatible() to give a better error message.
806+
794807
# We perform a unification step that's very similar to what
795-
# 'is_callable_compatible' would have done if we had set
796-
# 'unify_generics' to True -- the only difference is that
808+
# 'is_callable_compatible' does -- the only difference is that
797809
# we check and see if the impl_type's return value is a
798810
# *supertype* of the overload alternative, not a *subtype*.
799811
#

mypy/constraints.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -688,14 +688,19 @@ def visit_unpack_type(self, template: UnpackType) -> list[Constraint]:
688688

689689
def visit_parameters(self, template: Parameters) -> list[Constraint]:
690690
# Constraining Any against C[P] turns into infer_against_any([P], Any)
691-
# ... which seems like the only case this can happen. Better to fail loudly otherwise.
692691
if isinstance(self.actual, AnyType):
693692
return self.infer_against_any(template.arg_types, self.actual)
694693
if type_state.infer_polymorphic and isinstance(self.actual, Parameters):
695694
# For polymorphic inference we need to be able to infer secondary constraints
696695
# in situations like [x: T] <: P <: [x: int].
697696
return infer_callable_arguments_constraints(template, self.actual, self.direction)
698-
raise RuntimeError("Parameters cannot be constrained to")
697+
if type_state.infer_polymorphic and isinstance(self.actual, ParamSpecType):
698+
# Similar for [x: T] <: Q <: Concatenate[int, P].
699+
return infer_callable_arguments_constraints(
700+
template, self.actual.prefix, self.direction
701+
)
702+
# There also may be unpatched types after a user error, simply ignore them.
703+
return []
699704

700705
# Non-leaf types
701706

mypy/message_registry.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,10 @@ def with_additional_msg(self, info: str) -> ErrorMessage:
180180
)
181181
INVALID_UNPACK: Final = "{} cannot be unpacked (must be tuple or TypeVarTuple)"
182182
INVALID_UNPACK_POSITION: Final = "Unpack is only valid in a variadic position"
183+
INVALID_PARAM_SPEC_LOCATION: Final = "Invalid location for ParamSpec {}"
184+
INVALID_PARAM_SPEC_LOCATION_NOTE: Final = (
185+
'You can use ParamSpec as the first argument to Callable, e.g., "Callable[{}, int]"'
186+
)
183187

184188
# TypeVar
185189
INCOMPATIBLE_TYPEVAR_VALUE: Final = 'Value of type variable "{}" of {} cannot be {}'

mypy/semanal.py

Lines changed: 29 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -479,6 +479,9 @@ def __init__(
479479
# new uses of this, as this may cause leaking `UnboundType`s to type checking.
480480
self.allow_unbound_tvars = False
481481

482+
# Used to pass information about current overload index to visit_func_def().
483+
self.current_overload_item: int | None = None
484+
482485
# mypyc doesn't properly handle implementing an abstractproperty
483486
# with a regular attribute so we make them properties
484487
@property
@@ -869,6 +872,11 @@ def visit_func_def(self, defn: FuncDef) -> None:
869872
with self.scope.function_scope(defn):
870873
self.analyze_func_def(defn)
871874

875+
def function_fullname(self, fullname: str) -> str:
876+
if self.current_overload_item is None:
877+
return fullname
878+
return f"{fullname}#{self.current_overload_item}"
879+
872880
def analyze_func_def(self, defn: FuncDef) -> None:
873881
if self.push_type_args(defn.type_args, defn) is None:
874882
self.defer(defn)
@@ -895,17 +903,16 @@ def analyze_func_def(self, defn: FuncDef) -> None:
895903
self.prepare_method_signature(defn, self.type, has_self_type)
896904

897905
# Analyze function signature
898-
with self.tvar_scope_frame(self.tvar_scope.method_frame(defn.fullname)):
906+
fullname = self.function_fullname(defn.fullname)
907+
with self.tvar_scope_frame(self.tvar_scope.method_frame(fullname)):
899908
if defn.type:
900909
self.check_classvar_in_signature(defn.type)
901910
assert isinstance(defn.type, CallableType)
902911
# Signature must be analyzed in the surrounding scope so that
903912
# class-level imported names and type variables are in scope.
904913
analyzer = self.type_analyzer()
905914
tag = self.track_incomplete_refs()
906-
result = analyzer.visit_callable_type(
907-
defn.type, nested=False, namespace=defn.fullname
908-
)
915+
result = analyzer.visit_callable_type(defn.type, nested=False, namespace=fullname)
909916
# Don't store not ready types (including placeholders).
910917
if self.found_incomplete_ref(tag) or has_placeholder(result):
911918
self.defer(defn)
@@ -1117,7 +1124,8 @@ def update_function_type_variables(self, fun_type: CallableType, defn: FuncItem)
11171124
if defn is generic. Return True, if the signature contains typing.Self
11181125
type, or False otherwise.
11191126
"""
1120-
with self.tvar_scope_frame(self.tvar_scope.method_frame(defn.fullname)):
1127+
fullname = self.function_fullname(defn.fullname)
1128+
with self.tvar_scope_frame(self.tvar_scope.method_frame(fullname)):
11211129
a = self.type_analyzer()
11221130
fun_type.variables, has_self_type = a.bind_function_type_variables(fun_type, defn)
11231131
if has_self_type and self.type is not None:
@@ -1175,6 +1183,14 @@ def visit_overloaded_func_def(self, defn: OverloadedFuncDef) -> None:
11751183
with self.scope.function_scope(defn):
11761184
self.analyze_overloaded_func_def(defn)
11771185

1186+
@contextmanager
1187+
def overload_item_set(self, item: int | None) -> Iterator[None]:
1188+
self.current_overload_item = item
1189+
try:
1190+
yield
1191+
finally:
1192+
self.current_overload_item = None
1193+
11781194
def analyze_overloaded_func_def(self, defn: OverloadedFuncDef) -> None:
11791195
# OverloadedFuncDef refers to any legitimate situation where you have
11801196
# more than one declaration for the same function in a row. This occurs
@@ -1187,7 +1203,8 @@ def analyze_overloaded_func_def(self, defn: OverloadedFuncDef) -> None:
11871203

11881204
first_item = defn.items[0]
11891205
first_item.is_overload = True
1190-
first_item.accept(self)
1206+
with self.overload_item_set(0):
1207+
first_item.accept(self)
11911208

11921209
if isinstance(first_item, Decorator) and first_item.func.is_property:
11931210
# This is a property.
@@ -1272,7 +1289,8 @@ def analyze_overload_sigs_and_impl(
12721289
if i != 0:
12731290
# Assume that the first item was already visited
12741291
item.is_overload = True
1275-
item.accept(self)
1292+
with self.overload_item_set(i if i < len(defn.items) - 1 else None):
1293+
item.accept(self)
12761294
# TODO: support decorated overloaded functions properly
12771295
if isinstance(item, Decorator):
12781296
callable = function_type(item.func, self.named_type("builtins.function"))
@@ -1444,15 +1462,17 @@ def add_function_to_symbol_table(self, func: FuncDef | OverloadedFuncDef) -> Non
14441462
self.add_symbol(func.name, func, func)
14451463

14461464
def analyze_arg_initializers(self, defn: FuncItem) -> None:
1447-
with self.tvar_scope_frame(self.tvar_scope.method_frame(defn.fullname)):
1465+
fullname = self.function_fullname(defn.fullname)
1466+
with self.tvar_scope_frame(self.tvar_scope.method_frame(fullname)):
14481467
# Analyze default arguments
14491468
for arg in defn.arguments:
14501469
if arg.initializer:
14511470
arg.initializer.accept(self)
14521471

14531472
def analyze_function_body(self, defn: FuncItem) -> None:
14541473
is_method = self.is_class_scope()
1455-
with self.tvar_scope_frame(self.tvar_scope.method_frame(defn.fullname)):
1474+
fullname = self.function_fullname(defn.fullname)
1475+
with self.tvar_scope_frame(self.tvar_scope.method_frame(fullname)):
14561476
# Bind the type variables again to visit the body.
14571477
if defn.type:
14581478
a = self.type_analyzer()

mypy/semanal_typeargs.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from mypy import errorcodes as codes, message_registry
1313
from mypy.errorcodes import ErrorCode
1414
from mypy.errors import Errors
15+
from mypy.message_registry import INVALID_PARAM_SPEC_LOCATION, INVALID_PARAM_SPEC_LOCATION_NOTE
1516
from mypy.messages import format_type
1617
from mypy.mixedtraverser import MixedTraverserVisitor
1718
from mypy.nodes import ARG_STAR, Block, ClassDef, Context, FakeInfo, FuncItem, MypyFile
@@ -146,13 +147,25 @@ def validate_args(
146147
for (i, arg), tvar in zip(enumerate(args), type_vars):
147148
if isinstance(tvar, TypeVarType):
148149
if isinstance(arg, ParamSpecType):
149-
# TODO: Better message
150150
is_error = True
151-
self.fail(f'Invalid location for ParamSpec "{arg.name}"', ctx)
151+
self.fail(
152+
INVALID_PARAM_SPEC_LOCATION.format(format_type(arg, self.options)),
153+
ctx,
154+
code=codes.VALID_TYPE,
155+
)
152156
self.note(
153-
"You can use ParamSpec as the first argument to Callable, e.g., "
154-
"'Callable[{}, int]'".format(arg.name),
157+
INVALID_PARAM_SPEC_LOCATION_NOTE.format(arg.name),
158+
ctx,
159+
code=codes.VALID_TYPE,
160+
)
161+
continue
162+
if isinstance(arg, Parameters):
163+
is_error = True
164+
self.fail(
165+
f"Cannot use {format_type(arg, self.options)} for regular type variable,"
166+
" only for ParamSpec",
155167
ctx,
168+
code=codes.VALID_TYPE,
156169
)
157170
continue
158171
if tvar.values:
@@ -204,6 +217,7 @@ def validate_args(
204217
"Can only replace ParamSpec with a parameter types list or"
205218
f" another ParamSpec, got {format_type(arg, self.options)}",
206219
ctx,
220+
code=codes.VALID_TYPE,
207221
)
208222
return is_error
209223

mypy/solve.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -514,7 +514,8 @@ def skip_reverse_union_constraints(cs: list[Constraint]) -> list[Constraint]:
514514
is a linear constraint. This is however not true in presence of union types, for example
515515
T :> Union[S, int] vs S <: T. Trying to solve such constraints would be detected ambiguous
516516
as (T, S) form a non-linear SCC. However, simply removing the linear part results in a valid
517-
solution T = Union[S, int], S = <free>.
517+
solution T = Union[S, int], S = <free>. A similar scenario is when we get T <: Union[T, int],
518+
such constraints carry no information, and will equally confuse linearity check.
518519
519520
TODO: a cleaner solution may be to avoid inferring such constraints in first place, but
520521
this would require passing around a flag through all infer_constraints() calls.
@@ -525,7 +526,13 @@ def skip_reverse_union_constraints(cs: list[Constraint]) -> list[Constraint]:
525526
if isinstance(p_target, UnionType):
526527
for item in p_target.items:
527528
if isinstance(item, TypeVarType):
529+
if item == c.origin_type_var and c.op == SUBTYPE_OF:
530+
reverse_union_cs.add(c)
531+
continue
532+
# These two forms are semantically identical, but are different from
533+
# the point of view of Constraint.__eq__().
528534
reverse_union_cs.add(Constraint(item, neg_op(c.op), c.origin_type_var))
535+
reverse_union_cs.add(Constraint(c.origin_type_var, c.op, item))
529536
return [c for c in cs if c not in reverse_union_cs]
530537

531538

mypy/subtypes.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,12 @@
88
import mypy.constraints
99
import mypy.typeops
1010
from mypy.erasetype import erase_type
11-
from mypy.expandtype import expand_self_type, expand_type, expand_type_by_instance
11+
from mypy.expandtype import (
12+
expand_self_type,
13+
expand_type,
14+
expand_type_by_instance,
15+
freshen_function_type_vars,
16+
)
1217
from mypy.maptype import map_instance_to_supertype
1318

1419
# Circular import; done in the function instead.
@@ -1860,6 +1865,11 @@ def unify_generic_callable(
18601865
"""
18611866
import mypy.solve
18621867

1868+
if set(type.type_var_ids()) & {v.id for v in mypy.typeops.get_all_type_vars(target)}:
1869+
# Overload overlap check does nasty things like unifying in opposite direction.
1870+
# This can easily create type variable clashes, so we need to refresh.
1871+
type = freshen_function_type_vars(type)
1872+
18631873
if return_constraint_direction is None:
18641874
return_constraint_direction = mypy.constraints.SUBTYPE_OF
18651875

@@ -1882,7 +1892,9 @@ def unify_generic_callable(
18821892
constraints = [
18831893
c for c in constraints if not isinstance(get_proper_type(c.target), NoneType)
18841894
]
1885-
inferred_vars, _ = mypy.solve.solve_constraints(type.variables, constraints)
1895+
inferred_vars, _ = mypy.solve.solve_constraints(
1896+
type.variables, constraints, allow_polymorphic=True
1897+
)
18861898
if None in inferred_vars:
18871899
return None
18881900
non_none_inferred_vars = cast(List[Type], inferred_vars)

mypy/typeanal.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,14 @@
1010
from mypy import errorcodes as codes, message_registry, nodes
1111
from mypy.errorcodes import ErrorCode
1212
from mypy.expandtype import expand_type
13-
from mypy.messages import MessageBuilder, format_type_bare, quote_type_string, wrong_type_arg_count
13+
from mypy.message_registry import INVALID_PARAM_SPEC_LOCATION, INVALID_PARAM_SPEC_LOCATION_NOTE
14+
from mypy.messages import (
15+
MessageBuilder,
16+
format_type,
17+
format_type_bare,
18+
quote_type_string,
19+
wrong_type_arg_count,
20+
)
1421
from mypy.nodes import (
1522
ARG_NAMED,
1623
ARG_NAMED_OPT,
@@ -1782,12 +1789,14 @@ def anal_type(
17821789
analyzed = AnyType(TypeOfAny.from_error)
17831790
else:
17841791
self.fail(
1785-
f'Invalid location for ParamSpec "{analyzed.name}"', t, code=codes.VALID_TYPE
1792+
INVALID_PARAM_SPEC_LOCATION.format(format_type(analyzed, self.options)),
1793+
t,
1794+
code=codes.VALID_TYPE,
17861795
)
17871796
self.note(
1788-
"You can use ParamSpec as the first argument to Callable, e.g., "
1789-
"'Callable[{}, int]'".format(analyzed.name),
1797+
INVALID_PARAM_SPEC_LOCATION_NOTE.format(analyzed.name),
17901798
t,
1799+
code=codes.VALID_TYPE,
17911800
)
17921801
analyzed = AnyType(TypeOfAny.from_error)
17931802
return analyzed

mypy/typeops.py

Lines changed: 25 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,14 @@ def type_object_type_from_function(
152152
# ...
153153
#
154154
# We need to map B's __init__ to the type (List[T]) -> None.
155-
signature = bind_self(signature, original_type=default_self, is_classmethod=is_new)
155+
signature = bind_self(
156+
signature,
157+
original_type=default_self,
158+
is_classmethod=is_new,
159+
# Explicit instance self annotations have special handling in class_callable(),
160+
# we don't need to bind any type variables in them if they are generic.
161+
ignore_instances=True,
162+
)
156163
signature = cast(FunctionLike, map_type_from_supertype(signature, info, def_info))
157164

158165
special_sig: str | None = None
@@ -244,7 +251,9 @@ class C(D[E[T]], Generic[T]): ...
244251
return expand_type_by_instance(typ, inst_type)
245252

246253

247-
def supported_self_type(typ: ProperType, allow_callable: bool = True) -> bool:
254+
def supported_self_type(
255+
typ: ProperType, allow_callable: bool = True, allow_instances: bool = True
256+
) -> bool:
248257
"""Is this a supported kind of explicit self-types?
249258
250259
Currently, this means an X or Type[X], where X is an instance or
@@ -257,14 +266,19 @@ def supported_self_type(typ: ProperType, allow_callable: bool = True) -> bool:
257266
# as well as callable self for callback protocols.
258267
return True
259268
return isinstance(typ, TypeVarType) or (
260-
isinstance(typ, Instance) and typ != fill_typevars(typ.type)
269+
allow_instances and isinstance(typ, Instance) and typ != fill_typevars(typ.type)
261270
)
262271

263272

264273
F = TypeVar("F", bound=FunctionLike)
265274

266275

267-
def bind_self(method: F, original_type: Type | None = None, is_classmethod: bool = False) -> F:
276+
def bind_self(
277+
method: F,
278+
original_type: Type | None = None,
279+
is_classmethod: bool = False,
280+
ignore_instances: bool = False,
281+
) -> F:
268282
"""Return a copy of `method`, with the type of its first parameter (usually
269283
self or cls) bound to original_type.
270284
@@ -288,9 +302,10 @@ class B(A): pass
288302
289303
"""
290304
if isinstance(method, Overloaded):
291-
return cast(
292-
F, Overloaded([bind_self(c, original_type, is_classmethod) for c in method.items])
293-
)
305+
items = [
306+
bind_self(c, original_type, is_classmethod, ignore_instances) for c in method.items
307+
]
308+
return cast(F, Overloaded(items))
294309
assert isinstance(method, CallableType)
295310
func = method
296311
if not func.arg_types:
@@ -310,7 +325,9 @@ class B(A): pass
310325
# this special-casing looks not very principled, there is nothing meaningful we can infer
311326
# from such definition, since it is inherently indefinitely recursive.
312327
allow_callable = func.name is None or not func.name.startswith("__call__ of")
313-
if func.variables and supported_self_type(self_param_type, allow_callable=allow_callable):
328+
if func.variables and supported_self_type(
329+
self_param_type, allow_callable=allow_callable, allow_instances=not ignore_instances
330+
):
314331
from mypy.infer import infer_type_arguments
315332

316333
if original_type is None:

0 commit comments

Comments
 (0)