Skip to content

[mypyc] Optimize str.startswith and str.endswith with tuple argument #18678

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Feb 18, 2025
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions mypyc/doc/str_operations.rst
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ Methods
* ``s.encode(encoding: str)``
* ``s.encode(encoding: str, errors: str)``
* ``s1.endswith(s2: str)``
* ``s1.endswith(t: tuple[str, ...])``
* ``s.join(x: Iterable)``
* ``s.removeprefix(prefix: str)``
* ``s.removesuffix(suffix: str)``
Expand All @@ -43,6 +44,7 @@ Methods
* ``s.split(sep: str)``
* ``s.split(sep: str, maxsplit: int)``
* ``s1.startswith(s2: str)``
* ``s1.startswith(t: tuple[str, ...])``

.. note::

Expand Down
4 changes: 2 additions & 2 deletions mypyc/lib-rt/CPy.h
Original file line number Diff line number Diff line change
Expand Up @@ -725,8 +725,8 @@ PyObject *CPyStr_RSplit(PyObject *str, PyObject *sep, CPyTagged max_split);
PyObject *CPyStr_Replace(PyObject *str, PyObject *old_substr, PyObject *new_substr, CPyTagged max_replace);
PyObject *CPyStr_Append(PyObject *o1, PyObject *o2);
PyObject *CPyStr_GetSlice(PyObject *obj, CPyTagged start, CPyTagged end);
bool CPyStr_Startswith(PyObject *self, PyObject *subobj);
bool CPyStr_Endswith(PyObject *self, PyObject *subobj);
int CPyStr_Startswith(PyObject *self, PyObject *subobj);
int CPyStr_Endswith(PyObject *self, PyObject *subobj);
PyObject *CPyStr_Removeprefix(PyObject *self, PyObject *prefix);
PyObject *CPyStr_Removesuffix(PyObject *self, PyObject *suffix);
bool CPyStr_IsTrue(PyObject *obj);
Expand Down
40 changes: 38 additions & 2 deletions mypyc/lib-rt/str_ops.c
Original file line number Diff line number Diff line change
Expand Up @@ -161,15 +161,51 @@ PyObject *CPyStr_Replace(PyObject *str, PyObject *old_substr,
return PyUnicode_Replace(str, old_substr, new_substr, temp_max_replace);
}

bool CPyStr_Startswith(PyObject *self, PyObject *subobj) {
int CPyStr_Startswith(PyObject *self, PyObject *subobj) {
Py_ssize_t start = 0;
Py_ssize_t end = PyUnicode_GET_LENGTH(self);
if (PyTuple_Check(subobj)) {
Py_ssize_t i;
for (i = 0; i < PyTuple_GET_SIZE(subobj); i++) {
PyObject *substring = PyTuple_GET_ITEM(subobj, i);
if (!PyUnicode_Check(substring)) {
PyErr_Format(PyExc_TypeError,
"tuple for startswith must only contain str, "
"not %.100s",
Py_TYPE(substring)->tp_name);
return -1;
}
int result = PyUnicode_Tailmatch(self, substring, start, end, -1);
if (result) {
return 1;
}
}
return 0;
}
return PyUnicode_Tailmatch(self, subobj, start, end, -1);
}

bool CPyStr_Endswith(PyObject *self, PyObject *subobj) {
int CPyStr_Endswith(PyObject *self, PyObject *subobj) {
Py_ssize_t start = 0;
Py_ssize_t end = PyUnicode_GET_LENGTH(self);
if (PyTuple_Check(subobj)) {
Py_ssize_t i;
for (i = 0; i < PyTuple_GET_SIZE(subobj); i++) {
PyObject *substring = PyTuple_GET_ITEM(subobj, i);
if (!PyUnicode_Check(substring)) {
PyErr_Format(PyExc_TypeError,
"tuple for endswith must only contain str, "
"not %.100s",
Py_TYPE(substring)->tp_name);
return -1;
}
int result = PyUnicode_Tailmatch(self, substring, start, end, 1);
if (result) {
return 1;
}
}
return 0;
}
return PyUnicode_Tailmatch(self, subobj, start, end, 1);
}

Expand Down
27 changes: 25 additions & 2 deletions mypyc/primitives/str_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
object_rprimitive,
pointer_rprimitive,
str_rprimitive,
tuple_rprimitive,
)
from mypyc.primitives.registry import (
ERR_NEG_INT,
Expand Down Expand Up @@ -104,20 +105,42 @@
method_op(
name="startswith",
arg_types=[str_rprimitive, str_rprimitive],
return_type=bool_rprimitive,
return_type=c_int_rprimitive,
c_function_name="CPyStr_Startswith",
truncated_type=bool_rprimitive,
error_kind=ERR_NEVER,
)

# str.startswith(tuple) (return -1/0/1)
method_op(
name="startswith",
arg_types=[str_rprimitive, tuple_rprimitive],
return_type=c_int_rprimitive,
c_function_name="CPyStr_Startswith",
truncated_type=bool_rprimitive,
error_kind=ERR_NEG_INT,
)

# str.endswith(str)
method_op(
name="endswith",
arg_types=[str_rprimitive, str_rprimitive],
return_type=bool_rprimitive,
return_type=c_int_rprimitive,
c_function_name="CPyStr_Endswith",
truncated_type=bool_rprimitive,
error_kind=ERR_NEVER,
)

# str.endswith(tuple) (return -1/0/1)
method_op(
name="endswith",
arg_types=[str_rprimitive, tuple_rprimitive],
return_type=c_int_rprimitive,
c_function_name="CPyStr_Endswith",
truncated_type=bool_rprimitive,
error_kind=ERR_NEG_INT,
)

# str.removeprefix(str)
method_op(
name="removeprefix",
Expand Down
4 changes: 2 additions & 2 deletions mypyc/test-data/fixtures/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,8 +108,8 @@ def strip (self, item: str) -> str: pass
def join(self, x: Iterable[str]) -> str: pass
def format(self, *args: Any, **kwargs: Any) -> str: ...
def upper(self) -> str: ...
def startswith(self, x: str, start: int=..., end: int=...) -> bool: ...
def endswith(self, x: str, start: int=..., end: int=...) -> bool: ...
def startswith(self, x: Union[str, Tuple[str, ...]], start: int=..., end: int=...) -> bool: ...
def endswith(self, x: Union[str, Tuple[str, ...]], start: int=..., end: int=...) -> bool: ...
def replace(self, old: str, new: str, maxcount: int=...) -> str: ...
def encode(self, encoding: str=..., errors: str=...) -> bytes: ...
def removeprefix(self, prefix: str, /) -> str: ...
Expand Down
8 changes: 7 additions & 1 deletion mypyc/test-data/run-strings.test
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,13 @@ def eq(x: str) -> int:
return 2
def match(x: str, y: str) -> Tuple[bool, bool]:
return (x.startswith(y), x.endswith(y))
def match_tuple(x: str, y: Tuple[str, ...]) -> Tuple[bool, bool]:
return (x.startswith(y), x.endswith(y))
def remove_prefix_suffix(x: str, y: str) -> Tuple[str, str]:
return (x.removeprefix(y), x.removesuffix(y))

[file driver.py]
from native import f, g, tostr, booltostr, concat, eq, match, remove_prefix_suffix
from native import f, g, tostr, booltostr, concat, eq, match, match_tuple, remove_prefix_suffix
import sys

assert f() == 'some string'
Expand All @@ -45,6 +47,10 @@ assert match('abc', '') == (True, True)
assert match('abc', 'a') == (True, False)
assert match('abc', 'c') == (False, True)
assert match('', 'abc') == (False, False)
assert match_tuple('abc', ('d', 'e')) == (False, False)
assert match_tuple('abc', ('a', 'c')) == (True, True)
assert match_tuple('abc', ('a',)) == (True, False)
assert match_tuple('abc', ('c',)) == (False, True)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add a test case where startswith matches a non-first tuple item. Also add test for error case (tuple contains a non-string).

It would be good to add an irbuild test for a tuple literal argument, since it's easy to imagine how a fixed-length tuple literal wouldn't match against a variable-length tuple in the primitive arg type.


assert remove_prefix_suffix('', '') == ('', '')
assert remove_prefix_suffix('abc', 'a') == ('bc', 'abc')
Expand Down