0% found this document useful (0 votes)
18 views32 pages

Sample, Estimate, Aggregate: A Recipe For Causal Discovery Foundation Models

The document proposes a new framework called S EA for building fast and generalizable foundation models for causal discovery. S EA leverages estimates from classical causal discovery algorithms over small subsets of variables as well as global graph statistics as inputs to a pretrained deep learning model. The model is shown to attain state-of-the-art results on synthetic and realistic causal discovery tasks while providing significantly faster inference speeds compared to existing approaches.

Uploaded by

Spencer Xu
Copyright
© © All Rights Reserved
We take content rights seriously. If you suspect this is your content, claim it here.
Available Formats
Download as PDF, TXT or read online on Scribd
0% found this document useful (0 votes)
18 views32 pages

Sample, Estimate, Aggregate: A Recipe For Causal Discovery Foundation Models

The document proposes a new framework called S EA for building fast and generalizable foundation models for causal discovery. S EA leverages estimates from classical causal discovery algorithms over small subsets of variables as well as global graph statistics as inputs to a pretrained deep learning model. The model is shown to attain state-of-the-art results on synthetic and realistic causal discovery tasks while providing significantly faster inference speeds compared to existing approaches.

Uploaded by

Spencer Xu
Copyright
© © All Rights Reserved
We take content rights seriously. If you suspect this is your content, claim it here.
Available Formats
Download as PDF, TXT or read online on Scribd
You are on page 1/ 32

Sample, estimate, aggregate:

A recipe for causal discovery foundation models

Menghua Wu 1 Yujia Bao 2 Regina Barzilay 1 Tommi Jaakkola 1

Abstract approaches that differ in their treatment of the underlying


Causal discovery, the task of inferring causal causal graph. Discrete optimization algorithms explore the
arXiv:2402.01929v1 [cs.LG] 2 Feb 2024

structure from data, promises to accelerate sci- super-exponential space of graphs by proposing and eval-
entific research, inform policy making, and more. uating changes to a working graph (Glymour et al., 2019).
However, the per-dataset nature of existing causal While these methods are quite fast on small graphs, the
discovery algorithms renders them slow, data hun- combinatorial space renders them intractable for exploring
gry, and brittle. Inspired by foundation models, larger structures, with hundreds or thousands of nodes. Fur-
we propose a causal discovery framework where thermore, their correctness is tied to hypothesis tests, whose
a deep learning model is pretrained to resolve pre- results can be erroneous for noisy datasets, especially as
dictions from classical discovery algorithms run graph sizes increase. More recently, a number of works
over smaller subsets of variables. This method have reframed the discrete graph search as a continuous op-
is enabled by the observations that the outputs timization over weighted adjacency matrices (Zheng et al.,
from classical algorithms are fast to compute for 2018; Brouillard et al., 2020). These continuous optimiza-
small problems, informative of (marginal) data tion algorithms often require that a generative model be fit
structure, and their structure outputs as objects to the full data distribution, a difficult task when the number
remain comparable across datasets. Our method of variables increases but the data remain sparse.
achieves state-of-the-art performance on synthetic In this work, we present S EA: Sample, Estimate, Aggre-
and realistic datasets, generalizes to data gener- gate, a blueprint for developing causal discovery foundation
ating mechanisms not seen during training, and models that enable fast inference on new datasets, perform
offers inference speeds that are orders of magni- well in low data regimes, and generalize to causal mech-
tude faster than existing models.1 anisms beyond those seen during training. Our approach
is motivated by two observations. First, while classical
causal discovery algorithms scale poorly, their bottleneck
1. Introduction lies in the exponential search space, rather than individual
A fundamental aspect of scientific research is to discover independence tests. Second, in many cases, statistics like
and validate causal hypotheses involving variables of in- global correlation or inverse covariance are indeed strong
terest. Given observations of these variables, the goal of indicators for a graph’s overall connectivity. Therefore, we
causal discovery algorithms is to extract such hypotheses in propose to leverage 1) the estimates of classical algorithms
the form of directed graphs, in which edges denote causal over small subgraphs, and 2) global graph-level statistics, as
relationships (Spirtes et al., 2001). In much of basic science, inputs to a deep learning model, pretrained to resolve these
however, classical statistics remain the de facto basis of data statistical descriptors into causal graphs.
analysis (Replogle et al., 2022). Key barriers to widespread Theoretically, we prove that given only marginal estimates
adoption of causal discovery algorithms include their high over subgraphs, it is possible to recover causally sound
data requirements and their computational intractability on global graphs; and that our proposed model has the capac-
larger problems. ity to recapitulate such reasoning. Empirically, we imple-
Current causal discovery algorithms follow two primary ment instantiations of S EA using an axial-attention based
model (Ho et al., 2020) which takes as input, inverse co-
1
CSAIL, Massachusetts Institute of Technology, Cambridge, variance and estimates from the classical FCI or GIES algo-
MA 2 Accenture, Mountain View, CA. Correspondence to:
Menghua Wu <[email protected]>.
rithms (Spirtes et al., 1995; Hauser & Bühlmann, 2012). We
conduct thorough comparison to three classical baselines
Preprint. Under review. and five deep learning approaches. S EA attains the state-of-
1
Our code and data are publicly available here: https:// the-art results on synthetic and real-world causal discovery
github.com/rmwu/sea

1
Sample, estimate, aggregate: A recipe for causal discovery foundation models

tasks, while providing 10-1000x faster inference. While Causal discovery is also used in the context of pairwise
these experimental results reflect specific algorithms and ar- relationships (Zhang & Hyvarinen, 2012; Monti et al., 2019),
chitectures, the overall framework accepts any combination as well as for non-stationary (Huang et al., 2020) and time-
of sampling heuristics, classical causal discovery algorithms, series (Löwe et al., 2022) data. However, this work focuses
and statistical features. To summarize, our contributions are on causal discovery for stationary graphs, so we will discuss
as follows. the two main approaches to causal discovery in this setting.
1. To the best of our knowledge, we are the first to propose Discrete optimization methods make atomic changes to a
a method for building fast, robust, and generalizable proposal graph until a stopping criterion is met. Constraint-
foundation models for causal discovery. based algorithms identify edges based on conditional inde-
2. We show that both our overall framework and specific pendence tests, and their correctness is inseparable from
architecture have the capacity to reproduce causally the empirical results of those tests (Glymour et al., 2019),
sound graphs from their inputs. whose statistical power depends directly on dataset size.
3. We attain state-of-the-art results on synthetic and real- These include the FCI and PC algorithms in the observa-
istic settings, and we provide extensive experimental tional case (Spirtes et al., 1995), and the JCI algorithm in
analysis of our model. the interventional case (Mooij et al., 2020).
Score-based methods also make iterative modifications to a
2. Background and related work working graph, but their goal is to maximize a continuous
score over the discrete space of all valid graphs, with the true
2.1. Causal graphical models
graph at the optimum. Due to the intractable search space,
A causal graphical model is a directed, acyclic graph G = these methods often make decisions based on greedy heuris-
(V, E), where each node i ∈ V corresponds to a random tics. Classic examples include GES (Chickering, 2002),
variable Xi ∈ X and each edge (i, j) ∈ E represents a GIES (Hauser & Bühlmann, 2012), CAM (Bühlmann et al.,
causal relationship from Xi → Xj . 2014), and LiNGAM (Shimizu et al., 2006).
We assume that the data distribution PX is Markov to G, Continuous optimization approaches recast the combina-
torial space of graphs into a continuous space of weighted
∀i ∈ V, Xi ⊥
⊥ V \ (Xδi ∪ Xπi ) | Xπj (1) adjacency matrices. Many of these works train a genera-
tive model to learn the empirical data distribution, which is
where δi is the set of descendants of node i in G, and πi is parameterized through the adjacency matrix (Zheng et al.,
the set of parents of i. In addition, we assume that PX is 2018; Lachapelle et al., 2020; Brouillard et al., 2020). Oth-
minimal and faithful – that is, PX does not include any in- ers focus on properties related to the empirical data distri-
dependence relationships beyond those implied by applying bution, such as a relationship between the underlying graph
the Markov condition to G (Spirtes et al., 2001). and the Jacobian of the learned model (Reizinger et al.,
Causal graphical models allow us to perform interventions 2023), or between the Hessian of the data log-likelihood and
on nodes i by setting conditional P (Xi | Xπi ) to a different the topological ordering of the nodes (Sanchez et al., 2023).
distribution P̃ (Xi | Xπi ). While these methods bypass the combinatorial search over
discrete graphs, they still require copious data and time to
2.2. Causal discovery train accurate generative models of the data distributions.

Given a dataset D ∼ PX , the goal of causal discovery is to 2.3. Foundation models


recover G. There are two main challenges. First, the num-
ber of possible graphs is super-exponential in the number of The concept of foundation models has revolutionized the
nodes N , so causal discovery algorithms must navigate this machine learning workflow in a variety of disciplines: in-
combinatorial search space efficiently. In terms of graph stead of training domain-specific models from scratch,
size, the limits of current algorithms range from tens of we can query a pretrained, general-purpose “foundation”
nodes (Hägele et al., 2023) to hundreds of nodes (Lopez model (Radford et al., 2021; Brown et al., 2020; Bommasani
et al., 2022), where simplifying assumptions regarding the et al., 2022). Recent work has explored foundation models
graph structure are often made in the latter case. Second, for causal inference (Zhang et al., 2023), but this method
depending on data availability and the underlying data gen- addresses causal inference rather than causal discovery, so
eration process, causal discovery algorithms may or may not it assumes that the causal structures are already known.
be able to recover G in practice. In fact, many algorithms
are only analyzed in the infinite-data regime and require at
least thousands of data samples for reasonable empirical
performance (Spirtes et al., 2001; Brouillard et al., 2020).

2
Sample, estimate, aggregate: A recipe for causal discovery foundation models

Figure 1. An overview of the S EA inference procedure. Given a new dataset, we 1) sample batches and subsets of nodes, 2) estimate
marginal graphs and global statistics over these batches, and 3) aggregate these features using a pretrained model to obtain the underlying
causal graph. Raw data are depicted in green, and graph-related features are depicted in blue.

3. Methods 3. Sample T node subsets of size k. Each subset St ⊆ V


is constructed iteratively, with additional nodes sam-
We present S EA, a framework for developing fast and scal- pled
able causal discovery foundation models. We follow two key P one at a time with probability proportional to
j∈St αi,j (details in Section B.6).
insights. First, while classical causal discovery algorithms
scale poorly to large graphs, they are highly performant on Estimate: takes as inputs data batches and node sub-
small graphs. Second, in many cases, simple statistics like sets; and outputs global statistics ρ and marginal estimates
global correlation or inverse covariance are strong baselines {E1′ , . . . , ET′ }.
for a graph’s overall connectivity. Our framework combines 1. Compute global statistics ρ ∈ RN ×N over D0 .
global statistics and marginal estimates (the outputs of clas- 2. Run causal discovery algorithm f to obtain marginal
sical causal discovery algorithms on subsets of nodes) as estimates f (Dt [St ]) = Et′ for t = 1 . . . T .
inputs to a deep learning model, pretrained to aggregate
these features into causal graphs (Section 3.1). As proofs We use Dt [St ] to denote the observations in Dt that cor-
of concept, we describe instantiations of S EA using specific respond only to the variables in St . Each estimate Et′ is a
algorithms and architectures (Sections 3.2, 3.3). Finally, we k × k adjacency matrix, corresponding to the k nodes in St .
prove that our algorithm has the capacity to produce sound Aggregate: takes as inputs global statistics, marginal es-
causal graphs, in theory and in context of our implementa- timates, and node subsets. A pretrained aggregator model
tion (Section 3.4). outputs the predicted global causal graph Ê ∈ (0, 1)N ×N .

3.1. Causal discovery framework The aggregator’s architecture is agnostic to the assump-
tions of the underlying causal discovery algorithm f (e.g.
S EA is a causal discovery framework that learns to resolve (non)linearity, observational vs. interventional). Therefore,
statistical features and estimates of marginal graphs into a we can freely swap out f and the training datasets to train
global causal graph. The inference procedure is depicted aggregators with different implicit assumptions, without
in Figure 1. Specifically, given a new dataset D ∈ RM ×N modifying the architecture itself. Here, our aggregator is
faithful to graph G = (V, E), we apply the following stages. pretrained over a diverse set of graph sizes, structures, and
Sample: takes as input dataset D; and outputs data batches data generating mechanisms (details in Section 4.1).
{D0 , D1 , . . . , DT } and node subsets {S1 , . . . , ST }.
3.2. Model architecture
1. Sample T + 1 batches of b ≪ M observations uni-
formly at random from D. The core module of S EA is the aggregator (Figure 2), im-
2. Compute selection scores α ∈ (0, 1)N ×N over D0 (e.g. plemented as a sequence of axial attention blocks. The
inverse covariance). aggregator takes as input global statistics ρ ∈ RN ×N ,

3
Sample, estimate, aggregate: A recipe for causal discovery foundation models

Figure 2. Aggregator architecture. Marginal graph estimates and global statistics are first embedded into the model dimension, and 1D
positional embeddings are added along both rows and columns. These combined embeddings pass through a series of axial attention
blocks, which attend to the graph estimates and the global features. The final layer global features pass through a feedforward network to
predict the causal graph.


marginal estimates E1...T ∈ E T ×k×k , and node subsets (a)symmetries of their inputs, pos-ebd(ρi,j ) is symmetric,
T ×k ′
S1...T ∈ [N ] , where E is the set of output edge types while pos-ebd(Et,e align
) considers the node ordering. In sum-
for the causal discovery algorithm f . mary, the inputs to our axial attention blocks are
We project global statistics into the model dimension via hρi,j = (Wρ ρ)i,j + pos-ebd(ρi,j ) (3)
a learned linear projection matrix Wρ : R → Rd , and we ′
align ′
align
embed edge types via a learned embedding ebdE : E → Rd . hE
t,e = ebdE (Et,e ) + pos-ebd(Et,e ) (4)
To collect estimates of the same edge over all subsets, we for i, j ∈ [N ]2 , t ∈ [T ], e ∈ [K].


align entries of E1...T into ETalign ∈ E T ×K
( Axial attention An axial attention block contains two


align Et,i,j if i ∈ St , j ∈ St axial attention layers (marginal estimates, global features)
Et,e=(i,j) = (2)
0 otherwise and a feed-forward network (Figure 2, right).

where t indexes into the subsets, e indexes into the set of Given a 2D input, an axial attention layer attends first along
unique edges, and K is the number of unique edges. the rows, then along the columns. For example, given a
matrix of size (R,C,d), one pass of the axial attention
We add learned 1D positional embeddings along both di- layer is equivalent to running standard self-attention along
mensions of each input, C with batch size R, followed by the reverse. For marginal
pos-ebd(ρi,j ) = ebdnode (i′ ) + ebdnode (j ′ ) estimates, R is the number of subsets T , and C is the number

of unique edges K. For global features, R and C are both
align
pos-ebd(Et,e ) = ebdtime (t) the total number of vertices N .
+ FFN([ebdnode (i′ ), ebdnode (j ′ )]) Following (Rao et al., 2021), each self-attention mechanism
′ ′ is preceded by layer normalization and followed by dropout,
where i , j index into a random permutation on V for in-
with residual connections to the input,
variance to node permutation and graph size.2 Due to the
2 x = x + Dropout(Attn(LayerNorm(x))). (5)
The sampling of St already provides invariance to node order.
However, the mapping i′ = σ(V )i allows us to avoid updating
positional embeddings of lower order positions more than higher We pass messages between the marginal and global layers
order ones, due to the mixing of graph sizes during training. to propagate information. Let ϕE,ℓ be marginal layer ℓ,

4
Sample, estimate, aggregate: A recipe for causal discovery foundation models

let ϕρ,ℓ be global layer ℓ, and let h·,ℓ denote the hidden 3.4. Theoretical analyses
representations out of layer ℓ.
Our theoretical contributions span two aspects.
The marginal to global message mE→ρ ∈ RN ×N ×d con-
1. We formalize the notion of marginal estimates and
tains representations of each edge averaged over subsets,
( P prove that given sufficient marginal estimates, it is pos-
E→ρ,ℓ
1
hE,ℓ if ∃St , i, j ∈ St sible to recover a pattern faithful to the global causal
mi,j = Te t t,e=(i,j) (6) graph (Theorem 3.1). We provide bounds on the num-
ϵ otherwise.
ber of marginal estimates required and motivate global
where Te is the number of St containing e, and missing statistics as an efficient means to reduce this bound
entries are padded to learned constant ϵ. The global to (Propositions 3.2, 3.3).
marginal message mρ→E ∈ RK×d is simply the hidden 2. We show that our proposed axial attention has the ca-
representation itself, pacity to recapitulate the reasoning required for this
mρ→E,ℓ ρ,ℓ
t,e=(i,j) = hi,j . (7) task. In particular, we show that a stack of 3 axial atten-
tion blocks can recover the skeleton and v-structures
We incorporate these messages as follows. in O(N ) width (Theorem 3.4).
hE,ℓ = ϕE,ℓ (hE,ℓ−1 ) (marginal feature) (8)
We state our key results below, and direct the reader towards
hρ,ℓ−1 ← W ℓ hρ,ℓ−1 , mE→ρ,ℓ
 
(marginal to global) (9) Appendix A for all proofs. Our proofs assume only that the
hρ,ℓ = ϕρ,ℓ (hρ,ℓ−1 ) (global feature) (10) edge estimates are correct. We do not impose any functional
E,ℓ E,ℓ ρ→E,ℓ
requirements, and any appropriate independence test may
h ←h +m (global to marginal) (11) be used. We discuss robustness and stability in A.4.
W ℓ ∈ R2d×d is a learned linear projection, and [·] denotes Theorem 3.1 (Marginal estimate resolution). Let G =
concatenation. (V, E) be a directed acyclic graph with maximum degree
d. For S ⊆ V , let ES′ denote the marginal estimate over S.
Graph prediction For each pair of vertices i ̸= j ∈ V , Let Sd denote the superset that contains all subsets S ⊆ V
we predict e = 0, 1, or 2 for no edge, i → j, and j → i of size at most d. There exists a mapping from {ES′ }S∈Sd+2
respectively. This constraint may be omitted for cyclic to pattern E ∗ , faithful to G.
graphs. We do not additionally enforce that our predicted
graphs are acyclic, similar in spirit to (Lippe et al., 2022). Theorem 3.1 formalizes the intuition that it requires at most
d independence tests to check whether two nodes are inde-
Given the output of the final axial attention block, hρ , we pendent, and to estimate the full graph structure, it suffices
compute logits to run a causal discovery algorithm on subsets of d + 2.
z{i,j} = FFN hρi,j , hρj,i ∈ R3
 
(12) Proposition 3.2 (Skeleton bounds). Let G = (V, E) be a
which correspond to probabilities after softmax normaliza- directed acyclic graph with maximum degree d. It is possible
tion. The overall output Ê ∈ {0, 1}N ×N is supervised by to recover the undirected skeleton C = {{i, j} : (i, j) ∈ E}
the ground truth E, and our model is trained with cross in O(N 2 ) estimates over subsets of size d + 2.
entropy loss and L2 weight regularization.
If we only
 leverage marginal estimates, we must check at
least N2 subsets to cover each edge at least once. However,
3.3. Implementation details we can often approximate the skeleton via global statistics
We computed inverse covariance as the global statistic and such as correlation, allowing us to use marginal estimates
selection score, due to its relationship to partial correla- more efficiently, towards answering orientation questions.
tion and ease of computation. For comparison to additional Proposition 3.3 (V-structures bounds). Let G = (V, E) be
statistics, see C.2. We chose the constraint-based FCI algo- a directed acyclic graph with maximum degree d and ν v-
rithm in the observational setting (Spirtes et al., 2001), and structures. It is possible to identify all v-structures in O(ν)
the score-based GIES algorithm in the interventional set- estimates over subsets of at most size d + 2.
ting (Hauser & Bühlmann, 2012). For discussion regarding
alternative algorithms, see B.5. We sample batches of size Theorem 3.4 (Model capacity). Given a graph G with N
b = 500 over k = 5 nodes each (analysis in C.1). nodes, a stack of L axial attention blocks has the capacity
to recover its skeleton and v-structures in O(N ) width, and
Our model was implemented with 4 layers with 8 attention propagate orientations on paths of O(L) length.
heads and hidden dimension 512. Our model was trained
using the AdamW optimizer with a learning rate of 1e- Existing literature on the universality and computational
4 (Loshchilov & Hutter, 2019). Please see B.3 for additional power of vanilla Transformers (Yun et al., 2019; Pérez et al.,
details regarding hyperparameters. 2019) rely on generous assumptions regarding depth or

5
Sample, estimate, aggregate: A recipe for causal discovery foundation models

precision. Here, we show that our axial attention-based Table 1. Synthetic data generation settings. The symbol ∗ denotes
model can implement the specific reasoning required to training only, and † denotes testing only. We take the Cartesian
resolve marginal estimates under realistic conditions. product of all parameters for our settings. (Non)-additive refers
to (non)-additive Gaussian noise. Details regarding the causal
mechanisms can be found in B.1
4. Experimental setup
Parameter Values
We pretrained S EA models on synthetic training datasets Nodes (N ) 10, 20, 100
only and ran inference on held-out testing datasets, which Edges N, 2N ∗ , 3N ∗ , 4N
include both seen and unseen causal mechanisms. All base- Points 1000N
lines were trained and/or run from scratch on each testing Interventions N
Topology Erdős-Rényi, Scale Free
dataset using their published code and hyperparameters. Mechanism Linear, NN additive, NN non-additive,
Sigmoid additive† , Polynomial additive†
4.1. Data settings
We evaluate our model across diverse synthetic datasets, Table 2. Motivating continuous metrics in causal discovery. SHD
alone is a poor measure of prediction quality, as it fails to dis-
simulated mRNA datasets (Dibaeinia & Sinha, 2020), and a
tinguish between the graph of all 0s and the undirected skeleton.
real protein expression dataset (Sachs et al., 2005). Metrics reported on Sachs.
The synthetic datasets were constructed based on Table 1.
SHD ↓ mAP ↑ AUC ↑ EdgeAcc ↑
For each synthetic dataset, we first sample a graph based on
the desired topology and connectivity. Then we topologi- Ground truth 0 1 1 1
Undirected 17 0.5 0.92 0
cally sort the graph and sample observations starting from All 0s 17 0.14 0.5 0
the root nodes, using random instantiations of the designated
causal mechanism (details in B.1). We generated 90 training,
5 validation, and 5 testing datasets for each combination Edge orientation accuracy: We compute the accuracy of
(8160 total). To evaluate our model’s capacity to general- edge orientations as
ize to new functional classes, we reserve the Sigmoid and
(i,j)∈E 1{P (i, j) > P (j, i)}
P
Polynomial causal mechanisms for testing only. We include EdgeAcc = . (13)
details on our synthetic mRNA datasets in Appendix C.3. ∥E∥

Since this quantity is normalized by the size of E, it is


4.2. Causal discovery metrics
invariant to the positive rate. In contrast to edge orientation
Our causal discovery experiments consider both discrete F1 (Geffner et al., 2022), this quantity is also invariant to the
and continuous metrics. In addition to standard metrics assignment of forward/reverse edges as positive/negative.
like SHD (Tsamardinos et al., 2006), we advocate for the By design, symmetric predictions (like the undirected graph)
inclusion of continuous metrics, as neural networks can be have an edge orientation accuracy of 0.
notoriously uncalibrated (Guo et al., 2017), and arbitrary
discretization thresholds reflect an incomplete picture of 4.3. Baselines
model performance (Schaeffer et al., 2023). For example,
We consider several deep learning and classical baselines.
the graph of all 0s has the same accuracy as a graph’s undi-
The following deep learning baselines all start by fitting a
rected skeleton, though the two differ vastly in information
generative model to the data.
content (Table 2). For all continuous metrics, we exclude the
diagonal from evaluation, since several baselines manually D CDI (Brouillard et al., 2020) extracts the underlying graph
set it to zero (Brouillard et al., 2020; Lopez et al., 2022). as a model parameter. The G and D SF variants use Gaussian
or deep sigmoidal flow likelihoods, respectively. DCD-
SHD: Structural Hamming distance is the minimum number
FG (Lopez et al., 2022) follows D CDI -G, but factorizes the
of edge edits required to match two graphs (Tsamardinos
graph into a product of two low-rank matrices for scalability.
et al., 2006). Discretization thresholds are as published.
D ECI (Geffner et al., 2022) takes a Bayesian approach and
mAP: Mean average precision computes the area under
extracts the underlying graph as a model parameter.
precision-recall curve per edge and averages over the graph.
The random guessing baseline depends on the positive rate. D IFF A N (Sanchez et al., 2023) uses the trained model’s
Hessian to obtain a topological ordering, followed by a
AUC: Area under the ROC curve (Bradley, 1997) computed
classical pruning algorithm.
per edge (binary prediction) and averaged over the graph.
For each edge, 0.5 indicates random guessing, while 1 indi- The following classical baselines (ablations) quantify the
cates perfect performance. causal discovery utility of the individual inputs to our model.

6
Sample, estimate, aggregate: A recipe for causal discovery foundation models

Table 3. Causal discovery results on synthetic datasets. Each setting encompasses 5 distinct Erdős-Rényi graphs. The symbol † indicates
that S EA was not pretrained on this setting. Runtimes are plotted in Figure 3. Details regarding baselines can be found in B.2.

N E Model Linear NN add. NN non-add. Sigmoid† Polynomial†


mAP ↑ SHD ↓ mAP ↑ SHD ↓ mAP ↑ SHD ↓ mAP ↑ SHD ↓ mAP ↑ SHD ↓
D CDI -G 0.74±0.2 2.8±2.0 0.79±0.1 2.2±2.5 0.89±0.1 1.0±0.6 0.46±0.2 5.8±2.7 0.41±0.1 8.9±5.3
D CDI -D SF 0.82±0.2 2.0±2.6 0.57±0.2 3.0±2.3 0.50±0.2 4.2±1.3 0.38±0.2 6.3±2.7 0.29±0.1 11.2±5.2
D CD -F G 0.45±0.2 20.4±2.9 0.41±0.1 21.2±3.3 0.59±0.1 19.2±3.8 0.40±0.2 19.8±3.6 0.50±0.2 18.5±5.1
D IFF A N 0.25±0.1 14.0±4.4 0.32±0.2 13.6±12.9 0.12±0.0 21.8±6.7 0.24±0.1 12.0±5.0 0.20±0.1 15.0±6.1
D ECI 0.18±0.1 19.4±4.4 0.16±0.1 13.8±5.8 0.23±0.1 16.2±2.6 0.29±0.2 13.9±6.9 0.46±0.2 7.8±3.9
10 10
I NV C OV 0.49±0.0 11.0±2.8 0.45±0.1 11.4±5.5 0.36±0.1 13.6±2.9 0.44±0.1 11.4±4.1 0.45±0.1 10.9±3.5
F CI -AVG 0.52±0.1 10.0±2.3 0.38±0.2 8.2±3.8 0.40±0.1 9.8±2.1 0.56±0.2 9.1±2.5 0.41±0.1 10.0±3.3
G IES -AVG 0.81±0.1 3.6±1.9 0.61±0.2 6.0±4.1 0.71±0.2 4.8±2.2 0.70±0.1 5.9±3.0 0.61±0.1 7.1±3.2
S EA (F CI ) 0.97±0.0 1.6±1.4 0.95±0.1 2.4±3.3 0.92±0.1 2.8±0.7 0.83±0.1 3.7±1.9 0.69±0.1 6.7±2.6
S EA (G IES ) 0.99±0.0 1.2±0.7 0.94±0.1 2.6±3.8 0.91±0.1 3.2±1.3 0.85±0.1 4.0±2.5 0.70±0.1 5.8±2.6
D CDI -G 0.46±0.1 44.0±5 0.41±0.1 61.6±10 0.82±0.0 37.4±30 0.48±0.1 44.2±5 0.37±0.0 59.7±5
D CDI -D SF 0.48±0.1 41.2±3 0.44±0.0 60.0±11 0.74±0.1 28.4±23 0.48±0.0 43.6±5 0.38±0.0 57.6±5
D CD -F G 0.32±0.0 171.8±24 0.33±0.1 156.0±37 0.41±0.1 162.2±44 0.47±0.1 80.1±13 0.49±0.0 79.8±7
D IFF A N 0.21±0.0 127.2±5 0.19±0.0 153.6±9 0.18±0.0 144.6±6 0.22±0.0 116.8±20 0.18±0.0 157.1±7
D ECI 0.25±0.0 87.2±3 0.29±0.1 104.4±6 0.26±0.0 79.6±8 0.31±0.0 71.0±6 0.43±0.1 58.9±11
20 80
I NV C OV 0.35±0.0 94.2±8 0.27±0.0 107.8±7 0.30±0.0 100.2±8 0.34±0.0 91.7±8 0.32±0.0 94.4±5
F CI -AVG 0.30±0.0 75.8±10 0.31±0.0 80.2±4 0.30±0.0 74.4±7 0.41±0.1 72.3±6 0.34±0.0 76.6±5
G IES -AVG 0.41±0.1 70.0±9 0.44±0.0 75.2±4 0.46±0.1 67.4±6 0.50±0.1 65.6±7 0.49±0.0 68.1±5
S EA (F CI ) 0.86±0.0 29.6±7 0.55±0.1 73.6±4 0.72±0.0 51.8±6 0.77±0.0 42.8±7 0.61±0.0 61.8±5
S EA (G IES ) 0.89±0.0 26.8±8 0.58±0.1 71.4±8 0.73±0.0 50.6±6 0.76±0.1 45.0±7 0.65±0.0 60.1±5
D CD -F G 0.05±0.0 3068±132 0.07±0.0 3428±155 0.10±0.0 3510±601 0.13±0.0 3601±273 0.12±0.0 3316±698

100 400 I NV C OV 0.25±0.0 557.0±12 0.09±0.0 667.8±15 0.14±0.0 639.0±10 0.27±0.0 514.7±23 0.20±0.0 539.4±18

S EA (F CI ) 0.90±0.0 122.0±9 0.28±0.1 361.2±36 0.60±0.0 273.2±14 0.69±0.0 226.9±20 0.38±0.0 327.0±20
S EA (G IES ) 0.91±0.0 116.6±8 0.27±0.1 364.4±35 0.61±0.0 266.8±15 0.69±0.0 218.3±21 0.38±0.0 328.0±22

I NV C OV computes inverse covariance over 2000 examples.


This does not orient edges, but it is a strong connectivity
baseline. We discretize based on ground truth (oracle) E.
F CI -AVG, G IES -AVG run the FCI and GIES algorithms,
respectively, over all nodes, on 100 batches with 500 exam-
ples each. We take the mean P ((i, j)) over all batches. This
procedure yielded higher performance compared to running
the algorithm only once, over a larger batch.

5. Results
S EA excels across synthetic and realistic causal discovery Figure 3. Average wall time required to run each model on a single
tasks with fast runtimes. We also show that S EA performs dataset. The y-axis is plotted in log scale.
well, even in low-data regimes. Additional results and ab-
lations can be found in Appendix C, including classical
algorithm hyperparameters, global statistics, and scaling to trained aggregator consistently improves upon the perfor-
graphs much larger than those in the training set. mance of its individual inputs (I NV C OV, F CI -AVG, G IES -
AVG), demonstrating the value in learning such a model.
5.1. Synthetic experiments In terms of edge orientation accuracy, S EA outperforms the
S EA significantly outperforms the baselines across a vari- baselines in all settings (Table 4). We have omitted I NV C OV
ety of graph sizes and causal mechanisms in Table 3, and from this comparison since it does not orient edges.
we maintain high performance even for causal mechanisms We benchmarked the runtimes of each algorithm over all test
beyond the training set (Sigmoid, Polynomial). Our pre- datasets of N = 10, 20, 100, when possible (Figure 3). Our

7
Sample, estimate, aggregate: A recipe for causal discovery foundation models

Table 4. Synthetic experiments, edge direction accuracy (higher is Table 5. Causal discovery on (real) Sachs dataset. While I NV C OV
better). All standard deviations were within 0.2. The symbol † is a strong baseline for connectivity, it does not predict orientation.
indicates that S EA was not pretrained on this setting. The relative performance of the methods depends on metric.

N E Model Linear NN add NN Sig.† Poly.† Model mAP ↑ AUC ↑ SHD ↓ EdgeAcc ↑ Time (s) ↓
D CDI -G 0.74 0.80 0.85 0.41 0.44 D CDI -G 0.17 0.55 21.0 0.20 2436.5
D CDI -D SF 0.79 0.62 0.68 0.38 0.39 D CDI -D SF 0.20 0.59 20.0 0.20 1979.6
D CD -F G 0.50 0.47 0.70 0.43 0.54 D CD -F G 0.32 0.59 27.0 0.35 951.4
D IFF A N 0.61 0.55 0.26 0.53 0.47 D IFF A N 0.14 0.45 37.0 0.41 293.7
10 10 D ECI 0.50 0.43 0.62 0.63 0.75 D ECI 0.21 0.62 28.0 0.53 609.7
F CI -AVG 0.52 0.43 0.41 0.55 0.40 I NV C OV 0.31 0.61 20.0 — 0.002
G IES -AVG 0.76 0.49 0.69 0.67 0.63 F CI -AVG 0.27 0.59 18.0 0.24 41.9
G IES -AVG 0.21 0.59 17.0 0.24 77.9
S EA (F CI ) 0.92 0.92 0.94 0.76 0.71
S EA (G IES ) 0.94 0.88 0.93 0.84 0.79 S EA (F CI ) 0.23 0.54 24.0 0.47 3.2
S EA (G IES ) 0.23 0.60 14.0 0.41 2.9
D CDI -G 0.47 0.43 0.82 0.40 0.24
D CDI -D SF 0.50 0.49 0.78 0.41 0.28
D CD -F G 0.58 0.65 0.75 0.62 0.48
D IFF A N 0.46 0.28 0.36 0.45 0.21
20 80 D ECI 0.30 0.47 0.35 0.48 0.57
F CI -AVG 0.19 0.19 0.22 0.33 0.23
G IES -AVG 0.56 0.73 0.59 0.62 0.61
S EA (F CI ) 0.93 0.90 0.93 0.85 0.89
S EA (G IES ) 0.92 0.88 0.92 0.84 0.89
D CD -F G 0.46 0.60 0.70 0.67 0.53
100 400 Figure 4. Performance of our model (FCI) as a function of total
S EA (F CI ) 0.93 0.90 0.91 0.87 0.82
S EA (G IES ) 0.94 0.91 0.92 0.87 0.84 dataset size. Error bars indicate 95% confidence interval across the
5 datasets of each setting. Our model only requires approximately
500 samples for an acceptable level of performance on N = 100
model is orders of magnitude faster than other continuous graphs. Dashed lines indicate the I NV C OV estimate on 500 points.
optimization methods. All deep learning models were run
on a single V100-PCIE-32GB GPU, except for D IFF A N, narios. Figure 4 shows that S EA (F CI ) only requires around
since we were unable to achieve consistent GPU and R M = 500 data samples for decent performance on graphs
support within a Docker environment using their codebase. with N = 100 nodes. For these experiments, we set batch
For all models, we recorded only computation time (CPU size b = min(500, M ) and node subset size k = 5. This
and GPU) and omitted any file system-related time. is in contrast to existing continuous optimization methods,
which require thousands to tens of thousands of samples to
5.2. Realistic experiments fit completely. Compared to the I NV C OV baseline estimated
over the same number of points b = M = 500 (dotted
We report results on the real Sachs protein dataset (Sachs lines), S EA is able to perform significantly better.
et al., 2005) in Table 5. The relative performance of each
model differs based on metric. S EA performs comparably,
while maintaining fast inference speeds. However, despite
6. Conclusion
the popularity of this dataset in causal discovery literature In this work, we introduced S EA, a framework for designing
(due to lack of better alternatives), biological networks are causal discovery foundation models. S EA is motivated by
known to be time-resolved and cyclic, so the validity of the idea that classical discovery algorithms provide powerful
the ground truth “consensus” graph has been questioned by descriptors of data that are fast to compute and robust across
experts (Mooij et al., 2020). datasets. Given these statistics, we train a deep learning
We also trained a version of S EA (F CI ) on 7200 synthetic model to reproduce faithful causal graphs. Theoretically,
mRNA datasets (Dibaeinia & Sinha, 2020), exceeding base- we demonstrated that it is possible to produce sound causal
lines across held-out test datasets (Appendix C.3). graphs from marginal estimates, and that our model has the
capacity to do so. Empirically, we implemented two proofs
of concept of S EA that perform well across a variety of
5.3. Performance in low-data regimes
causal discovery tasks. We hope that this work will inspire a
One of the main advantages of foundation models is that new avenue of research in generalizable and scalable causal
they enable high levels of performance in low resource sce- discovery algorithms.

8
Sample, estimate, aggregate: A recipe for causal discovery foundation models

7. Impact statement References


This paper presents work whose goal is to advance the field Bommasani, R., Hudson, D. A., Adeli, E., Altman, R.,
of machine learning and causal inference. When applied Arora, S., von Arx, S., Bernstein, M. S., Bohg, J., Bosse-
to the biological sciences, these techniques may help un- lut, A., Brunskill, E., Brynjolfsson, E., Buch, S., Card,
cover new knowledge regarding the interaction of genes and D., Castellon, R., Chatterji, N., Chen, A., Creel, K.,
proteins. This information may then be used to develop Davis, J. Q., Demszky, D., Donahue, C., Doumbouya,
both therapeutics and toxic substances. However, these tech- M., Durmus, E., Ermon, S., Etchemendy, J., Ethayarajh,
niques are only useful in the context of large-scale, high K., Fei-Fei, L., Finn, C., Gale, T., Gillespie, L., Goel,
quality data, controlled access to which mitigates risk much K., Goodman, N., Grossman, S., Guha, N., Hashimoto,
more effectively than the restriction of any model discussed T., Henderson, P., Hewitt, J., Ho, D. E., Hong, J., Hsu,
within this work. K., Huang, J., Icard, T., Jain, S., Jurafsky, D., Kalluri, P.,
Karamcheti, S., Keeling, G., Khani, F., Khattab, O., Koh,
P. W., Krass, M., Krishna, R., Kuditipudi, R., Kumar, A.,
Ladhak, F., Lee, M., Lee, T., Leskovec, J., Levent, I., Li,
X. L., Li, X., Ma, T., Malik, A., Manning, C. D., Mirchan-
dani, S., Mitchell, E., Munyikwa, Z., Nair, S., Narayan,
A., Narayanan, D., Newman, B., Nie, A., Niebles, J. C.,
Nilforoshan, H., Nyarko, J., Ogut, G., Orr, L., Papadim-
itriou, I., Park, J. S., Piech, C., Portelance, E., Potts, C.,
Raghunathan, A., Reich, R., Ren, H., Rong, F., Roohani,
Y., Ruiz, C., Ryan, J., Ré, C., Sadigh, D., Sagawa, S., San-
thanam, K., Shih, A., Srinivasan, K., Tamkin, A., Taori,
R., Thomas, A. W., Tramèr, F., Wang, R. E., Wang, W.,
Wu, B., Wu, J., Wu, Y., Xie, S. M., Yasunaga, M., You, J.,
Zaharia, M., Zhang, M., Zhang, T., Zhang, X., Zhang, Y.,
Zheng, L., Zhou, K., and Liang, P. On the opportunities
and risks of foundation models, 2022.
Bradley, A. P. The use of the area under the roc curve in
the evaluation of machine learning algorithms. Pattern
Recognition, 30(7):1145–1159, 1997. ISSN 0031-3203.
doi: https://doi.org/10.1016/S0031-3203(96)00142-2.
Brouillard, P., Lachapelle, S., Lacoste, A., Lacoste-Julien,
S., and Drouin, A. Differentiable causal discovery from
interventional data, 2020.
Brown, T. B., Mann, B., Ryder, N., Subbiah, M., Kaplan,
J., Dhariwal, P., Neelakantan, A., Shyam, P., Sastry, G.,
Askell, A., Agarwal, S., Herbert-Voss, A., Krueger, G.,
Henighan, T., Child, R., Ramesh, A., Ziegler, D. M., Wu,
J., Winter, C., Hesse, C., Chen, M., Sigler, E., Litwin, M.,
Gray, S., Chess, B., Clark, J., Berner, C., McCandlish,
S., Radford, A., Sutskever, I., and Amodei, D. Language
models are few-shot learners, 2020.
Bühlmann, P., Peters, J., and Ernest, J. CAM: Causal addi-
tive models, high-dimensional order search and penalized
regression. The Annals of Statistics, 42(6):2526 – 2556,
2014. doi: 10.1214/14-AOS1260.
Chickering, D. M. Optimal structure identification with
greedy search. 3:507–554, November 2002.
Dibaeinia, P. and Sinha, S. Sergio: A single-cell expression
simulator guided by gene regulatory networks. Cell Sys-

9
Sample, estimate, aggregate: A recipe for causal discovery foundation models

tems, 11(3):252–271.e11, 2020. ISSN 2405-4712. doi: Lippe, P., Cohen, T., and Gavves, E. Efficient neural causal
https://doi.org/10.1016/j.cels.2020.08.003. discovery without acyclicity constraints. In International
Conference on Learning Representations, 2022.
Geffner, T., Antoran, J., Foster, A., Gong, W., Ma, C., Kici-
man, E., Sharma, A., Lamb, A., Kukla, M., Pawlowski, Lopez, R., Hütter, J.-C., Pritchard, J. K., and Regev,
N., Allamanis, M., and Zhang, C. Deep end-to-end causal A. Large-scale differentiable causal discovery of factor
inference, 2022. graphs. In Advances in Neural Information Processing
Systems, 2022.
Glymour, C., Zhang, K., and Spirtes, P. Review of causal
discovery methods based on graphical models. Frontiers Loshchilov, I. and Hutter, F. Decoupled weight decay regu-
in Genetics, 10, 2019. ISSN 1664-8021. doi: 10.3389/ larization, 2019.
fgene.2019.00524.
Löwe, S., Madras, D., Zemel, R., and Welling, M. Amor-
Guo, C., Pleiss, G., Sun, Y., and Weinberger, K. Q. On tized causal discovery: Learning to infer causal graphs
calibration of modern neural networks. In International from time-series data, 2022.
Conference on Machine Learning, 2017.
Monti, R. P., Zhang, K., and Hyvarinen, A. Causal discovery
Hauser, A. and Bühlmann, P. Characterization and greedy with general non-linear relationships using non-linear ica,
learning of interventional markov equivalence classes of 2019.
directed acyclic graphs. 2012. Mooij, J. M., Magliacane, S., and Claassen, T. Joint causal
inference from multiple contexts. 2020.
Ho, J., Kalchbrenner, N., Weissenborn, D., and Salimans, T.
Axial attention in multidimensional transformers, 2020. Pérez, J., Marinković, J., and Barceló, P. On the turing
completeness of modern neural network architectures. In
Hornik, K., Stinchcombe, M., and White, H. Multilayer International Conference on Learning Representations,
feedforward networks are universal approximators. Neu- 2019.
ral Networks, 2(5):359–366, 1989. ISSN 0893-6080. doi:
https://doi.org/10.1016/0893-6080(89)90020-8. Radford, A., Kim, J. W., Hallacy, C., Ramesh, A., Goh, G.,
Agarwal, S., Sastry, G., Askell, A., Mishkin, P., Clark,
Huang, B., Zhang, K., Zhang, J., Ramsey, J., Sanchez- J., Krueger, G., and Sutskever, I. Learning transferable
Romero, R., Glymour, C., and Schölkopf, B. Causal visual models from natural language supervision, 2021.
discovery from heterogeneous/nonstationary data with
independent changes, 2020. Rao, R. M., Liu, J., Verkuil, R., Meier, J., Canny, J., Abbeel,
P., Sercu, T., and Rives, A. Msa transformer. In Meila,
Hägele, A., Rothfuss, J., Lorch, L., Somnath, V. R., M. and Zhang, T. (eds.), Proceedings of the 38th Inter-
Schölkopf, B., and Krause, A. Bacadi: Bayesian causal national Conference on Machine Learning, volume 139
discovery with unknown interventions, 2023. of Proceedings of Machine Learning Research, pp. 8844–
8856. PMLR, 18–24 Jul 2021.
Kalainathan, D., Goudet, O., and Dutta, R. Causal discovery
toolbox: Uncovering causal relationships in python. Jour- Reizinger, P., Sharma, Y., Bethge, M., Schölkopf, B.,
nal of Machine Learning Research, 21(37):1–5, 2020. Huszár, F., and Brendel, W. Jacobian-based causal dis-
covery with nonlinear ICA. Transactions on Machine
Lachapelle, S., Brouillard, P., Deleu, T., and Lacoste-Julien, Learning Research, 2023. ISSN 2835-8856.
S. Gradient-based neural dag learning, 2020.
Replogle, J. M., Saunders, R. A., Pogson, A. N., Hussmann,
Lam, W.-Y., Andrews, B., and Ramsey, J. Greedy relax- J. A., Lenail, A., Guna, A., Mascibroda, L., Wagner,
ations of the sparsest permutation algorithm. In Cussens, E. J., Adelman, K., Lithwick-Yanai, G., Iremadze, N.,
J. and Zhang, K. (eds.), Proceedings of the Thirty-Eighth Oberstrass, F., Lipson, D., Bonnar, J. L., Jost, M., Nor-
Conference on Uncertainty in Artificial Intelligence, vol- man, T. M., and Weissman, J. S. Mapping information-
ume 180 of Proceedings of Machine Learning Research, rich genotype-phenotype landscapes with genome-scale
pp. 1052–1062. PMLR, 01–05 Aug 2022. Perturb-seq. Cell, 185(14):2559–2575, Jul 2022.

Ledoit, O. and Wolf, M. A well-conditioned estimator for Sachs, K., Perez, O., Pe’er, D., Lauffenburger, D. A., and
large-dimensional covariance matrices. Journal of Multi- Nolan, G. P. Causal protein-signaling networks derived
variate Analysis, 88(2):365–411, 2004. ISSN 0047-259X. from multiparameter single-cell data. Science, 308(5721):
doi: https://doi.org/10.1016/S0047-259X(03)00096-4. 523–529, 2005. doi: 10.1126/science.1105809.

10
Sample, estimate, aggregate: A recipe for causal discovery foundation models

Sanchez, P., Liu, X., O’Neil, A. Q., and Tsaftaris, S. A. Zheng, X., Aragam, B., Ravikumar, P., and Xing, E. P.
Diffusion models for causal discovery via topological Dags with no tears: Continuous optimization for structure
ordering. In The Eleventh International Conference on learning, 2018.
Learning Representations, ICLR 2023, Kigali, Rwanda,
May 1-5, 2023. OpenReview.net, 2023. Zheng, Y., Huang, B., Chen, W., Ramsey, J., Gong, M.,
Cai, R., Shimizu, S., Spirtes, P., and Zhang, K. Causal-
Schaeffer, R., Miranda, B., and Koyejo, O. Are emergent learn: Causal discovery in python. arXiv preprint
abilities of large language models a mirage? Advances in arXiv:2307.16405, 2023.
Neural Information Processing Systems, abs/2304.15004,
2023.

Schwarz, G. Estimating the Dimension of a Model. The


Annals of Statistics, 6(2):461 – 464, 1978. doi: 10.1214/
aos/1176344136.

Shimizu, S., Hoyer, P. O., Hyvarinen, A., and Kerminen, A.


A linear non-gaussian acyclic model for causal discovery.
Journal of Machine Learning Research, 7(72):2003–2030,
2006.

Spirtes, P., Glymour, C., and Scheines, R. Causality from


probability. In Conference Proceedings: Advanced Com-
puting for the Social Sciences, 1990.

Spirtes, P., Meek, C., and Richardson, T. Causal inference


in the presence of latent variables and selection bias. In
Proceedings of the Eleventh Conference on Uncertainty
in Artificial Intelligence, UAI’95, pp. 499–506, San Fran-
cisco, CA, USA, 1995. Morgan Kaufmann Publishers Inc.
ISBN 1558603859.

Spirtes, P., Glymour, C., and Scheines, R. Causation,


Prediction, and Search. MIT Press, 2001. doi: https:
//doi.org/10.7551/mitpress/1754.001.0001.

Sz’ekely, G. J., Rizzo, M. L., and Bakirov, N. K. Measur-


ing and testing dependence by correlation of distances.
Annals of Statistics, 35:2769–2794, 2007.

Tsamardinos, I., Brown, L. E., and Aliferis, C. F. The max-


min hill-climbing bayesian network structure learning
algorithm. Machine learning, 65(1):31–78, 2006.

Verma, T. S. and Pearl, J. On the equivalence of causal


models. In Proceedings of the Sixth Conference on Un-
certainty in Artificial Intelligence, 1990.

Yun, C., Bhojanapalli, S., Rawat, A. S., Reddi, S. J., and


Kumar, S. Are transformers universal approximators of
sequence-to-sequence functions? CoRR, abs/1912.10077,
2019.

Zhang, J., Jennings, J., Zhang, C., and Ma, C. Towards


causal foundation model: on duality between causal in-
ference and attention, 2023.

Zhang, K. and Hyvarinen, A. On the identifiability of the


post-nonlinear causal model, 2012.

11
Sample, estimate, aggregate: A recipe for causal discovery foundation models

A. Proofs and derivations


Our theoretical contributions focus on two primary directions.
1. We formalize the notion of marginal estimates and prove that given sufficient marginal estimates, it is possible to
recover a pattern faithful to the global causal graph. We provide lower bounds on the number of marginal estimates
required for such a task, and motivate global statistics as an efficient means to reduce this bound.
2. We show that our proposed axial attention has the capacity to recapitulate the reasoning required for marginal estimate
resolution. We provide realistic, finite bounds on the width and depth required for this task.
Before these formal discussions, we start with a toy example to provide intuition regarding marginal estimates and
constraint-based causal discovery algorithms.

A.1. Toy example: Resolving marginal graphs


Consider the Y-shaped graph with four nodes in Figure 5. Suppose we run the PC algorithm on all subsets of three nodes,
and we would like to recover the result of the PC algorithm on the full graph. We illustrate how one might resolve the
marginal graph estimates. The PC algorithm consists of the following steps (Spirtes et al., 2001).
1. Start from the fully connected, undirected graph on N nodes.
2. Remove all edges (i, j) where Xi ⊥ ⊥ Xj .
3. For each edge (i, j) and subsets S ⊆ [N ] \ {i, j} of increasing size n = 1, 2, . . . , d, where d is the maximum degree
in G, and all k ∈ S are connected to either i or j: if Xi ⊥⊥ Xj | S, remove edge (i, j).
4. For each triplet (i, j, k), such that only edges (i, k) and (j, k) remain, if k was not in the set S that eliminated edge
(i, j), then orient the “v-structure” as i → k ← j.
5. (Orientation propagation) If i → j, edge (j, k) remains, and edge (i, k) has been removed, orient j → k. If there is a
directed path i ⇝ j and an undirected edge (i, j), then orient i → j.

Figure 5. Resolving marginal graphs. Subsets of nodes revealed to the PC algorithm (circled in row 1) and its outputs (row 2).

In each of the four cases, the PC algorithm estimates the respective graphs as follows.
(A) We remove edge (X, Y ) via (2) and orient the v-structure.
(B) We remove edge (X, Y ) via (2) and orient the v-structure.
(C) We remove edge (X, W ) via (3) by conditioning on Z. There are no v-structures, so the edges remain undirected.
(D) We remove edge (Y, W ) via (3) by conditioning on Z. There are no v-structures, so the edges remain undirected.
The outputs (A-D) admit the full PC algorithm output as the only consistent graph on four nodes.

• X and Y are unconditionally independent, so no subset will reveal an edge between (X, Y ).
• There are no edges between (X, W ) and (Y, W ). Otherwise, (C) and (D) would yield the undirected triangle.
• X, Y, Z must be oriented as X → Z ← Y . Paths X → Z → Y and X ← Z ← Y would induce an (X, Y ) edge in
(B). Reversing orientations X ← Z → Y would contradict (A).
• (Y, Z) must be oriented as Y → Z. Otherwise, (A) would remain unoriented.

12
Sample, estimate, aggregate: A recipe for causal discovery foundation models

A.2. Resolving marginal estimates into global graphs


A.2.1. P RELIMINARIES
Classical results have characterized the Markov equivalency class of directed acyclic graphs. Two graphs are observationally
equivalent if they have the same skeleton and v-structures (Verma & Pearl, 1990). Thus, a pattern P is faithful to a graph G
if and only if they share the same skeletons and v-structures (Spirtes et al., 1990).
Definition A.1. Let G = (V, E) be a directed acyclic graph. A pattern P is a set of directed and undirected edges over V .
Definition A.2 (Theorem 3.4 from Spirtes et al. (2001)). If pattern P is faithful to some directed acyclic graph, then P is
faithful to G if and only if
1. for all vertices X, Y of G, X and Y are adjacent if and only if X and Y are dependent conditional on every set of
vertices of G that does not include X or Y ; and
2. for all vertices X, Y, Z, such that X is adjacent to Y and Y is adjacent to Z and X and Z are not adjacent, X → Y ← Z
is a subgraph of G if and only if X, Z are dependent conditional on every set containing Y but not X or Z.

Given data faithful to G, a number of classical constraint-based algorithms produce patterns that are faithful to G. We
denote this set of algorithms as F.
Theorem A.3 (Theorem 5.1 from (Spirtes et al., 2001)). If the input to the PC, SGS, PC-1, PC-2, PC∗ , or IG algorithms
faithful to directed acyclic graph G, the output is a pattern that represents the faithful indistinguishability class of G.

The algorithms in F are sound and complete if there are no unobserved confounders.

A.2.2. M ARGINAL ESTIMATES


Let PV be a probability distribution that is Markov, minimal, and faithful to G. Let D ∈ RM ×N ∼ PV be a dataset of M
observations over all N = |V | nodes.
Consider a subset S ⊆ V . Let D[S] denote the subset of D over S,

D[S] = {xi,v : v ∈ S}N


i=1 , (14)

and let G[S] denote the subgraph of G induced by S

G[S] = (S, {(i, j) : i, j ∈ S, (i, j) ∈ E}. (15)

If we apply any f ∈ F to D[S], the results are not necessarily faithful to G[S], as now there may be latent confounders in
V \ S (by construction). We introduce the term marginal estimate to denote the resultant pattern that, while not faithful to
G[S], is still informative.
Definition A.4 (Marginal estimate). A pattern E ′ is a marginal estimate of G[S] if and only if
1. for all vertices X, Y of S, X and Y are adjacent if and only if X and Y are dependent conditional on every set of
vertices of S that does not include X or Y ; and
2. for all vertices X, Y, Z, such that X is adjacent to Y and Y is adjacent to Z and X and Z are not adjacent, X → Y ← Z
is a subgraph of S if and only if X, Z are dependent conditional on every set containing Y but not X or Z.

A.2.3. M ARGINAL ESTIMATE RESOLUTION


We claim that given marginal estimates on sufficient subsets of nodes, it is always possible to recover a pattern faithful to
the entire graph. First we construct a mapping from marginal estimates to the desired pattern, and then we provide tighter
bounds on the number of estimates required.
Theorem 3.1 (Marginal estimate resolution). Let G = (V, E) be a directed acyclic graph with maximum degree d. For
S ⊆ V , let ES′ denote the marginal estimate over S. Let Sd denote the superset that contains all subsets S ⊆ V of size at
most d. There exists a mapping from {ES′ }S∈Sd+2 to pattern E ∗ , faithful to G.

We will show that the following algorithm produces the desired answer. On a high level, lines 3-8 recover the undirected
“skeleton” graph of E ∗ , lines 9-15 recover the v-structures, and line 16 references step 5 in Section A.1.

13
Sample, estimate, aggregate: A recipe for causal discovery foundation models

Algorithm 1 Resolve marginal estimates of f ∈ F


1: Input: Data DG faithful to G
2: Initialize E ′ ← KN as the complete undirected graph on N nodes.
3: for S ∈ Sd+2 do
4: Compute ES′ = f (DG[S] )
5: for (i, j) ̸∈ ES′ do
6: Remove (i, j) from E ′
7: end for
8: end for
9: for ES′ ∈ {ES′ }Sd+2 do
10: for v-structure i → j ← k in ES′ do
11: if {i, j}, {j, k} ∈ E ′ and {i, k} ̸∈ E ′ then
12: Assign orientation i → j ← k in E ′
13: end if
14: end for
15: end for
16: Propagate orientations in E ′ (optional).

Remark A.5. In the PC algorithm ((Spirtes et al., 2001), A.1), its derivatives, and Algorithm 1, there is no need to consider
separating sets with cardinality greater than maximum degree d, since the maximum number of independence tests required
to separate any node from the rest of the graph is equal to number of its parents plus its children (due to the Markov
assumption).
Lemma A.6. The undirected skeleton of E ∗ is equivalent to the undirected skeleton of E ′

C ∗ := {{i, j} | (i, j) ∈ E ∗ or (j, i) ∈ E ∗ } = {{i, j} | (i, j) ∈ E ′ or (j, i) ∈ E ′ } := C ′ . (16)

That is, {i, j} ∈ C ∗ ⇐⇒ {i, j} ∈ C ′ .

Proof. It is equivalent to show that {i, j} ̸∈ C ∗ ⇐⇒ {i, j} ̸∈ C ′


⇒ If {i, j} ̸∈ C∗, then there must exist a separating set S in G of at most size d such that i ⊥⊥ j | S. Then S ∪ {i, j} is a

set of at most size d + 2, where {i, j} ̸∈ CS∪{i,j} . Thus, {i, j} would have been removed from C ′ in line 6 of Algorithm 1.
⇐ If {i, j} ̸∈ C ′ , let S be a separating set in Sd+2 such that {i, j} ̸∈ CS∪{i,j}

and i ⊥⊥ j | S. S is also a separating set in

G, and conditioning on S removes {i, j} from C .

Lemma A.7. A v-structure i → j ← k exists in E ∗ if and only if there exists the same v-structure in E ′ .

Proof. The PCI algorithm orients v-structures i → j ← k in E ∗ if there is an edge between {i, j} and {j, k} but not {i, k};
and if j was not in the conditioning set that removed {i, k}. Algorithm 1 orients v-structures i → j ← k in E ′ if they are
oriented as such in any ES′ ; and if {i, j}, {j, k} ∈ E ′ , {i, k} ̸∈ E ′
⇒ Suppose for contradiction that i → j ← k is oriented as a v-structure in E ∗ , but not in E ′ . There are two cases.
1. No ES′ contains the undirected path i − j − k. If either i − j or j − k are missing from any ES′ , then E ∗ would not
contain (i, j) or (k, j). Otherwise, if all S contain {i, k}, then E ∗ would not be missing {i, k} (Lemma A.6).
2. In every ES′ that contains i − j − k, j is in the conditioning set that removed {i, k}, i.e. i ⊥⊥ k | S, S ∋ j. This would
violate the faithfulness property, as j is neither a parent of i or k in E ∗ , and the outputs of the PC algorithm are faithful
to the equivalence class of G (Theorem 5.1 (Spirtes et al., 2001)).
⇐ Suppose for contradiction that i → j ← k is oriented as a v-structure in E ′ , but not in E ∗ . By Lemma A.6, the path
i − j − k must exist in E ∗ . There are two cases.
1. If i → j → k or i ← j ← k, then j must be in the conditioning set that removes {i, k}, so no ES′ containing {i, j, k}
would orient them as v-structures.

14
Sample, estimate, aggregate: A recipe for causal discovery foundation models

2. If j is the root of a fork i ← j → k, then as the parent of both i and k, j must be in the conditioning set that removes
{i, k}, so no ES′ containing {i, j, k} would orient them as v-structures.
Therefore, all v-structures in E ′ are also v-structures in E ∗ .

Proof of Theorem 3.1. Given data that is faithful to G, Algorithm 1 produces a pattern E ′ with the same connectivity and
v-structures as E ∗ . Any additional orientations in both patterns are propagated using identical, deterministic procedures, so
E′ = E∗.

This proof presents a deterministic but inefficient algorithm for resolving marginal subgraph estimates. In reality, it is
possible to recover the undirected skeleton and the v-structures of G without checking all subsets S ∈ Sd+2 .
Proposition 3.2 (Skeleton bounds). Let G = (V, E) be a directed acyclic graph with maximum degree d. It is possible to
recover the undirected skeleton C = {{i, j} : (i, j) ∈ E} in O(N 2 ) estimates over subsets of size d + 2.

Proof. Following Lemma A.6, an edge (i, j) is not present in C if it is not present in any of the size d + 2 estimates.
Therefore, every pair of nodes {i, j} requires only a single estimate of size d + 2, so it is possible to recover C in N2


estimates.

Proposition 3.3 (V-structures bounds). Let G = (V, E) be a directed acyclic graph with maximum degree d and ν
v-structures. It is possible to identify all v-structures in O(ν) estimates over subsets of at most size d + 2.

Proof. Each v-structure i → j ← k falls under two cases.


1. i ⊥
⊥ k unconditionally. Then an estimate over {i, j, k} will identify the v-structure.
2. i ⊥
⊥ k | S, where j ̸∈ S ⊂ V . Then an estimate over S ∪ {i, j, k} will identify the v-structure. Note that |S| ≤ d + 2
since the degree of i is at least |S| + 1.
Therefore, each v-structure only requires one estimate, and it is possible to identify all v-structures in O(ν) estimates.

There are three takeaways from this section.


1. If we exhaustively run a constraint-based algorithm on all subsets of size d + 2, it is trivial to recover the estimate of
the full graph. However, this is no more efficient than running the causal discovery algorithm on the full graph.
2. In theory, it is possible to recover the undirected graph in O(N 2 ) estimates, and the v-structures in O(ν) estimates.
However, we may not know the appropriate subsets ahead of time.
3. In practice, if we have a surrogate for connectivity, such as the global statistics used in S EA, then we can vastly reduce
the number of estimates used to eliminate edges from consideration, and more effectively focus on sampling subsets
for orientation determination.

A.3. Computational power of the axial attention model


In this section, we focus on the computational capacity of our axial attention-based architecture. We show that three blocks
can recover the skeleton and v-structures in O(N ) width, and additional blocks have the capacity to propagate orientations.
We first formalize the notion of a neural network architecture’s capacity to “implement” an algorithm. Then we prove
Theorem 3.4 by construction.
Definition A.8. Let f be a map from finite sets Q to F , and let ϕ be a map from finite sets QΦ to FΦ . We say ϕ implements
f if there exists injection gin : Q → QΦ and surjection gout : FΦ → F such that

∀q ∈ Q, gout (ϕ(gin (q))) = f (q). (17)

Definition A.9. Let Q, F, QΦ , FΦ be finite sets. Let f be a map from Q to F , and let Φ be a finite set of maps {ϕ : QΦ →
FΦ }. We say Φ has the capacity to implement f if and only if there exists at least one element ϕ ∈ Φ that implements f .
Theorem 3.4 (Model capacity). Given a graph G with N nodes, a stack of L axial attention blocks has the capacity to
recover its skeleton and v-structures in O(N ) width, and propagate orientations on paths of O(L) length.

15
Sample, estimate, aggregate: A recipe for causal discovery foundation models

Proof. We consider axial attention blocks with dot-product attention and omit layer normalization from our analysis, as
is common in the Transformer universality literature (Yun et al., 2019). Our inputs X ∈ Rd×R×C consist of d-dimension
embeddings over R rows and C columns. Since our axial attention only operates over one dimension at a time, we use X·,c
to denote a 1D sequence of length R, given a fixed column c, and Xr,· to denote a 1D sequence of length C, given a fixed
row r. A single axial attention layer (with one head) consists of two attention layers and a feedforward network,

Attnrow (X·,c ) = X·,c + WO WV X·,c · σ (WK X·,c )T WQ X·,c ,


 
(18)
X ← Attnrow (X)
Attncol (Xr,· ) = Xr,· + WO WV Xr,· · σ (WK Xr,· )T WQ Xr,· ,
 
(19)
X ← Attncol (X)
FFN(X) = X + W2 · ReLU(W1 · X + b1 1TL ) + b2 1TL , (20)

where WO ∈ Rd×d , WV , WK , WQ ∈ Rd×d , W2 ∈ Rd×m , W1 ∈ Rm×d , b2 ∈ Rd , b1 ∈ Rm , and m is the hidden layer


size of the feedforward network. For concision, we have omitted the r and c subscripts on the W s, but the row and
column attentions use different parameters. Any row or column attention can take on the identity mapping by setting
WO , WV , WK , WQ to d × d matrices of zeros.
A single axial attention block consists of two axial attention layers ϕE and ϕρ , connected via messages (Section 3.2)

hE,ℓ = ϕE,ℓ (hE,ℓ−1 )


hρ,ℓ−1 ← Wρ,ℓ hρ,ℓ−1 , mE→ρ,ℓ
 

hρ,ℓ = ϕρ,ℓ (hρ,ℓ−1 )


hE,ℓ ← hE,ℓ + mρ→E,ℓ

where hℓ denote the hidden representations of E and ρ at layer ℓ, and the outputs of the axial attention block are hρ,ℓ , hE,ℓ .
We construct a stack of L ≥ 3 axial attention blocks that implement Algorithm 1.


Model inputs Consider edge estimate Ei,j ∈ E in a graph of size N . Let ei , ej denote the endpoints of (i, j). Outputs of
the PC algorithm can be expressed by three endpoints: {∅, •, ▶}. A directed edge from i → j has endpoints (•, ▶), the
reversed edge i ← j has endpoints (▶, •), an undirected edge has endpoints (•, •), and the lack of any edge between i, j
has endpoints (∅, ∅).
Let one-hotN (i) denote the N -dimensional one-hot column vector where element i is 1. We define the embedding of (i, j)
as a d = 2N + 6 dimensional vector,
 
one-hot3 (ei )
 one-hot3 (ej ) 
gin (Et,(i,j) ) = hE,0
(i,j) =
 one-hotN (i)  .
 (21)
one-hotN (j)

To recover graph structures from hE , we simply read off the indices of non-zero entries (gout ). We can set hρ,0 to any
Rd×N ×N matrix, as we do not consider its values in this analysis and discard it during the first step.

Claim A.10. (Consistency) The outputs of each step


1. are consistent with (21), and
2. are equivariant to the ordering of nodes in edges.

For example, if (i, j) is oriented as (▶, •), then we expect (j, i) to be oriented (•, ▶).

Step 1: Undirected skeleton We use the first axial attention block to recover the undirected skeleton C ′ . We set all
attentions to the identity, set Wρ,1 ∈ R2d×d to a d × d zeros matrix, stacked on top of a d × d identity matrix (discard ρ),

16
Sample, estimate, aggregate: A recipe for causal discovery foundation models

and set FFNE to the identity (inputs are positive). This yields
 
Pei (∅)

 Pei (•) 

 Pei (▶) 
hρ,0 E→ρ,1
i,j = mi,j = , (22)
 
..

 . 

 one-hotN (i) 
one-hotN (j)
where Pei (·) is the frequency that endpoint ei = · within the subsets sampled. FFNs with 1 hidden layer are universal
approximators of continuous functions (Hornik et al., 1989), so we use FFNρ to map

0
 i≤6
FFNρ (Xi,u,v ) = 0 i > 6, X1,u,v = 0 (23)

−Xi,u,v otherwise,

where i ∈ [2N + 6] indexes into the feature dimension, and u, v index into the rows and columns. This allows us to remove
edges not present in C ′ from consideration:
mρ→E,1 = hρ,1
(
0 (i, j) ̸∈ C ′
hE,1 E,1 ρ→E,1
i,j ← hi,j + mi,j = (24)
hE,0
i,j otherwise.

This yields (i, j) ∈ C ′ if and only if hρ,1


i,j ̸= 0. We satisfy A.10 since our inputs are valid PC algorithm outputs for which
Pei (∅) = Pej (∅).

Step 2: V-structures The second and third axial attention blocks recover v-structures. We run the same procedure twice,
once to capture v-structures that point towards the first node in an ordered pair, and one to capture v-structures that point
towards the latter node.
We start with the first row attention over edge estimates, given a fixed subset t. We set the key and query attention matrices
   
0 0 1 0 0 1

 0 1 0 


 0 1 0 

WK = k · 
 .. 
W = k ·
 .. 
(25)
 . 

Q 
 . 

 IN   IN 
−IN IN
where k is a large constant, IN denotes the size N identity matrix, and all unmarked entries are 0s.
Recall that a v-structure is a pair of directed edges that share a target node. We claim that two edges (i, j), (u, v) form a
v-structure in E ′ , pointing towards i = u, if this inner product takes on the maximum value
(WK hE,1 )i,j , (WQ hE,1 )u,v = 3. (26)
Suppose both edges (i, j) and (u, v) still remain in C ′ . There are two components to consider.
1. If i = u, then their shared node contributes +1 to the inner product (prior to scaling by k). If j = v, then the inner
product accrues −1.
2. Nodes that do not share the same endpoint contribute 0 to the inner product. Of edges that share one node, only
endpoints that match ▶ at the starting node, or • at the ending node contribute +1 to the inner product each. We
provide some examples below.
(ei , ej ) (eu , ev ) contribution note
(▶, •) (•, ▶) 0 no shared node
(•, ▶) (•, ▶) 0 wrong endpoints
(•, •) (•, •) 1 one correct endpoint
(▶, •) (▶, •) 2 v-structure

17
Sample, estimate, aggregate: A recipe for causal discovery foundation models

All edges with endpoints ∅ were “removed” in step 1, resulting in an inner product of zero, since their node embeddings
were set to zero. We set k to some large constant (empirically, k 2 = 1000 is more than enough) to ensure that after softmax
scaling, σe,e′ > 0 only if e, e′ form a v-structure.
Given ordered pair e = (i, j), let Vi ⊂ V denote the set of nodes that form a v-structure with e with shared node i. Note
that Vi excludes j itself, since setting of WK , WQ exclude edges that share both nodes. We set WV to the identity, and we
multiply by attention weights σ to obtain
..
 
.
(WV hE,1 σ)e=(i,j) =  one-hotN (i)  (27)
 

αj · binaryN (Vj )
where binaryN (S) denotes the N -dimensional binary vector with ones at elements in S, and the scaling factor
αj = (1/∥Vj ∥) · 1{∥Vj ∥ > 0} ∈ [0, 1] (28)
results from softmax normalization. We set
 
0N +6
WO = (29)
0.5 · IN
to preserve the original endpoint values, and to distinguish between the edge’s own node identity and newly recognized
v-structures. To summarize, the output of this row attention layer is
Attnrow (X·,c ) = X·,c + WO WV X·,c · σ,
which is equal to its input hE,1 plus additional positive values ∈ (0, 0.5) in the last N positions that indicate the presence of
v-structures that exist in the overall E ′ .
Our final step is to “copy” newly assigned edge directions into all the edges. We set the ϕE column attention, FFNE and the
ϕρ attentions to the identity mapping. We also set Wρ,2 to a d × d zeros matrix, stacked on top of a d × d identity matrix.
This passes the output of the ϕE row attention, aggregated over subsets, directly to FFNϕ,2 .
For endpoint dimensions e = [6], we let FFNϕ,2 implement
(
[0, 0, 1, 0, 1, 0]T − Xe,u,v
P
0 < i>N +6 Xi,u,v < 0.5
FFNρ,2 (Xe,u,v ) = (30)
0 otherwise.
Subtracting Xe,u,v “erases” the original endpoints and replaces them with (▶, •) after the update
hE,1 E,1 ρ→E,1
i,j ← hi,j + mi,j .
The overall operation translates to checking whether any v-structure points towards i, and if so, assigning edge directions
accordingly. For dimensions i > 6,
(
−Xi,u,v Xi,u,v ≤ 0.5
FFNρ,2 (Xi,u,v ) = (31)
0 otherwise,
effectively erasing the stored v-structures from the representation and remaining consistent to (21).
At this point, we have copied all v-structures once. However, our orientations are not necessarily symmetric. For example,
given v-structure i → j ← k, our model orients edges (j, i) and (j, k), but not (i, j) or (k, j).
The simplest way to symmetrize these edges (for the writer and the reader) is to run another axial attention block, in which
we focus on v-structures that point towards the second node of a pair. The only changes are as follows.

• For WK and WQ , we swap columns 1-3 with 4-6, and columns 7 to N + 6 with the last N columns.
• (hE,2 σ)i,j sees the third and fourth blocks swapped.
• WO swaps the N × N blocks that correspond to i and j’s node embeddings.
• FFNρ,3 sets the endpoint embedding to [0, 1, 0, 0, 0, 1]T − Xe,u,v if i = 7, ..., N + 6 sum to a value between 0 and 0.5.

The result is hE,3 with all v-structures oriented symmetrically, satisfying A.10.

18
Sample, estimate, aggregate: A recipe for causal discovery foundation models

Step 3: Orientation propagation To propagate orientations, we would like to identify cases (i, j), (i, k) ∈ E ′ , (j, k) ̸∈ E ′
with shared node i and corresponding endpoints (▶, •), (•, •). We use ϕE to identify triangles, and ϕρ to identify edges
(i, j), (i, k) ∈ E ′ with the desired endpoints, while ignoring triangles.

Marginal layer The row attention in ϕE fixes a subset t and varies the edge (i, j).
Given edge (i, j), we want to extract all (i, k) that share node i. We set the key and query attention matrices to
 
0 1 1 0 1 1
 .. 
WK , WQ = k ·  . . (32)
 
 IN 
±IN

We set WV to the identity to obtain

..
 
 . 
E
 .. 
(WV h σ)e=(i,k) =
 . ,
 (33)
 one-hotN (i) 
αk · binaryN (Vk )

where Vk is the set of nodes k that share any edge with i. To distinguish between k and Vk , we again set Wo to the same as
in (29). Finally, we set FFNE to the identity and pass hE directly to ϕρ . To summarize, we have hE equal to its input, with
values ∈ (0, 0.5) in the last N locations indicating 1-hop neighbors of each edge.

Global layer Now we would like to identify cases (i, k), (j, k) with corresponding endpoints (•, ▶), (•, •). We set the
key and query attention matrices
   
0 0 1 0 1 −1 0 1 −1
 ..   .. 
WK = k ·  . WQ = k ·  . . (34)
   

 IN   IN 
IN −IN

The key allows us to check that endpoint i is directed, and the query allows us to check that (i, k) exists in C ′ , and does not
already point elsewhere. After softmax normalization, for sufficiently large k, we obtain σ(i,j),(i,k) > 0 if and only if (i, k)
should be oriented (•, ▶), and the inner product attains the maximum possible value

⟨(WK hρ )i,j , (WQ hρ )i,k ⟩ = 2. (35)

We consider two components.


1. If the endpoints match our desired endpoints, we gain a +1 contribution to the inner product.
2. A match between the first nodes contributes +1. If the second node shares any overlap (either same edge, or a triangle),
then a negative value would be added to the overall inner product.
Therefore, we can only attain the maximal inner product if only one edge is directed, and if there exists no triangle.
We set Wo to the same as in (29), and we add hρ to the input of the next ϕE . To summarize, we have hρ equal to its input,
with values ∈ (0, 0.5) in the last N locations indicating incoming edges.

Orientation assignment Our final step is to assign our new edge orientations. Let the column attention take on the identity
mapping. For endpoint dimensions e = (4, 5, 6), we let FFNρ implement
(
[0, 0, 1]T − Xe,u,v 0 < i>N +6 Xi,u,v < 0.5
P
FFNρ (Xe,u,v ) = (36)
0 otherwise.

19
Sample, estimate, aggregate: A recipe for causal discovery foundation models

This translates to checking whether any incoming edge points towards v, and if so, assigning the new edge direction
accordingly. For dimensions i > 6, (
0 Xi,u,v ≤ 0.5
FFNρ (Xi,u,v ) = (37)
Xi,u,v otherwise,

effectively erasing the stored assignments from the representation. Thus, we are left with hE,ℓ that conforms to the same
format as the initial embedding in (21).
To symmetrize these edges, we run another axial attention block, in which we focus on paths that point towards the second
node of a pair. The only changes are as follows.

• For ϕE layer WK and WQ (32), we swap IN and ±IN .

• For ϕρ layer WK and WQ (34), we swap IN and ±IN .

• WO swaps the N × N blocks that correspond to i and j’s node embeddings.

• For FFNρ (36), we let e = (1, 2, 3) instead.

The result is hE with symmetric 1-hop orientation propagation, satisfying A.10. We may repeat this procedure k times to
capture k-hop paths.
To summarize, we used axial attention block 1 to recover the undirected skeleton C ′ , blocks 2-3 to identify and copy
v-structures in E ′ , and all subsequent L − 3 layers to propagate orientations on paths up to ⌊(L − 3)/2⌋ length. Overall,
this particular construction requires O(N ) width for O(L) paths.

Final remarks Information theoretically, it should be possible to encode the same information in log N space, and achieve
O(log N ) width. For ease of construction, we have allowed for wider networks than optimal.
On the other hand, if we increase the width and encode each edge symmetrically, e.g. (ei , ej , ej , ei | i, j, j, i), we can
reduce the number of blocks by half, since we no longer need to run each operation twice. However, attention weights scale
quadratically, so we opted for an asymmetric construction.
Finally, a strict limitation of our model is that it only considers 1D pairwise interactions. In the graph layer, we cannot
compare different edges’ estimates at different times in a single step. In the feature layer, we cannot compare (i, j) to (j, i)
in a single step either. However, the graph layer does enable us to compare all edges at once (sparsely), and the feature
layer looks at a time-collapsed version of the whole graph. Therefore, though we opted for this design for computational
efficiency, we have shown that it is able to capture significant graph reasoning.

A.4. Robustness and stability


We discuss the notion of stability informally, in the context of Spirtes et al. (2001). There are two cases in which our
framework may receive erroneous inputs: low/noisy data settings, and functionally misspecified situations. We consider our
framework’s robustness to these cases, in terms of recovering the skeleton and orienting edges.

A.4.1. DATA NOISE


In the case of noisy data, edges may be erroneously added, removed, or misdirected from marginal estimates E ′ . Our
framework provides two avenues to mitigating such noise.
1. We observe that global statistics can be estimated reliably in low data scenarios. For example, Figure 6 suggests that
200 examples suffice to provide a robust estimate over 100 variables in our synthetic settings. Therefore, even if the
marginal estimates are erroneous, the neural network can learn the skeleton from the global statistics.
2. Most classical causal discovery algorithms are not stable with respect to edge orientation assignment. That is, an error
in a single edge may propagate throughout the graph. Empirically, we observe that the majority vote of G IES achieves
reasonable accuracy even without any training, while F CI suffers in this assessment (Table 4). However both S EA

20
Sample, estimate, aggregate: A recipe for causal discovery foundation models

(G IES ) and S EA (F CI ) achieve high edge accuracy. Therefore, while the underlying algorithms may not be stable with
respect to edge orientation, our pretrained aggregator seems to be robust.

A.4.2. F UNCTIONAL MISSPECIFICATION


It is also possible that our global statistics and marginal estimates make misspecified assumptions regarding the data
generating mechanisms. The degree of misspecification can vary case by case, so it is hard to provide any broad guarantees
about the performance of our algorithm, in general. However, we can make the following observation.
If two variables are independent, Xi ⊥
⊥ Xj , they are independent, e.g. under linear Gaussian assumptions. If Xi , Xj exhibit
more complex functional dependencies, they may be erroneously deemed independent. Therefore, any systematic errors are
necessarily one-sided, and the model can learn to recover the connectivity based on global statistics.

B. Experimental details
B.1. Synthetic data generation
Synthetic datasets were generated using code from D CDI (Brouillard et al., 2020), which extended the Causal Discovery
Toolkit data generators to interventional data (Kalainathan et al., 2020).
We considered the following causal mechanisms. Let y be the node in question, let X be its parents, let E be an independent
noise variable (details below), and let W be randomly initialized weight matrices.

• Linear: y = XW + E.
• Polynomial: y = W0 + XW1 + X 2 W2 + E
Pd
• Sigmoid additive: y = i=1 Wi · sigmoid(Xi ) + E
• Randomly initialized neural network (NN): y = Tanh((X, E)Win )Wout
• Randomly initialized neural network, additive (NN additive): y = Tanh(XWin )Wout + E

Root causal mechanisms, noise variables, and interventional distributions maintained the D CDI defaults.

• Root causal mechanisms were set to Uniform(−2, 2).


• Noise was set to E ∼ 0.4 · N (0, σ 2 ) where σ 2 ∼ Uniform(1, 2).
• Interventions were applied to all nodes (one at a time) by setting their causal mechanisms to N (0, 1).

Ablation datasets with N > 100 nodes contained 100,000 points each (same as N = 100).

B.2. Baseline details


We considered the following baselines. All baselines were run using official implementations published by the authors.
D CDI (Brouillard et al., 2020) was trained on each of the N = 10, 20 datasets using their published hyperparameters. We
denote the Gaussian and Deep Sigmoidal Flow versions as D CDI -G and D CDI -D SF respectively. D CDI could not scale to
graphs with N = 100 due to memory constraints (did not fit on a 32GB V100 GPU).
D CD -F G (Lopez et al., 2022) was trained on all of the test datasets using their published hyperparameters. We set the
number of factors to 5, 10, 20 for each of N = 10, 20, 100, based on their ablation studies. Due to numerical instability on
N = 100, we clamped augmented Lagrangian multipliers µ and γ to 10 and stopped training if elements of the learned
adjacency matrix reached NaN values. After discussion with the authors, we also tried adjusting the µ multiplier from 2 to
1.1, but the model did not converge within 48 hours.
D ECI (Geffner et al., 2022) was trained on all of the test datasets using their published hyperparameters. However, on all
N = 100 cases, the model failed to produce any meaningful results (adjacency matrices nearly all remained 0s with AUCs
of 0.5). Thus, we only report results on N = 10, 20.

21
Sample, estimate, aggregate: A recipe for causal discovery foundation models

D IFF AN (Sanchez et al., 2023) was trained on the each of the N = 10, 20 datasets using their published hyperparameters.
The authors write that “most hyperparameters are hard-coded into [the] constructor of the D IFF AN class and we verified
they work across a wide set of datasets.” We used the original, non-approximation version of their algorithm by maintaining
residue=True in their codebase. We were unable to consistently run D IFF AN with both R and GPU support within a
Docker container, and the authors did not respond to questions regarding reproducibility, so all models were trained on the
CPU only. We observed approximately a 10x speedup in the < 5 cases that were able to complete running on the GPU.

B.3. Neural network design


Hyperparameters and architectural choices were selected by training the model on 20% of the the training and validation
data for approximately 50k steps (several hours). We considered the following parameters in sequence.

• learned positional embedding vs. sinusoidal positional embedding

• number of layers × number of heads: {4, 8} × {4, 8}

• learning rate η = {1e − 4, 5e − 5, 1e − 5}

For our final model, we selected learned positional embeddings, 4 layers, 8 heads, and learning rate η = 1e − 4.

B.4. Training and hardware details


The models were trained across 2 NVIDIA RTX A6000 GPUs and 60 CPU cores. We used the GPU exclusively for running
the aggregator, and retained all classical algorithm execution on the CPUs (during data loading). The total pretraining time
took approximately 14 hours for the final FCI model and 16 hours for the final GIES model.
For the scope of this paper, our models and datasets are fairly small. We did not scale further due to hardware constraints.
Our primary bottlenecks to scaling up lay in availability of CPU cores and networking speed across nodes, rather than GPU
memory or utilization.
We are able to run inference comfortably over N = 500 graphs with T = 500 subsets of k = 5 nodes each, on a single
32GB V100 GPU. For runtime analysis, we used a batch size of 1, with 1 data worker per dataset. Runtime could be further
improved if we amortized the GPU utilization across batches.

B.5. Choice of classical causal discovery algorithm


We selected FCI (Spirtes et al., 1995) as the underlying discovery algorithm in the observational setting over GES (Chickering,
2002) and GRaSP (Lam et al., 2022) due to its superior downstream performance. We hypothesize this may be due to its richer
output (ancestral graph) providing more signal to the Transformer model. We also tried Causal Additive Models (Bühlmann
et al., 2014), but its runtime was too slow for consistent GPU utilization. Observational algorithm implementations were
provided by the causal-learn library (Zheng et al., 2023). The code for running these alternative classical algorithms is
available in our codebase.
We selected GIES as the underlying discovery algorithm in the interventional setting because a Python implementation was
readily available at https://github.com/juangamella/gies.
We tried incorporating implementations from the Causal Discovery Toolbox via a Docker image (Kalainathan et al., 2020),
but there was excessive overhead associated with calling an R subroutine and reading/writing the inputs/results from disk.
Finally, we considered other independence tests for richer characterization, such as kernel-based methods. However, due to
speed, we chose to remain with the default Fisherz conditional independence test for FCI, and BIC for GIES (Schwarz,
1978).

B.6. Sampling procedure


Selection scores: We consider three strategies for computing selection scores α. We include an empirical comparison of
these strategies in Table 9.
1. Random selection: α is an N × N matrix of ones.

22
Sample, estimate, aggregate: A recipe for causal discovery foundation models

2. Global-statistic-based selection: α = ρ.
3. Uncertainty-based selection: α = Ĥ(Et ), where H denotes the information entropy
X
αi,j = − p(e) log p(e). (38)
e∈{0,1,2}
q
Let cti,j be the number of times edge (i, j) was selected in S1 . . . St−1 , and let αt = α/ cti,j . We consider two strategies
for selecting St based on αt .

Greedy selection: Throughout our experiments, we used a greedy algorithm for subset selection. We normalize probabili-
ties to 1 before the constructing each Categorical. Initialize

St ← {i : i ∼ Categorical(α1t . . . αN
t
)}. (39)

where αit = t
P
j̸=i∈V αi,j . While |St | < k, update
t t
St ← St ∪ {j : j ∼ Categorical(α1,St
. . . αN,St
)) (40)

where (P
t
i∈St αi,j j ̸∈ St
αj,St = (41)
0 otherwise.

Subset selection: We also considered the following subset-level selection procedure, and observed minor performance
gain for significantly longer runtime (linear program takes around 1 second per batch). Therefore, we opted for the greedy
method instead.
t
P
We solve the following integer linear program to select a subset St of size k that maximizes i∈St αi,j . Let νi ∈ {0, 1}
denote the selection of node i, and let ϵi,j ∈ {0, 1} denote the selection of edge (i, j). Our objective is to

at · ϵ
P
maximize
Pi,j i,j i,j
subject to i νi = k subset size
ϵi,j ≥ νi + νj − 1 node-edge consistency
ϵi,j ≤ νi
ϵi,j ≤ νj ,
νi ∈ {0, 1}
ϵi,j ∈ {0, 1}

for i, j ∈ V × V , i ∈ V . St is the set of non-zero indices in ν.


The final algorithm used the greedy selection strategy, with the first half of batches sampled according to global statistics,
and the latter half sampled randomly, with visit counts shared. This strategy was selected heuristically, and we did not
observe significant improvements or drops in performance when switching to other strategies (e.g. all greedy statistics-based,
greedy uncertainty-based, linear program uncertainty-based, etc.)

C. Additional analyses
C.1. Traditional algorithm parameters
We investigated model performance with respect to the settings of our graph estimation parameters. Our model is sensitive
to the size of batches used to estimate global features and marginal graphs (Figure 6). In particular, at least 250 points are
required per batch for an acceptable level of performance. Our model is not particularly sensitive to the number of batches
sampled (Figure 7), or to the number of variables sampled in each subset (Figure 8).

C.2. Choice of global statistic


We selected inverse covariance as our global feature due to its ease of computation and its relationship to partial correlation.
For context, we also provide the performance analysis of several alternatives. Tables 6 and 7 compare the results of different

23
Sample, estimate, aggregate: A recipe for causal discovery foundation models

Figure 6. Performance of our model (GIES) as a function of traditional algorithm batch size. Error bars indicate 95% confidence interval
across the 10 datasets of each setting. The global feature and marginal graph estimates are sensitive to batch size and require at least 250
points per batch to achieve an acceptable level of performance.

Figure 7. Performance of our model (GIES) as a function of number of batches sampled. Error bars indicate 95% confidence interval
across the 10 datasets of each setting. Our model is relatively insensitive to the number of batches sampled, though more batches are
beneficial in harder cases, e.g. sigmoid mechanism with additive noise or smaller batch size.

graph-level statistics on our synthetic datasets. Discretization thresholds for SHD were obtained by computing the pth
quantile of the computed values, where p = 1 − (E/N ). This is not entirely fair, as no other baseline receives the same
calibration, but these ablation studies only seek to compare state-of-the-art causal discovery methods with the “best” possible
(oracle) statistical alternatives.
C ORR refers to global correlation,
E (Xi Xj ) − E (Xi ) E (Xj )
ρi,j = q q . (42)
2  2
E (Xi2 ) − E (Xi ) · E Xj2 − E (Xj )

D-C ORR refers to distance correlation, computed between all pairs of variables. Distance correlation captures both linear
and non-linear dependencies, and D-C ORR(Xi , Xj ) = 0 if and only if Xi ⊥⊥ Xj . Please refer to (Sz’ekely et al., 2007) for
the full derivation. Despite its power to capture non-linear dependencies, we opted not to use D-C ORR because it is quite
slow to compute between all pairs of variables.
I NV C OV refers to inverse covariance, computed globally,
−1
ρ = E (X − E (X))(X − E (X))T . (43)

For graphs N < 100, inverse covariance was computed directly using NumPy. For graphs N ≥ 100, inverse covariance was
computed using Ledoit-Wolf shrinkage at inference time (Ledoit & Wolf, 2004).

24
Sample, estimate, aggregate: A recipe for causal discovery foundation models

Figure 8. Performance of our model (GIES) as a function of subset size |S| (number of variables sampled). Error bars indicate 95%
confidence interval across the 10 datasets of each setting. Our model was trained on |S| = 5, but it is insensitive to the number of variables
sampled per subset at inference. Runtime scales exponentially upwards.

Table 6. Comparison of global statistics (continuous metrics).


N E Model Linear NN add. NN non-add. Sigmoid Polynomial
mAP ↑ AUC ↑ mAP ↑ AUC ↑ mAP ↑ AUC ↑ mAP ↑ AUC ↑ mAP ↑ AUC ↑
C ORR 0.45±0.0 0.87±0.0 0.41±0.0 0.86±0.1 0.41±0.1 0.85±0.0 0.46±0.0 0.86±0.1 0.45±0.0 0.85±0.1
10 10 D-C ORR 0.42±0.0 0.86±0.0 0.41±0.1 0.87±0.1 0.40±0.0 0.87±0.0 0.43±0.0 0.86±0.1 0.45±0.0 0.89±0.1
I NV C OV 0.49±0.0 0.87±0.0 0.45±0.1 0.86±0.1 0.36±0.1 0.81±0.1 0.44±0.1 0.86±0.1 0.45±0.1 0.83±0.1
C ORR 0.47±0.0 0.53±0.0 0.47±0.0 0.52±0.0 0.46±0.0 0.52±0.0 0.48±0.0 0.53±0.0 0.48±0.0 0.54±0.0
10 40 D-C ORR 0.46±0.0 0.53±0.0 0.46±0.0 0.51±0.0 0.46±0.0 0.54±0.0 0.48±0.0 0.53±0.0 0.47±0.0 0.54±0.0
I NV C OV 0.50±0.0 0.57±0.0 0.48±0.0 0.52±0.0 0.47±0.0 0.53±0.0 0.47±0.0 0.50±0.0 0.48±0.0 0.52±0.0
C ORR 0.42±0.0 0.99±0.0 0.25±0.0 0.94±0.0 0.25±0.0 0.93±0.0 0.42±0.0 0.98±0.0 0.35±0.0 0.91±0.0
100 100 D-C ORR 0.41±0.0 0.99±0.0 0.25±0.0 0.96±0.0 0.26±0.0 0.96±0.0 0.41±0.0 0.98±0.0 0.37±0.0 0.94±0.0
I NV C OV 0.40±0.0 0.99±0.0 0.22±0.0 0.94±0.0 0.16±0.0 0.87±0.0 0.40±0.0 0.97±0.0 0.36±0.0 0.90±0.0
C ORR 0.19±0.0 0.80±0.0 0.10±0.0 0.63±0.0 0.14±0.0 0.72±0.0 0.27±0.0 0.84±0.0 0.20±0.0 0.72±0.0
100 400 D-C ORR 0.19±0.0 0.80±0.0 0.10±0.0 0.63±0.0 0.14±0.0 0.75±0.0 0.26±0.0 0.84±0.0 0.21±0.0 0.74±0.0
I NV C OV 0.25±0.0 0.91±0.0 0.09±0.0 0.62±0.0 0.14±0.0 0.77±0.0 0.27±0.0 0.86±0.0 0.20±0.0 0.67±0.0

C.3. Results on simulated mRNA data


We generator mRNA data using the SERGIO simulator (Dibaeinia & Sinha, 2020). We sampled datasets with the Hill
coefficient set to {0.25, 0.5, 1, 2, 4} for training, and 2 for testing (2 was default). We set the decay rate to the default 0.8,
and the noise parameter to the default of 1.0. We sampled 400 graphs for each of N = {10, 20} and E = {N, 2N }.
These data distributions are quite different from typical synthetic datasets, as they simulate steady-state measurements and
the data are lower bounded at 0 (gene counts).

C.4. Additional results on synthetic data


For completeness, we include additional results and analysis.
Table 9 compares the heuristics-based greedy sampler (inverse covariance + random) with the model uncertainty-based
greedy sampler. Runtimes are plotted in Figure 10. The latter was run on CPU only, since it was non-trivial to access the
GPU within a PyTorch data loader. We ran a forward pass to obtain an updated selection score every 10 batches, so this
accrued over 10 times the number of forward passes, all on CPU. With proper engineering, this model-based sampler is
expected to be much more efficient than reported. Still, it is faster than nearly all baselines.
Table 10 and Figure 9 show that the current implementations of S EA can generalize to graphs up to 4× larger than those
seen during training. With respect to larger graphs, there are two minor issues with the current implementation. We set
an insufficient maximum subset positional embedding size of 500, and we did not sample random starting subset indices
to ensure that higher-order embeddings are updated equally. We anticipate that increasing the limit on the number of
subsets and ensuring that all embeddings are sufficiently learned will improve the generalization capacity on larger graphs.

25
Sample, estimate, aggregate: A recipe for causal discovery foundation models

Table 7. Comparison of global statistics (SHD). Discretization thresholds for SHD were obtained by computing the pth quantile of the
computed values, where p = 1 − (E/N ).

N E Model Linear NN add. NN non-add. Sigmoid Polynomial


C ORR 10.6±2.8 10.2±4.6 12.0±1.9 11.1±4.3 9.9±2.8
10 10 D-C ORR 10.4±2.6 9.8±4.7 12.2±2.6 10.8±3.3 10.2±3.2
I NV C OV 11.0±2.8 11.4±5.5 13.6±2.9 11.4±4.1 10.9±3.5
C ORR 39.2±2.4 38.0±1.8 38.2±0.7 38.8±3.3 38.2±2.0
10 40 D-C ORR 38.8±2.0 38.8±1.5 37.0±0.6 38.9±3.2 38.0±2.0
I NV C OV 35.8±2.3 39.2±1.5 37.6±2.7 40.7±2.2 38.4±1.2
C ORR 113.0±4.9 132.2±18.0 144.6±5.2 106.5±11.5 110.3±6.1
100 100 D-C ORR 113.8±5.3 133.2±17.9 144.2±6.7 108.5±11.9 109.5±5.7
I NV C OV 124.4±8.1 130.0±17.2 158.8±6.2 112.3±14.8 106.3±4.6
C ORR 580.4±24.5 666.0±13.5 626.2±23.4 516.5±18.5 562.5±20.1
100 400 D-C ORR 578.2±24.7 665.4±15.4 626.6±21.9 522.3±17.6 557.2±20.4
I NV C OV 557.0±11.7 667.8±15.4 639.0±9.7 514.7±23.1 539.4±18.4

Nonetheless, our current model already obtains reasonable performance on larger graphs, out of the box.
Due to the scope of this project and computing resources, we did not train very “big” models in the modern sense. There is
much space to scale, both in terms of model architecture and the datasets covered. Table 11 probes the generalization limits
of the two implementations of S EA in this paper.
We identified that our models, trained primarily on additive noise, achieve reasonable performance, but do not generalize
reliably to causal mechanisms with multiplicative noise. For example, we tested additional datasets with the following
mechanisms (same format as B.1).
Pd
• Sigmoid mix: y = i=1 Wi · sigmoid(Xi ) × E

• Polynomial mix: y = (W0 + XW1 + X 2 W2 ) × E

We anticipate that incorporating these data into the training set would alleviate some of this gap (just as training on synthetic
mRNA data enabled us to perform well there, despite its non-standard data distributions). However, we did not have time to
test this hypothesis empirically.
D CDI learns a new generative model over each dataset, and its more powerful, deep sigmoidal flow variant seems to perform
well in some (but not all) of these harder cases.
Tables 12 and 13 report the full results on N = 100 graphs.
Tables 14 and 15 report our results on scale-free graphs.

26
Sample, estimate, aggregate: A recipe for causal discovery foundation models

Table 8. Causal discovery results on simulated mRNA data. Each setting encompasses 5 distinct scale-free graphs. Data were generated
via SERGIO (Dibaeinia & Sinha, 2020).

N E Model mAP ↑ AUC ↑ SHD ↓ EdgeAcc ↑


D CDI -G 0.48 0.73 16.1 0.70
D CDI -D SF 0.63 0.84 18.5 0.81
10 10
D CD -F G 0.59 0.82 81.0 0.79
S EA (F CI ) 0.92 0.98 1.9 0.92
D CDI -G 0.32 0.57 26.2 0.59
D CDI -D SF 0.44 0.64 25.7 0.64
10 20
D CD -F G 0.43 0.69 73.0 0.67
S EA (F CI ) 0.76 0.90 8.8 0.85
D CDI -G 0.48 0.86 37.3 0.90
D CDI -D SF 0.45 0.92 51.9 0.94
20 20
D CD -F G 0.34 0.87 361.0 0.66
S EA (F CI ) 0.54 0.94 16.6 0.83
D CDI -G 0.31 0.65 54.7 0.72
D CDI -D SF 0.40 0.71 54.6 0.74
20 20
D CD -F G 0.36 0.77 343.0 0.67
S EA (F CI ) 0.50 0.85 31.4 0.78

Table 9. Comparison between heuristics-based sampler (random and inverse covariance) vs. model confidence-based sampler. Details
regarding the samplers can be found in B.6. The suffix -L indicates the greedy confidence-based sampler. Each setting encompasses
5 distinct Erdős-Rényi graphs. The symbol † indicates that S EA was not pretrained on this setting. Bold indicates best of all models
considered (including baselines not pictured).

N E Model Linear NN add. NN non-add. Sigmoid† Polynomial†


mAP↑ EA↑ SHD↓ mAP↑ EA↑ SHD↓ mAP↑ EA↑ SHD↓ mAP↑ EA↑ SHD↓ mAP↑ EA↑ SHD↓
S EA ( F ) 0.97 0.92 1.6 0.95 0.92 2.4 0.92 0.94 2.8 0.83 0.76 3.7 0.69 0.71 6.7
S EA ( G ) 0.99 0.94 1.2 0.94 0.88 2.6 0.91 0.93 3.2 0.85 0.84 4.0 0.70 0.79 5.8
10 10
S EA ( F )- L 0.97 0.93 1.0 0.95 0.87 2.4 0.92 0.98 3.4 0.84 0.77 3.9 0.70 0.79 5.8
S EA ( G )- L 0.98 0.93 1.4 0.94 0.91 2.8 0.91 0.94 4.0 0.88 0.84 3.6 0.70 0.80 5.8
S EA ( F ) 0.90 0.87 14.4 0.91 0.94 11.2 0.87 0.86 16.0 0.81 0.85 22.7 0.81 0.92 33.4
S EA ( G ) 0.94 0.91 12.8 0.91 0.95 10.4 0.89 0.89 17.2 0.81 0.87 24.5 0.89 0.93 29.5
10 40
S EA ( F )- L 0.91 0.90 15.6 0.91 0.92 15.8 0.88 0.86 14.2 0.81 0.84 23.2 0.82 0.93 33.8
S EA ( G )- L 0.93 0.91 13.4 0.91 0.93 10.4 0.88 0.85 16.2 0.79 0.83 25.5 0.90 0.94 28.3
S EA ( F ) 0.97 0.92 3.2 0.94 0.97 3.2 0.84 0.93 7.2 0.84 0.85 7.6 0.71 0.80 10.2
S EA ( G ) 0.97 0.89 3.0 0.94 0.95 3.4 0.83 0.94 7.8 0.84 0.83 8.1 0.69 0.78 10.1
20 20
S EA ( F )- L 0.97 0.92 2.8 0.93 0.95 3.8 0.85 0.94 6.8 0.85 0.85 7.5 0.67 0.78 9.9
S EA ( G )- L 0.97 0.90 2.6 0.94 0.98 3.4 0.83 0.97 7.0 0.84 0.84 7.9 0.67 0.79 10.6
S EA ( F ) 0.86 0.93 29.6 0.55 0.90 73.6 0.72 0.93 51.8 0.77 0.85 42.8 0.61 0.89 61.8
S EA ( G ) 0.89 0.92 26.8 0.58 0.88 71.4 0.73 0.92 50.6 0.76 0.84 45.0 0.65 0.89 60.1
20 80
S EA ( F )- L 0.86 0.92 32.0 0.55 0.90 74.0 0.74 0.93 49.2 0.76 0.87 41.8 0.59 0.88 62.3
S EA ( G )- L 0.89 0.92 28.4 0.58 0.89 71.6 0.75 0.92 49.4 0.75 0.85 45.7 0.65 0.88 60.6

27
Sample, estimate, aggregate: A recipe for causal discovery foundation models

Table 10. Scaling to synthetic graphs, larger than those seen in training. Each setting encompasses 5 distinct Erdős-Rényi graphs. For all
analysis in this table, we took T = 500 subsets of nodes, with b = 500 examples per batch. Here, the mean AUC values are artificially
high due to the high negative rates, as actual edges scale linearly as N , while the number of possible edges scales quadratically.

N Model Linear, E = N Linear, E = 4N


mAP ↑ AUC ↑ SHD ↓ EdgeAcc ↑ mAP ↑ AUC ↑ SHD ↓ EdgeAcc ↑
I NV C OV 0.43±0.0 0.99±0.0 116.8±7 — 0.30±0.0 0.93±0.0 511.8±11 —
C ORR 0.42±0.0 0.99±0.0 113.0±5 — 0.19±0.0 0.80±0.0 579.4±25 —
100
S EA (F CI ) 0.97±0.0 1.00±0.0 11.6±4 0.93±0.0 0.88±0.0 0.98±0.0 129.0±10 0.94±0.0
S EA (G IES ) 0.97±0.0 1.00±0.0 12.8±5 0.91±0.0 0.91±0.0 0.99±0.0 104.6±6 0.95±0.0
I NV C OV 0.45±0.0 1.00±0.0 218.4±11 — 0.33±0.0 0.96±0.0 999.6±23 —
C ORR 0.42±0.0 0.99±0.0 223.0±8 — 0.18±0.0 0.86±0.0 1183.5±25 —
200
S EA (F CI ) 0.91±0.0 1.00±0.0 49.9±5 0.87±0.0 0.82±0.0 0.97±0.0 327.4±52 0.92±0.0
S EA (G IES ) 0.95±0.0 1.00±0.0 35.4±6 0.91±0.0 0.86±0.0 0.98±0.0 271.9±50 0.92±0.0
I NV C OV 0.46±0.0 1.00±0.0 308.3±20 — 0.35±0.0 0.98±0.0 1444.7±56 —
C ORR 0.42±0.0 1.00±0.0 326.2±21 — 0.20±0.0 0.89±0.0 1710.4±82 —
300
S EA (F CI ) 0.80±0.0 1.00±0.0 121.1±14 0.78±0.0 0.70±0.0 0.95±0.0 693.1±67 0.86±0.0
S EA (G IES ) 0.88±0.0 1.00±0.0 88.9±11 0.84±0.0 0.78±0.0 0.96±0.0 556.1±71 0.87±0.0
I NV C OV 0.47±0.0 1.00±0.0 417.7±7 — 0.36±0.0 0.98±0.0 1882.7±28 —
C ORR 0.42±0.0 1.00±0.0 445.4±14 — 0.20±0.0 0.91±0.0 2269.3±52 —
400
S EA (F CI ) 0.49±0.2 0.93±0.1 313.9±107 0.61±0.1 0.56±0.1 0.90±0.1 1103.1±190 0.75±0.1
S EA (G IES ) 0.70±0.1 0.99±0.0 225.9±57 0.71±0.1 0.70±0.0 0.94±0.0 871.6±44 0.80±0.0
I NV C OV 0.47±0.0 1.00±0.0 504.5±19 — 0.38±0.0 0.99±0.0 2299.8±34 —
C ORR 0.42±0.0 1.00±0.0 543.3±18 — 0.21±0.0 0.93±0.0 2789.5±78 —
500
S EA (F CI ) 0.27±0.1 0.90±0.1 757.6±297 0.51±0.0 0.29±0.1 0.86±0.1 1823.6±273 0.56±0.1
S EA (G IES ) 0.41±0.2 0.98±0.0 485.3±170 0.57±0.1 0.48±0.1 0.92±0.0 1653.5±505 0.67±0.0

Figure 9. mAP on graphs larger than seen during training. Due to an insufficient maximum number of subset embeddings, we were only
able to sample 500 batches, which appears to be too few for larger graphs. These values correspond to the numbers in Table 10.

28
Sample, estimate, aggregate: A recipe for causal discovery foundation models

Table 11. Generalization limits of our current implementations. Each setting represents 5 distinct graphs. Our models were not pretrained
on multiplicative noise, so they do not generalize reliably to these cases. While much slower, D CDI variants learn the data distribution
from scratch each time, so they seem to perform well in some of these cases. We anticipate that training on multiplicative data would
alleviate this generalization gap, but we did not have time to test this empirically.

N E Model Sigmoid mix† Polynomial mix† Sigmoid mix (SF)† Polynomial mix (SF)†
mAP ↑ EA ↑ SHD ↓ mAP ↑ EA ↑ SHD ↓ mAP ↑ EA ↑ SHD ↓ mAP ↑ EA ↑ SHD ↓
D CDI -G 0.84 0.81 0.9 0.11 0.00 10.2 0.67 0.92 15.2 0.12 0.00 10.6
D CDI -D SF 0.96 0.91 0.3 0.39 0.32 7.3 0.81 0.98 13.6 0.39 0.36 9.8
D CD -F G 0.52 0.61 18.0 0.11 0.00 10.2 0.57 0.66 15.5 0.12 0.00 10.6
D IFF A N 0.14 0.31 19.9 0.11 0.09 14.0 0.10 0.22 17.8 0.11 0.14 17.2
D ECI 0.17 0.43 19.1 0.11 0.07 14.1 0.20 0.53 15.0 0.12 0.06 13.6
10 10
I NV C OV 0.34 0.51 12.1 0.18 0.51 16.4 0.31 0.56 11.4 0.16 0.50 16.1
F CI -AVG 0.56 0.54 9.1 0.11 0.00 10.2 0.53 0.52 8.1 0.12 0.01 10.6
G IES -AVG 0.91 0.88 3.1 0.22 0.38 10.2 0.93 0.88 2.3 0.22 0.37 10.6
S EA (F CI ) 0.58 0.51 5.7 0.25 0.53 10.2 0.67 0.56 4.5 0.20 0.46 10.6
S EA (G IES ) 0.41 0.39 8.1 0.18 0.51 10.2 0.48 0.48 5.4 0.21 0.46 10.6
D CDI -G 0.58 0.35 25.0 0.44 0.00 39.8 0.76 0.83 22.0 0.31 0.00 28.2
D CDI -D SF 0.79 0.46 12.2 0.48 0.09 35.6 0.92 0.88 15.8 0.35 0.07 26.4
D CD -F G 0.60 0.40 26.8 0.44 0.00 39.8 0.52 0.46 22.1 0.31 0.00 28.2
D IFF A N 0.38 0.32 31.8 0.40 0.32 30.0 0.26 0.30 35.0 0.28 0.29 29.3
D ECI 0.46 0.39 27.5 0.43 0.03 39.5 0.35 0.51 24.9 0.30 0.05 30.1
10 40
I NV C OV 0.48 0.51 38.7 0.48 0.54 40.3 0.40 0.51 32.5 0.34 0.48 38.2
F CI -AVG 0.44 0.18 39.7 0.44 0.01 39.8 0.43 0.30 25.2 0.32 0.01 28.2
G IES -AVG 0.43 0.37 38.2 0.50 0.48 39.8 0.48 0.45 22.8 0.36 0.39 28.2
S EA (F CI ) 0.56 0.58 28.6 0.81 0.85 39.7 0.62 0.67 17.3 0.58 0.85 28.2
S EA (G IES ) 0.58 0.65 30.2 0.80 0.86 39.5 0.64 0.69 16.7 0.60 0.81 28.2
D CDI -G 0.57 0.72 6.2 0.06 0.01 21.7 0.47 0.98 45.4 0.06 0.01 20.7
D CDI -D SF 0.91 0.94 1.1 0.21 0.46 34.6 0.56 0.97 40.6 0.29 0.80 83.0
D CD -F G 0.50 0.68 62.1 0.06 0.00 24.3 0.63 0.81 256.6 0.06 0.11 186.4
D IFF A N 0.08 0.26 47.5 0.06 0.15 51.5 0.08 0.30 42.6 0.06 0.22 53.7
D ECI 0.17 0.58 38.0 0.06 0.05 36.0 0.18 0.60 32.9 0.06 0.03 38.9
20 20
I NV C OV 0.31 0.59 22.3 0.09 0.52 35.9 0.24 0.44 24.5 0.07 0.50 35.4
F CI -AVG 0.64 0.69 17.4 0.06 0.01 21.3 0.58 0.66 15.3 0.06 0.00 21.0
G IES -AVG 0.82 0.75 8.5 0.15 0.44 21.3 0.83 0.78 7.8 0.12 0.42 21.0
S EA (F CI ) 0.61 0.60 8.5 0.12 0.55 21.3 0.63 0.56 9.0 0.11 0.55 21.1
S EA (G IES ) 0.41 0.48 13.0 0.11 0.55 21.6 0.50 0.49 12.3 0.11 0.52 21.3
D CDI -G 0.60 0.59 32.3 0.21 0.05 86.4 0.54 0.84 78.6 0.18 0.14 74.8
D CDI -D SF 0.89 0.81 8.4 0.24 0.25 102.3 0.65 0.89 60.1 0.27 0.60 201.6
D CD -F G 0.46 0.72 222.2 0.21 0.00 81.8 0.52 0.76 202.2 0.18 0.15 225.9
D IFF A N 0.18 0.30 151.1 0.19 0.31 130.8 0.15 0.30 137.0 0.16 0.31 127.9
D ECI 0.31 0.43 70.5 0.20 0.03 89.0 0.25 0.41 66.5 0.17 0.02 79.0
20 80
I NV C OV 0.30 0.51 98.1 0.22 0.51 115.4 0.27 0.53 89.1 0.19 0.49 108.1
F CI -AVG 0.37 0.30 76.0 0.21 0.01 79.3 0.38 0.34 59.6 0.18 0.01 66.5
G IES -AVG 0.54 0.64 68.4 0.23 0.37 79.3 0.55 0.65 50.7 0.19 0.36 66.5
S EA (F CI ) 0.61 0.64 52.6 0.41 0.87 79.3 0.62 0.70 36.7 0.34 0.82 66.7
S EA (G IES ) 0.53 0.64 58.1 0.43 0.86 79.0 0.59 0.74 41.4 0.35 0.82 66.8

29
Sample, estimate, aggregate: A recipe for causal discovery foundation models

Table 12. Causal discovery results on synthetic datasets with 100 nodes, continuous metrics. Each setting encompasses 5 distinct
Erdős-Rényi graphs. The symbol † indicates that the model was not trained on this setting. All standard deviations were within 0.1.

N E Model Linear NN add. NN non-add. Sigmoid† Polynomial†


mAP ↑ AUC ↑ mAP ↑ AUC ↑ mAP ↑ AUC ↑ mAP ↑ AUC ↑ mAP ↑ AUC ↑
D CD -F G 0.11 0.75 0.12 0.71 0.18 0.73 0.20 0.72 0.06 0.60
100 100 I NV C OV 0.40 0.99 0.22 0.94 0.16 0.87 0.40 0.97 0.36 0.90
S EA (F CI ) 0.96 1.00 0.83 0.97 0.75 0.97 0.79 0.97 0.56 0.88
S EA (G IES ) 0.97 1.00 0.82 0.98 0.74 0.96 0.80 0.97 0.54 0.85
D CD -F G 0.05 0.59 0.07 0.64 0.10 0.72 0.13 0.72 0.12 0.64
100 400 I NV C OV 0.25 0.91 0.09 0.62 0.14 0.77 0.27 0.86 0.20 0.67
S EA (F CI ) 0.90 0.99 0.28 0.82 0.60 0.92 0.69 0.92 0.38 0.80
S EA (G IES ) 0.91 0.99 0.27 0.82 0.61 0.92 0.69 0.91 0.38 0.78

Table 13. Causal discovery results on synthetic datasets with 100 nodes, discrete metrics. Each setting encompasses 5 distinct Erdős-Rényi
graphs. The symbol † indicates that the model was not trained on this setting.

N E Model Linear NN add. NN non-add. Sigmoid† Polynomial†


EdgeAcc ↑ SHD ↓ EdgeAcc ↑ SHD ↓ EdgeAcc ↑ SHD ↓ EdgeAcc ↑ SHD ↓ EdgeAcc ↑ SHD ↓
D CD -F G 0.63 3075.8 0.58 2965.0 0.60 2544.4 0.59 3808.0 0.34 1927.9
100 100 I NV C OV — 124.4 — 130.0 — 158.8 — 112.3 — 106.3
S EA (F CI ) 0.91 13.4 0.90 34.4 0.91 47.2 0.78 40.3 0.69 59.2
S EA (G IES ) 0.91 13.6 0.93 32.8 0.91 45.8 0.78 38.6 0.68 60.3
D CD -F G 0.46 3068.2 0.60 3428.8 0.70 3510.8 0.67 3601.8 0.53 3316.7
100 400 I NV C OV — 557.0 — 667.8 — 639.0 — 514.7 — 539.4
S EA (F CI ) 0.93 122.0 0.90 361.2 0.91 273.2 0.87 226.9 0.82 327.0
S EA (G IES ) 0.94 116.6 0.91 364.4 0.92 266.8 0.87 218.3 0.84 328.0

Figure 10. Runtime for heuristics-based greedy sampler vs. model uncertainty-based greedy sampler (suffix - L). For sampling, the model
was run on CPU only, due to the difficulty of invoking GPU in the PyTorch data sampler.

30
Sample, estimate, aggregate: A recipe for causal discovery foundation models

Table 14. Causal discovery results on synthetic scale-free datasets, continuous metrics. Each setting encompasses 5 distinct scale-free
graphs. The symbol † indicates that S EA was not pretrained on this setting. Details regarding baselines can be found in B.2.

N E Model Linear NN add. NN non-add. Sigmoid† Polynomial†


mAP ↑ AUC ↑ mAP ↑ AUC ↑ mAP ↑ AUC ↑ mAP ↑ AUC ↑ mAP ↑ AUC ↑
D CDI -G 0.54±0.2 0.90±0.0 0.59±0.1 0.88±0.1 0.69±0.1 0.89±0.0 0.48±0.2 0.77±0.1 0.50±0.2 0.73±0.1
D CDI -D SF 0.70±0.2 0.92±0.1 0.71±0.2 0.88±0.1 0.36±0.1 0.83±0.1 0.46±0.2 0.75±0.1 0.49±0.1 0.76±0.1
D CD -F G 0.56±0.1 0.76±0.1 0.47±0.1 0.72±0.1 0.50±0.2 0.73±0.1 0.44±0.2 0.68±0.1 0.57±0.2 0.75±0.1
D IFF A N 0.25±0.1 0.73±0.1 0.15±0.0 0.66±0.0 0.16±0.1 0.62±0.1 0.31±0.1 0.75±0.1 0.24±0.3 0.63±0.1
D ECI 0.17±0.1 0.65±0.1 0.17±0.0 0.67±0.0 0.20±0.0 0.72±0.1 0.27±0.1 0.73±0.1 0.49±0.3 0.82±0.1
10 10
I NV C OV 0.38±0.1 0.57±0.2 0.26±0.1 0.50±0.1 0.29±0.1 0.52±0.1 0.27±0.1 0.46±0.1 0.08±0.0 0.13±0.1
F CI -AVG 0.56±0.2 0.80±0.1 0.51±0.2 0.80±0.2 0.43±0.2 0.74±0.1 0.60±0.2 0.82±0.1 0.34±0.1 0.68±0.1
G IES -AVG 0.87±0.1 0.98±0.0 0.61±0.1 0.94±0.0 0.69±0.1 0.94±0.0 0.75±0.1 0.96±0.0 0.71±0.1 0.91±0.1
S EA (F CI ) 0.94±0.1 0.99±0.0 0.93±0.0 0.98±0.0 0.93±0.1 0.98±0.0 0.81±0.1 0.97±0.0 0.76±0.1 0.91±0.1
S EA (G IES ) 0.95±0.0 0.99±0.0 0.94±0.0 0.98±0.0 0.92±0.0 0.98±0.0 0.85±0.1 0.98±0.0 0.74±0.1 0.90±0.1
D CDI -G 0.70±0.0 0.85±0.0 0.74±0.1 0.85±0.1 0.88±0.0 0.91±0.0 0.56±0.1 0.66±0.1 0.53±0.1 0.64±0.1
D CDI -D SF 0.74±0.1 0.87±0.1 0.73±0.1 0.84±0.0 0.71±0.1 0.90±0.0 0.56±0.1 0.69±0.1 0.51±0.0 0.63±0.1
D CD -F G 0.37±0.0 0.58±0.0 0.45±0.1 0.61±0.1 0.45±0.2 0.58±0.1 0.49±0.1 0.63±0.1 0.63±0.1 0.73±0.1
D IFF A N 0.29±0.1 0.50±0.1 0.25±0.0 0.38±0.1 0.28±0.0 0.46±0.1 0.31±0.1 0.53±0.1 0.27±0.0 0.44±0.1
D ECI 0.30±0.0 0.51±0.1 0.41±0.0 0.65±0.1 0.33±0.0 0.51±0.0 0.38±0.1 0.60±0.1 0.59±0.1 0.77±0.1
10 40
I NV C OV 0.36±0.0 0.48±0.1 0.34±0.0 0.49±0.1 0.37±0.0 0.48±0.0 0.39±0.0 0.54±0.1 0.26±0.0 0.34±0.0
F CI -AVG 0.47±0.1 0.64±0.0 0.41±0.1 0.60±0.0 0.40±0.0 0.58±0.0 0.48±0.1 0.64±0.1 0.41±0.1 0.59±0.1
G IES -AVG 0.43±0.1 0.68±0.1 0.43±0.1 0.63±0.1 0.44±0.1 0.61±0.1 0.49±0.1 0.69±0.1 0.59±0.1 0.71±0.1
S EA (F CI ) 0.93±0.0 0.96±0.0 0.84±0.0 0.92±0.0 0.81±0.0 0.90±0.0 0.81±0.0 0.89±0.0 0.73±0.1 0.84±0.0
S EA (G IES ) 0.92±0.0 0.96±0.0 0.84±0.1 0.93±0.0 0.83±0.0 0.90±0.0 0.79±0.0 0.88±0.0 0.78±0.1 0.87±0.0
D CDI -G 0.41±0.0 0.95±0.0 0.50±0.1 0.94±0.0 0.69±0.1 0.96±0.0 0.37±0.0 0.83±0.1 0.37±0.1 0.77±0.1
D CDI -D SF 0.48±0.0 0.95±0.0 0.55±0.1 0.93±0.0 0.33±0.1 0.90±0.0 0.37±0.1 0.79±0.1 0.35±0.1 0.82±0.1
D CD -F G 0.51±0.1 0.87±0.1 0.39±0.1 0.83±0.1 0.48±0.1 0.84±0.0 0.56±0.1 0.84±0.1 0.50±0.1 0.84±0.1
D IFF A N 0.27±0.1 0.80±0.0 0.11±0.0 0.65±0.1 0.11±0.0 0.66±0.1 0.26±0.2 0.77±0.1 0.12±0.0 0.69±0.1
D ECI 0.13±0.1 0.69±0.1 0.15±0.1 0.71±0.1 0.15±0.0 0.73±0.0 0.15±0.1 0.71±0.1 0.25±0.1 0.79±0.1
20 20
I NV C OV 0.30±0.0 0.57±0.1 0.26±0.1 0.53±0.1 0.21±0.0 0.53±0.0 0.24±0.1 0.49±0.1 0.03±0.0 0.10±0.1
F CI -AVG 0.63±0.1 0.84±0.1 0.44±0.1 0.78±0.0 0.43±0.1 0.79±0.0 0.60±0.1 0.86±0.1 0.47±0.1 0.78±0.1
G IES -AVG 0.82±0.1 0.99±0.0 0.58±0.1 0.95±0.0 0.57±0.1 0.96±0.0 0.75±0.1 0.98±0.0 0.61±0.1 0.90±0.0
S EA (F CI ) 0.95±0.0 1.00±0.0 0.91±0.0 0.98±0.0 0.87±0.1 0.98±0.0 0.84±0.1 0.98±0.0 0.70±0.1 0.92±0.1
S EA (G IES ) 0.93±0.0 1.00±0.0 0.91±0.1 0.98±0.0 0.88±0.1 0.98±0.0 0.82±0.1 0.98±0.0 0.70±0.1 0.91±0.1
D CDI -G 0.62±0.1 0.88±0.0 0.61±0.1 0.89±0.0 0.76±0.1 0.94±0.0 0.44±0.1 0.76±0.0 0.36±0.0 0.60±0.0
D CDI -D SF 0.58±0.0 0.87±0.0 0.55±0.1 0.86±0.0 0.58±0.0 0.92±0.0 0.43±0.0 0.78±0.0 0.35±0.0 0.66±0.1
D CD -F G 0.38±0.1 0.70±0.1 0.30±0.0 0.69±0.0 0.48±0.1 0.80±0.1 0.48±0.1 0.75±0.0 0.53±0.1 0.73±0.1
D IFF A N 0.18±0.0 0.55±0.1 0.15±0.0 0.44±0.1 0.16±0.0 0.53±0.1 0.19±0.0 0.56±0.1 0.15±0.0 0.38±0.0
D ECI 0.21±0.0 0.58±0.0 0.24±0.0 0.64±0.0 0.26±0.1 0.66±0.1 0.30±0.0 0.68±0.0 0.41±0.0 0.75±0.0
20 80
I NV C OV 0.27±0.0 0.52±0.0 0.22±0.0 0.51±0.0 0.25±0.0 0.54±0.0 0.27±0.0 0.51±0.1 0.12±0.0 0.30±0.0
F CI -AVG 0.31±0.0 0.63±0.0 0.30±0.0 0.62±0.0 0.30±0.1 0.62±0.0 0.41±0.1 0.68±0.0 0.32±0.1 0.62±0.0
G IES -AVG 0.51±0.0 0.87±0.0 0.43±0.1 0.78±0.0 0.47±0.0 0.81±0.0 0.52±0.1 0.82±0.0 0.47±0.0 0.73±0.0
S EA (F CI ) 0.92±0.0 0.98±0.0 0.64±0.1 0.89±0.0 0.71±0.1 0.90±0.0 0.73±0.1 0.90±0.0 0.59±0.1 0.81±0.0
S EA (G IES ) 0.92±0.0 0.98±0.0 0.63±0.1 0.89±0.0 0.73±0.1 0.91±0.0 0.77±0.1 0.92±0.0 0.62±0.1 0.84±0.0

31
Sample, estimate, aggregate: A recipe for causal discovery foundation models

Table 15. Causal discovery results on synthetic scale-free datasets, discrete metrics. Each setting encompasses 5 distinct scale-free graphs.
The symbol † indicates that S EA was not pretrained on this setting. Details regarding baselines can be found in B.2.

N E Model Linear NN add. NN non-add. Sigmoid† Polynomial†


EdgeAcc ↑ SHD ↓ EdgeAcc ↑ SHD ↓ EdgeAcc ↑ SHD ↓ EdgeAcc ↑ SHD ↓ EdgeAcc ↑ SHD ↓
D CDI -G 0.85±0.0 16.6±1 0.82±0.2 17.4±1 0.81±0.0 16.2±2 0.71±0.2 16.9±5 0.70±0.2 16.6±3
D CDI -D SF 0.89±0.1 16.2±2 0.75±0.1 15.4±2 0.78±0.2 16.8±2 0.80±0.2 18.1±2 0.77±0.2 18.0±2
D CD -F G 0.60±0.1 16.4±4 0.57±0.1 22.2±4 0.57±0.2 20.0±3 0.47±0.3 17.9±4 0.59±0.2 16.8±4
D IFF A N 0.55±0.1 9.2±4 0.49±0.1 14.6±4 0.36±0.2 11.6±4 0.59±0.1 7.7±3 0.42±0.2 14.8±6
D ECI 0.51±0.2 17.4±4 0.55±0.1 17.4±4 0.58±0.1 12.0±1 0.62±0.2 12.6±5 0.74±0.2 8.2±6
10 10
I NV C OV 0.60±0.2 9.8±2 0.54±0.2 12.0±2 0.54±0.2 10.8±2 0.40±0.1 11.9±2 0.14±0.1 17.4±2
F CI -AVG 0.62±0.2 8.4±2 0.51±0.3 7.8±1 0.45±0.2 8.0±2 0.61±0.2 9.2±3 0.34±0.2 9.6±2
G IES -AVG 0.83±0.1 2.2±2 0.60±0.1 6.0±2 0.75±0.2 4.2±2 0.72±0.1 4.9±2 0.73±0.2 5.7±2
S EA (F CI ) 0.87±0.1 1.2±1 0.93±0.1 2.0±1 0.96±0.1 2.2±2 0.70±0.2 4.3±2 0.82±0.2 5.3±2
S EA (G IES ) 0.85±0.1 1.4±0 0.96±0.1 1.8±1 0.94±0.1 2.0±1 0.86±0.1 3.4±2 0.83±0.1 5.0±2
D CDI -G 0.79±0.1 24.0±3 0.77±0.1 27.8±4 0.82±0.1 19.6±2 0.65±0.1 31.4±2 0.61±0.2 32.6±3
D CDI -D SF 0.84±0.0 22.8±4 0.78±0.0 24.4±3 0.83±0.1 20.4±2 0.71±0.1 31.6±2 0.62±0.1 33.3±3
D CD -F G 0.36±0.1 24.8±4 0.41±0.1 25.2±4 0.38±0.2 25.6±7 0.41±0.1 23.2±6 0.54±0.1 18.5±3
D IFF A N 0.40±0.2 29.8±7 0.28±0.1 37.0±2 0.38±0.1 32.6±2 0.45±0.1 28.0±6 0.33±0.1 32.7±5
D ECI 0.43±0.1 27.8±3 0.66±0.1 22.6±3 0.48±0.1 28.6±2 0.52±0.1 22.2±3 0.66±0.1 13.3±3
10 40
I NV C OV 0.42±0.1 36.2±8 0.50±0.1 35.2±3 0.48±0.1 39.2±4 0.59±0.1 33.7±4 0.34±0.1 50.8±5
F CI -AVG 0.33±0.1 23.8±2 0.28±0.1 27.2±2 0.25±0.1 27.6±3 0.36±0.1 26.1±2 0.24±0.1 27.3±2
G IES -AVG 0.46±0.1 21.8±2 0.50±0.1 24.4±1 0.48±0.2 25.2±5 0.52±0.1 22.7±3 0.64±0.1 22.5±2
S EA (F CI ) 0.88±0.1 7.0±2 0.95±0.0 15.2±4 0.91±0.1 15.0±2 0.87±0.1 13.9±2 0.91±0.0 19.4±3
S EA (G IES ) 0.88±0.1 6.6±2 0.98±0.0 14.0±5 0.88±0.0 14.4±3 0.87±0.0 14.1±3 0.93±0.1 19.1±3
D CDI -G 0.95±0.1 40.4±1 0.92±0.0 44.8±7 0.96±0.1 39.8±6 0.88±0.1 41.1±3 0.79±0.1 38.4±5
D CDI -D SF 0.95±0.1 40.4±2 0.92±0.1 42.4±7 0.90±0.0 42.2±7 0.84±0.1 41.1±4 0.83±0.1 49.3±17
D CD -F G 0.68±0.1 252.2±21 0.77±0.1 182.8±46 0.78±0.1 181.2±24 0.70±0.1 251.3±42 0.69±0.1 278.2±65
D IFF A N 0.67±0.1 23.6±12 0.40±0.1 42.2±20 0.42±0.1 34.0±10 0.59±0.1 22.6±11 0.50±0.1 46.8±13
D ECI 0.50±0.2 42.0±5 0.54±0.1 43.0±10 0.57±0.1 40.0±12 0.51±0.1 34.7±7 0.65±0.1 25.3±6
20 20
I NV C OV 0.54±0.1 20.6±1 0.54±0.1 24.8±4 0.52±0.0 26.6±4 0.50±0.1 23.9±5 0.13±0.0 38.2±5
F CI -AVG 0.67±0.2 13.8±1 0.52±0.1 17.4±1 0.53±0.1 17.8±4 0.65±0.1 16.7±4 0.50±0.1 18.9±4
G IES -AVG 0.82±0.1 6.4±3 0.71±0.0 12.6±2 0.68±0.2 13.2±3 0.75±0.1 11.4±4 0.73±0.1 13.4±4
S EA (F CI ) 0.91±0.1 2.6±2 0.96±0.0 4.6±2 0.93±0.1 6.4±3 0.82±0.1 6.7±3 0.76±0.1 10.3±3
S EA (G IES ) 0.85±0.1 4.0±2 0.95±0.0 3.6±2 0.93±0.1 6.2±4 0.80±0.1 7.9±4 0.82±0.2 9.9±3
D CDI -G 0.83±0.0 93.0±9 0.89±0.1 104.2±6 0.91±0.0 67.8±7 0.78±0.1 82.5±7 0.60±0.1 79.7±5
D CDI -D SF 0.84±0.0 103.2±7 0.85±0.1 94.8±10 0.89±0.0 63.8±7 0.78±0.1 84.4±7 0.64±0.1 82.6±5
D CD -F G 0.63±0.1 188.2±22 0.70±0.1 187.2±13 0.78±0.1 190.2±23 0.71±0.1 217.3±36 0.71±0.1 234.7±30
D IFF A N 0.42±0.1 110.6±16 0.30±0.1 144.6±10 0.40±0.1 118.8±12 0.41±0.1 99.7±21 0.21±0.1 149.4±10
D ECI 0.33±0.1 72.2±9 0.48±0.1 81.4±9 0.48±0.1 67.0±8 0.50±0.1 60.4±11 0.58±0.1 47.2±7
20 80
I NV C OV 0.50±0.0 85.4±4 0.48±0.0 94.2±3 0.53±0.0 86.6±3 0.51±0.1 87.2±10 0.33±0.0 116.8±3
F CI -AVG 0.26±0.0 58.2±5 0.22±0.0 58.4±4 0.26±0.1 55.0±7 0.37±0.0 59.3±6 0.23±0.1 62.9±3
G IES -AVG 0.65±0.0 48.4±5 0.71±0.1 53.6±3 0.68±0.1 48.0±3 0.64±0.1 53.6±5 0.61±0.1 54.9±5
S EA (F CI ) 0.92±0.0 19.0±5 0.93±0.0 48.4±4 0.88±0.1 38.4±4 0.85±0.1 37.1±5 0.86±0.1 49.8±5
S EA (G IES ) 0.92±0.0 17.6±3 0.93±0.0 49.2±8 0.89±0.0 37.2±4 0.88±0.0 35.1±6 0.89±0.0 48.1±5

32

You might also like