Skip to content

Commit 517bce7

Browse files
committed
Try fixing edge case in check_self_arg()
1 parent f229456 commit 517bce7

File tree

7 files changed

+30
-16
lines changed

7 files changed

+30
-16
lines changed

mypy/checker.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1373,8 +1373,12 @@ def check_func_def(
13731373
if not is_same_type(arg_type, ref_type):
13741374
# This level of erasure matches the one in checkmember.check_self_arg(),
13751375
# better keep these two checks consistent.
1376-
erased = get_proper_type(erase_typevars(erase_to_bound(arg_type)))
1377-
if not is_subtype(ref_type, erased, ignore_type_params=True):
1376+
erased = get_proper_type(
1377+
erase_typevars(erase_to_bound(arg_type), use_upper_bound=True)
1378+
)
1379+
if not is_subtype(
1380+
ref_type, erased, ignore_type_params=True, always_covariant=True
1381+
):
13781382
if (
13791383
isinstance(erased, Instance)
13801384
and erased.type.is_protocol

mypy/checkmember.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,6 @@
4444
erase_to_bound,
4545
freeze_all_type_vars,
4646
function_type,
47-
get_all_type_vars,
4847
get_type_vars,
4948
make_simplified_union,
5049
supported_self_type,
@@ -1060,12 +1059,8 @@ def f(self: S) -> T: ...
10601059
# better keep these two checks consistent.
10611060
if subtypes.is_subtype(
10621061
dispatched_arg_type,
1063-
erase_typevars(erase_to_bound(selfarg)),
1064-
# This is to work around the fact that erased ParamSpec and TypeVarTuple
1065-
# callables are not always compatible with non-erased ones both ways.
1066-
always_covariant=any(
1067-
not isinstance(tv, TypeVarType) for tv in get_all_type_vars(selfarg)
1068-
),
1062+
erase_typevars(erase_to_bound(selfarg), use_upper_bound=True),
1063+
always_covariant=True,
10691064
ignore_pos_arg_names=True,
10701065
):
10711066
new_items.append(item)

mypy/erasetype.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,9 @@ def visit_type_alias_type(self, t: TypeAliasType) -> ProperType:
140140
raise RuntimeError("Type aliases should be expanded before accepting this visitor")
141141

142142

143-
def erase_typevars(t: Type, ids_to_erase: Container[TypeVarId] | None = None) -> Type:
143+
def erase_typevars(
144+
t: Type, ids_to_erase: Container[TypeVarId] | None = None, use_upper_bound: bool = False
145+
) -> Type:
144146
"""Replace all type variables in a type with any,
145147
or just the ones in the provided collection.
146148
"""
@@ -150,7 +152,7 @@ def erase_id(id: TypeVarId) -> bool:
150152
return True
151153
return id in ids_to_erase
152154

153-
return t.accept(TypeVarEraser(erase_id, AnyType(TypeOfAny.special_form)))
155+
return t.accept(TypeVarEraser(erase_id, AnyType(TypeOfAny.special_form), use_upper_bound))
154156

155157

156158
def replace_meta_vars(t: Type, target_type: Type) -> Type:
@@ -161,13 +163,21 @@ def replace_meta_vars(t: Type, target_type: Type) -> Type:
161163
class TypeVarEraser(TypeTranslator):
162164
"""Implementation of type erasure"""
163165

164-
def __init__(self, erase_id: Callable[[TypeVarId], bool], replacement: Type) -> None:
166+
def __init__(
167+
self,
168+
erase_id: Callable[[TypeVarId], bool],
169+
replacement: Type,
170+
use_upper_bound: bool = False,
171+
) -> None:
165172
super().__init__()
166173
self.erase_id = erase_id
167174
self.replacement = replacement
175+
self.use_upper_bound = use_upper_bound
168176

169177
def visit_type_var(self, t: TypeVarType) -> Type:
170178
if self.erase_id(t.id):
179+
if self.use_upper_bound:
180+
return t.upper_bound
171181
return self.replacement
172182
return t
173183

@@ -204,11 +214,16 @@ def visit_tuple_type(self, t: TupleType) -> Type:
204214
return result
205215

206216
def visit_callable_type(self, t: CallableType) -> Type:
217+
use_upper_bound = self.use_upper_bound
218+
# This is to work around the fact that erased callables are not compatible
219+
# with non-erased ones (due to contravariance in arg types).
220+
self.use_upper_bound = False
207221
result = super().visit_callable_type(t)
208222
assert isinstance(result, ProperType) and isinstance(result, CallableType)
209223
# Usually this is done in semanal_typeargs.py, but erasure can create
210224
# a non-normal callable from normal one.
211225
result.normalize_trivial_unpack()
226+
self.use_upper_bound = use_upper_bound
212227
return result
213228

214229
def visit_type_var_tuple(self, t: TypeVarTupleType) -> Type:

mypy/nodes.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -592,7 +592,7 @@ def is_trivial_self(self) -> bool:
592592
if not item.is_trivial_self:
593593
self._is_trivial_self = False
594594
return False
595-
elif item.decorators or not item.func.is_trivial_self:
595+
elif len(item.decorators) > 1 or not item.func.is_trivial_self:
596596
self._is_trivial_self = False
597597
return False
598598
self._is_trivial_self = True

mypy/typeshed/stdlib/os/__init__.pyi

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -721,7 +721,7 @@ class _Environ(MutableMapping[AnyStr, AnyStr], Generic[AnyStr]):
721721
unsetenv: Callable[[AnyStr, AnyStr], object],
722722
) -> None: ...
723723

724-
def setdefault(self, key: AnyStr, value: AnyStr) -> AnyStr: ...
724+
def setdefault(self, key: AnyStr, value: AnyStr) -> AnyStr: ... # type: ignore[override]
725725
def copy(self) -> dict[AnyStr, AnyStr]: ...
726726
def __delitem__(self, key: AnyStr) -> None: ...
727727
def __getitem__(self, key: AnyStr) -> AnyStr: ...

mypy/typeshed/stdlib/weakref.pyi

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ class WeakValueDictionary(MutableMapping[_KT, _VT]):
110110
def items(self) -> Iterator[tuple[_KT, _VT]]: ... # type: ignore[override]
111111
def itervaluerefs(self) -> Iterator[KeyedRef[_KT, _VT]]: ...
112112
def valuerefs(self) -> list[KeyedRef[_KT, _VT]]: ...
113-
def setdefault(self, key: _KT, default: _VT) -> _VT: ...
113+
def setdefault(self, key: _KT, default: _VT) -> _VT: ... # type: ignore[override]
114114
@overload
115115
def pop(self, key: _KT) -> _VT: ...
116116
@overload

test-data/unit/check-overloading.test

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6762,7 +6762,7 @@ class D(Generic[T]):
67626762
def f(self, x: int) -> int: ...
67636763
@overload
67646764
def f(self, x: str) -> str: ...
6765-
def f(Self, x): ...
6765+
def f(self, x): ...
67666766

67676767
a: D[str] # E: Type argument "str" of "D" must be a subtype of "C"
67686768
reveal_type(a.f(1)) # N: Revealed type is "builtins.int"

0 commit comments

Comments
 (0)