Skip to content

Commit ee364ce

Browse files
authored
Allow Any to match sequence patterns in match/case (#18448)
Fixes #17095 (comment, the primary issue was already fixed somewhere before). Fixes #16272. Fixes #12532. Fixes #12770. Prior to this PR mypy did not consider that `Any` can match any patterns, including sequence patterns (e.g. `case [_]`). This PR allows matching `Any` against any such patterns.
1 parent 9685171 commit ee364ce

File tree

3 files changed

+124
-77
lines changed

3 files changed

+124
-77
lines changed

mypy/checkpattern.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -713,6 +713,8 @@ def should_self_match(self, typ: Type) -> bool:
713713
return False
714714

715715
def can_match_sequence(self, typ: ProperType) -> bool:
716+
if isinstance(typ, AnyType):
717+
return True
716718
if isinstance(typ, UnionType):
717719
return any(self.can_match_sequence(get_proper_type(item)) for item in typ.items)
718720
for other in self.non_sequence_match_types:
@@ -763,6 +765,8 @@ def construct_sequence_child(self, outer_type: Type, inner_type: Type) -> Type:
763765
or class T(Sequence[Tuple[T, T]]), there is no way any of those can map to Sequence[str].
764766
"""
765767
proper_type = get_proper_type(outer_type)
768+
if isinstance(proper_type, AnyType):
769+
return outer_type
766770
if isinstance(proper_type, UnionType):
767771
types = [
768772
self.construct_sequence_child(item, inner_type)
@@ -772,7 +776,6 @@ def construct_sequence_child(self, outer_type: Type, inner_type: Type) -> Type:
772776
return make_simplified_union(types)
773777
sequence = self.chk.named_generic_type("typing.Sequence", [inner_type])
774778
if is_subtype(outer_type, self.chk.named_type("typing.Sequence")):
775-
proper_type = get_proper_type(outer_type)
776779
if isinstance(proper_type, TupleType):
777780
proper_type = tuple_fallback(proper_type)
778781
assert isinstance(proper_type, Instance)

mypyc/test-data/irbuild-match.test

Lines changed: 88 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -1378,14 +1378,15 @@ def f(x):
13781378
r15 :: bit
13791379
r16 :: bool
13801380
r17 :: native_int
1381-
r18, rest :: object
1382-
r19 :: str
1383-
r20 :: object
1384-
r21 :: str
1385-
r22 :: object
1386-
r23 :: object[1]
1387-
r24 :: object_ptr
1388-
r25, r26 :: object
1381+
r18 :: object
1382+
r19, rest :: list
1383+
r20 :: str
1384+
r21 :: object
1385+
r22 :: str
1386+
r23 :: object
1387+
r24 :: object[1]
1388+
r25 :: object_ptr
1389+
r26, r27 :: object
13891390
L0:
13901391
r0 = CPySequence_Check(x)
13911392
r1 = r0 != 0
@@ -1414,21 +1415,23 @@ L3:
14141415
L4:
14151416
r17 = r2 - 0
14161417
r18 = PySequence_GetSlice(x, 2, r17)
1417-
rest = r18
1418+
r19 = cast(list, r18)
1419+
rest = r19
14181420
L5:
1419-
r19 = 'matched'
1420-
r20 = builtins :: module
1421-
r21 = 'print'
1422-
r22 = CPyObject_GetAttr(r20, r21)
1423-
r23 = [r19]
1424-
r24 = load_address r23
1425-
r25 = _PyObject_Vectorcall(r22, r24, 1, 0)
1426-
keep_alive r19
1421+
r20 = 'matched'
1422+
r21 = builtins :: module
1423+
r22 = 'print'
1424+
r23 = CPyObject_GetAttr(r21, r22)
1425+
r24 = [r20]
1426+
r25 = load_address r24
1427+
r26 = _PyObject_Vectorcall(r23, r25, 1, 0)
1428+
keep_alive r20
14271429
goto L7
14281430
L6:
14291431
L7:
1430-
r26 = box(None, 1)
1431-
return r26
1432+
r27 = box(None, 1)
1433+
return r27
1434+
14321435
[case testMatchSequenceWithStarPatternInTheMiddle_python3_10]
14331436
def f(x):
14341437
match x:
@@ -1455,14 +1458,15 @@ def f(x):
14551458
r16 :: bit
14561459
r17 :: bool
14571460
r18 :: native_int
1458-
r19, rest :: object
1459-
r20 :: str
1460-
r21 :: object
1461-
r22 :: str
1462-
r23 :: object
1463-
r24 :: object[1]
1464-
r25 :: object_ptr
1465-
r26, r27 :: object
1461+
r19 :: object
1462+
r20, rest :: list
1463+
r21 :: str
1464+
r22 :: object
1465+
r23 :: str
1466+
r24 :: object
1467+
r25 :: object[1]
1468+
r26 :: object_ptr
1469+
r27, r28 :: object
14661470
L0:
14671471
r0 = CPySequence_Check(x)
14681472
r1 = r0 != 0
@@ -1492,21 +1496,23 @@ L3:
14921496
L4:
14931497
r18 = r2 - 1
14941498
r19 = PySequence_GetSlice(x, 1, r18)
1495-
rest = r19
1499+
r20 = cast(list, r19)
1500+
rest = r20
14961501
L5:
1497-
r20 = 'matched'
1498-
r21 = builtins :: module
1499-
r22 = 'print'
1500-
r23 = CPyObject_GetAttr(r21, r22)
1501-
r24 = [r20]
1502-
r25 = load_address r24
1503-
r26 = _PyObject_Vectorcall(r23, r25, 1, 0)
1504-
keep_alive r20
1502+
r21 = 'matched'
1503+
r22 = builtins :: module
1504+
r23 = 'print'
1505+
r24 = CPyObject_GetAttr(r22, r23)
1506+
r25 = [r21]
1507+
r26 = load_address r25
1508+
r27 = _PyObject_Vectorcall(r24, r26, 1, 0)
1509+
keep_alive r21
15051510
goto L7
15061511
L6:
15071512
L7:
1508-
r27 = box(None, 1)
1509-
return r27
1513+
r28 = box(None, 1)
1514+
return r28
1515+
15101516
[case testMatchSequenceWithStarPatternAtTheStart_python3_10]
15111517
def f(x):
15121518
match x:
@@ -1530,14 +1536,15 @@ def f(x):
15301536
r17 :: bit
15311537
r18 :: bool
15321538
r19 :: native_int
1533-
r20, rest :: object
1534-
r21 :: str
1535-
r22 :: object
1536-
r23 :: str
1537-
r24 :: object
1538-
r25 :: object[1]
1539-
r26 :: object_ptr
1540-
r27, r28 :: object
1539+
r20 :: object
1540+
r21, rest :: list
1541+
r22 :: str
1542+
r23 :: object
1543+
r24 :: str
1544+
r25 :: object
1545+
r26 :: object[1]
1546+
r27 :: object_ptr
1547+
r28, r29 :: object
15411548
L0:
15421549
r0 = CPySequence_Check(x)
15431550
r1 = r0 != 0
@@ -1568,21 +1575,23 @@ L3:
15681575
L4:
15691576
r19 = r2 - 2
15701577
r20 = PySequence_GetSlice(x, 0, r19)
1571-
rest = r20
1578+
r21 = cast(list, r20)
1579+
rest = r21
15721580
L5:
1573-
r21 = 'matched'
1574-
r22 = builtins :: module
1575-
r23 = 'print'
1576-
r24 = CPyObject_GetAttr(r22, r23)
1577-
r25 = [r21]
1578-
r26 = load_address r25
1579-
r27 = _PyObject_Vectorcall(r24, r26, 1, 0)
1580-
keep_alive r21
1581+
r22 = 'matched'
1582+
r23 = builtins :: module
1583+
r24 = 'print'
1584+
r25 = CPyObject_GetAttr(r23, r24)
1585+
r26 = [r22]
1586+
r27 = load_address r26
1587+
r28 = _PyObject_Vectorcall(r25, r27, 1, 0)
1588+
keep_alive r22
15811589
goto L7
15821590
L6:
15831591
L7:
1584-
r28 = box(None, 1)
1585-
return r28
1592+
r29 = box(None, 1)
1593+
return r29
1594+
15861595
[case testMatchBuiltinClassPattern_python3_10]
15871596
def f(x):
15881597
match x:
@@ -1634,14 +1643,15 @@ def f(x):
16341643
r2 :: native_int
16351644
r3, r4 :: bit
16361645
r5 :: native_int
1637-
r6, rest :: object
1638-
r7 :: str
1639-
r8 :: object
1640-
r9 :: str
1641-
r10 :: object
1642-
r11 :: object[1]
1643-
r12 :: object_ptr
1644-
r13, r14 :: object
1646+
r6 :: object
1647+
r7, rest :: list
1648+
r8 :: str
1649+
r9 :: object
1650+
r10 :: str
1651+
r11 :: object
1652+
r12 :: object[1]
1653+
r13 :: object_ptr
1654+
r14, r15 :: object
16451655
L0:
16461656
r0 = CPySequence_Check(x)
16471657
r1 = r0 != 0
@@ -1654,21 +1664,23 @@ L1:
16541664
L2:
16551665
r5 = r2 - 0
16561666
r6 = PySequence_GetSlice(x, 0, r5)
1657-
rest = r6
1667+
r7 = cast(list, r6)
1668+
rest = r7
16581669
L3:
1659-
r7 = 'matched'
1660-
r8 = builtins :: module
1661-
r9 = 'print'
1662-
r10 = CPyObject_GetAttr(r8, r9)
1663-
r11 = [r7]
1664-
r12 = load_address r11
1665-
r13 = _PyObject_Vectorcall(r10, r12, 1, 0)
1666-
keep_alive r7
1670+
r8 = 'matched'
1671+
r9 = builtins :: module
1672+
r10 = 'print'
1673+
r11 = CPyObject_GetAttr(r9, r10)
1674+
r12 = [r8]
1675+
r13 = load_address r12
1676+
r14 = _PyObject_Vectorcall(r11, r13, 1, 0)
1677+
keep_alive r8
16671678
goto L5
16681679
L4:
16691680
L5:
1670-
r14 = box(None, 1)
1671-
return r14
1681+
r15 = box(None, 1)
1682+
return r15
1683+
16721684
[case testMatchTypeAnnotatedNativeClass_python3_10]
16731685
class A:
16741686
a: int

test-data/unit/check-python310.test

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2439,3 +2439,35 @@ def foo(x: T) -> T:
24392439
return out
24402440

24412441
[builtins fixtures/isinstance.pyi]
2442+
2443+
[case testMatchSequenceReachableFromAny]
2444+
# flags: --warn-unreachable
2445+
from typing import Any
2446+
2447+
def maybe_list(d: Any) -> int:
2448+
match d:
2449+
case []:
2450+
return 0
2451+
case [[_]]:
2452+
return 1
2453+
case [_]:
2454+
return 1
2455+
case _:
2456+
return 2
2457+
2458+
def with_guard(d: Any) -> None:
2459+
match d:
2460+
case [s] if isinstance(s, str):
2461+
reveal_type(s) # N: Revealed type is "builtins.str"
2462+
match d:
2463+
case (s,) if isinstance(s, str):
2464+
reveal_type(s) # N: Revealed type is "builtins.str"
2465+
2466+
def nested_in_dict(d: dict[str, Any]) -> int:
2467+
match d:
2468+
case {"src": ["src"]}:
2469+
return 1
2470+
case _:
2471+
return 0
2472+
2473+
[builtins fixtures/dict.pyi]

0 commit comments

Comments
 (0)