Skip to content

Commit 0d4af91

Browse files
committed
Improve: Unlock GIL in Str.write_to
Closes ashvardanian#105
1 parent ca3a410 commit 0d4af91

File tree

2 files changed

+27
-0
lines changed

2 files changed

+27
-0
lines changed

python/lib.c

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1272,15 +1272,21 @@ static PyObject *Str_write_to(PyObject *self, PyObject *args, PyObject *kwargs)
12721272
}
12731273
memcpy(path_buffer, path.start, path.length);
12741274

1275+
// Unlock the Global Interpreter Lock (GIL) to allow other threads to run
1276+
// while the current thread is waiting for the file to be written.
1277+
PyThreadState *gil_state = PyEval_SaveThread();
12751278
FILE *file_pointer = fopen(path_buffer, "wb");
12761279
if (file_pointer == NULL) {
1280+
PyEval_RestoreThread(gil_state);
12771281
PyErr_SetFromErrnoWithFilename(PyExc_OSError, path_buffer);
12781282
free(path_buffer);
1283+
PyEval_RestoreThread(gil_state);
12791284
return NULL;
12801285
}
12811286

12821287
setbuf(file_pointer, NULL); // Set the stream to unbuffered
12831288
int status = fwrite(text.start, 1, text.length, file_pointer);
1289+
PyEval_RestoreThread(gil_state);
12841290
if (status != text.length) {
12851291
PyErr_SetFromErrnoWithFilename(PyExc_OSError, path_buffer);
12861292
free(path_buffer);

scripts/test.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
from random import choice, randint
22
from string import ascii_lowercase
33
from typing import Optional
4+
import tempfile
5+
import os
46

57
import pytest
68

@@ -104,6 +106,25 @@ def test_unit_buffer_protocol():
104106
assert "".join([c.decode("utf-8") for c in arr.tolist()]) == "hello"
105107

106108

109+
def test_str_write_to():
110+
native = "line1\nline2\nline3"
111+
big = Str(native) # Assuming Str is your custom class
112+
113+
# Create a temporary file
114+
with tempfile.NamedTemporaryFile(delete=False) as tmpfile:
115+
temp_filename = tmpfile.name # Store the name for later use
116+
117+
try:
118+
big.write_to(temp_filename)
119+
with open(temp_filename, "r") as file:
120+
content = file.read()
121+
assert (
122+
content == native
123+
), "The content of the file does not match the expected output"
124+
finally:
125+
os.remove(temp_filename)
126+
127+
107128
def test_unit_split():
108129
native = "line1\nline2\nline3"
109130
big = Str(native)

0 commit comments

Comments
 (0)