Sample, Estimate, Aggregate: A Recipe For Causal Discovery Foundation Models
Sample, Estimate, Aggregate: A Recipe For Causal Discovery Foundation Models
structure from data, promises to accelerate sci- super-exponential space of graphs by proposing and eval-
entific research, inform policy making, and more. uating changes to a working graph (Glymour et al., 2019).
However, the per-dataset nature of existing causal While these methods are quite fast on small graphs, the
discovery algorithms renders them slow, data hun- combinatorial space renders them intractable for exploring
gry, and brittle. Inspired by foundation models, larger structures, with hundreds or thousands of nodes. Fur-
we propose a causal discovery framework where thermore, their correctness is tied to hypothesis tests, whose
a deep learning model is pretrained to resolve pre- results can be erroneous for noisy datasets, especially as
dictions from classical discovery algorithms run graph sizes increase. More recently, a number of works
over smaller subsets of variables. This method have reframed the discrete graph search as a continuous op-
is enabled by the observations that the outputs timization over weighted adjacency matrices (Zheng et al.,
from classical algorithms are fast to compute for 2018; Brouillard et al., 2020). These continuous optimiza-
small problems, informative of (marginal) data tion algorithms often require that a generative model be fit
structure, and their structure outputs as objects to the full data distribution, a difficult task when the number
remain comparable across datasets. Our method of variables increases but the data remain sparse.
achieves state-of-the-art performance on synthetic In this work, we present S EA: Sample, Estimate, Aggre-
and realistic datasets, generalizes to data gener- gate, a blueprint for developing causal discovery foundation
ating mechanisms not seen during training, and models that enable fast inference on new datasets, perform
offers inference speeds that are orders of magni- well in low data regimes, and generalize to causal mech-
tude faster than existing models.1 anisms beyond those seen during training. Our approach
is motivated by two observations. First, while classical
causal discovery algorithms scale poorly, their bottleneck
1. Introduction lies in the exponential search space, rather than individual
A fundamental aspect of scientific research is to discover independence tests. Second, in many cases, statistics like
and validate causal hypotheses involving variables of in- global correlation or inverse covariance are indeed strong
terest. Given observations of these variables, the goal of indicators for a graph’s overall connectivity. Therefore, we
causal discovery algorithms is to extract such hypotheses in propose to leverage 1) the estimates of classical algorithms
the form of directed graphs, in which edges denote causal over small subgraphs, and 2) global graph-level statistics, as
relationships (Spirtes et al., 2001). In much of basic science, inputs to a deep learning model, pretrained to resolve these
however, classical statistics remain the de facto basis of data statistical descriptors into causal graphs.
analysis (Replogle et al., 2022). Key barriers to widespread Theoretically, we prove that given only marginal estimates
adoption of causal discovery algorithms include their high over subgraphs, it is possible to recover causally sound
data requirements and their computational intractability on global graphs; and that our proposed model has the capac-
larger problems. ity to recapitulate such reasoning. Empirically, we imple-
Current causal discovery algorithms follow two primary ment instantiations of S EA using an axial-attention based
model (Ho et al., 2020) which takes as input, inverse co-
1
CSAIL, Massachusetts Institute of Technology, Cambridge, variance and estimates from the classical FCI or GIES algo-
MA 2 Accenture, Mountain View, CA. Correspondence to:
Menghua Wu <[email protected]>.
rithms (Spirtes et al., 1995; Hauser & Bühlmann, 2012). We
conduct thorough comparison to three classical baselines
Preprint. Under review. and five deep learning approaches. S EA attains the state-of-
1
Our code and data are publicly available here: https:// the-art results on synthetic and real-world causal discovery
github.com/rmwu/sea
1
Sample, estimate, aggregate: A recipe for causal discovery foundation models
tasks, while providing 10-1000x faster inference. While Causal discovery is also used in the context of pairwise
these experimental results reflect specific algorithms and ar- relationships (Zhang & Hyvarinen, 2012; Monti et al., 2019),
chitectures, the overall framework accepts any combination as well as for non-stationary (Huang et al., 2020) and time-
of sampling heuristics, classical causal discovery algorithms, series (Löwe et al., 2022) data. However, this work focuses
and statistical features. To summarize, our contributions are on causal discovery for stationary graphs, so we will discuss
as follows. the two main approaches to causal discovery in this setting.
1. To the best of our knowledge, we are the first to propose Discrete optimization methods make atomic changes to a
a method for building fast, robust, and generalizable proposal graph until a stopping criterion is met. Constraint-
foundation models for causal discovery. based algorithms identify edges based on conditional inde-
2. We show that both our overall framework and specific pendence tests, and their correctness is inseparable from
architecture have the capacity to reproduce causally the empirical results of those tests (Glymour et al., 2019),
sound graphs from their inputs. whose statistical power depends directly on dataset size.
3. We attain state-of-the-art results on synthetic and real- These include the FCI and PC algorithms in the observa-
istic settings, and we provide extensive experimental tional case (Spirtes et al., 1995), and the JCI algorithm in
analysis of our model. the interventional case (Mooij et al., 2020).
Score-based methods also make iterative modifications to a
2. Background and related work working graph, but their goal is to maximize a continuous
score over the discrete space of all valid graphs, with the true
2.1. Causal graphical models
graph at the optimum. Due to the intractable search space,
A causal graphical model is a directed, acyclic graph G = these methods often make decisions based on greedy heuris-
(V, E), where each node i ∈ V corresponds to a random tics. Classic examples include GES (Chickering, 2002),
variable Xi ∈ X and each edge (i, j) ∈ E represents a GIES (Hauser & Bühlmann, 2012), CAM (Bühlmann et al.,
causal relationship from Xi → Xj . 2014), and LiNGAM (Shimizu et al., 2006).
We assume that the data distribution PX is Markov to G, Continuous optimization approaches recast the combina-
torial space of graphs into a continuous space of weighted
∀i ∈ V, Xi ⊥
⊥ V \ (Xδi ∪ Xπi ) | Xπj (1) adjacency matrices. Many of these works train a genera-
tive model to learn the empirical data distribution, which is
where δi is the set of descendants of node i in G, and πi is parameterized through the adjacency matrix (Zheng et al.,
the set of parents of i. In addition, we assume that PX is 2018; Lachapelle et al., 2020; Brouillard et al., 2020). Oth-
minimal and faithful – that is, PX does not include any in- ers focus on properties related to the empirical data distri-
dependence relationships beyond those implied by applying bution, such as a relationship between the underlying graph
the Markov condition to G (Spirtes et al., 2001). and the Jacobian of the learned model (Reizinger et al.,
Causal graphical models allow us to perform interventions 2023), or between the Hessian of the data log-likelihood and
on nodes i by setting conditional P (Xi | Xπi ) to a different the topological ordering of the nodes (Sanchez et al., 2023).
distribution P̃ (Xi | Xπi ). While these methods bypass the combinatorial search over
discrete graphs, they still require copious data and time to
2.2. Causal discovery train accurate generative models of the data distributions.
2
Sample, estimate, aggregate: A recipe for causal discovery foundation models
Figure 1. An overview of the S EA inference procedure. Given a new dataset, we 1) sample batches and subsets of nodes, 2) estimate
marginal graphs and global statistics over these batches, and 3) aggregate these features using a pretrained model to obtain the underlying
causal graph. Raw data are depicted in green, and graph-related features are depicted in blue.
3.1. Causal discovery framework The aggregator’s architecture is agnostic to the assump-
tions of the underlying causal discovery algorithm f (e.g.
S EA is a causal discovery framework that learns to resolve (non)linearity, observational vs. interventional). Therefore,
statistical features and estimates of marginal graphs into a we can freely swap out f and the training datasets to train
global causal graph. The inference procedure is depicted aggregators with different implicit assumptions, without
in Figure 1. Specifically, given a new dataset D ∈ RM ×N modifying the architecture itself. Here, our aggregator is
faithful to graph G = (V, E), we apply the following stages. pretrained over a diverse set of graph sizes, structures, and
Sample: takes as input dataset D; and outputs data batches data generating mechanisms (details in Section 4.1).
{D0 , D1 , . . . , DT } and node subsets {S1 , . . . , ST }.
3.2. Model architecture
1. Sample T + 1 batches of b ≪ M observations uni-
formly at random from D. The core module of S EA is the aggregator (Figure 2), im-
2. Compute selection scores α ∈ (0, 1)N ×N over D0 (e.g. plemented as a sequence of axial attention blocks. The
inverse covariance). aggregator takes as input global statistics ρ ∈ RN ×N ,
3
Sample, estimate, aggregate: A recipe for causal discovery foundation models
Figure 2. Aggregator architecture. Marginal graph estimates and global statistics are first embedded into the model dimension, and 1D
positional embeddings are added along both rows and columns. These combined embeddings pass through a series of axial attention
blocks, which attend to the graph estimates and the global features. The final layer global features pass through a feedforward network to
predict the causal graph.
′
marginal estimates E1...T ∈ E T ×k×k , and node subsets (a)symmetries of their inputs, pos-ebd(ρi,j ) is symmetric,
T ×k ′
S1...T ∈ [N ] , where E is the set of output edge types while pos-ebd(Et,e align
) considers the node ordering. In sum-
for the causal discovery algorithm f . mary, the inputs to our axial attention blocks are
We project global statistics into the model dimension via hρi,j = (Wρ ρ)i,j + pos-ebd(ρi,j ) (3)
a learned linear projection matrix Wρ : R → Rd , and we ′
align ′
align
embed edge types via a learned embedding ebdE : E → Rd . hE
t,e = ebdE (Et,e ) + pos-ebd(Et,e ) (4)
To collect estimates of the same edge over all subsets, we for i, j ∈ [N ]2 , t ∈ [T ], e ∈ [K].
′
′
align entries of E1...T into ETalign ∈ E T ×K
( Axial attention An axial attention block contains two
′
′
align Et,i,j if i ∈ St , j ∈ St axial attention layers (marginal estimates, global features)
Et,e=(i,j) = (2)
0 otherwise and a feed-forward network (Figure 2, right).
where t indexes into the subsets, e indexes into the set of Given a 2D input, an axial attention layer attends first along
unique edges, and K is the number of unique edges. the rows, then along the columns. For example, given a
matrix of size (R,C,d), one pass of the axial attention
We add learned 1D positional embeddings along both di- layer is equivalent to running standard self-attention along
mensions of each input, C with batch size R, followed by the reverse. For marginal
pos-ebd(ρi,j ) = ebdnode (i′ ) + ebdnode (j ′ ) estimates, R is the number of subsets T , and C is the number
′
of unique edges K. For global features, R and C are both
align
pos-ebd(Et,e ) = ebdtime (t) the total number of vertices N .
+ FFN([ebdnode (i′ ), ebdnode (j ′ )]) Following (Rao et al., 2021), each self-attention mechanism
′ ′ is preceded by layer normalization and followed by dropout,
where i , j index into a random permutation on V for in-
with residual connections to the input,
variance to node permutation and graph size.2 Due to the
2 x = x + Dropout(Attn(LayerNorm(x))). (5)
The sampling of St already provides invariance to node order.
However, the mapping i′ = σ(V )i allows us to avoid updating
positional embeddings of lower order positions more than higher We pass messages between the marginal and global layers
order ones, due to the mixing of graph sizes during training. to propagate information. Let ϕE,ℓ be marginal layer ℓ,
4
Sample, estimate, aggregate: A recipe for causal discovery foundation models
let ϕρ,ℓ be global layer ℓ, and let h·,ℓ denote the hidden 3.4. Theoretical analyses
representations out of layer ℓ.
Our theoretical contributions span two aspects.
The marginal to global message mE→ρ ∈ RN ×N ×d con-
1. We formalize the notion of marginal estimates and
tains representations of each edge averaged over subsets,
( P prove that given sufficient marginal estimates, it is pos-
E→ρ,ℓ
1
hE,ℓ if ∃St , i, j ∈ St sible to recover a pattern faithful to the global causal
mi,j = Te t t,e=(i,j) (6) graph (Theorem 3.1). We provide bounds on the num-
ϵ otherwise.
ber of marginal estimates required and motivate global
where Te is the number of St containing e, and missing statistics as an efficient means to reduce this bound
entries are padded to learned constant ϵ. The global to (Propositions 3.2, 3.3).
marginal message mρ→E ∈ RK×d is simply the hidden 2. We show that our proposed axial attention has the ca-
representation itself, pacity to recapitulate the reasoning required for this
mρ→E,ℓ ρ,ℓ
t,e=(i,j) = hi,j . (7) task. In particular, we show that a stack of 3 axial atten-
tion blocks can recover the skeleton and v-structures
We incorporate these messages as follows. in O(N ) width (Theorem 3.4).
hE,ℓ = ϕE,ℓ (hE,ℓ−1 ) (marginal feature) (8)
We state our key results below, and direct the reader towards
hρ,ℓ−1 ← W ℓ hρ,ℓ−1 , mE→ρ,ℓ
(marginal to global) (9) Appendix A for all proofs. Our proofs assume only that the
hρ,ℓ = ϕρ,ℓ (hρ,ℓ−1 ) (global feature) (10) edge estimates are correct. We do not impose any functional
E,ℓ E,ℓ ρ→E,ℓ
requirements, and any appropriate independence test may
h ←h +m (global to marginal) (11) be used. We discuss robustness and stability in A.4.
W ℓ ∈ R2d×d is a learned linear projection, and [·] denotes Theorem 3.1 (Marginal estimate resolution). Let G =
concatenation. (V, E) be a directed acyclic graph with maximum degree
d. For S ⊆ V , let ES′ denote the marginal estimate over S.
Graph prediction For each pair of vertices i ̸= j ∈ V , Let Sd denote the superset that contains all subsets S ⊆ V
we predict e = 0, 1, or 2 for no edge, i → j, and j → i of size at most d. There exists a mapping from {ES′ }S∈Sd+2
respectively. This constraint may be omitted for cyclic to pattern E ∗ , faithful to G.
graphs. We do not additionally enforce that our predicted
graphs are acyclic, similar in spirit to (Lippe et al., 2022). Theorem 3.1 formalizes the intuition that it requires at most
d independence tests to check whether two nodes are inde-
Given the output of the final axial attention block, hρ , we pendent, and to estimate the full graph structure, it suffices
compute logits to run a causal discovery algorithm on subsets of d + 2.
z{i,j} = FFN hρi,j , hρj,i ∈ R3
(12) Proposition 3.2 (Skeleton bounds). Let G = (V, E) be a
which correspond to probabilities after softmax normaliza- directed acyclic graph with maximum degree d. It is possible
tion. The overall output Ê ∈ {0, 1}N ×N is supervised by to recover the undirected skeleton C = {{i, j} : (i, j) ∈ E}
the ground truth E, and our model is trained with cross in O(N 2 ) estimates over subsets of size d + 2.
entropy loss and L2 weight regularization.
If we only
leverage marginal estimates, we must check at
least N2 subsets to cover each edge at least once. However,
3.3. Implementation details we can often approximate the skeleton via global statistics
We computed inverse covariance as the global statistic and such as correlation, allowing us to use marginal estimates
selection score, due to its relationship to partial correla- more efficiently, towards answering orientation questions.
tion and ease of computation. For comparison to additional Proposition 3.3 (V-structures bounds). Let G = (V, E) be
statistics, see C.2. We chose the constraint-based FCI algo- a directed acyclic graph with maximum degree d and ν v-
rithm in the observational setting (Spirtes et al., 2001), and structures. It is possible to identify all v-structures in O(ν)
the score-based GIES algorithm in the interventional set- estimates over subsets of at most size d + 2.
ting (Hauser & Bühlmann, 2012). For discussion regarding
alternative algorithms, see B.5. We sample batches of size Theorem 3.4 (Model capacity). Given a graph G with N
b = 500 over k = 5 nodes each (analysis in C.1). nodes, a stack of L axial attention blocks has the capacity
to recover its skeleton and v-structures in O(N ) width, and
Our model was implemented with 4 layers with 8 attention propagate orientations on paths of O(L) length.
heads and hidden dimension 512. Our model was trained
using the AdamW optimizer with a learning rate of 1e- Existing literature on the universality and computational
4 (Loshchilov & Hutter, 2019). Please see B.3 for additional power of vanilla Transformers (Yun et al., 2019; Pérez et al.,
details regarding hyperparameters. 2019) rely on generous assumptions regarding depth or
5
Sample, estimate, aggregate: A recipe for causal discovery foundation models
precision. Here, we show that our axial attention-based Table 1. Synthetic data generation settings. The symbol ∗ denotes
model can implement the specific reasoning required to training only, and † denotes testing only. We take the Cartesian
resolve marginal estimates under realistic conditions. product of all parameters for our settings. (Non)-additive refers
to (non)-additive Gaussian noise. Details regarding the causal
mechanisms can be found in B.1
4. Experimental setup
Parameter Values
We pretrained S EA models on synthetic training datasets Nodes (N ) 10, 20, 100
only and ran inference on held-out testing datasets, which Edges N, 2N ∗ , 3N ∗ , 4N
include both seen and unseen causal mechanisms. All base- Points 1000N
lines were trained and/or run from scratch on each testing Interventions N
Topology Erdős-Rényi, Scale Free
dataset using their published code and hyperparameters. Mechanism Linear, NN additive, NN non-additive,
Sigmoid additive† , Polynomial additive†
4.1. Data settings
We evaluate our model across diverse synthetic datasets, Table 2. Motivating continuous metrics in causal discovery. SHD
alone is a poor measure of prediction quality, as it fails to dis-
simulated mRNA datasets (Dibaeinia & Sinha, 2020), and a
tinguish between the graph of all 0s and the undirected skeleton.
real protein expression dataset (Sachs et al., 2005). Metrics reported on Sachs.
The synthetic datasets were constructed based on Table 1.
SHD ↓ mAP ↑ AUC ↑ EdgeAcc ↑
For each synthetic dataset, we first sample a graph based on
the desired topology and connectivity. Then we topologi- Ground truth 0 1 1 1
Undirected 17 0.5 0.92 0
cally sort the graph and sample observations starting from All 0s 17 0.14 0.5 0
the root nodes, using random instantiations of the designated
causal mechanism (details in B.1). We generated 90 training,
5 validation, and 5 testing datasets for each combination Edge orientation accuracy: We compute the accuracy of
(8160 total). To evaluate our model’s capacity to general- edge orientations as
ize to new functional classes, we reserve the Sigmoid and
(i,j)∈E 1{P (i, j) > P (j, i)}
P
Polynomial causal mechanisms for testing only. We include EdgeAcc = . (13)
details on our synthetic mRNA datasets in Appendix C.3. ∥E∥
6
Sample, estimate, aggregate: A recipe for causal discovery foundation models
Table 3. Causal discovery results on synthetic datasets. Each setting encompasses 5 distinct Erdős-Rényi graphs. The symbol † indicates
that S EA was not pretrained on this setting. Runtimes are plotted in Figure 3. Details regarding baselines can be found in B.2.
100 400 I NV C OV 0.25±0.0 557.0±12 0.09±0.0 667.8±15 0.14±0.0 639.0±10 0.27±0.0 514.7±23 0.20±0.0 539.4±18
S EA (F CI ) 0.90±0.0 122.0±9 0.28±0.1 361.2±36 0.60±0.0 273.2±14 0.69±0.0 226.9±20 0.38±0.0 327.0±20
S EA (G IES ) 0.91±0.0 116.6±8 0.27±0.1 364.4±35 0.61±0.0 266.8±15 0.69±0.0 218.3±21 0.38±0.0 328.0±22
5. Results
S EA excels across synthetic and realistic causal discovery Figure 3. Average wall time required to run each model on a single
tasks with fast runtimes. We also show that S EA performs dataset. The y-axis is plotted in log scale.
well, even in low-data regimes. Additional results and ab-
lations can be found in Appendix C, including classical
algorithm hyperparameters, global statistics, and scaling to trained aggregator consistently improves upon the perfor-
graphs much larger than those in the training set. mance of its individual inputs (I NV C OV, F CI -AVG, G IES -
AVG), demonstrating the value in learning such a model.
5.1. Synthetic experiments In terms of edge orientation accuracy, S EA outperforms the
S EA significantly outperforms the baselines across a vari- baselines in all settings (Table 4). We have omitted I NV C OV
ety of graph sizes and causal mechanisms in Table 3, and from this comparison since it does not orient edges.
we maintain high performance even for causal mechanisms We benchmarked the runtimes of each algorithm over all test
beyond the training set (Sigmoid, Polynomial). Our pre- datasets of N = 10, 20, 100, when possible (Figure 3). Our
7
Sample, estimate, aggregate: A recipe for causal discovery foundation models
Table 4. Synthetic experiments, edge direction accuracy (higher is Table 5. Causal discovery on (real) Sachs dataset. While I NV C OV
better). All standard deviations were within 0.2. The symbol † is a strong baseline for connectivity, it does not predict orientation.
indicates that S EA was not pretrained on this setting. The relative performance of the methods depends on metric.
N E Model Linear NN add NN Sig.† Poly.† Model mAP ↑ AUC ↑ SHD ↓ EdgeAcc ↑ Time (s) ↓
D CDI -G 0.74 0.80 0.85 0.41 0.44 D CDI -G 0.17 0.55 21.0 0.20 2436.5
D CDI -D SF 0.79 0.62 0.68 0.38 0.39 D CDI -D SF 0.20 0.59 20.0 0.20 1979.6
D CD -F G 0.50 0.47 0.70 0.43 0.54 D CD -F G 0.32 0.59 27.0 0.35 951.4
D IFF A N 0.61 0.55 0.26 0.53 0.47 D IFF A N 0.14 0.45 37.0 0.41 293.7
10 10 D ECI 0.50 0.43 0.62 0.63 0.75 D ECI 0.21 0.62 28.0 0.53 609.7
F CI -AVG 0.52 0.43 0.41 0.55 0.40 I NV C OV 0.31 0.61 20.0 — 0.002
G IES -AVG 0.76 0.49 0.69 0.67 0.63 F CI -AVG 0.27 0.59 18.0 0.24 41.9
G IES -AVG 0.21 0.59 17.0 0.24 77.9
S EA (F CI ) 0.92 0.92 0.94 0.76 0.71
S EA (G IES ) 0.94 0.88 0.93 0.84 0.79 S EA (F CI ) 0.23 0.54 24.0 0.47 3.2
S EA (G IES ) 0.23 0.60 14.0 0.41 2.9
D CDI -G 0.47 0.43 0.82 0.40 0.24
D CDI -D SF 0.50 0.49 0.78 0.41 0.28
D CD -F G 0.58 0.65 0.75 0.62 0.48
D IFF A N 0.46 0.28 0.36 0.45 0.21
20 80 D ECI 0.30 0.47 0.35 0.48 0.57
F CI -AVG 0.19 0.19 0.22 0.33 0.23
G IES -AVG 0.56 0.73 0.59 0.62 0.61
S EA (F CI ) 0.93 0.90 0.93 0.85 0.89
S EA (G IES ) 0.92 0.88 0.92 0.84 0.89
D CD -F G 0.46 0.60 0.70 0.67 0.53
100 400 Figure 4. Performance of our model (FCI) as a function of total
S EA (F CI ) 0.93 0.90 0.91 0.87 0.82
S EA (G IES ) 0.94 0.91 0.92 0.87 0.84 dataset size. Error bars indicate 95% confidence interval across the
5 datasets of each setting. Our model only requires approximately
500 samples for an acceptable level of performance on N = 100
model is orders of magnitude faster than other continuous graphs. Dashed lines indicate the I NV C OV estimate on 500 points.
optimization methods. All deep learning models were run
on a single V100-PCIE-32GB GPU, except for D IFF A N, narios. Figure 4 shows that S EA (F CI ) only requires around
since we were unable to achieve consistent GPU and R M = 500 data samples for decent performance on graphs
support within a Docker environment using their codebase. with N = 100 nodes. For these experiments, we set batch
For all models, we recorded only computation time (CPU size b = min(500, M ) and node subset size k = 5. This
and GPU) and omitted any file system-related time. is in contrast to existing continuous optimization methods,
which require thousands to tens of thousands of samples to
5.2. Realistic experiments fit completely. Compared to the I NV C OV baseline estimated
over the same number of points b = M = 500 (dotted
We report results on the real Sachs protein dataset (Sachs lines), S EA is able to perform significantly better.
et al., 2005) in Table 5. The relative performance of each
model differs based on metric. S EA performs comparably,
while maintaining fast inference speeds. However, despite
6. Conclusion
the popularity of this dataset in causal discovery literature In this work, we introduced S EA, a framework for designing
(due to lack of better alternatives), biological networks are causal discovery foundation models. S EA is motivated by
known to be time-resolved and cyclic, so the validity of the idea that classical discovery algorithms provide powerful
the ground truth “consensus” graph has been questioned by descriptors of data that are fast to compute and robust across
experts (Mooij et al., 2020). datasets. Given these statistics, we train a deep learning
We also trained a version of S EA (F CI ) on 7200 synthetic model to reproduce faithful causal graphs. Theoretically,
mRNA datasets (Dibaeinia & Sinha, 2020), exceeding base- we demonstrated that it is possible to produce sound causal
lines across held-out test datasets (Appendix C.3). graphs from marginal estimates, and that our model has the
capacity to do so. Empirically, we implemented two proofs
of concept of S EA that perform well across a variety of
5.3. Performance in low-data regimes
causal discovery tasks. We hope that this work will inspire a
One of the main advantages of foundation models is that new avenue of research in generalizable and scalable causal
they enable high levels of performance in low resource sce- discovery algorithms.
8
Sample, estimate, aggregate: A recipe for causal discovery foundation models
9
Sample, estimate, aggregate: A recipe for causal discovery foundation models
tems, 11(3):252–271.e11, 2020. ISSN 2405-4712. doi: Lippe, P., Cohen, T., and Gavves, E. Efficient neural causal
https://doi.org/10.1016/j.cels.2020.08.003. discovery without acyclicity constraints. In International
Conference on Learning Representations, 2022.
Geffner, T., Antoran, J., Foster, A., Gong, W., Ma, C., Kici-
man, E., Sharma, A., Lamb, A., Kukla, M., Pawlowski, Lopez, R., Hütter, J.-C., Pritchard, J. K., and Regev,
N., Allamanis, M., and Zhang, C. Deep end-to-end causal A. Large-scale differentiable causal discovery of factor
inference, 2022. graphs. In Advances in Neural Information Processing
Systems, 2022.
Glymour, C., Zhang, K., and Spirtes, P. Review of causal
discovery methods based on graphical models. Frontiers Loshchilov, I. and Hutter, F. Decoupled weight decay regu-
in Genetics, 10, 2019. ISSN 1664-8021. doi: 10.3389/ larization, 2019.
fgene.2019.00524.
Löwe, S., Madras, D., Zemel, R., and Welling, M. Amor-
Guo, C., Pleiss, G., Sun, Y., and Weinberger, K. Q. On tized causal discovery: Learning to infer causal graphs
calibration of modern neural networks. In International from time-series data, 2022.
Conference on Machine Learning, 2017.
Monti, R. P., Zhang, K., and Hyvarinen, A. Causal discovery
Hauser, A. and Bühlmann, P. Characterization and greedy with general non-linear relationships using non-linear ica,
learning of interventional markov equivalence classes of 2019.
directed acyclic graphs. 2012. Mooij, J. M., Magliacane, S., and Claassen, T. Joint causal
inference from multiple contexts. 2020.
Ho, J., Kalchbrenner, N., Weissenborn, D., and Salimans, T.
Axial attention in multidimensional transformers, 2020. Pérez, J., Marinković, J., and Barceló, P. On the turing
completeness of modern neural network architectures. In
Hornik, K., Stinchcombe, M., and White, H. Multilayer International Conference on Learning Representations,
feedforward networks are universal approximators. Neu- 2019.
ral Networks, 2(5):359–366, 1989. ISSN 0893-6080. doi:
https://doi.org/10.1016/0893-6080(89)90020-8. Radford, A., Kim, J. W., Hallacy, C., Ramesh, A., Goh, G.,
Agarwal, S., Sastry, G., Askell, A., Mishkin, P., Clark,
Huang, B., Zhang, K., Zhang, J., Ramsey, J., Sanchez- J., Krueger, G., and Sutskever, I. Learning transferable
Romero, R., Glymour, C., and Schölkopf, B. Causal visual models from natural language supervision, 2021.
discovery from heterogeneous/nonstationary data with
independent changes, 2020. Rao, R. M., Liu, J., Verkuil, R., Meier, J., Canny, J., Abbeel,
P., Sercu, T., and Rives, A. Msa transformer. In Meila,
Hägele, A., Rothfuss, J., Lorch, L., Somnath, V. R., M. and Zhang, T. (eds.), Proceedings of the 38th Inter-
Schölkopf, B., and Krause, A. Bacadi: Bayesian causal national Conference on Machine Learning, volume 139
discovery with unknown interventions, 2023. of Proceedings of Machine Learning Research, pp. 8844–
8856. PMLR, 18–24 Jul 2021.
Kalainathan, D., Goudet, O., and Dutta, R. Causal discovery
toolbox: Uncovering causal relationships in python. Jour- Reizinger, P., Sharma, Y., Bethge, M., Schölkopf, B.,
nal of Machine Learning Research, 21(37):1–5, 2020. Huszár, F., and Brendel, W. Jacobian-based causal dis-
covery with nonlinear ICA. Transactions on Machine
Lachapelle, S., Brouillard, P., Deleu, T., and Lacoste-Julien, Learning Research, 2023. ISSN 2835-8856.
S. Gradient-based neural dag learning, 2020.
Replogle, J. M., Saunders, R. A., Pogson, A. N., Hussmann,
Lam, W.-Y., Andrews, B., and Ramsey, J. Greedy relax- J. A., Lenail, A., Guna, A., Mascibroda, L., Wagner,
ations of the sparsest permutation algorithm. In Cussens, E. J., Adelman, K., Lithwick-Yanai, G., Iremadze, N.,
J. and Zhang, K. (eds.), Proceedings of the Thirty-Eighth Oberstrass, F., Lipson, D., Bonnar, J. L., Jost, M., Nor-
Conference on Uncertainty in Artificial Intelligence, vol- man, T. M., and Weissman, J. S. Mapping information-
ume 180 of Proceedings of Machine Learning Research, rich genotype-phenotype landscapes with genome-scale
pp. 1052–1062. PMLR, 01–05 Aug 2022. Perturb-seq. Cell, 185(14):2559–2575, Jul 2022.
Ledoit, O. and Wolf, M. A well-conditioned estimator for Sachs, K., Perez, O., Pe’er, D., Lauffenburger, D. A., and
large-dimensional covariance matrices. Journal of Multi- Nolan, G. P. Causal protein-signaling networks derived
variate Analysis, 88(2):365–411, 2004. ISSN 0047-259X. from multiparameter single-cell data. Science, 308(5721):
doi: https://doi.org/10.1016/S0047-259X(03)00096-4. 523–529, 2005. doi: 10.1126/science.1105809.
10
Sample, estimate, aggregate: A recipe for causal discovery foundation models
Sanchez, P., Liu, X., O’Neil, A. Q., and Tsaftaris, S. A. Zheng, X., Aragam, B., Ravikumar, P., and Xing, E. P.
Diffusion models for causal discovery via topological Dags with no tears: Continuous optimization for structure
ordering. In The Eleventh International Conference on learning, 2018.
Learning Representations, ICLR 2023, Kigali, Rwanda,
May 1-5, 2023. OpenReview.net, 2023. Zheng, Y., Huang, B., Chen, W., Ramsey, J., Gong, M.,
Cai, R., Shimizu, S., Spirtes, P., and Zhang, K. Causal-
Schaeffer, R., Miranda, B., and Koyejo, O. Are emergent learn: Causal discovery in python. arXiv preprint
abilities of large language models a mirage? Advances in arXiv:2307.16405, 2023.
Neural Information Processing Systems, abs/2304.15004,
2023.
11
Sample, estimate, aggregate: A recipe for causal discovery foundation models
Figure 5. Resolving marginal graphs. Subsets of nodes revealed to the PC algorithm (circled in row 1) and its outputs (row 2).
In each of the four cases, the PC algorithm estimates the respective graphs as follows.
(A) We remove edge (X, Y ) via (2) and orient the v-structure.
(B) We remove edge (X, Y ) via (2) and orient the v-structure.
(C) We remove edge (X, W ) via (3) by conditioning on Z. There are no v-structures, so the edges remain undirected.
(D) We remove edge (Y, W ) via (3) by conditioning on Z. There are no v-structures, so the edges remain undirected.
The outputs (A-D) admit the full PC algorithm output as the only consistent graph on four nodes.
• X and Y are unconditionally independent, so no subset will reveal an edge between (X, Y ).
• There are no edges between (X, W ) and (Y, W ). Otherwise, (C) and (D) would yield the undirected triangle.
• X, Y, Z must be oriented as X → Z ← Y . Paths X → Z → Y and X ← Z ← Y would induce an (X, Y ) edge in
(B). Reversing orientations X ← Z → Y would contradict (A).
• (Y, Z) must be oriented as Y → Z. Otherwise, (A) would remain unoriented.
12
Sample, estimate, aggregate: A recipe for causal discovery foundation models
Given data faithful to G, a number of classical constraint-based algorithms produce patterns that are faithful to G. We
denote this set of algorithms as F.
Theorem A.3 (Theorem 5.1 from (Spirtes et al., 2001)). If the input to the PC, SGS, PC-1, PC-2, PC∗ , or IG algorithms
faithful to directed acyclic graph G, the output is a pattern that represents the faithful indistinguishability class of G.
The algorithms in F are sound and complete if there are no unobserved confounders.
If we apply any f ∈ F to D[S], the results are not necessarily faithful to G[S], as now there may be latent confounders in
V \ S (by construction). We introduce the term marginal estimate to denote the resultant pattern that, while not faithful to
G[S], is still informative.
Definition A.4 (Marginal estimate). A pattern E ′ is a marginal estimate of G[S] if and only if
1. for all vertices X, Y of S, X and Y are adjacent if and only if X and Y are dependent conditional on every set of
vertices of S that does not include X or Y ; and
2. for all vertices X, Y, Z, such that X is adjacent to Y and Y is adjacent to Z and X and Z are not adjacent, X → Y ← Z
is a subgraph of S if and only if X, Z are dependent conditional on every set containing Y but not X or Z.
We will show that the following algorithm produces the desired answer. On a high level, lines 3-8 recover the undirected
“skeleton” graph of E ∗ , lines 9-15 recover the v-structures, and line 16 references step 5 in Section A.1.
13
Sample, estimate, aggregate: A recipe for causal discovery foundation models
Remark A.5. In the PC algorithm ((Spirtes et al., 2001), A.1), its derivatives, and Algorithm 1, there is no need to consider
separating sets with cardinality greater than maximum degree d, since the maximum number of independence tests required
to separate any node from the rest of the graph is equal to number of its parents plus its children (due to the Markov
assumption).
Lemma A.6. The undirected skeleton of E ∗ is equivalent to the undirected skeleton of E ′
Lemma A.7. A v-structure i → j ← k exists in E ∗ if and only if there exists the same v-structure in E ′ .
Proof. The PCI algorithm orients v-structures i → j ← k in E ∗ if there is an edge between {i, j} and {j, k} but not {i, k};
and if j was not in the conditioning set that removed {i, k}. Algorithm 1 orients v-structures i → j ← k in E ′ if they are
oriented as such in any ES′ ; and if {i, j}, {j, k} ∈ E ′ , {i, k} ̸∈ E ′
⇒ Suppose for contradiction that i → j ← k is oriented as a v-structure in E ∗ , but not in E ′ . There are two cases.
1. No ES′ contains the undirected path i − j − k. If either i − j or j − k are missing from any ES′ , then E ∗ would not
contain (i, j) or (k, j). Otherwise, if all S contain {i, k}, then E ∗ would not be missing {i, k} (Lemma A.6).
2. In every ES′ that contains i − j − k, j is in the conditioning set that removed {i, k}, i.e. i ⊥⊥ k | S, S ∋ j. This would
violate the faithfulness property, as j is neither a parent of i or k in E ∗ , and the outputs of the PC algorithm are faithful
to the equivalence class of G (Theorem 5.1 (Spirtes et al., 2001)).
⇐ Suppose for contradiction that i → j ← k is oriented as a v-structure in E ′ , but not in E ∗ . By Lemma A.6, the path
i − j − k must exist in E ∗ . There are two cases.
1. If i → j → k or i ← j ← k, then j must be in the conditioning set that removes {i, k}, so no ES′ containing {i, j, k}
would orient them as v-structures.
14
Sample, estimate, aggregate: A recipe for causal discovery foundation models
2. If j is the root of a fork i ← j → k, then as the parent of both i and k, j must be in the conditioning set that removes
{i, k}, so no ES′ containing {i, j, k} would orient them as v-structures.
Therefore, all v-structures in E ′ are also v-structures in E ∗ .
Proof of Theorem 3.1. Given data that is faithful to G, Algorithm 1 produces a pattern E ′ with the same connectivity and
v-structures as E ∗ . Any additional orientations in both patterns are propagated using identical, deterministic procedures, so
E′ = E∗.
This proof presents a deterministic but inefficient algorithm for resolving marginal subgraph estimates. In reality, it is
possible to recover the undirected skeleton and the v-structures of G without checking all subsets S ∈ Sd+2 .
Proposition 3.2 (Skeleton bounds). Let G = (V, E) be a directed acyclic graph with maximum degree d. It is possible to
recover the undirected skeleton C = {{i, j} : (i, j) ∈ E} in O(N 2 ) estimates over subsets of size d + 2.
Proof. Following Lemma A.6, an edge (i, j) is not present in C if it is not present in any of the size d + 2 estimates.
Therefore, every pair of nodes {i, j} requires only a single estimate of size d + 2, so it is possible to recover C in N2
estimates.
Proposition 3.3 (V-structures bounds). Let G = (V, E) be a directed acyclic graph with maximum degree d and ν
v-structures. It is possible to identify all v-structures in O(ν) estimates over subsets of at most size d + 2.
Definition A.9. Let Q, F, QΦ , FΦ be finite sets. Let f be a map from Q to F , and let Φ be a finite set of maps {ϕ : QΦ →
FΦ }. We say Φ has the capacity to implement f if and only if there exists at least one element ϕ ∈ Φ that implements f .
Theorem 3.4 (Model capacity). Given a graph G with N nodes, a stack of L axial attention blocks has the capacity to
recover its skeleton and v-structures in O(N ) width, and propagate orientations on paths of O(L) length.
15
Sample, estimate, aggregate: A recipe for causal discovery foundation models
Proof. We consider axial attention blocks with dot-product attention and omit layer normalization from our analysis, as
is common in the Transformer universality literature (Yun et al., 2019). Our inputs X ∈ Rd×R×C consist of d-dimension
embeddings over R rows and C columns. Since our axial attention only operates over one dimension at a time, we use X·,c
to denote a 1D sequence of length R, given a fixed column c, and Xr,· to denote a 1D sequence of length C, given a fixed
row r. A single axial attention layer (with one head) consists of two attention layers and a feedforward network,
where hℓ denote the hidden representations of E and ρ at layer ℓ, and the outputs of the axial attention block are hρ,ℓ , hE,ℓ .
We construct a stack of L ≥ 3 axial attention blocks that implement Algorithm 1.
′
Model inputs Consider edge estimate Ei,j ∈ E in a graph of size N . Let ei , ej denote the endpoints of (i, j). Outputs of
the PC algorithm can be expressed by three endpoints: {∅, •, ▶}. A directed edge from i → j has endpoints (•, ▶), the
reversed edge i ← j has endpoints (▶, •), an undirected edge has endpoints (•, •), and the lack of any edge between i, j
has endpoints (∅, ∅).
Let one-hotN (i) denote the N -dimensional one-hot column vector where element i is 1. We define the embedding of (i, j)
as a d = 2N + 6 dimensional vector,
one-hot3 (ei )
one-hot3 (ej )
gin (Et,(i,j) ) = hE,0
(i,j) =
one-hotN (i) .
(21)
one-hotN (j)
To recover graph structures from hE , we simply read off the indices of non-zero entries (gout ). We can set hρ,0 to any
Rd×N ×N matrix, as we do not consider its values in this analysis and discard it during the first step.
For example, if (i, j) is oriented as (▶, •), then we expect (j, i) to be oriented (•, ▶).
Step 1: Undirected skeleton We use the first axial attention block to recover the undirected skeleton C ′ . We set all
attentions to the identity, set Wρ,1 ∈ R2d×d to a d × d zeros matrix, stacked on top of a d × d identity matrix (discard ρ),
16
Sample, estimate, aggregate: A recipe for causal discovery foundation models
and set FFNE to the identity (inputs are positive). This yields
Pei (∅)
Pei (•)
Pei (▶)
hρ,0 E→ρ,1
i,j = mi,j = , (22)
..
.
one-hotN (i)
one-hotN (j)
where Pei (·) is the frequency that endpoint ei = · within the subsets sampled. FFNs with 1 hidden layer are universal
approximators of continuous functions (Hornik et al., 1989), so we use FFNρ to map
0
i≤6
FFNρ (Xi,u,v ) = 0 i > 6, X1,u,v = 0 (23)
−Xi,u,v otherwise,
where i ∈ [2N + 6] indexes into the feature dimension, and u, v index into the rows and columns. This allows us to remove
edges not present in C ′ from consideration:
mρ→E,1 = hρ,1
(
0 (i, j) ̸∈ C ′
hE,1 E,1 ρ→E,1
i,j ← hi,j + mi,j = (24)
hE,0
i,j otherwise.
Step 2: V-structures The second and third axial attention blocks recover v-structures. We run the same procedure twice,
once to capture v-structures that point towards the first node in an ordered pair, and one to capture v-structures that point
towards the latter node.
We start with the first row attention over edge estimates, given a fixed subset t. We set the key and query attention matrices
0 0 1 0 0 1
0 1 0
0 1 0
WK = k ·
..
W = k ·
..
(25)
.
Q
.
IN IN
−IN IN
where k is a large constant, IN denotes the size N identity matrix, and all unmarked entries are 0s.
Recall that a v-structure is a pair of directed edges that share a target node. We claim that two edges (i, j), (u, v) form a
v-structure in E ′ , pointing towards i = u, if this inner product takes on the maximum value
(WK hE,1 )i,j , (WQ hE,1 )u,v = 3. (26)
Suppose both edges (i, j) and (u, v) still remain in C ′ . There are two components to consider.
1. If i = u, then their shared node contributes +1 to the inner product (prior to scaling by k). If j = v, then the inner
product accrues −1.
2. Nodes that do not share the same endpoint contribute 0 to the inner product. Of edges that share one node, only
endpoints that match ▶ at the starting node, or • at the ending node contribute +1 to the inner product each. We
provide some examples below.
(ei , ej ) (eu , ev ) contribution note
(▶, •) (•, ▶) 0 no shared node
(•, ▶) (•, ▶) 0 wrong endpoints
(•, •) (•, •) 1 one correct endpoint
(▶, •) (▶, •) 2 v-structure
17
Sample, estimate, aggregate: A recipe for causal discovery foundation models
All edges with endpoints ∅ were “removed” in step 1, resulting in an inner product of zero, since their node embeddings
were set to zero. We set k to some large constant (empirically, k 2 = 1000 is more than enough) to ensure that after softmax
scaling, σe,e′ > 0 only if e, e′ form a v-structure.
Given ordered pair e = (i, j), let Vi ⊂ V denote the set of nodes that form a v-structure with e with shared node i. Note
that Vi excludes j itself, since setting of WK , WQ exclude edges that share both nodes. We set WV to the identity, and we
multiply by attention weights σ to obtain
..
.
(WV hE,1 σ)e=(i,j) = one-hotN (i) (27)
αj · binaryN (Vj )
where binaryN (S) denotes the N -dimensional binary vector with ones at elements in S, and the scaling factor
αj = (1/∥Vj ∥) · 1{∥Vj ∥ > 0} ∈ [0, 1] (28)
results from softmax normalization. We set
0N +6
WO = (29)
0.5 · IN
to preserve the original endpoint values, and to distinguish between the edge’s own node identity and newly recognized
v-structures. To summarize, the output of this row attention layer is
Attnrow (X·,c ) = X·,c + WO WV X·,c · σ,
which is equal to its input hE,1 plus additional positive values ∈ (0, 0.5) in the last N positions that indicate the presence of
v-structures that exist in the overall E ′ .
Our final step is to “copy” newly assigned edge directions into all the edges. We set the ϕE column attention, FFNE and the
ϕρ attentions to the identity mapping. We also set Wρ,2 to a d × d zeros matrix, stacked on top of a d × d identity matrix.
This passes the output of the ϕE row attention, aggregated over subsets, directly to FFNϕ,2 .
For endpoint dimensions e = [6], we let FFNϕ,2 implement
(
[0, 0, 1, 0, 1, 0]T − Xe,u,v
P
0 < i>N +6 Xi,u,v < 0.5
FFNρ,2 (Xe,u,v ) = (30)
0 otherwise.
Subtracting Xe,u,v “erases” the original endpoints and replaces them with (▶, •) after the update
hE,1 E,1 ρ→E,1
i,j ← hi,j + mi,j .
The overall operation translates to checking whether any v-structure points towards i, and if so, assigning edge directions
accordingly. For dimensions i > 6,
(
−Xi,u,v Xi,u,v ≤ 0.5
FFNρ,2 (Xi,u,v ) = (31)
0 otherwise,
effectively erasing the stored v-structures from the representation and remaining consistent to (21).
At this point, we have copied all v-structures once. However, our orientations are not necessarily symmetric. For example,
given v-structure i → j ← k, our model orients edges (j, i) and (j, k), but not (i, j) or (k, j).
The simplest way to symmetrize these edges (for the writer and the reader) is to run another axial attention block, in which
we focus on v-structures that point towards the second node of a pair. The only changes are as follows.
• For WK and WQ , we swap columns 1-3 with 4-6, and columns 7 to N + 6 with the last N columns.
• (hE,2 σ)i,j sees the third and fourth blocks swapped.
• WO swaps the N × N blocks that correspond to i and j’s node embeddings.
• FFNρ,3 sets the endpoint embedding to [0, 1, 0, 0, 0, 1]T − Xe,u,v if i = 7, ..., N + 6 sum to a value between 0 and 0.5.
The result is hE,3 with all v-structures oriented symmetrically, satisfying A.10.
18
Sample, estimate, aggregate: A recipe for causal discovery foundation models
Step 3: Orientation propagation To propagate orientations, we would like to identify cases (i, j), (i, k) ∈ E ′ , (j, k) ̸∈ E ′
with shared node i and corresponding endpoints (▶, •), (•, •). We use ϕE to identify triangles, and ϕρ to identify edges
(i, j), (i, k) ∈ E ′ with the desired endpoints, while ignoring triangles.
Marginal layer The row attention in ϕE fixes a subset t and varies the edge (i, j).
Given edge (i, j), we want to extract all (i, k) that share node i. We set the key and query attention matrices to
0 1 1 0 1 1
..
WK , WQ = k · . . (32)
IN
±IN
..
.
E
..
(WV h σ)e=(i,k) =
. ,
(33)
one-hotN (i)
αk · binaryN (Vk )
where Vk is the set of nodes k that share any edge with i. To distinguish between k and Vk , we again set Wo to the same as
in (29). Finally, we set FFNE to the identity and pass hE directly to ϕρ . To summarize, we have hE equal to its input, with
values ∈ (0, 0.5) in the last N locations indicating 1-hop neighbors of each edge.
Global layer Now we would like to identify cases (i, k), (j, k) with corresponding endpoints (•, ▶), (•, •). We set the
key and query attention matrices
0 0 1 0 1 −1 0 1 −1
.. ..
WK = k · . WQ = k · . . (34)
IN IN
IN −IN
The key allows us to check that endpoint i is directed, and the query allows us to check that (i, k) exists in C ′ , and does not
already point elsewhere. After softmax normalization, for sufficiently large k, we obtain σ(i,j),(i,k) > 0 if and only if (i, k)
should be oriented (•, ▶), and the inner product attains the maximum possible value
Orientation assignment Our final step is to assign our new edge orientations. Let the column attention take on the identity
mapping. For endpoint dimensions e = (4, 5, 6), we let FFNρ implement
(
[0, 0, 1]T − Xe,u,v 0 < i>N +6 Xi,u,v < 0.5
P
FFNρ (Xe,u,v ) = (36)
0 otherwise.
19
Sample, estimate, aggregate: A recipe for causal discovery foundation models
This translates to checking whether any incoming edge points towards v, and if so, assigning the new edge direction
accordingly. For dimensions i > 6, (
0 Xi,u,v ≤ 0.5
FFNρ (Xi,u,v ) = (37)
Xi,u,v otherwise,
effectively erasing the stored assignments from the representation. Thus, we are left with hE,ℓ that conforms to the same
format as the initial embedding in (21).
To symmetrize these edges, we run another axial attention block, in which we focus on paths that point towards the second
node of a pair. The only changes are as follows.
The result is hE with symmetric 1-hop orientation propagation, satisfying A.10. We may repeat this procedure k times to
capture k-hop paths.
To summarize, we used axial attention block 1 to recover the undirected skeleton C ′ , blocks 2-3 to identify and copy
v-structures in E ′ , and all subsequent L − 3 layers to propagate orientations on paths up to ⌊(L − 3)/2⌋ length. Overall,
this particular construction requires O(N ) width for O(L) paths.
Final remarks Information theoretically, it should be possible to encode the same information in log N space, and achieve
O(log N ) width. For ease of construction, we have allowed for wider networks than optimal.
On the other hand, if we increase the width and encode each edge symmetrically, e.g. (ei , ej , ej , ei | i, j, j, i), we can
reduce the number of blocks by half, since we no longer need to run each operation twice. However, attention weights scale
quadratically, so we opted for an asymmetric construction.
Finally, a strict limitation of our model is that it only considers 1D pairwise interactions. In the graph layer, we cannot
compare different edges’ estimates at different times in a single step. In the feature layer, we cannot compare (i, j) to (j, i)
in a single step either. However, the graph layer does enable us to compare all edges at once (sparsely), and the feature
layer looks at a time-collapsed version of the whole graph. Therefore, though we opted for this design for computational
efficiency, we have shown that it is able to capture significant graph reasoning.
20
Sample, estimate, aggregate: A recipe for causal discovery foundation models
(G IES ) and S EA (F CI ) achieve high edge accuracy. Therefore, while the underlying algorithms may not be stable with
respect to edge orientation, our pretrained aggregator seems to be robust.
B. Experimental details
B.1. Synthetic data generation
Synthetic datasets were generated using code from D CDI (Brouillard et al., 2020), which extended the Causal Discovery
Toolkit data generators to interventional data (Kalainathan et al., 2020).
We considered the following causal mechanisms. Let y be the node in question, let X be its parents, let E be an independent
noise variable (details below), and let W be randomly initialized weight matrices.
• Linear: y = XW + E.
• Polynomial: y = W0 + XW1 + X 2 W2 + E
Pd
• Sigmoid additive: y = i=1 Wi · sigmoid(Xi ) + E
• Randomly initialized neural network (NN): y = Tanh((X, E)Win )Wout
• Randomly initialized neural network, additive (NN additive): y = Tanh(XWin )Wout + E
Root causal mechanisms, noise variables, and interventional distributions maintained the D CDI defaults.
Ablation datasets with N > 100 nodes contained 100,000 points each (same as N = 100).
21
Sample, estimate, aggregate: A recipe for causal discovery foundation models
D IFF AN (Sanchez et al., 2023) was trained on the each of the N = 10, 20 datasets using their published hyperparameters.
The authors write that “most hyperparameters are hard-coded into [the] constructor of the D IFF AN class and we verified
they work across a wide set of datasets.” We used the original, non-approximation version of their algorithm by maintaining
residue=True in their codebase. We were unable to consistently run D IFF AN with both R and GPU support within a
Docker container, and the authors did not respond to questions regarding reproducibility, so all models were trained on the
CPU only. We observed approximately a 10x speedup in the < 5 cases that were able to complete running on the GPU.
For our final model, we selected learned positional embeddings, 4 layers, 8 heads, and learning rate η = 1e − 4.
22
Sample, estimate, aggregate: A recipe for causal discovery foundation models
2. Global-statistic-based selection: α = ρ.
3. Uncertainty-based selection: α = Ĥ(Et ), where H denotes the information entropy
X
αi,j = − p(e) log p(e). (38)
e∈{0,1,2}
q
Let cti,j be the number of times edge (i, j) was selected in S1 . . . St−1 , and let αt = α/ cti,j . We consider two strategies
for selecting St based on αt .
Greedy selection: Throughout our experiments, we used a greedy algorithm for subset selection. We normalize probabili-
ties to 1 before the constructing each Categorical. Initialize
St ← {i : i ∼ Categorical(α1t . . . αN
t
)}. (39)
where αit = t
P
j̸=i∈V αi,j . While |St | < k, update
t t
St ← St ∪ {j : j ∼ Categorical(α1,St
. . . αN,St
)) (40)
where (P
t
i∈St αi,j j ̸∈ St
αj,St = (41)
0 otherwise.
Subset selection: We also considered the following subset-level selection procedure, and observed minor performance
gain for significantly longer runtime (linear program takes around 1 second per batch). Therefore, we opted for the greedy
method instead.
t
P
We solve the following integer linear program to select a subset St of size k that maximizes i∈St αi,j . Let νi ∈ {0, 1}
denote the selection of node i, and let ϵi,j ∈ {0, 1} denote the selection of edge (i, j). Our objective is to
at · ϵ
P
maximize
Pi,j i,j i,j
subject to i νi = k subset size
ϵi,j ≥ νi + νj − 1 node-edge consistency
ϵi,j ≤ νi
ϵi,j ≤ νj ,
νi ∈ {0, 1}
ϵi,j ∈ {0, 1}
C. Additional analyses
C.1. Traditional algorithm parameters
We investigated model performance with respect to the settings of our graph estimation parameters. Our model is sensitive
to the size of batches used to estimate global features and marginal graphs (Figure 6). In particular, at least 250 points are
required per batch for an acceptable level of performance. Our model is not particularly sensitive to the number of batches
sampled (Figure 7), or to the number of variables sampled in each subset (Figure 8).
23
Sample, estimate, aggregate: A recipe for causal discovery foundation models
Figure 6. Performance of our model (GIES) as a function of traditional algorithm batch size. Error bars indicate 95% confidence interval
across the 10 datasets of each setting. The global feature and marginal graph estimates are sensitive to batch size and require at least 250
points per batch to achieve an acceptable level of performance.
Figure 7. Performance of our model (GIES) as a function of number of batches sampled. Error bars indicate 95% confidence interval
across the 10 datasets of each setting. Our model is relatively insensitive to the number of batches sampled, though more batches are
beneficial in harder cases, e.g. sigmoid mechanism with additive noise or smaller batch size.
graph-level statistics on our synthetic datasets. Discretization thresholds for SHD were obtained by computing the pth
quantile of the computed values, where p = 1 − (E/N ). This is not entirely fair, as no other baseline receives the same
calibration, but these ablation studies only seek to compare state-of-the-art causal discovery methods with the “best” possible
(oracle) statistical alternatives.
C ORR refers to global correlation,
E (Xi Xj ) − E (Xi ) E (Xj )
ρi,j = q q . (42)
2 2
E (Xi2 ) − E (Xi ) · E Xj2 − E (Xj )
D-C ORR refers to distance correlation, computed between all pairs of variables. Distance correlation captures both linear
and non-linear dependencies, and D-C ORR(Xi , Xj ) = 0 if and only if Xi ⊥⊥ Xj . Please refer to (Sz’ekely et al., 2007) for
the full derivation. Despite its power to capture non-linear dependencies, we opted not to use D-C ORR because it is quite
slow to compute between all pairs of variables.
I NV C OV refers to inverse covariance, computed globally,
−1
ρ = E (X − E (X))(X − E (X))T . (43)
For graphs N < 100, inverse covariance was computed directly using NumPy. For graphs N ≥ 100, inverse covariance was
computed using Ledoit-Wolf shrinkage at inference time (Ledoit & Wolf, 2004).
24
Sample, estimate, aggregate: A recipe for causal discovery foundation models
Figure 8. Performance of our model (GIES) as a function of subset size |S| (number of variables sampled). Error bars indicate 95%
confidence interval across the 10 datasets of each setting. Our model was trained on |S| = 5, but it is insensitive to the number of variables
sampled per subset at inference. Runtime scales exponentially upwards.
25
Sample, estimate, aggregate: A recipe for causal discovery foundation models
Table 7. Comparison of global statistics (SHD). Discretization thresholds for SHD were obtained by computing the pth quantile of the
computed values, where p = 1 − (E/N ).
Nonetheless, our current model already obtains reasonable performance on larger graphs, out of the box.
Due to the scope of this project and computing resources, we did not train very “big” models in the modern sense. There is
much space to scale, both in terms of model architecture and the datasets covered. Table 11 probes the generalization limits
of the two implementations of S EA in this paper.
We identified that our models, trained primarily on additive noise, achieve reasonable performance, but do not generalize
reliably to causal mechanisms with multiplicative noise. For example, we tested additional datasets with the following
mechanisms (same format as B.1).
Pd
• Sigmoid mix: y = i=1 Wi · sigmoid(Xi ) × E
We anticipate that incorporating these data into the training set would alleviate some of this gap (just as training on synthetic
mRNA data enabled us to perform well there, despite its non-standard data distributions). However, we did not have time to
test this hypothesis empirically.
D CDI learns a new generative model over each dataset, and its more powerful, deep sigmoidal flow variant seems to perform
well in some (but not all) of these harder cases.
Tables 12 and 13 report the full results on N = 100 graphs.
Tables 14 and 15 report our results on scale-free graphs.
26
Sample, estimate, aggregate: A recipe for causal discovery foundation models
Table 8. Causal discovery results on simulated mRNA data. Each setting encompasses 5 distinct scale-free graphs. Data were generated
via SERGIO (Dibaeinia & Sinha, 2020).
Table 9. Comparison between heuristics-based sampler (random and inverse covariance) vs. model confidence-based sampler. Details
regarding the samplers can be found in B.6. The suffix -L indicates the greedy confidence-based sampler. Each setting encompasses
5 distinct Erdős-Rényi graphs. The symbol † indicates that S EA was not pretrained on this setting. Bold indicates best of all models
considered (including baselines not pictured).
27
Sample, estimate, aggregate: A recipe for causal discovery foundation models
Table 10. Scaling to synthetic graphs, larger than those seen in training. Each setting encompasses 5 distinct Erdős-Rényi graphs. For all
analysis in this table, we took T = 500 subsets of nodes, with b = 500 examples per batch. Here, the mean AUC values are artificially
high due to the high negative rates, as actual edges scale linearly as N , while the number of possible edges scales quadratically.
Figure 9. mAP on graphs larger than seen during training. Due to an insufficient maximum number of subset embeddings, we were only
able to sample 500 batches, which appears to be too few for larger graphs. These values correspond to the numbers in Table 10.
28
Sample, estimate, aggregate: A recipe for causal discovery foundation models
Table 11. Generalization limits of our current implementations. Each setting represents 5 distinct graphs. Our models were not pretrained
on multiplicative noise, so they do not generalize reliably to these cases. While much slower, D CDI variants learn the data distribution
from scratch each time, so they seem to perform well in some of these cases. We anticipate that training on multiplicative data would
alleviate this generalization gap, but we did not have time to test this empirically.
N E Model Sigmoid mix† Polynomial mix† Sigmoid mix (SF)† Polynomial mix (SF)†
mAP ↑ EA ↑ SHD ↓ mAP ↑ EA ↑ SHD ↓ mAP ↑ EA ↑ SHD ↓ mAP ↑ EA ↑ SHD ↓
D CDI -G 0.84 0.81 0.9 0.11 0.00 10.2 0.67 0.92 15.2 0.12 0.00 10.6
D CDI -D SF 0.96 0.91 0.3 0.39 0.32 7.3 0.81 0.98 13.6 0.39 0.36 9.8
D CD -F G 0.52 0.61 18.0 0.11 0.00 10.2 0.57 0.66 15.5 0.12 0.00 10.6
D IFF A N 0.14 0.31 19.9 0.11 0.09 14.0 0.10 0.22 17.8 0.11 0.14 17.2
D ECI 0.17 0.43 19.1 0.11 0.07 14.1 0.20 0.53 15.0 0.12 0.06 13.6
10 10
I NV C OV 0.34 0.51 12.1 0.18 0.51 16.4 0.31 0.56 11.4 0.16 0.50 16.1
F CI -AVG 0.56 0.54 9.1 0.11 0.00 10.2 0.53 0.52 8.1 0.12 0.01 10.6
G IES -AVG 0.91 0.88 3.1 0.22 0.38 10.2 0.93 0.88 2.3 0.22 0.37 10.6
S EA (F CI ) 0.58 0.51 5.7 0.25 0.53 10.2 0.67 0.56 4.5 0.20 0.46 10.6
S EA (G IES ) 0.41 0.39 8.1 0.18 0.51 10.2 0.48 0.48 5.4 0.21 0.46 10.6
D CDI -G 0.58 0.35 25.0 0.44 0.00 39.8 0.76 0.83 22.0 0.31 0.00 28.2
D CDI -D SF 0.79 0.46 12.2 0.48 0.09 35.6 0.92 0.88 15.8 0.35 0.07 26.4
D CD -F G 0.60 0.40 26.8 0.44 0.00 39.8 0.52 0.46 22.1 0.31 0.00 28.2
D IFF A N 0.38 0.32 31.8 0.40 0.32 30.0 0.26 0.30 35.0 0.28 0.29 29.3
D ECI 0.46 0.39 27.5 0.43 0.03 39.5 0.35 0.51 24.9 0.30 0.05 30.1
10 40
I NV C OV 0.48 0.51 38.7 0.48 0.54 40.3 0.40 0.51 32.5 0.34 0.48 38.2
F CI -AVG 0.44 0.18 39.7 0.44 0.01 39.8 0.43 0.30 25.2 0.32 0.01 28.2
G IES -AVG 0.43 0.37 38.2 0.50 0.48 39.8 0.48 0.45 22.8 0.36 0.39 28.2
S EA (F CI ) 0.56 0.58 28.6 0.81 0.85 39.7 0.62 0.67 17.3 0.58 0.85 28.2
S EA (G IES ) 0.58 0.65 30.2 0.80 0.86 39.5 0.64 0.69 16.7 0.60 0.81 28.2
D CDI -G 0.57 0.72 6.2 0.06 0.01 21.7 0.47 0.98 45.4 0.06 0.01 20.7
D CDI -D SF 0.91 0.94 1.1 0.21 0.46 34.6 0.56 0.97 40.6 0.29 0.80 83.0
D CD -F G 0.50 0.68 62.1 0.06 0.00 24.3 0.63 0.81 256.6 0.06 0.11 186.4
D IFF A N 0.08 0.26 47.5 0.06 0.15 51.5 0.08 0.30 42.6 0.06 0.22 53.7
D ECI 0.17 0.58 38.0 0.06 0.05 36.0 0.18 0.60 32.9 0.06 0.03 38.9
20 20
I NV C OV 0.31 0.59 22.3 0.09 0.52 35.9 0.24 0.44 24.5 0.07 0.50 35.4
F CI -AVG 0.64 0.69 17.4 0.06 0.01 21.3 0.58 0.66 15.3 0.06 0.00 21.0
G IES -AVG 0.82 0.75 8.5 0.15 0.44 21.3 0.83 0.78 7.8 0.12 0.42 21.0
S EA (F CI ) 0.61 0.60 8.5 0.12 0.55 21.3 0.63 0.56 9.0 0.11 0.55 21.1
S EA (G IES ) 0.41 0.48 13.0 0.11 0.55 21.6 0.50 0.49 12.3 0.11 0.52 21.3
D CDI -G 0.60 0.59 32.3 0.21 0.05 86.4 0.54 0.84 78.6 0.18 0.14 74.8
D CDI -D SF 0.89 0.81 8.4 0.24 0.25 102.3 0.65 0.89 60.1 0.27 0.60 201.6
D CD -F G 0.46 0.72 222.2 0.21 0.00 81.8 0.52 0.76 202.2 0.18 0.15 225.9
D IFF A N 0.18 0.30 151.1 0.19 0.31 130.8 0.15 0.30 137.0 0.16 0.31 127.9
D ECI 0.31 0.43 70.5 0.20 0.03 89.0 0.25 0.41 66.5 0.17 0.02 79.0
20 80
I NV C OV 0.30 0.51 98.1 0.22 0.51 115.4 0.27 0.53 89.1 0.19 0.49 108.1
F CI -AVG 0.37 0.30 76.0 0.21 0.01 79.3 0.38 0.34 59.6 0.18 0.01 66.5
G IES -AVG 0.54 0.64 68.4 0.23 0.37 79.3 0.55 0.65 50.7 0.19 0.36 66.5
S EA (F CI ) 0.61 0.64 52.6 0.41 0.87 79.3 0.62 0.70 36.7 0.34 0.82 66.7
S EA (G IES ) 0.53 0.64 58.1 0.43 0.86 79.0 0.59 0.74 41.4 0.35 0.82 66.8
29
Sample, estimate, aggregate: A recipe for causal discovery foundation models
Table 12. Causal discovery results on synthetic datasets with 100 nodes, continuous metrics. Each setting encompasses 5 distinct
Erdős-Rényi graphs. The symbol † indicates that the model was not trained on this setting. All standard deviations were within 0.1.
Table 13. Causal discovery results on synthetic datasets with 100 nodes, discrete metrics. Each setting encompasses 5 distinct Erdős-Rényi
graphs. The symbol † indicates that the model was not trained on this setting.
Figure 10. Runtime for heuristics-based greedy sampler vs. model uncertainty-based greedy sampler (suffix - L). For sampling, the model
was run on CPU only, due to the difficulty of invoking GPU in the PyTorch data sampler.
30
Sample, estimate, aggregate: A recipe for causal discovery foundation models
Table 14. Causal discovery results on synthetic scale-free datasets, continuous metrics. Each setting encompasses 5 distinct scale-free
graphs. The symbol † indicates that S EA was not pretrained on this setting. Details regarding baselines can be found in B.2.
31
Sample, estimate, aggregate: A recipe for causal discovery foundation models
Table 15. Causal discovery results on synthetic scale-free datasets, discrete metrics. Each setting encompasses 5 distinct scale-free graphs.
The symbol † indicates that S EA was not pretrained on this setting. Details regarding baselines can be found in B.2.
32