Skip to content

[mypyc] Optimize str.startswith and str.endswith with tuple argument #18678

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Feb 18, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions mypyc/doc/str_operations.rst
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ Methods
* ``s.encode(encoding: str)``
* ``s.encode(encoding: str, errors: str)``
* ``s1.endswith(s2: str)``
* ``s1.endswith(t: tuple[str, ...])``
* ``s.join(x: Iterable)``
* ``s.removeprefix(prefix: str)``
* ``s.removesuffix(suffix: str)``
Expand All @@ -43,6 +44,7 @@ Methods
* ``s.split(sep: str)``
* ``s.split(sep: str, maxsplit: int)``
* ``s1.startswith(s2: str)``
* ``s1.startswith(t: tuple[str, ...])``

.. note::

Expand Down
4 changes: 2 additions & 2 deletions mypyc/lib-rt/CPy.h
Original file line number Diff line number Diff line change
Expand Up @@ -725,8 +725,8 @@ PyObject *CPyStr_RSplit(PyObject *str, PyObject *sep, CPyTagged max_split);
PyObject *CPyStr_Replace(PyObject *str, PyObject *old_substr, PyObject *new_substr, CPyTagged max_replace);
PyObject *CPyStr_Append(PyObject *o1, PyObject *o2);
PyObject *CPyStr_GetSlice(PyObject *obj, CPyTagged start, CPyTagged end);
bool CPyStr_Startswith(PyObject *self, PyObject *subobj);
bool CPyStr_Endswith(PyObject *self, PyObject *subobj);
int CPyStr_Startswith(PyObject *self, PyObject *subobj);
int CPyStr_Endswith(PyObject *self, PyObject *subobj);
PyObject *CPyStr_Removeprefix(PyObject *self, PyObject *prefix);
PyObject *CPyStr_Removesuffix(PyObject *self, PyObject *suffix);
bool CPyStr_IsTrue(PyObject *obj);
Expand Down
40 changes: 38 additions & 2 deletions mypyc/lib-rt/str_ops.c
Original file line number Diff line number Diff line change
Expand Up @@ -161,15 +161,51 @@ PyObject *CPyStr_Replace(PyObject *str, PyObject *old_substr,
return PyUnicode_Replace(str, old_substr, new_substr, temp_max_replace);
}

bool CPyStr_Startswith(PyObject *self, PyObject *subobj) {
int CPyStr_Startswith(PyObject *self, PyObject *subobj) {
Py_ssize_t start = 0;
Py_ssize_t end = PyUnicode_GET_LENGTH(self);
if (PyTuple_Check(subobj)) {
Py_ssize_t i;
for (i = 0; i < PyTuple_GET_SIZE(subobj); i++) {
PyObject *substring = PyTuple_GET_ITEM(subobj, i);
if (!PyUnicode_Check(substring)) {
PyErr_Format(PyExc_TypeError,
"tuple for startswith must only contain str, "
"not %.100s",
Py_TYPE(substring)->tp_name);
return -1;
}
int result = PyUnicode_Tailmatch(self, substring, start, end, -1);
if (result) {
return 1;
}
}
return 0;
}
return PyUnicode_Tailmatch(self, subobj, start, end, -1);
}

bool CPyStr_Endswith(PyObject *self, PyObject *subobj) {
int CPyStr_Endswith(PyObject *self, PyObject *subobj) {
Py_ssize_t start = 0;
Py_ssize_t end = PyUnicode_GET_LENGTH(self);
if (PyTuple_Check(subobj)) {
Py_ssize_t i;
for (i = 0; i < PyTuple_GET_SIZE(subobj); i++) {
PyObject *substring = PyTuple_GET_ITEM(subobj, i);
if (!PyUnicode_Check(substring)) {
PyErr_Format(PyExc_TypeError,
"tuple for endswith must only contain str, "
"not %.100s",
Py_TYPE(substring)->tp_name);
return -1;
}
int result = PyUnicode_Tailmatch(self, substring, start, end, 1);
if (result) {
return 1;
}
}
return 0;
}
return PyUnicode_Tailmatch(self, subobj, start, end, 1);
}

Expand Down
27 changes: 25 additions & 2 deletions mypyc/primitives/str_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
object_rprimitive,
pointer_rprimitive,
str_rprimitive,
tuple_rprimitive,
)
from mypyc.primitives.registry import (
ERR_NEG_INT,
Expand Down Expand Up @@ -104,20 +105,42 @@
method_op(
name="startswith",
arg_types=[str_rprimitive, str_rprimitive],
return_type=bool_rprimitive,
return_type=c_int_rprimitive,
c_function_name="CPyStr_Startswith",
truncated_type=bool_rprimitive,
error_kind=ERR_NEVER,
)

# str.startswith(tuple) (return -1/0/1)
method_op(
name="startswith",
arg_types=[str_rprimitive, tuple_rprimitive],
return_type=c_int_rprimitive,
c_function_name="CPyStr_Startswith",
truncated_type=bool_rprimitive,
error_kind=ERR_NEG_INT,
)

# str.endswith(str)
method_op(
name="endswith",
arg_types=[str_rprimitive, str_rprimitive],
return_type=bool_rprimitive,
return_type=c_int_rprimitive,
c_function_name="CPyStr_Endswith",
truncated_type=bool_rprimitive,
error_kind=ERR_NEVER,
)

# str.endswith(tuple) (return -1/0/1)
method_op(
name="endswith",
arg_types=[str_rprimitive, tuple_rprimitive],
return_type=c_int_rprimitive,
c_function_name="CPyStr_Endswith",
truncated_type=bool_rprimitive,
error_kind=ERR_NEG_INT,
)

# str.removeprefix(str)
method_op(
name="removeprefix",
Expand Down
4 changes: 2 additions & 2 deletions mypyc/test-data/fixtures/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,8 +108,8 @@ def strip (self, item: str) -> str: pass
def join(self, x: Iterable[str]) -> str: pass
def format(self, *args: Any, **kwargs: Any) -> str: ...
def upper(self) -> str: ...
def startswith(self, x: str, start: int=..., end: int=...) -> bool: ...
def endswith(self, x: str, start: int=..., end: int=...) -> bool: ...
def startswith(self, x: Union[str, Tuple[str, ...]], start: int=..., end: int=...) -> bool: ...
def endswith(self, x: Union[str, Tuple[str, ...]], start: int=..., end: int=...) -> bool: ...
def replace(self, old: str, new: str, maxcount: int=...) -> str: ...
def encode(self, encoding: str=..., errors: str=...) -> bytes: ...
def removeprefix(self, prefix: str, /) -> str: ...
Expand Down
67 changes: 67 additions & 0 deletions mypyc/test-data/irbuild-str.test
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,73 @@ L4:
L5:
unreachable

[case testStrStartswithEndswithTuple]
from typing import Tuple

def do_startswith(s1: str, s2: Tuple[str, ...]) -> bool:
return s1.startswith(s2)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm more interested in what happens when startswith is given a tuple literal argument (e.g. return s1.startswith(('x', 'y'))), since this is probably the most common use case. Can you test also this? It would be good to have also a run test for this.


def do_endswith(s1: str, s2: Tuple[str, ...]) -> bool:
return s1.endswith(s2)

def do_tuple_literal_args(s1: str) -> None:
x = s1.startswith(("a", "b"))
y = s1.endswith(("a", "b"))
[out]
def do_startswith(s1, s2):
s1 :: str
s2 :: tuple
r0 :: i32
r1 :: bit
r2 :: bool
L0:
r0 = CPyStr_Startswith(s1, s2)
r1 = r0 >= 0 :: signed
r2 = truncate r0: i32 to builtins.bool
return r2
def do_endswith(s1, s2):
s1 :: str
s2 :: tuple
r0 :: i32
r1 :: bit
r2 :: bool
L0:
r0 = CPyStr_Endswith(s1, s2)
r1 = r0 >= 0 :: signed
r2 = truncate r0: i32 to builtins.bool
return r2
def do_tuple_literal_args(s1):
s1, r0, r1 :: str
r2 :: tuple[str, str]
r3 :: object
r4 :: i32
r5 :: bit
r6, x :: bool
r7, r8 :: str
r9 :: tuple[str, str]
r10 :: object
r11 :: i32
r12 :: bit
r13, y :: bool
L0:
r0 = 'a'
r1 = 'b'
r2 = (r0, r1)
r3 = box(tuple[str, str], r2)
r4 = CPyStr_Startswith(s1, r3)
r5 = r4 >= 0 :: signed
r6 = truncate r4: i32 to builtins.bool
x = r6
r7 = 'a'
r8 = 'b'
r9 = (r7, r8)
r10 = box(tuple[str, str], r9)
r11 = CPyStr_Endswith(s1, r10)
r12 = r11 >= 0 :: signed
r13 = truncate r11: i32 to builtins.bool
y = r13
return 1

[case testStrToBool]
def is_true(x: str) -> bool:
if x:
Expand Down
22 changes: 21 additions & 1 deletion mypyc/test-data/run-strings.test
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,20 @@ def eq(x: str) -> int:
return 2
def match(x: str, y: str) -> Tuple[bool, bool]:
return (x.startswith(y), x.endswith(y))
def match_tuple(x: str, y: Tuple[str, ...]) -> Tuple[bool, bool]:
return (x.startswith(y), x.endswith(y))
def match_tuple_literal_args(x: str, y: str, z: str) -> Tuple[bool, bool]:
return (x.startswith((y, z)), x.endswith((y, z)))
def remove_prefix_suffix(x: str, y: str) -> Tuple[str, str]:
return (x.removeprefix(y), x.removesuffix(y))

[file driver.py]
from native import f, g, tostr, booltostr, concat, eq, match, remove_prefix_suffix
from native import (
f, g, tostr, booltostr, concat, eq, match, match_tuple,
match_tuple_literal_args, remove_prefix_suffix
)
import sys
from testutil import assertRaises

assert f() == 'some string'
assert f() is sys.intern('some string')
Expand All @@ -45,6 +53,18 @@ assert match('abc', '') == (True, True)
assert match('abc', 'a') == (True, False)
assert match('abc', 'c') == (False, True)
assert match('', 'abc') == (False, False)
assert match_tuple('abc', ('d', 'e')) == (False, False)
assert match_tuple('abc', ('a', 'c')) == (True, True)
assert match_tuple('abc', ('a',)) == (True, False)
assert match_tuple('abc', ('c',)) == (False, True)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add a test case where startswith matches a non-first tuple item. Also add test for error case (tuple contains a non-string).

It would be good to add an irbuild test for a tuple literal argument, since it's easy to imagine how a fixed-length tuple literal wouldn't match against a variable-length tuple in the primitive arg type.

assert match_tuple('abc', ('x', 'y', 'z')) == (False, False)
assert match_tuple('abc', ('x', 'y', 'z', 'a', 'c')) == (True, True)
with assertRaises(TypeError, "tuple for startswith must only contain str"):
assert match_tuple('abc', (None,))
with assertRaises(TypeError, "tuple for endswith must only contain str"):
assert match_tuple('abc', ('a', None))
assert match_tuple_literal_args('abc', 'z', 'a') == (True, False)
assert match_tuple_literal_args('abc', 'z', 'c') == (False, True)

assert remove_prefix_suffix('', '') == ('', '')
assert remove_prefix_suffix('abc', 'a') == ('bc', 'abc')
Expand Down