Diffusion Based Causal Representation Learning
Diffusion Based Causal Representation Learning
3
Max Planck Institute for Intelligent Systems, Tübingen, Germany
4
Technical University of Munich
5
Department of Computer Science, University of Oxford
Abstract
Causal reasoning can be considered a cornerstone of intelligent systems. Having access
to an underlying causal graph comes with the promise of cause-effect estimation and the
identification of efficient and safe interventions. However, learning causal representations
remains a major challenge, due to the complexity of many real-world systems. Previous
works on causal representation learning have mostly focused on Variational Auto-Encoders
(VAE). These methods only provide representations from a point estimate, and they
are unsuitable to handle high dimensions. To overcome these problems, we proposed a
new Diffusion-based Causal Representation Learning (DCRL) algorithm. This algorithm
uses diffusion-based representations for causal discovery. DCRL offers access to infinite
dimensional latent codes, which encode different levels of information in the latent code.
In a first proof of principle, we investigate the use of DCRL for causal representation
learning. We further demonstrate experimentally that this approach performs comparably
well in identifying the causal structure and causal variables.
1 Introduction
Causal representation learning consists in uncovering a system’s latent causal factors and their
relationships, from observed low-level data. Causal representation learning finds applicability
in domains such as autonomous driving [Schölkopf et al., 2021], robotics [Hellström, 2021],
healthcare [Anwar et al., 2014], climate studies [Runge et al., 2019], epidemiology [Hernán
et al., 2000, Robins et al., 2000], and finance [Hiemstra and Jones, 1994]. In these tasks,
the underlying causal variables are often unknown, and we only have access to low-level
representations.
Causal representation learning is a challenging problem. In fact, identifying latent causal factors
is generally impossible from observational data. There has been an ongoing effort to study
sets of assumptions that ensure the identifiability of causal variables and their relationships
[Brehmer et al., 2022, Liu et al., 2022, Schölkopf et al., 2021, Subramanian et al., 2022, Yang
et al., 2020]. These approaches consider the availability of additional information or they use
assumptions on the underlying causal structure of the DGP. Interestingly, Brehmer et al. [2022]
consider a weak form of supervision in which we have access to a data pair, corresponding
1
to the state of the system before and after a random, unknown intervention. Brehmer et al.
[2022] prove that, in this weakly-supervised setting, the structure and the causal variables are
identifiable up to a relabeling and element-wise reparameterization.
There has been a growing interest in leveraging generative models to learn causal representations
with specific properties. For example, disentangled and object-centric representations have
been shown to be helpful for complex downstream tasks and generalization [Dittadi et al., 2022,
Papa et al., 2022, Van Steenkiste et al., 2019, Wu et al., 2022, Yoon et al., 2023]. Variational
Autoencoders (VAE) [Kingma and Welling, 2014] are among the most widely studied generative
models, and they have been successfully used for disentanglement and causal representation
learning [Brehmer et al., 2022, Locatello et al., 2020]. However, the problem of learning causal
representations has not yet been approached with more powerful generative models.
Recently, diffusion models have emerged as state-of-the-art generative models, and they have
demonstrated remarkable success across several domains [Dhariwal and Nichol, 2021a, Ho
et al., 2022b, Höppe et al., 2022, Ramesh et al., 2022, Saharia et al., 2022]. Diffusion models
draw on concepts and principles from diffusion processes to learn the data distribution [Cai
et al., 2020, Chen et al., 2021, Dhariwal and Nichol, 2021b, Ho et al., 2020, 2022a, Luhman
and Luhman, 2021, Mehrjou et al., 2017, Niu et al., 2020, Sajjadi et al., 2018, Saremi et al.,
2018, Sohl-Dickstein et al., 2015, Sohl-Dickstein et al., 2015, Song et al., 2021a,b,c]. These
models exploit diffusion behavior to produce diverse, high-quality, and realistic samples.
Furthermore, diffusion-based models have the appealing property of infinite-dimensional latent
codes [Abstreiter et al., 2022], which allows to efficiently learn representations across different
downstream tasks. Despite their remarkable performance and advantages, diffusion models
have not yet been employed for causal representation learning, indicating that their potential
has yet to be explored in this context.
Our contribution. In this work, we study the connection between diffusion-based models
and causal structure learning. In particular, our contributions are the following:
• We propose DCRL, a diffusion-based model for causal representation learning. We study
and test the connection between the learned representations of DCRL with causal variables.
To accomplish this, we utilize both finite and infinite-dimensional representations.
• We derive the Evidence Lower Bound (ELBO) for DCRL, in the case of both finite and
infinite-dimensional representations.
• We empirically illustrate that the noise and diffusion-based representations contain
equivalent information about the underlying causal variables and causal mechanisms,
and can be used interchangeably.
2 Related Work
Diffusion-based Representation Learning. Learning representations with diffusion mod-
els remains a relatively unexplored area. Several works try to train an external module (e.g.,
an encoder) along with the score function of the diffusion model to extract representations.
Abstreiter et al. [2022] and Mittal et al. [2022] condition the score function of a diffusion model
on a time-independent and time-dependent encoder and obtain finite and infinite-dimensional
representations, respectively. Wang et al. [2023] use the same conditioning but regularizes
2
the objective function with the mutual information between the input data and learned rep-
resentations. Traub [2022] does the same conditioning but they use Latent Diffusion Models
[Rombach et al., 2022] where the inputs of the diffusion model are latent variables obtained from
applying a pre-trained autoencoder on the input. Furthermore, Kwon et al. [2022] proposes an
asymmetric reverse process that discovers the semantic latent space of a frozen diffusion model
where modification in the space synthesizes various attributes on input images. However, in
principle, diffusion models lack a semantic latent space and it’s unclear how to efficiently learn
representations using their capabilities.
2.1 Overview
The fundamental concept behind diffusion-based generative models is to learn to generate data
by inverting a diffusion process. Diffusion models comprise two processes: a forward process
and a backward process. The forward process gradually adds noise to data and maps data
to (almost) pure noise. The backward process, on the other hand, is used to go from a noise
sample back to the original data space.
The forward process is defined by a stochastic differential equation (SDE) across a continuous
time domain t ∈ [0, 1], aiming to transform the data distribution to a known prior distribution,
typically a standard multivariate Gaussian. Given x0 sampled from a data distribution p(x),
the forward process constructs a trajectory (xt )t∈[0,1] across the time domain. We utilize the
Variance Exploding SDE [Song et al., 2021c] for the forward process, which is defined as:
r
d[σ 2 (t)]
dx = f (x, t) + g(t)dw := dw,
dt
where w is the standard Wiener process and σ 2 (t) is the noise variance of the diffusion process
at time t. The backward process is also formulated as an SDE in the following manner:
dx = [f (x, t) − g 2 (t)∇x log pt (x)]dt + g(t)dw̄ ,
where w̄ is the standard Wiener process in reverse time.
3
Score matching. To use this backward process, the score function ∇x log pt (x) is required.
It is usually approximated by a neural score function sθ (·) which can be trained by Explicit
Score Matching [Hyvärinen and Dayan, 2005] defined as:
" #
h i
2
L(θ) = Et λ(t)Ep(xt ) ||sθ (xt , t) − ∇xt log pt (xt )|| ,
However, the ground-truth score function ∇x log pt (x) is generally not known. Vincent [2011]
addresses this issue by proposing Denoising Score Matching. The approximate score function
is then learned by minimizing the loss function:
" #
h i
2
L(θ) = λ(t)Ex0 Ep(xt |x0 ) ||sθ (xt , t) − ∇xt log pt (xt |x0 )|| ,
where the conditional distribution of xt given x0 is pt (xt |x0 ) = N (xt ; x0 , [σ 2 (t) − σ 2 (0)]I) and
λ(t) is a positive weighting function. This objective function originates from the evidence lower
bound (ELBO) of the data distribution, and it’s been shown that with a specific weighting
function, this objective function becomes exactly a term in the ELBO Song et al. [2021c]. For
more details, see Appendix A.
where the score function is conditioned on a module Eϕ (x0 ) which provides additional informa-
tion about the data to the diffusion model through a learned encoder with parameters ϕ. In
fact, the encoder learns to extract necessary information from x0 in a reduced-dimensional
space that helps recover x0 by denoising xt . Abstreiter et al. [2022] also presents an alternative
objective where the encoder is a function of time. Formally, the new objective is
" #
h i
L(θ, ϕ) = Et λ(t)Ex0 Ep(xt |x0 ) ||sθ (xt , Eϕ (x0 , t), t) − ∇xt log pt (xt |x0 )||2 , (2)
With this objective, the encoder learns a representation trajectory of x0 instead of a single
representation. Training this system has the potential to minimize the objective to zero,
motivating the encoder Eϕ (.) to learn meaningful, distinct representations at different timesteps
[Abstreiter et al., 2022, Mittal et al., 2022].
4
U0 U1 UT
p(U0 | U1) p(U1 | U2) p(UT-1 | UT)
…
q(U1 | U0) q(U2 | U1) q(UT-1 | UT)
Encoder
Projection
Encoder
Figure 1: Overview of our framework. Here we have a paired image of a face before and after
an intervention (the smile). The paired image is mapped to latent variables by a stochastic
encoder. The intervention target is determined by applying the intervention encoder to these
latent variables. To maintain the weakly supervised structure, the latent variables are projected
into a new pair and then, serve as the conditioning module for a conditional diffusion model
(The projected latent variables are diffusion-based representations of the input pair). Finally,
they are utilized in neural solution functions together with the intervention target to obtain
the latent causal variables.
along the trajectory contain different levels of information, as highlighted by Mittal et al. [2022].
In this work, we first explore a time-independent single code where we employ Eq. 1 and show
that with a certain weighting function, this objective function will become the ELBO. Then,
we apply the same experiments with infinite-dimensional latent code (Eq. 2) and study the
benefits and implications of these formulations for causal representation learning.
3 Problem Description
We consider a system that is described by an unknown underlying SCM on the latent causal
variable Z where we have access to low-level data pairs (x, x̃) ∼ p(x, x̃) representing the system
before and after a random, unknown, and atomic intervention. It is known that under this
weakly supervised setting, it is possible to identify the causal variables and causal mechanisms
up to a permutation and elementwise reparameterization of the variables [Brehmer et al.,
2022]. Our objective is to learn an SCM that accurately represents the true underlying SCM
associated with the given data, up to a permutation and elementwise reparameterization
of causal variables. To this end, we train an SCM by maximizing the likelihood of data.
5
With sufficient data and perfect optimization, we can find the SCM that is equivalent to the
ground-truth SCM.
The Encoding and the Intervention Module. The encoding module consists of two
main parts: the stochastic encoder and the projection module. The stochastic encoder q(e|x)
maps data pairs (x, x̃) to pre-projection latent variables (e, ẽ). The encoded inputs are then
utilized in the intervention module q(I|x, x̃) to infer the intervention target I for the data pair
(x, x̃). Based on our data generation process, the encoded inputs have the property that only
for the elements that are intervened upon, we have ei ̸= ẽi , i ∈ I, and the rest will remain
the same. Based on this property, in order to infer interventions, we employ an intervention
module q(I|e, ẽ) which is defined heuristically as
1
log q(i ∈ I|x, x̃) = (α + β|µe (x)i − µe (x̃)i | + γ|µe (x)i − µe (x̃)i |2 )
Z
Where µe (x) is the mean of the stochastic encoder q(e|x), α, β, and γ are learnable parameters,
and Z is a normalization constant. Using this simple heuristic function, we increase the
likelihood of a component as it undergoes more significant changes in response to interventions
on the encoded input. Once the intervention is inferred from the pre-projection latent variables,
we apply the projection module. The projection module is dependent on the inferred intervention
target I and projects the encoded input (e, ẽ) to new latent variables in a way that for the
components ei that are not intervened upon, i ∈ / I, the pre-intervention and post-intervention
latent components will be equal, ei = ẽi . This prevents solution functions from deviating from
the weakly supervised structure.
We write the combination of the encoder and the projection module as q(e, ẽ|x, x̃, I), and refer
it to as the encoding module. By this definition, the encoding module q(e, ẽ|x, x̃, I) maps the
input (x, x̃) to latent variables (e, ẽ) and the intervention module infers the intervention I
based on pre-projection latent variables.
Prior. Given the intervention target I and latent variables (e, ẽ), we define the prior p(e, ẽ, I)
as p(e, ẽ, I) = p(I)p(e)p(ẽ|e, I). The objective of the prior distribution is to implicitly capture
6
the causal structure and causal mechanisms within the system. Specifically, p(I) and p(e)
denote the prior distributions over intervention targets and latent variables, respectively,
and are configured as uniform categorical and standard Gaussian distributions, respectively.
According to our data generation process, when an intervention is applied, only the elements in
the latent variables that are intervened upon are altered; the other elements remain unchanged
and independent of each other. Consequently, we can define p(ẽ|e, I) as follows:
Y Y
p(ẽ|e, I) = δ(ẽi − ei ) p(ẽi |e)
i∈I
/ i∈I
In this equation, δ(.) is the Dirac delta function that fulfills this property for non-intervened
elements of latent variables.
Neural Solution Functions. Finally, in order to encode the information about the intervened
variables, we incorporate a conditional normalizing flow p(ẽi |e) defined as
∂hi (ẽi ; ei )
p(ẽi |e) = p̃(hi (ẽi ; ei ))
∂ẽi
where h(.) are the solution functions of the SCM. They are defined as invertible affine trans-
formations with parameters learned with neural networks. Therefore, by learning solution
functions, i.e., learning to transform e to z, we implicitly model the causal graph into the
framework and obtain the latent causal variables. For more details about the implementation,
see Appendix B.
+ log p(ẽ|e, I) − log q(I|x, x̃) − log q(e, ẽ|x, x̃, I) + λ(t)||sθ (ut , e, t) − ∇ut log p(ut |x)||22
#
+λ(t)||sθ (ũt , ẽ, t) − ∇ũt log p(ũt |x̃)||22 ,
where λ(t) is a positive weighting function. We train the model by minimizing a reweighted
loss function reminiscent of β-VAEs:
"
Lmodel = Ep(x,x̃) Eq(I|x,x̃) Eq(e,ẽ|x,x̃,I) Et∼U (0,1) Eq(ut |x) Eq(ũt |x̃) λ(t)||sθ (ut , e, t)
h
− ∇ut log p(ut |x)||22 + λ(t)||sθ (ũt , ẽ, t) − ∇ũt log p(ũt |x̃)||22 + β log p(I) + log p(e)
#
i
+ log p(ẽ|e, I) − log q(I|x, x̃) − log q(e, ẽ|x, x̃, I) ,
7
In case of using infinite-dimensional representations (Eq. 2), the objective function be-
comes:
"
Lmodel = Ep(x,x̃) Eq(I|x,x̃) Et∼U (0,1) Eq(et ,e˜t |x,x̃,I) Eq(ut |x) Eq(ũt |x̃) λ(t)||sθ (ut , et , t)
h
− ∇ut log p(ut |x)||22 + λ(t)||sθ (ũt , e˜t , t) − ∇ũt log p(ũt |x̃)||22 + β log p(I) + log p(et )
#
i
+ log p(ẽt |et , I) − log q(I|x, x̃) − log q(et , ẽt |x, x̃, I) , (3)
where (et )t∈[0,1] is the trajectory-based representation and et ∈ Rd is the single point of the
trajectory at time t. For more details about the problem formulation, see Appendix A. To
prevent a collapse of the latent space to a lower-dimensional subspace, we add the negative
entropy of the batch-aggregate intervention posterior (qIbatch (I) = Ex,x̃∈batch [q(I|x, x̃]) as a
regularization term to the loss function:
h X i
Lentropy = Ebatches − qIbatch (I) log qIbatch (I)
I
where Ebatches [ · ] is the expected value over all the batches of data. After the training,
the framework contains information about the underlying causal structure and latent causal
variables and it can be used in different downstream tasks.
5 Experiments
5.1 Overview of the Experiments
Here we analyze the performance of the proposed model, DCRL, on synthetic data. We
employ DCRL for the task of causal discovery and subsequently use ENCO [Lippe et al.,
2021], a continuous optimization structure learning method that leverages observational and
interventional data, on top of DCRL to infer the underlying causal graph. Furthermore, we
evaluate the learned latent variables with the DCI framework [Eastwood and Williams, 2018].
Data Generation. In order to generate latent variables, we adopt random graphs where
each edge in a fixed topological order is sampled from a Bernoulli distribution with a parameter
that is equal to 0.5. We consider the SCM to be linear Gaussian and we sample the weights
from a multivariate Normal distribution with zero mean and unit variance. We make sure the
weights are not close to zero to avoid the violation of the faithfulness assumption. We introduce
additive Gaussian noise with equal variances across all nodes, with its variance set to 0.1.
Latent causal variables are then sampled using ancestral sampling, and we generate 105 training
samples, 104 validation samples, and 104 test samples. Finally, to generate input data x, we
apply a random linear projection on the obtained latent variables. We keep the dimension of x
fixed to 16. We utilize an SCM with 5, 10, and 15 variables. To enhance the robustness of the
results, we generate data for 4 different seeds and repeat our experiments for each seed.
8
Figure 2: Comparison of models on different metrics when using single-point representation.
Our approach outperforms or competes favorably with the baseline methods on all metrics.
Particularly in higher dimensions, our method excels by capturing additional information about
the causal variables and the underlying causal structure.
Baselines. We consider ILCM as our main baseline. To the best of our knowledge, there
aren’t any other methods that consider the same weakly-supervised assumptions. We also
evaluate the outcomes against a variation of disentanglement VAE proposed by [Locatello
et al., 2020] tailored for weakly supervised settings. This model, referred to as d-VAE, models
the weakly supervised process but assumes unconnected variation factors instead of a causal
relationship among variables. Similarly, we apply ENCO on top of both to obtain the learned
graph.
9
5.2 Single-point Representations
Utilizing single-point representations where e ∈ Rd and is independent of time, our method
demonstrates superior or competitive performance compared to the baselines, as indicated by
the metrics shown in Figure 2. In higher dimensions, our method excels by acquiring more
information about the causal variables and underlying causal structure.
6 Conclusion
Identifying the underlying causal variables and mechanisms of a system solely from observational
data is considered impossible without additional assumptions. In this project, we use weak
supervision as an inductive bias and study if the information encoded in the latent code of
diffusion-based representations contains useful knowledge of causal variables and the underlying
causal graph.
References
K. Abstreiter, S. Mittal, S. Bauer, B. Schölkopf, and A. Mehrjou. Diffusion-based representation
learning. CoRR, abs/2105.14257, 2022.
10
N. Chen, Y. Zhang, H. Zen, R. J. Weiss, M. Norouzi, and W. Chan. Wavegrad: Estimating
gradients for waveform generation. In Proc. of ICLR, 2021.
P. Dhariwal and A. Nichol. Diffusion models beat gans on image synthesis. In Proc. of NeurIPS,
pages 8780–8794, 2021a.
P. Dhariwal and A. Q. Nichol. Diffusion models beat gans on image synthesis. In Proc. of
NeurIPS, pages 8780–8794, 2021b.
C. Hiemstra and J. D. Jones. Testing for linear and nonlinear granger causality in the stock
price-volume relation. The Journal of Finance, 49(5):1639–1664, 1994.
J. Ho, A. Jain, and P. Abbeel. Denoising diffusion probabilistic models. In Proc. of NeurIPS,
pages 8780–8794, 2020.
T. Höppe, A. Mehrjou, S. Bauer, D. Nielsen, and A. Dittadi. Diffusion models for video
prediction and infilling. CoRR, abs/2206.07696, 2022.
11
A. Komanduri, Y. Wu, W. Huang, F. Chen, and X. Wu. Scm-vae: Learning identifiable causal
representations via structural knowledge. In IEEE Big Data, pages 1014–1023, 2022.
M. Kwon, J. Jeong, and Y. Uh. Diffusion models already have a semantic latent space. CoRR,
abs/2210.10960, 2022.
P. Lippe, T. Cohen, and E. Gavves. Efficient neural causal discovery without acyclicity
constraints. CoRR, abs/2107.10483, 2021.
E. Luhman and T. Luhman. Knowledge distillation in iterative generative models for improved
sampling speed. CoRR, abs/2101.02388, 2021.
C. Niu, Y. Song, J. Song, S. Zhao, A. Grover, and S. Ermon. Permutation invariant graph
generation via score-based generative modeling. In Proc. of AISTATS, volume 108, pages
4474–4484, 2020.
S. Papa, O. Winther, and A. Dittadi. Inductive biases for object-centric representations in the
presence of complex textures. In UAI 2022 Workshop on Causal Representation Learning,
2022.
12
J. Runge, S. Bathiany, E. Bollt, G. Camps-Valls, D. Coumou, E. Deyle, C. Glymour,
M. Kretschmer, M. D. Mahecha, J. Muñoz-Marí, et al. Inferring causation from time
series in earth system sciences. Nature Communications, 10(1):2553, 2019.
J. Song, C. Meng, and S. Ermon. Denoising diffusion implicit models. In Proc. of ICLR, 2021a.
13
Y. Wang, Y. Schiff, A. Gokaslan, W. Pan, F. Wang, C. De Sa, and V. Kuleshov. Infodiffusion:
Representation learning using information maximizing diffusion models. arXiv preprint
arXiv:2306.08757, 2023.
Z. Wu, N. Dvornik, K. Greff, T. Kipf, and A. Garg. Slotformer: Unsupervised visual dynamics
simulation with object-centric models. CoRR, abs/2210.05861, 2022.
M. Yang, F. Liu, Z. Chen, X. Shen, J. Hao, and J. Wang. Causalvae: Structured causal
disentanglement in variational autoencoder. CoRR, abs/2208.14153, 2020.
J. Yoon, Y.-F. Wu, H. Bae, and S. Ahn. An investigation into pre-training object-centric
representations for reinforcement learning. CoRR, abs/2302.04419, 2023.
14
Appendix
A Problem Formulation & ELBO
The ELBO for the proposed framework will be (For simplicity, we only derive the ELBO
when using single representations independent of time, i.e., e ∈ Rd . The ELBO for the
infinite-dimensional case would be similar):
p(x, x̃, u, ũ, e, ẽ, I)
log p(x, x̃) ≥ Eq(e,ẽ,u,ũ,I|x,x̃) log
q(e, ẽ, I, u, ũ|x, x̃)
p(I) p(e)p(ẽ|e, I) p(x, u|e) p(x̃, ũ|ẽ)
=Eq(e,ẽ,u,ũ,I|x,x̃) log + log + log + log
q(I|x, x̃) q(e, ẽ|x, x̃, I) q(u|x) q(ũ|x̃)
"
h
=Eq(I|x,x̃) Eq(e,ẽ|x,x̃,I) Eq(u|x) Eq(ũ|x̃) log p(I) + log p(e) + log p(ẽ|e, I) − log q(I|x, x̃)
#
i p(x, u|e) p(x̃, ũ|ẽ)
− log q(e, ẽ|x, x̃, I) + log + log
q(u|x) q(ũ|x̃)
The terms in the first bracket correspond to the intervention encoder and the noise encoding
module, respectively, and the terms in the second bracket correspond to the diffusion model
conditioned on pre- and post-intervention noise encodings.
Song et al. [2021c] shows that the discretization of SDE formulations of the diffusion model is
equivalent to discrete-time diffusion models. Therefore, for simplicity, we derive the ELBO for
discrete-time diffusion models. Following [Luo, 2022], for a discrete-time diffusion model where
t ∈ [1, T ], we have
" #
p(x, u|e)
Eq(I|x,x̃) Eq(e,ẽ|x,x̃,I) Eq(u|x) Eq(ũ|x̃) log
q(u|x)
"
= Eq(I|x,x̃) Eq(e,ẽ|x,x̃,I) Eq(u|x) Eq(ũ|x̃) Eq(u1 |x) [log p(x|u1 )] − DKL (q(uT |x)||p(uT ))
T
#
X
− Eq(ut |x) [DKL (q(ut−1 |ut , x, e)||p(ut−1 |ut , e)] (4)
t=2
15
The weight λ(t) of denoising matching terms is related to the diffusion coefficient of the forward
SDE. For a Variance Exploding SDE the weight is defined as λ(t) = 2σ 2 (t) log(σmax /σmin )
with σ(t) = σmin · (σmax /σmin )t .
Therefore, by combining (4) with (5), the ELBO becomes
log p(x, x̃) ≥ Ep(x,x̃) Eq(I|x,x̃) Eq(e,ẽ|x,x̃,I) Et∼U (0,1) Eq(ut |x) Eq(ũt |x̃)
"
log p(I) + log p(e) + log p(ẽ|e, I) − log q(I|x, x̃) − log q(e, ẽ|x, x̃, I)
#
h i
2 2
+λ(t) ||sθ (ut , e, t) − ∇ut log p(ut |x)||2 + ||sθ (ũt , ẽ, t) − ∇ũt log p(ũt |x̃)||2
For infinite-dimensional representations, we can derive the ELBO using a similar argument. In
this case, the formula for the ELBO is
log p(x, x̃) ≥ Ep(x,x̃) Eq(I|x,x̃) Et∼U (0,1) Eq(et ,ẽt |x,x̃,I) Eq(ut |x) Eq(ũt |x̃)
"
log p(I) + log p(et ) + log p(ẽt |et , I) − log q(I|x, x̃) − log q(et , ẽt |x, x̃, I)
#
+λ(t)||sθ (ut , et , t) − ∇ut log p(ut |x)||22 + λ(t)||sθ (ũt , ẽt , t) − ∇ũt log p(ũt |x̃)||22 ,
B Implementation Details
Training For the training, we follow the 4-phase training of Brehmer et al. [2022] but consider
only the first 3 phases. In summary, we consider the following steps:
(1) We begin by training the diffusion model and the encoding module together on data
pairs for 20 epochs. This can be interpreted as a warm-up on the diffusion model and
the encoding module to extract meaningful representations of data.
(2) We include all modules except for solution functions. We consider p(ẽi |e) to be a uniform
probability density. We do this phase for 50 epochs.
(3) We include solution functions and train the whole framework with the proposed loss and
do this for 50 epochs.
We find out that considering our data generation process, including the fourth training phase
of Brehmer et al. [2022] has no impact on the model’s performance. Consequently, we choose
to disregard it in our analysis. We use the loss in Eq. 3 as the objective function and consider
the coefficient of the regularization term Lentropy to be 1. Therefore, our overall loss function
is then given by L = Lmodel + Lentropy .
Architectures & Hyperparameters We train the model for 120 epochs and use the learning
rate of 3e-4 with a batch size of 64. β is initially set to 0 and increased to 1 during training.
The noise encoder is considered Gaussian, with mean and standard deviation parameterized
as an MLP with two hidden layers and 64 units each and ReLU activation functions. The
architecture of the score function of the diffusion model is based on NCSN++ architecture
16
[Song et al., 2021c] with the same set of hyperparameters. As the input x is 16-dimensional
and the score model follows a convolutional architecture, we reshape the input into a 4 × 4
format and then feed it into the diffusion model. Furthermore, In the forward SDE, σmin and
σmax are set to 0.01 and 50, respectively.
C Missing Plots
17