|
3 | 3 | from __future__ import annotations |
4 | 4 |
|
5 | 5 | import pickle |
| 6 | +import warnings |
6 | 7 | from pathlib import Path |
7 | 8 |
|
| 9 | +import h5py # type: ignore |
8 | 10 | import numpy as np |
9 | | -import pandas as pd |
10 | | -from anndata import AnnData, read_h5ad |
| 11 | +import pandas as pd # type: ignore |
| 12 | +from anndata import AnnData, read_h5ad # type: ignore |
| 13 | + |
| 14 | +from tfmindi.types import _PATTERN_SPEC, _SEQLET_SPEC, Pattern, Seqlet |
11 | 15 |
|
12 | 16 |
|
13 | 17 | def _sanitize_hdf5_keys(data): |
@@ -254,3 +258,97 @@ def _convert_numpy_arrays_to_strings_chunked(df, col, chunk_size=1000): |
254 | 258 |
|
255 | 259 | # Convert to categorical to save memory |
256 | 260 | df[col] = pd.Series(converted_values, index=series.index).astype(str).astype("category") |
| 261 | + |
| 262 | + |
| 263 | +def _save_seqlet(seqlet: Seqlet, grp: h5py.Group) -> None: |
| 264 | + """Save seqlet to h5 group.""" |
| 265 | + grp.attrs["version"] = _SEQLET_SPEC |
| 266 | + for k, v in seqlet.__dict__.items(): |
| 267 | + if v is None: |
| 268 | + continue |
| 269 | + grp[k] = v |
| 270 | + |
| 271 | + |
| 272 | +def _save_pattern(pattern: Pattern, grp: h5py.Group) -> None: |
| 273 | + """Save pattern to h5 group.""" |
| 274 | + grp.attrs["version"] = _PATTERN_SPEC |
| 275 | + for k, v in pattern.__dict__.items(): |
| 276 | + if k == "seqlets": |
| 277 | + continue |
| 278 | + if v is None: |
| 279 | + continue |
| 280 | + grp[k] = v |
| 281 | + seqlets_grp = grp.create_group("seqlets") |
| 282 | + for i, seqlet in enumerate(pattern.seqlets): |
| 283 | + seqlet_grp = seqlets_grp.create_group(f"seqlet_{i}") |
| 284 | + _save_seqlet(seqlet, seqlet_grp) |
| 285 | + |
| 286 | + |
| 287 | +def _read_seqlet(grp: h5py.Group) -> Seqlet: |
| 288 | + """Load seqlet from h5 group.""" |
| 289 | + kwargs = {} |
| 290 | + if grp.attrs["version"] != _SEQLET_SPEC: |
| 291 | + warnings.warn( |
| 292 | + f"The version of the seqlet on disk ({grp.attrs['version']}) does not match with the pattern version in TF-MInDi ({_PATTERN_SPEC})! Will try to read anyway.", |
| 293 | + stacklevel=1, |
| 294 | + ) |
| 295 | + for k in grp.keys(): |
| 296 | + value = grp[k][()] # type: ignore |
| 297 | + if isinstance(value, bytes): |
| 298 | + value = value.decode("utf-8") |
| 299 | + kwargs[k] = value |
| 300 | + return Seqlet(**kwargs) |
| 301 | + |
| 302 | + |
| 303 | +def _load_pattern(grp: h5py.Group) -> Pattern: |
| 304 | + """Load pattern from h5 group.""" |
| 305 | + kwargs = {} |
| 306 | + if grp.attrs["version"] != _PATTERN_SPEC: |
| 307 | + warnings.warn( |
| 308 | + f"The version of the pattern on disk ({grp.attrs['version']}) does not match with the pattern version in TF-MInDi ({_PATTERN_SPEC})! Will try to read anyway.", |
| 309 | + stacklevel=1, |
| 310 | + ) |
| 311 | + for k in grp.keys(): |
| 312 | + if k == "seqlets": |
| 313 | + continue |
| 314 | + value = grp[k][()] # type: ignore |
| 315 | + if isinstance(value, bytes): |
| 316 | + value = value.decode("utf-8") |
| 317 | + kwargs[k] = value |
| 318 | + seqlets: list[Seqlet] = [] |
| 319 | + # Sorted to make sure that the order of the seqlets is the same as when they were saved. |
| 320 | + for seqlet_key in sorted(grp["seqlets"].keys(), key=lambda x: int(x.split("_")[1])): # type: ignore |
| 321 | + seqlets.append(_read_seqlet(grp["seqlets"][seqlet_key])) # type: ignore |
| 322 | + kwargs["seqlets"] = seqlets |
| 323 | + return Pattern(**kwargs) |
| 324 | + |
| 325 | + |
| 326 | +def save_patterns(patterns: dict[str, Pattern], filename: str | Path) -> None: |
| 327 | + """Save dict of Patterns to disk. |
| 328 | +
|
| 329 | + Paramaters |
| 330 | + ---------- |
| 331 | + patterns |
| 332 | + Dict of patterns. |
| 333 | + filename |
| 334 | + output filename. |
| 335 | + """ |
| 336 | + with h5py.File(filename, "w") as h5_handle: |
| 337 | + for key, pattern in patterns.items(): |
| 338 | + pattern_grp = h5_handle.create_group(f"pattern_{key}") |
| 339 | + _save_pattern(pattern, pattern_grp) |
| 340 | + |
| 341 | + |
| 342 | +def load_patterns(filename: str | Path) -> dict[str, Pattern]: |
| 343 | + """Load patterns from disk. |
| 344 | +
|
| 345 | + Parameters |
| 346 | + ---------- |
| 347 | + filename |
| 348 | + input filename. |
| 349 | + """ |
| 350 | + patterns: dict[str, Pattern] = {} |
| 351 | + with h5py.File(filename, "r") as h5_handle: |
| 352 | + for pattern_name in h5_handle.keys(): |
| 353 | + patterns[pattern_name.replace("pattern_", "")] = _load_pattern(h5_handle[pattern_name]) |
| 354 | + return patterns |
0 commit comments