Skip to content

Commit 5f7c2da

Browse files
authored
Merge pull request #19 from aertslab/feature-save_patterns_to_disk
Add functionality to save and load patterns to and from disk.
2 parents db5c114 + 3c9ab8a commit 5f7c2da

File tree

7 files changed

+158
-4
lines changed

7 files changed

+158
-4
lines changed

docs/api/io.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,4 +10,6 @@ Functions for reading and writing AnnData objects optimized for seqlet analysis.
1010
1111
save_h5ad
1212
load_h5ad
13+
save_patterns
14+
load_patterns
1315
```

docs/notebooks/02_analysis_tutorial.ipynb

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -306,7 +306,7 @@
306306
},
307307
{
308308
"cell_type": "code",
309-
"execution_count": 18,
309+
"execution_count": 3,
310310
"id": "e174f226",
311311
"metadata": {},
312312
"outputs": [
@@ -326,6 +326,17 @@
326326
")"
327327
]
328328
},
329+
{
330+
"cell_type": "markdown",
331+
"id": "2b6ced23",
332+
"metadata": {},
333+
"source": [
334+
"```{tip}\n",
335+
"\n",
336+
"You can save patterns using {func}`tfmindi.save_patterns` and load them back using {func}`tfmindi.load_patterns`.\n",
337+
"```"
338+
]
339+
},
329340
{
330341
"cell_type": "code",
331342
"execution_count": 19,

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ classifiers = [
2222
]
2323
dependencies = [
2424
"anndata",
25+
"h5py>=3.14.0",
2526
"igraph",
2627
"lda",
2728
"logomaker",

src/tfmindi/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
load_motif_collection,
2323
load_motif_to_dbd,
2424
)
25-
from tfmindi.io import load_h5ad, save_h5ad # noqa: E402
25+
from tfmindi.io import load_h5ad, load_patterns, save_h5ad, save_patterns # noqa: E402
2626
from tfmindi.types import Pattern, Seqlet # noqa: E402
2727

2828
__all__ = [
@@ -41,6 +41,8 @@
4141
"load_motif_to_dbd",
4242
"save_h5ad",
4343
"load_h5ad",
44+
"save_patterns",
45+
"load_patterns",
4446
]
4547

4648
__version__ = version("tfmindi")

src/tfmindi/io.py

Lines changed: 100 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,15 @@
33
from __future__ import annotations
44

55
import pickle
6+
import warnings
67
from pathlib import Path
78

9+
import h5py # type: ignore
810
import numpy as np
9-
import pandas as pd
10-
from anndata import AnnData, read_h5ad
11+
import pandas as pd # type: ignore
12+
from anndata import AnnData, read_h5ad # type: ignore
13+
14+
from tfmindi.types import _PATTERN_SPEC, _SEQLET_SPEC, Pattern, Seqlet
1115

1216

1317
def _sanitize_hdf5_keys(data):
@@ -254,3 +258,97 @@ def _convert_numpy_arrays_to_strings_chunked(df, col, chunk_size=1000):
254258

255259
# Convert to categorical to save memory
256260
df[col] = pd.Series(converted_values, index=series.index).astype(str).astype("category")
261+
262+
263+
def _save_seqlet(seqlet: Seqlet, grp: h5py.Group) -> None:
264+
"""Save seqlet to h5 group."""
265+
grp.attrs["version"] = _SEQLET_SPEC
266+
for k, v in seqlet.__dict__.items():
267+
if v is None:
268+
continue
269+
grp[k] = v
270+
271+
272+
def _save_pattern(pattern: Pattern, grp: h5py.Group) -> None:
273+
"""Save pattern to h5 group."""
274+
grp.attrs["version"] = _PATTERN_SPEC
275+
for k, v in pattern.__dict__.items():
276+
if k == "seqlets":
277+
continue
278+
if v is None:
279+
continue
280+
grp[k] = v
281+
seqlets_grp = grp.create_group("seqlets")
282+
for i, seqlet in enumerate(pattern.seqlets):
283+
seqlet_grp = seqlets_grp.create_group(f"seqlet_{i}")
284+
_save_seqlet(seqlet, seqlet_grp)
285+
286+
287+
def _read_seqlet(grp: h5py.Group) -> Seqlet:
288+
"""Load seqlet from h5 group."""
289+
kwargs = {}
290+
if grp.attrs["version"] != _SEQLET_SPEC:
291+
warnings.warn(
292+
f"The version of the seqlet on disk ({grp.attrs['version']}) does not match with the pattern version in TF-MInDi ({_PATTERN_SPEC})! Will try to read anyway.",
293+
stacklevel=1,
294+
)
295+
for k in grp.keys():
296+
value = grp[k][()] # type: ignore
297+
if isinstance(value, bytes):
298+
value = value.decode("utf-8")
299+
kwargs[k] = value
300+
return Seqlet(**kwargs)
301+
302+
303+
def _load_pattern(grp: h5py.Group) -> Pattern:
304+
"""Load pattern from h5 group."""
305+
kwargs = {}
306+
if grp.attrs["version"] != _PATTERN_SPEC:
307+
warnings.warn(
308+
f"The version of the pattern on disk ({grp.attrs['version']}) does not match with the pattern version in TF-MInDi ({_PATTERN_SPEC})! Will try to read anyway.",
309+
stacklevel=1,
310+
)
311+
for k in grp.keys():
312+
if k == "seqlets":
313+
continue
314+
value = grp[k][()] # type: ignore
315+
if isinstance(value, bytes):
316+
value = value.decode("utf-8")
317+
kwargs[k] = value
318+
seqlets: list[Seqlet] = []
319+
# Sorted to make sure that the order of the seqlets is the same as when they were saved.
320+
for seqlet_key in sorted(grp["seqlets"].keys(), key=lambda x: int(x.split("_")[1])): # type: ignore
321+
seqlets.append(_read_seqlet(grp["seqlets"][seqlet_key])) # type: ignore
322+
kwargs["seqlets"] = seqlets
323+
return Pattern(**kwargs)
324+
325+
326+
def save_patterns(patterns: dict[str, Pattern], filename: str | Path) -> None:
327+
"""Save dict of Patterns to disk.
328+
329+
Paramaters
330+
----------
331+
patterns
332+
Dict of patterns.
333+
filename
334+
output filename.
335+
"""
336+
with h5py.File(filename, "w") as h5_handle:
337+
for key, pattern in patterns.items():
338+
pattern_grp = h5_handle.create_group(f"pattern_{key}")
339+
_save_pattern(pattern, pattern_grp)
340+
341+
342+
def load_patterns(filename: str | Path) -> dict[str, Pattern]:
343+
"""Load patterns from disk.
344+
345+
Parameters
346+
----------
347+
filename
348+
input filename.
349+
"""
350+
patterns: dict[str, Pattern] = {}
351+
with h5py.File(filename, "r") as h5_handle:
352+
for pattern_name in h5_handle.keys():
353+
patterns[pattern_name.replace("pattern_", "")] = _load_pattern(h5_handle[pattern_name])
354+
return patterns

src/tfmindi/types.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,12 @@
1010
_BASE_TO_BIN = {"A": 0, "C": 1, "G": 2, "T": 3}
1111
_BIN_TO_BASE = {0: "A", 1: "C", 2: "G", 3: "T"}
1212

13+
# Change these version numbers when breaking changes are introduced in Pattern and/or Seqlet.
14+
# That way incompatibilities can be detected when serializing to or from disk.
15+
# This version number is saved on disk along with the pattern and seqlet data.
16+
_PATTERN_SPEC = "1.0"
17+
_SEQLET_SPEC = "1.0"
18+
1319

1420
@dataclass
1521
class Seqlet:

tests/test_io.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -204,3 +204,37 @@ def test_save_h5ad_preserves_original(sample_adata_with_arrays):
204204
assert isinstance(sample_adata_with_arrays.var["motif_pwm"].iloc[0], np.ndarray)
205205
np.testing.assert_array_equal(sample_adata_with_arrays.obs["seqlet_matrix"].iloc[0], original_obs_array)
206206
np.testing.assert_array_equal(sample_adata_with_arrays.var["motif_pwm"].iloc[0], original_var_array)
207+
208+
209+
def test_save_and_load_pattern(sample_patterns):
210+
"""Test saving and loading of patterns."""
211+
with tempfile.TemporaryDirectory() as tmp_dir:
212+
filepath = Path(tmp_dir) / "test_pattern.hdf5"
213+
214+
tm.save_patterns(sample_patterns, filepath)
215+
216+
loaded_patterns = tm.load_patterns(filepath)
217+
218+
assert set(sample_patterns.keys()) == set(loaded_patterns.keys())
219+
220+
for k in loaded_patterns.keys():
221+
pattern_orig = sample_patterns[k]
222+
pattern_loaded = loaded_patterns[k]
223+
assert set(pattern_orig.__dict__.keys()) == set(pattern_loaded.__dict__.keys())
224+
for attr in pattern_loaded.__dict__.keys():
225+
if attr == "seqlets":
226+
continue
227+
if isinstance(pattern_orig.__dict__[attr], np.ndarray):
228+
np.testing.assert_array_equal(pattern_orig.__dict__[attr], pattern_loaded.__dict__[attr])
229+
else:
230+
assert pattern_orig.__dict__[attr] == pattern_loaded.__dict__[attr]
231+
seqlets_orig = pattern_orig.seqlets
232+
seqlets_loaded = pattern_loaded.seqlets
233+
assert len(seqlets_orig) == len(seqlets_loaded)
234+
for s_orig, s_loaded in zip(seqlets_orig, seqlets_loaded, strict=True):
235+
assert set(s_orig.__dict__.keys()) == set(s_loaded.__dict__.keys())
236+
for attr in s_loaded.__dict__.keys():
237+
if isinstance(s_orig.__dict__[attr], np.ndarray):
238+
np.testing.assert_array_equal(s_orig.__dict__[attr], s_loaded.__dict__[attr])
239+
else:
240+
assert s_orig.__dict__[attr] == s_loaded.__dict__[attr]

0 commit comments

Comments
 (0)