How to save trace in between sampling and then resume sampling later on?

Hello everyone,

I would like to save the trace every n samples or so. Is there a way to do this in pymc? Specifically, I would like to do something like the following pseudo code:

with pm.Model() as model:

  ## priors 

  trace1 = pm.sample(3000, tune = 1000, chains = 3, random_seed=rng, 
                                            return_inferencedata=True)
  az.to_netcdf(trace1, 'TRACE1.nc')


  trace2 = pm.sample(4000, tune = 0, chains = 3, random_seed=rng, 
                                           return_inferencedata=True)
  az.to_netcdf(trace2, 'TRACE2.nc')

Following this, I would like to resume sampling in a separate script while being consistent with the previous definition of priors and model:

with pm.Model() as model:
  # Exact same priors and model

  # somehow resume from where trace2 stopped
  trace3 = pm.sample(5000, tune = 0, chains = 3, random_seed=rng, 
                                           return_inferencedata=True)  
  az.to_netcdf(trace3, 'TRACE3.nc')

Is this possible to do? If not, would it at least be possible to implement the first block?
Thank you for your time and help. I would really appreciate any guidance or comments.

@lucianopaz has been working on this.

EDIT: take a look at pymc.backends.zarr.ZarrTrace — PyMC 5.23.0 documentation

You can’t do what you want at the moment. BUT, I have a PR open that enables it. The one catch is that you’ll have to use the new and special ZarrTrace instead of leaving the trace argument of pm.sample as the default None.

With the code from that PR you’ll basically be able to do this:

with model:
    trace = ZarrTrace(
        store=zarr.ZipStore("path_to_file.zip"),
        include_transformed=True,
        draws_per_chunk=100,
    )
    # First tuning run
    pm.sample(tune=400, draws=0, trace=trace)

    # Do whatever to decide if you want to continue tuning   
    pm.sample(tune=800, draws=0, trace=trace)

    # Switch to sampling
    idata = pm.sample(tune=800, draws=1000, trace=trace)

    # Work with idata as you'd do regulary

The trace object will still be around for you to use, and all of your progress will be saved in the zarr store that you chose at the beginning. The draws_per_chunk argument will control how often you write the draws to the store, so if your process suddenly dies (for example due to power outage), the progress will have been stored and you’ll be able to pick up your sampling from where you left off. You can load the trace object from the storage like this:

loaded_trace = ZarrTrace.from_store(zarr_store_object)