From 7cdf06f3b31c44b8a66baea66aa6f0254f064e38 Mon Sep 17 00:00:00 2001 From: Adam Li Date: Thu, 27 Jul 2023 09:32:40 -0400 Subject: [PATCH 01/14] Fix multidomain representation Signed-off-by: Adam Li --- pywhy_graphs/algorithms/multidomain.py | 194 ++++++++++++++---- .../algorithms/tests/test_multidomain_algs.py | 77 ++++++- pywhy_graphs/classes/augmented.py | 112 +++++++++- 3 files changed, 341 insertions(+), 42 deletions(-) diff --git a/pywhy_graphs/algorithms/multidomain.py b/pywhy_graphs/algorithms/multidomain.py index d2493a2c1..730f963bf 100644 --- a/pywhy_graphs/algorithms/multidomain.py +++ b/pywhy_graphs/algorithms/multidomain.py @@ -1,11 +1,45 @@ from itertools import combinations -from typing import Optional, Set +from typing import Optional from warnings import warn from pywhy_graphs.classes import AugmentedGraph from pywhy_graphs.typing import Node +def get_all_snode_combinations(n_domains): + """Compute a mapping of domain pairs to all possible S-nodes. + + S-nodes are defined as ``('S', )``, where ```` is an integer + starting from 0. Each S-node is by construction mapped to a pair of domain + IDs. + + Parameters + ---------- + n_domains : int + The number of possible domains. + + Returns + ------- + s_node_domains : dict + A mapping of domain pairs to S-nodes. + """ + s_node_domains = dict() + + sdx = 0 + # add all the S-nodes representing differences across pairs of domains + # to every single node with S-nodes + for domains in combinations(range(1, n_domains + 1), 2): + source_domain, target_domain = sorted(domains) + + # now modify the function of the edge, S-nodes are pointing to + s_node = ("S", sdx) + s_node_domains[(source_domain, target_domain)] = s_node + + # increment the S-node counter + sdx += 1 + return s_node_domains + + def add_all_snode_combinations(G, n_domains: int, on_error="raise"): """Add all possible S-nodes to the graph given number of domains. @@ -28,16 +62,16 @@ def add_all_snode_combinations(G, n_domains: int, on_error="raise"): of the added S-nodes have any edges. """ G = G.copy() - s_node_domains = dict() - sdx = 0 + # compute the relevant S-node combinations + s_node_domains = get_all_snode_combinations(n_domains) + # add all the S-nodes representing differences across pairs of domains # to every single node with S-nodes - for domains in combinations(range(1, n_domains + 1), 2): - source_domain, target_domain = sorted(domains) + for (source_domain, target_domain), s_node in s_node_domains.items(): # now modify the function of the edge, S-nodes are pointing to - s_node = ("S", sdx) + s_node_domains[(source_domain, target_domain)] = s_node if s_node in G.s_nodes: if on_error == "raise": raise RuntimeError(f"There is already an S-node {s_node} in G!") @@ -45,21 +79,96 @@ def add_all_snode_combinations(G, n_domains: int, on_error="raise"): warn(f"There is already an S-node {s_node} in G!") G.add_node(s_node, domain_ids=(source_domain, target_domain)) - s_node_domains[(source_domain, target_domain)] = s_node # add S-nodes G.graph["S-nodes"][s_node] = (source_domain, target_domain) - # increment the S-node counter - sdx += 1 return G, s_node_domains -# XXX: does not work? +def get_connected_snodes(G, node): + """Get all the connected S-nodes to a node. + + Parameters + ---------- + G : AugmentedGraph + The augmented graph. + node : Node + The node to get the connected S-nodes for. + + Returns + ------- + connected_snodes : Set[Node] + Set of connected S-nodes. + """ + connected_snodes = set() + for s_node in G.s_nodes: + if G.has_edge(s_node, node): + connected_snodes.add(s_node) + return connected_snodes + + +def remove_snode_edge(G, snode, node, preserve_invariance=True): + """Remove an S-node edge from a selection diagram. + + The removal of an S-node edge implies invariances in the diagram + across different domains represented by the S-node. This invariance + may lead to other invariances in selection diagrams representing + more than 3 domain. + + Parameters + ---------- + G : AugmentedGraph + The augmented graph with S-nodes. + snode : Node + A S-node representing a possible difference across two domains. + node : Node + The to node of the S-node. + preserve_invariance : bool, optional + Whether or not to remove additional S-node edges that are required + to preserve the relative invariances, by default True. + + Returns + ------- + G : AugmentedGraph + Augmented graph with removed S-node edges. + """ + domain_ids = G.domain_ids + snode_domains = get_all_snode_combinations(len(domain_ids)) + + if snode not in snode_domains.values(): + raise RuntimeError(f"S-node {snode} is not a valid S-node!") + + # remove the edge + G.remove_edge(snode, node) + + # now compute the connected pairs of domains + if preserve_invariance: + # get all the other S-nodes not linked to node + other_snodes = set() + domain_pairs = [] + + # get all S-nodes with an edge to `node` + connected_snodes = get_connected_snodes(G, node) + + # find connected pairs of domain IDs that must be invariant + for domain_pair, snode_ in snode_domains.items(): + if snode_ not in connected_snodes or snode_ == snode: + other_snodes.add(snode_) + domain_pairs.append(domain_pair) + connected_domain_pairs = find_connected_pairs(domain_pairs, len(domain_ids)) + + # now remove all the S-node edges for S-nodes that are in the + # connected component + for domain_pair in connected_domain_pairs: + snode_ = snode_domains[domain_pair] + G.remove_edge(snode_, node) + return G + + def compute_invariant_domains_per_node( G: AugmentedGraph, node: Node, - all_poss_snodes: Optional[Set] = None, n_domains: Optional[int] = None, inconsistency="raise", ): @@ -75,9 +184,6 @@ def compute_invariant_domains_per_node( The augmented graph. node : Node The node in G to compute the invariant domains for. - all_poss_snodes : Optional[Set], optional - All possible S-nodes, by default None. If None, will infer based on the - number of domains. n_domains : int, optional The number of domains, by default None. If None, will infer based on the ``domain_ids`` attribute of G. @@ -96,41 +202,39 @@ def compute_invariant_domains_per_node( G : AugmentedGraph The augmented graph """ + G = G.copy() + # infer the number of domains based on the number of domain IDs in the augmented # graph so far if n_domains is None: n_domains = len(G.domain_ids) - # original S-nodes - orig_s_nodes = set(G.s_nodes) - # add now all relevant S-nodes considering the domains - if all_poss_snodes is None: - G_copy, s_node_domains = add_all_snode_combinations(G.copy(), n_domains, on_error="ignore") - all_poss_snodes = set(G_copy.s_nodes) - - remove_s_node = [] - for s_node in all_poss_snodes: - if s_node not in orig_s_nodes: - remove_s_node.append(s_node) - - # find all connected pairs - tuples = [] - for s_node in remove_s_node: - source_domain, target_domain = G.nodes(data=True)[s_node]["domain_ids"] - tuples.append((source_domain, target_domain)) - G.remove_node(s_node) + s_node_domains = get_all_snode_combinations(n_domains) + + # get all S-nodes with an edge to `node` + connected_snodes = get_connected_snodes(G, node) + + # get all the other S-nodes not linked to node + other_snodes = set() + domain_pairs = [] + + # find connected pairs of domain IDs that must be invariant + for domain_pair, snode_ in s_node_domains.items(): + if snode_ not in connected_snodes: + other_snodes.add(snode_) + domain_pairs.append(domain_pair) + connected_domain_pairs = find_connected_pairs(domain_pairs, n_domains) # now compute all invariant domains - connected_pairs = find_connected_pairs(tuples, n_domains) invariant_domains = set() - for domain_pair in connected_pairs: + for domain_pair in connected_domain_pairs: # remove all the S-nodes that are not in the connected component s_node = s_node_domains[domain_pair] - G.remove_edge(s_node, node) - # check if any S-nodes are not in the original - if s_node not in orig_s_nodes: + # if there is an S-node edge that is inconsistent with the invariant + # domains, then raise an error + if G.has_edge(s_node, node): if inconsistency == "raise": raise RuntimeError(f"Inconsistency in S-nodes for node {node}!") elif inconsistency == "warn": @@ -178,12 +282,28 @@ def get_sets(self): def find_connected_pairs(tuples, max_number): """Find connected pairs of domain tuples. + This is useful for removing S-nodes among a selection diagram that represents + more than 3 domains. For example, if we have 4 domains, we can have the following + S-nodes: (1, 2), (1, 3), (1, 4), (2, 3), (2, 4), (3, 4). However, if we removed + the S-node (1, 2), and (1, 3), we should also remove (2, 3) as it is connected. + + This is because the removal of (1, 2) and (1, 3) implies that domains 1 and 2 + are invariant for some node, and domains 1 and 3 are invariant for that node. + However, this also implies that domains 2 and 3 are invariant for that node + by transitivity. + + However, this is only required if we want the S-node to be "strict" as in + the invariance is definitely not valid when the S-node edge is not present. + If the S-node edge only implies that the invariance "could" not hold, then + we do not need to remove the additional S-nodes. + Parameters ---------- tuples : List of tuples List of tuples of domain ids (i, j). max_number : int The maximum number that can be in a domain id. + Assumes indexing starts at 1. Returns ------- diff --git a/pywhy_graphs/algorithms/tests/test_multidomain_algs.py b/pywhy_graphs/algorithms/tests/test_multidomain_algs.py index 4567f9d91..8466af7d9 100644 --- a/pywhy_graphs/algorithms/tests/test_multidomain_algs.py +++ b/pywhy_graphs/algorithms/tests/test_multidomain_algs.py @@ -1,5 +1,12 @@ +import pytest + from pywhy_graphs import AugmentedGraph -from pywhy_graphs.algorithms import add_all_snode_combinations, find_connected_pairs +from pywhy_graphs.algorithms import ( + add_all_snode_combinations, + compute_invariant_domains_per_node, + find_connected_pairs, + remove_snode_edge, +) def test_find_connected_domain_pairs(): @@ -41,3 +48,71 @@ def test_add_all_snode_combinations(): # Assert that the edges are added assert len(list(G.edges())) == len(s_node_domains) + + +def example_augmented_graph(n_domains=3): + # Create an example AugmentedGraph for testing + G = AugmentedGraph() + G.add_node("x") + + G, _ = add_all_snode_combinations(G, n_domains=n_domains) + for snode in G.s_nodes: + G.add_edge(snode, "x") + return G + + +def test_compute_invariant_domains_per_node_when_three_domains(): + n_domains = 3 + G = example_augmented_graph(n_domains=n_domains) + + # Compute the invariant domains for node "A" with 3 domains + # which should be none when the S-nodes are fully connected + G_result = compute_invariant_domains_per_node(G, "x", n_domains=n_domains) + assert G_result.nodes()["x"]["invariant_domains"] == set() + + # removing one S-node should only result in two pairwise invariant domains + snode = ("S", 0) + remove_snode_edge(G, snode, "x") + G_result = compute_invariant_domains_per_node(G, "x", n_domains=n_domains) + assert G_result.nodes()["x"]["invariant_domains"] == set(G.s_node_domain_ids[snode]) + + # if we remove another S-node without preserving the invariance, it should be caught + snode = ("S", 1) + G_copy = remove_snode_edge(G.copy(), snode, "x", preserve_invariance=False) + with pytest.raises(RuntimeError, match="Inconsistency in S-nodes"): + G_result = compute_invariant_domains_per_node(G_copy, "x", n_domains=n_domains) + + # removing another S-node should result in all domains being invariant + G_copy = remove_snode_edge(G.copy(), snode, "x", preserve_invariance=True) + G_result = compute_invariant_domains_per_node(G_copy, "x", n_domains=n_domains) + assert G_result.nodes()["x"]["invariant_domains"] == set(G.domain_ids) + + for snode in G.s_nodes: + assert not G_result.has_edge(snode, "x") + + +def test_compute_invariant_domains_per_node_when_many_domains(): + n_domains = 4 + G = example_augmented_graph(n_domains=n_domains) + + # Compute the invariant domains for node "A" with 3 domains + # which should be none when the S-nodes are fully connected + G_result = compute_invariant_domains_per_node(G, "x", n_domains=n_domains) + assert G_result.nodes()["x"]["invariant_domains"] == set() + + # map domain IDs + snode_domains = G.domain_ids_to_snodes + + # removing one S-node should only result in two pairwise invariant domains + snode = snode_domains[(1, 2)] + remove_snode_edge(G, snode, "x") + snode = snode_domains[(2, 4)] + remove_snode_edge(G, snode, "x") + + G_result = compute_invariant_domains_per_node(G, "x", n_domains=n_domains) + assert G_result.nodes()["x"]["invariant_domains"] == set([1, 2, 4]) + for snode in G.s_nodes: + if all(domain in [1, 2, 4] for domain in G.s_node_domain_ids[snode]): + assert not G_result.has_edge(snode, "x") + else: + assert G_result.has_edge(snode, "x") diff --git a/pywhy_graphs/classes/augmented.py b/pywhy_graphs/classes/augmented.py index 3bf0c543f..6c4f0880f 100644 --- a/pywhy_graphs/classes/augmented.py +++ b/pywhy_graphs/classes/augmented.py @@ -10,10 +10,107 @@ from .pag import PAG +def create_augmented_diagram( + G, + intervention_targets: List[Set[Node]], + domain_ids: List[int] = None, +): + """Create an augmented causal diagram. + + Each additional F-node created is mapped back to a symmetric difference + set of intervention targets and a pair of domain ids. + + It is assumed that the intervention targets and domain IDs are in the same order + as the distributions of data available. + + Parameters + ---------- + G : Causal Graph + The causal graph before augmenting. + intervention_targets : List[Set[Node]] + Sets of intervention targets. All intervention targets must be nodes in the graph G. + domain_ids : List[int], optional + The corresponding domain indices of each intervention target, by default None, + which will assign them all to domain 1. + + Returns + ------- + G : Causal Graph + The augmented causal graph with additional nodes. + + Examples + -------- + >>> import pywhy_graphs as pg + >>> # add observational data in three different domains + >>> G = create_augmented_diagram(G, [{}, {}, {}], domain_ids=[1, 2, 3]) + >>> + >>> # add interventional data in the same domain + >>> G = create_augmented_diagram(G, [{'x'}, {'x', 'y'}, {}]) + >>> + >>> # add observational and interventional data in two different domains + >>> G = create_augmented_diagram(G, [{'x'}, {'x', 'y'}, {}], domain_ids=[1, 2, 1]) + """ + if domain_ids is None: + domain_ids = [1] * len(intervention_targets) + + # map augmented nodes to domains + node_domain_map = dict() + reverse_sigma_map = dict() + symmetric_diff_map = dict() + sigma_map = dict() + f_nodes = [] + + # create F-nodes, which is now all combinations of distributions choose 2 + k = 0 + seen_domain_pairs = dict() + seen_distr_pairs = dict() + + # compare every pair of distributions to now add interventions if necessary + for dataset_idx, source in enumerate(domain_ids): + for dataset_jdx, target in enumerate(domain_ids): + # perform memoization to avoid duplicate augmented nodes + domain_memo_key = frozenset([source, target]) + distr_memo_key = frozenset([dataset_idx, dataset_jdx]) + if dataset_jdx <= dataset_idx: + continue + if domain_memo_key in seen_domain_pairs and distr_memo_key in seen_distr_pairs: + continue + seen_domain_pairs[domain_memo_key] = None + seen_distr_pairs[distr_memo_key] = None + + # map each augmented-node to a tuple of distribution indices, or to a set of nodes + # representing the intervention targets + symm_diff = set(intervention_targets[dataset_idx]).symmetric_difference( + set(intervention_targets[dataset_jdx]) + ) + targets = frozenset(symm_diff) + + # if targets is the empty set + + # if targets == frozenset() and source == target: + # # the two interventional distributions are exactly the same + # logger.warn( + # f"Interventional distributions {dataset_idx} and {dataset_jdx} have " + # f"the same interventions within the same domain {source}." + # ) + # continue + + # create the F-node + f_node = ("F", k) + f_nodes.append(f_node) + + # map each F-node to a set of domain(s) + node_domain_map[f_node] = [source, target] + sigma_map[f_node] = [dataset_idx, dataset_jdx] + reverse_sigma_map[frozenset([dataset_idx, dataset_jdx])] = f_node + symmetric_diff_map[f_node] = targets + + k += 1 + + class AugmentedNodeMixin: graph: dict nodes: NodeView - domains: Set[int] = set() @abstractmethod def add_edge(self, u_of_edge, v_of_edge, edge_type="all", **attr): @@ -154,6 +251,16 @@ def s_nodes(self) -> List[Node]: """Return set of S-nodes.""" return list(self.graph["S-nodes"].keys()) + @property + def s_node_domain_ids(self) -> List[int]: + """Return a mapping of S-nodes to their domain ids.""" + return self.graph["S-nodes"] + + @property + def domain_ids_to_snodes(self) -> List[int]: + """Return a mapping of domain ids to their ocrresponding S-nodes.""" + return {v: k for k, v in self.graph["S-nodes"].items()} + def add_s_node(self, domain_ids: Tuple, node_changes: Set[Node] = None): if isinstance(node_changes, str) or not isinstance(node_changes, Iterable): raise RuntimeError("The intervention set nodes must be an iterable set of node(s).") @@ -171,9 +278,6 @@ def add_s_node(self, domain_ids: Tuple, node_changes: Set[Node] = None): f"there is already an augmented-node." ) - # add domains - self.domains.update(domain_ids) - # add a new S-node into the graph s_node_name = ("S", len(self.s_nodes)) self.add_node(s_node_name, domain_ids=domain_ids) From f95c04b022e1580a955d91dbb9244621cc6fbd8b Mon Sep 17 00:00:00 2001 From: Adam Li Date: Thu, 27 Jul 2023 09:44:26 -0400 Subject: [PATCH 02/14] Adding selection diagram example Signed-off-by: Adam Li --- doc/references.bib | 8 +++ examples/multiple-domains/README.txt | 4 ++ .../plot_selection_diagram.py | 66 +++++++++++++++++++ 3 files changed, 78 insertions(+) create mode 100644 examples/multiple-domains/README.txt create mode 100644 examples/multiple-domains/plot_selection_diagram.py diff --git a/doc/references.bib b/doc/references.bib index 2ba181a6f..da6650a20 100644 --- a/doc/references.bib +++ b/doc/references.bib @@ -17,6 +17,14 @@ @article{bareinboim_causal_2016 pages = {7345--7352} } +@incollection{pearl2022external, + title={External validity: From do-calculus to transportability across populations}, + author={Pearl, Judea and Bareinboim, Elias}, + booktitle={Probabilistic and causal inference: The works of Judea Pearl}, + pages={451--482}, + year={2022} +} + @article{Colombo2012, author = {Diego Colombo and Marloes H. Maathuis and Markus Kalisch and Thomas S. Richardson}, title = {{Learning high-dimensional directed acyclic graphs with latent and selection variables}}, diff --git a/examples/multiple-domains/README.txt b/examples/multiple-domains/README.txt new file mode 100644 index 000000000..126bb3402 --- /dev/null +++ b/examples/multiple-domains/README.txt @@ -0,0 +1,4 @@ +Examples Representing Causal Selection Diagrams Over Multiple Domains +--------------------------------------------------------------------- + +Examples demonstrating how to represent causal diagrams for multiple domains. diff --git a/examples/multiple-domains/plot_selection_diagram.py b/examples/multiple-domains/plot_selection_diagram.py new file mode 100644 index 000000000..d4e4ba14c --- /dev/null +++ b/examples/multiple-domains/plot_selection_diagram.py @@ -0,0 +1,66 @@ +""" +.. _ex-selection-diagrams: + +========================================================= +An introduction to selection diagrams and how to use them +========================================================= + +Selection diagrams are causal graphical objects that allow the user and scientist +to represent causal models with multiple domains. This is useful for representing +domain-shifts, generalizability and invariances across different environments. + +This is a common problem in machine learning, where the goal is to learn a model +that generalizes to unseen data. In this case, the unseen data can be a different +domain, and the model needs to be invariant across domains. + +This short example will introduce selection diagrams, and how they are constructed +and different from regular causal graphs. +""" + +import matplotlib.pyplot as plt +import networkx as nx + +# %% +# Import the required libraries +# ----------------------------- +import numpy as np + +import pywhy_graphs as pg + +# %% +# Build a selection diagram +# ------------------------- +# Let us assume that there are only two domains in our causal model. +# +# A selection diagram fundamentally represents two different SCMs that represent +# the two different domains, but share some common variables and causal structure. +# Let M1 and M2 represent two different SCMs. Each SCM is a 4-tuple of the functionals, +# endogenous (observed) variables, exogenous (latent) variables and the probability +# distribution over the exogenous variables. +# +# :math:`M1 = \langle \mathcal{F}, V, U, P(u) \rangle` +# .. math:: +# V = \{W, X, Y, Z\} +# P(U) = P(U_W, U_X, U_Y, U_Z) +# \mathcal{F} = \begin{cases} +# W = f_W(U_W) \\ +# X = f_X(U_X) \\ +# Y = f_Y(W, X, U_Y) \\ +# Z = f_Z(X, Y, U_Z) +# \end{cases} +# +# :math:`M2 = \langle \mathcal{F'}, V, U', P'(u) \rangle` +# .. math:: +# P(U') = P(U_W', U_X', U_Y', U_Z') +# \mathcal{F'} = \begin{cases} +# W = f'_W(U_W) \\ +# X = f'_X(U_X) \\ +# Y = f'_Y(W, X, U_Y) \\ +# Z = f'_Z(X, Y, U_Z) +# \end{cases} +# +# +# +# The most general version of a selection diagram allows S-nodes to represent a +# change in graphical structure. We do not explore that generality in this example, +# or package :footcite:`pearl2022external`. From 853638b8ed426e23dca39751fb623e295339677a Mon Sep 17 00:00:00 2001 From: Adam Li Date: Thu, 27 Jul 2023 14:53:24 -0400 Subject: [PATCH 03/14] Adding citations Signed-off-by: Adam Li --- doc/conf.py | 1 + doc/references.bib | 8 ++ .../plot_selection_diagram.py | 89 +++++++++++++++++-- pywhy_graphs/classes/augmented.py | 14 +-- pywhy_graphs/functional/multidomain.py | 6 +- 5 files changed, 102 insertions(+), 16 deletions(-) diff --git a/doc/conf.py b/doc/conf.py index 3f198c8b6..79135aecc 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -303,6 +303,7 @@ def setup(app): "../examples/intro", "../examples/visualization", "../examples/simulations", + "../examples/multiple-domains", ] ), # "filename_pattern": "^((?!sgskip).)*$", diff --git a/doc/references.bib b/doc/references.bib index da6650a20..e3a8cb01f 100644 --- a/doc/references.bib +++ b/doc/references.bib @@ -108,6 +108,14 @@ @article{gerhardus2021characterization year = {2021} } +@techreport{li2023discovery, + author = "Li, A. and Jaber, A. and Bareinboim, E.", + title = "Causal Discovery from Observational and Interventional Data Across Multiple Environments", + year = "2023", + month = "May", + number = "R-98", + institution = "Causal Artificial Intelligence Lab, Columbia University" +} @inproceedings{Malinsky18a_svarfci, title = {Causal Structure Learning from Multivariate Time Series in Settings with Unmeasured Confounding}, diff --git a/examples/multiple-domains/plot_selection_diagram.py b/examples/multiple-domains/plot_selection_diagram.py index d4e4ba14c..f6f143c70 100644 --- a/examples/multiple-domains/plot_selection_diagram.py +++ b/examples/multiple-domains/plot_selection_diagram.py @@ -17,15 +17,14 @@ and different from regular causal graphs. """ -import matplotlib.pyplot as plt -import networkx as nx - # %% # Import the required libraries # ----------------------------- -import numpy as np +from pprint import pprint import pywhy_graphs as pg +from pywhy_graphs.algorithms import compute_invariant_domains_per_node, remove_snode_edge +from pywhy_graphs.viz import draw # %% # Build a selection diagram @@ -59,8 +58,88 @@ # Z = f'_Z(X, Y, U_Z) # \end{cases} # -# +# These two SCMs share the same causal structure, but the mechanisms for generating +# each variable may be different either due to different distributions over the +# exogenous variables, or different functional forms. The selection diagram encodes +# this information via an extra node, called the S-node, which represents the possibility +# of a difference in the data-generating mechanisms for the nodes it points to. The +# lack of an S-node pointing to a variable indicates that the data-generating mechanism +# for that variable is the same, or invariant across the two domains. This notion can +# be extended to N domains, where there are now :math:`\binom{N}{2}` S-nodes. # # The most general version of a selection diagram allows S-nodes to represent a # change in graphical structure. We do not explore that generality in this example, # or package :footcite:`pearl2022external`. +# +# We will now construct the selection diagram representing the two SCMs above. + +# %% + +G = pg.AugmentedGraph() +G.add_edges_from( + [ + ("W", "Y"), + ("X", "Y"), + ("X", "Z"), + ("Y", "Z"), + ], + edge_type=G.directed_edge_name, +) +G.add_s_node(domain_ids=(1, 2), node_changes=["W", "X", "Y", "Z"]) +G.add_s_node(domain_ids=(2, 3), node_changes=["W", "X", "Y", "Z"]) +G.add_s_node(domain_ids=(1, 3), node_changes=["W", "X", "Y", "Z"]) + +draw(G) + +# %% +# Imposing cross-domain invariances +# --------------------------------- +# The selection diagram above allows for the possibility of different data-generating +# mechanisms for each variable. Currently, the S-nodes points to every single +# node in the graph. Therefore, there is no invariance across domains. Simply put, +# the data-generating mechanisms for each variable can be different across domains. +# +# However, we may want to impose invariances across domains 1 and 2 for the variables +# W and X. This can be done by removing the S-node pointing to W and X corresponding +# to domain 1 and 2. + +# first, get the mapping from domain ids to s-nodes +domain_id_to_s_node = G.domain_ids_to_snodes + +# remove the edge S^{1, 2} -> W +G = remove_snode_edge(G, domain_id_to_s_node[frozenset(1, 2)], "W") +G = remove_snode_edge(G, domain_id_to_s_node[frozenset(1, 2)], "X") + +draw(G) + +# let's explicitly compute the invariant domains per node +G = compute_invariant_domains_per_node(G, "W") +pprint(G.nodes(data=True)) + +# %% +# Consistency in cross-domain invariances +# --------------------------------------- +# In :footcite:`li2023discovery`, it is noted that there may be inconsistencies +# when removing S-node edges. For example, if we remove the edge S^{1, 2} -> W, +# and then remove the edge S^{2, 3} -> W, then we should have removed the +# edge S^{1, 3} -> W. This is because the invariances are transitive. In pywhy-graphs, +# we have a function that automatically checks for these inconsistencies and removes them. +# The :func:`pywhy_graphs.algorithms.remove_snode_edge` function automatically does this. + +G = remove_snode_edge(G, domain_id_to_s_node[frozenset(2, 3)], "W") + +# now the S-node edge corresponding to S^{1, 3} -> W should be removed as well +draw(G) + +# %% +# Summary +# ------- +# In this example, we have seen how to construct a selection diagram. We have also +# seen how to model invariances across domains using S-nodes and the lack of S-node edges +# to certain nodes in the graph. + + +# %% +# References +# ---------- +# .. footbibliography:: diff --git a/pywhy_graphs/classes/augmented.py b/pywhy_graphs/classes/augmented.py index 6c4f0880f..0a93a890a 100644 --- a/pywhy_graphs/classes/augmented.py +++ b/pywhy_graphs/classes/augmented.py @@ -1,6 +1,6 @@ import collections from abc import abstractmethod -from typing import Iterable, List, Optional, Set, Tuple +from typing import Dict, Iterable, List, Optional, Set, Tuple from networkx.classes.reportviews import NodeView @@ -62,8 +62,8 @@ def create_augmented_diagram( # create F-nodes, which is now all combinations of distributions choose 2 k = 0 - seen_domain_pairs = dict() - seen_distr_pairs = dict() + seen_domain_pairs: Dict = dict() + seen_distr_pairs: Dict = dict() # compare every pair of distributions to now add interventions if necessary for dataset_idx, source in enumerate(domain_ids): @@ -237,14 +237,14 @@ def intervened_nodes(self): return nodes @property - def domain_ids(self) -> List[int]: + def domain_ids(self) -> Set[int]: """Return set of domain ids.""" domain_ids = set() for src, target in self.graph["S-nodes"].values(): domain_ids.add(src) domain_ids.add(target) - return list(domain_ids) + return domain_ids @property def s_nodes(self) -> List[Node]: @@ -257,7 +257,7 @@ def s_node_domain_ids(self) -> List[int]: return self.graph["S-nodes"] @property - def domain_ids_to_snodes(self) -> List[int]: + def domain_ids_to_snodes(self) -> Dict: """Return a mapping of domain ids to their ocrresponding S-nodes.""" return {v: k for k, v in self.graph["S-nodes"].items()} @@ -288,7 +288,7 @@ def add_s_node(self, domain_ids: Tuple, node_changes: Set[Node] = None): # adding nodes to F-node container occurs last, because of the error checks # that occur in adding edges - self.graph["S-nodes"][s_node_name] = domain_ids + self.graph["S-nodes"][s_node_name] = frozenset(domain_ids) class AugmentedGraph(ADMG, AugmentedNodeMixin): diff --git a/pywhy_graphs/functional/multidomain.py b/pywhy_graphs/functional/multidomain.py index cede33323..d2b971efe 100644 --- a/pywhy_graphs/functional/multidomain.py +++ b/pywhy_graphs/functional/multidomain.py @@ -228,10 +228,9 @@ def generate_multidomain_noise_for_node( if check_s_node_consistency: # compute all possible S-nodes given the number of domains G, _ = add_all_snode_combinations(G, n_domains) - all_poss_snodes = set(G.s_nodes) # for each node with S-nodes and compute the invariant domains - G = compute_invariant_domains_per_node(G, node, all_poss_snodes, n_domains=n_domains) + G = compute_invariant_domains_per_node(G, node, n_domains=n_domains) else: if "invariant_domains" not in G.nodes()[node]: raise ValueError("Must specify invariant domains for node {}.".format(node)) @@ -331,11 +330,10 @@ def sample_multidomain_lin_functions( # compute all possible S-nodes given the number of domains G, s_node_domains = add_all_snode_combinations(G, n_domains) - all_poss_snodes = set(G.s_nodes) for node in G.nodes: if node in nodes_with_s_nodes: # for each node with S-nodes and compute the invariant domains - G = compute_invariant_domains_per_node(G, node, all_poss_snodes, n_domains=n_domains) + G = compute_invariant_domains_per_node(G, node, n_domains=n_domains) # now set a random function for each domain that is not invariant G = generate_multidomain_noise_for_node( From 48798e992562fff00c4f56a43f21bf68f571d500 Mon Sep 17 00:00:00 2001 From: Adam Li Date: Thu, 27 Jul 2023 17:39:16 -0400 Subject: [PATCH 04/14] Adding completed example Signed-off-by: Adam Li --- examples/multiple-domains/plot_selection_diagram.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/examples/multiple-domains/plot_selection_diagram.py b/examples/multiple-domains/plot_selection_diagram.py index f6f143c70..239524883 100644 --- a/examples/multiple-domains/plot_selection_diagram.py +++ b/examples/multiple-domains/plot_selection_diagram.py @@ -8,6 +8,8 @@ Selection diagrams are causal graphical objects that allow the user and scientist to represent causal models with multiple domains. This is useful for representing domain-shifts, generalizability and invariances across different environments. +For a detailed theoretical introduction to selection diagrams, see +:footcite:`bareinboim_causal_2016,pearl2022external`. This is a common problem in machine learning, where the goal is to learn a model that generalizes to unseen data. In this case, the unseen data can be a different From 4b6081b9805b98f4488c9ea7095fa6527ed0607d Mon Sep 17 00:00:00 2001 From: Adam Li Date: Fri, 28 Jul 2023 00:00:31 -0700 Subject: [PATCH 05/14] Try again Signed-off-by: Adam Li --- .../plot_selection_diagram.py | 6 +- pywhy_graphs/classes/augmented.py | 58 ++++++++++--------- 2 files changed, 35 insertions(+), 29 deletions(-) diff --git a/examples/multiple-domains/plot_selection_diagram.py b/examples/multiple-domains/plot_selection_diagram.py index 239524883..774e3c3fb 100644 --- a/examples/multiple-domains/plot_selection_diagram.py +++ b/examples/multiple-domains/plot_selection_diagram.py @@ -109,8 +109,8 @@ domain_id_to_s_node = G.domain_ids_to_snodes # remove the edge S^{1, 2} -> W -G = remove_snode_edge(G, domain_id_to_s_node[frozenset(1, 2)], "W") -G = remove_snode_edge(G, domain_id_to_s_node[frozenset(1, 2)], "X") +G = remove_snode_edge(G, domain_id_to_s_node[frozenset([1, 2])], "W") +G = remove_snode_edge(G, domain_id_to_s_node[frozenset([1, 2])], "X") draw(G) @@ -128,7 +128,7 @@ # we have a function that automatically checks for these inconsistencies and removes them. # The :func:`pywhy_graphs.algorithms.remove_snode_edge` function automatically does this. -G = remove_snode_edge(G, domain_id_to_s_node[frozenset(2, 3)], "W") +G = remove_snode_edge(G, domain_id_to_s_node[frozenset([2, 3])], "W") # now the S-node edge corresponding to S^{1, 3} -> W should be removed as well draw(G) diff --git a/pywhy_graphs/classes/augmented.py b/pywhy_graphs/classes/augmented.py index 0a93a890a..8164f4917 100644 --- a/pywhy_graphs/classes/augmented.py +++ b/pywhy_graphs/classes/augmented.py @@ -80,20 +80,17 @@ def create_augmented_diagram( # map each augmented-node to a tuple of distribution indices, or to a set of nodes # representing the intervention targets - symm_diff = set(intervention_targets[dataset_idx]).symmetric_difference( - set(intervention_targets[dataset_jdx]) - ) - targets = frozenset(symm_diff) - - # if targets is the empty set - - # if targets == frozenset() and source == target: - # # the two interventional distributions are exactly the same - # logger.warn( - # f"Interventional distributions {dataset_idx} and {dataset_jdx} have " - # f"the same interventions within the same domain {source}." - # ) - # continue + if ( + intervention_targets[dataset_idx] is not None + and intervention_targets[dataset_jdx] is not None + and source == target + ): + symm_diff = set(intervention_targets[dataset_idx]).symmetric_difference( + set(intervention_targets[dataset_jdx]) + ) + targets = frozenset(symm_diff) + else: + targets = None # create the F-node f_node = ("F", k) @@ -107,6 +104,16 @@ def create_augmented_diagram( k += 1 + # get non-augmented nodes + non_aug_nodes = set(G.non_augmented_nodes) + for aug_node in f_nodes: + G.add_f_node( + aug_node, targets=symmetric_diff_map[aug_node], domain=node_domain_map[aug_node] + ) + for node in non_aug_nodes: + G.add_edge(aug_node, node, G.directed_edge_name) + return G, sigma_map + class AugmentedNodeMixin: graph: dict @@ -141,7 +148,7 @@ def _verify_augmentednode_dict(self): "There is a graph property named S-nodes already that is not of type dict." ) - def add_f_node(self, intervention_set: Set[Node], require_unique=True, domain=None): + def add_f_node(self, targets: Set[Node], require_unique=True, domain=None): """Add an F-node to the graph. Parameters @@ -153,26 +160,25 @@ def add_f_node(self, intervention_set: Set[Node], require_unique=True, domain=No then the intervention set is added to the graph, even if it is already an F-node. The default is True. domain : Optional[Set[int]], optional - The domain of the F-node. If None, then the domain is just set to 1. + The domains of the F-node. If None, then the domain is just set to {1}. """ - if isinstance(intervention_set, str) or not isinstance(intervention_set, Iterable): + if isinstance(targets, str) or not isinstance(targets, Iterable): raise RuntimeError("The intervention set nodes must be an iterable set of node(s).") if domain is None: domain = set([1]) # check that there are no duplicates and perform set conversion - orig_len = len(intervention_set) - intervention_set = frozenset(intervention_set) # type: ignore - if len(intervention_set) != orig_len: + orig_len = len(targets) + targets = frozenset(targets) # type: ignore + if len(targets) != orig_len: raise RuntimeError("The intervention set must be a set of unique nodes.") # check that the F-node intervention set has variables within the graph - if require_unique and intervention_set in self.intervention_sets: + if require_unique and targets in self.intervention_sets: raise RuntimeError( - f"You cannot add an F-node for {intervention_set} because " - f"there is already an F-node." + f"You cannot add an F-node for {targets} because " f"there is already an F-node." ) - for node in intervention_set: + for node in targets: if node not in self.nodes: raise RuntimeError( f"All intervention sets must be nodes already in the graph. {node} is not." @@ -183,12 +189,12 @@ def add_f_node(self, intervention_set: Set[Node], require_unique=True, domain=No self.add_node(f_node_name) # add edge between the F-node and its intervention set - for intervened_node in intervention_set: + for intervened_node in targets: self.add_edge(f_node_name, intervened_node, self.directed_edge_name) # adding nodes to F-node container occurs last, because of the error checks # that occur in adding edges - self.graph["F-nodes"][f_node_name]["targets"] = intervention_set + self.graph["F-nodes"][f_node_name]["targets"] = targets self.graph["F-nodes"][f_node_name]["domain"] = domain def add_f_nodes_from(self, intervention_sets: List[Set[Node]]): From abd30074a70bb98ae089c56d10106837a95bb157 Mon Sep 17 00:00:00 2001 From: Adam Li Date: Fri, 28 Jul 2023 09:36:12 -0700 Subject: [PATCH 06/14] Fix circleci Signed-off-by: Adam Li --- doc/api.rst | 1 + examples/multiple-domains/plot_selection_diagram.py | 4 ++-- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/doc/api.rst b/doc/api.rst index 990414cfb..37e12f9f2 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -58,6 +58,7 @@ causal graph operations. find_connected_pairs add_all_snode_combinations compute_invariant_domains_per_node + remove_snode_edge Conversions between other package's causal graphs ================================================= diff --git a/examples/multiple-domains/plot_selection_diagram.py b/examples/multiple-domains/plot_selection_diagram.py index 774e3c3fb..6d384e9ce 100644 --- a/examples/multiple-domains/plot_selection_diagram.py +++ b/examples/multiple-domains/plot_selection_diagram.py @@ -40,7 +40,7 @@ # distribution over the exogenous variables. # # :math:`M1 = \langle \mathcal{F}, V, U, P(u) \rangle` -# .. math:: +# .. math:: # V = \{W, X, Y, Z\} # P(U) = P(U_W, U_X, U_Y, U_Z) # \mathcal{F} = \begin{cases} @@ -51,7 +51,7 @@ # \end{cases} # # :math:`M2 = \langle \mathcal{F'}, V, U', P'(u) \rangle` -# .. math:: +# .. math:: # P(U') = P(U_W', U_X', U_Y', U_Z') # \mathcal{F'} = \begin{cases} # W = f'_W(U_W) \\ From 482adf4db8087259bf593d93548ff936ff381914 Mon Sep 17 00:00:00 2001 From: Adam Li Date: Fri, 28 Jul 2023 09:37:38 -0700 Subject: [PATCH 07/14] Fix circleci Signed-off-by: Adam Li --- doc/whats_new/v0.2.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/whats_new/v0.2.rst b/doc/whats_new/v0.2.rst index 3c7bad424..7751aee51 100644 --- a/doc/whats_new/v0.2.rst +++ b/doc/whats_new/v0.2.rst @@ -25,7 +25,7 @@ Version 0.2 Changelog --------- -- +- |Feature| Add algorithms for interfacing with a selection diagram in ``pywhy_graphs.algorithms.multidomain``, by `Adam Li`_ (:pr:`88`) Code and Documentation Contributors ----------------------------------- From 2abfd12ac0dd457516984b48c230ea64c55a5440 Mon Sep 17 00:00:00 2001 From: Adam Li Date: Fri, 28 Jul 2023 20:07:37 -0700 Subject: [PATCH 08/14] Upgrade codebase for mypy Signed-off-by: Adam Li --- poetry.lock | 22 ++-- pyproject.toml | 8 +- pywhy_graphs/algorithms/generic.py | 9 +- pywhy_graphs/algorithms/multidomain.py | 1 - pywhy_graphs/algorithms/pag.py | 4 +- pywhy_graphs/algorithms/tests/test_cyclic.py | 2 +- pywhy_graphs/algorithms/tests/test_generic.py | 3 - pywhy_graphs/algorithms/tests/test_pag.py | 4 +- pywhy_graphs/array/api.py | 4 +- pywhy_graphs/classes/__init__.py | 2 +- pywhy_graphs/classes/augmented.py | 50 +++---- pywhy_graphs/classes/tests/test_augmented.py | 56 ++++++++ pywhy_graphs/classes/timeseries/conversion.py | 8 +- pywhy_graphs/classes/timeseries/mixededge.py | 4 +- pywhy_graphs/export/pcalg.py | 2 +- pywhy_graphs/export/tests/test_ananke.py | 2 - pywhy_graphs/functional/__init__.py | 2 +- pywhy_graphs/functional/additive.py | 12 +- pywhy_graphs/functional/base.py | 10 +- pywhy_graphs/functional/discrete.py | 6 +- pywhy_graphs/functional/linear.py | 30 +++-- pywhy_graphs/functional/multidomain.py | 4 +- pywhy_graphs/functional/tests/test_linear.py | 123 +++++++++++++++++- .../functional/tests/test_multidomain.py | 4 +- pywhy_graphs/functional/utils.py | 7 +- .../algorithms/causal/m_separation.py | 2 - .../algorithms/causal/mixed_edge_moral.py | 1 - .../algorithms/causal/tests/test_convert.py | 1 - .../causal/tests/test_m_separation.py | 3 - pywhy_graphs/simulate.py | 8 +- pywhy_graphs/viz/draw.py | 14 +- 31 files changed, 298 insertions(+), 110 deletions(-) create mode 100644 pywhy_graphs/classes/tests/test_augmented.py diff --git a/poetry.lock b/poetry.lock index f73fb4032..fee750b5f 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1006,21 +1006,21 @@ requirements-deprecated-finder = ["pip-api", "pipreqs"] [[package]] name = "jedi" -version = "0.18.2" +version = "0.19.0" description = "An autocompletion tool for Python that can be used for text editors." optional = false python-versions = ">=3.6" files = [ - {file = "jedi-0.18.2-py2.py3-none-any.whl", hash = "sha256:203c1fd9d969ab8f2119ec0a3342e0b49910045abe6af0a3ae83a5764d54639e"}, - {file = "jedi-0.18.2.tar.gz", hash = "sha256:bae794c30d07f6d910d32a7048af09b5a39ed740918da923c6b780790ebac612"}, + {file = "jedi-0.19.0-py2.py3-none-any.whl", hash = "sha256:cb8ce23fbccff0025e9386b5cf85e892f94c9b822378f8da49970471335ac64e"}, + {file = "jedi-0.19.0.tar.gz", hash = "sha256:bcf9894f1753969cbac8022a8c2eaee06bfa3724e4192470aaffe7eb6272b0c4"}, ] [package.dependencies] -parso = ">=0.8.0,<0.9.0" +parso = ">=0.8.3,<0.9.0" [package.extras] docs = ["Jinja2 (==2.11.3)", "MarkupSafe (==1.1.1)", "Pygments (==2.8.1)", "alabaster (==0.7.12)", "babel (==2.9.1)", "chardet (==4.0.0)", "commonmark (==0.8.1)", "docutils (==0.17.1)", "future (==0.18.2)", "idna (==2.10)", "imagesize (==1.2.0)", "mock (==1.0.1)", "packaging (==20.9)", "pyparsing (==2.4.7)", "pytz (==2021.1)", "readthedocs-sphinx-ext (==2.1.4)", "recommonmark (==0.5.0)", "requests (==2.25.1)", "six (==1.15.0)", "snowballstemmer (==2.1.0)", "sphinx (==1.8.5)", "sphinx-rtd-theme (==0.4.3)", "sphinxcontrib-serializinghtml (==1.1.4)", "sphinxcontrib-websupport (==1.2.4)", "urllib3 (==1.26.4)"] -qa = ["flake8 (==3.8.3)", "mypy (==0.782)"] +qa = ["flake8 (==5.0.4)", "mypy (==0.971)", "types-setuptools (==67.2.0.1)"] testing = ["Django (<3.1)", "attrs", "colorama", "docopt", "pytest (<7.0.0)"] [[package]] @@ -1652,13 +1652,13 @@ test = ["flaky", "ipykernel (>=6.19.3)", "ipython", "ipywidgets", "nbconvert (>= [[package]] name = "nbconvert" -version = "7.7.2" +version = "7.7.3" description = "Converting Jupyter Notebooks" optional = false python-versions = ">=3.8" files = [ - {file = "nbconvert-7.7.2-py3-none-any.whl", hash = "sha256:25e0cf2b663ee0cd5a90afb6b2f2940bf1abe5cc5bc995b88c8156ca65fa7ede"}, - {file = "nbconvert-7.7.2.tar.gz", hash = "sha256:36d3e7bf32f0c075878176cdeeb645931c994cbed5b747bc7a570ba8cd2321f3"}, + {file = "nbconvert-7.7.3-py3-none-any.whl", hash = "sha256:3022adadff3f86578a47fab7c2228bb3ca9c56a24345642a22f917f6168b48fc"}, + {file = "nbconvert-7.7.3.tar.gz", hash = "sha256:4a5996bf5f3cd16aa0431897ba1aa4c64842c2079f434b3dc6b8c4b252ef3355"}, ] [package.dependencies] @@ -1931,13 +1931,13 @@ files = [ [[package]] name = "pathspec" -version = "0.11.1" +version = "0.11.2" description = "Utility library for gitignore style pattern matching of file paths." optional = false python-versions = ">=3.7" files = [ - {file = "pathspec-0.11.1-py3-none-any.whl", hash = "sha256:d8af70af76652554bd134c22b3e8a1cc46ed7d91edcdd721ef1a0c51a84a5293"}, - {file = "pathspec-0.11.1.tar.gz", hash = "sha256:2798de800fa92780e33acca925945e9a19a133b715067cf165b8866c15a31687"}, + {file = "pathspec-0.11.2-py3-none-any.whl", hash = "sha256:1d6ed233af05e679efb96b1851550ea95bbb64b7c490b0f5aa52996c11e92a20"}, + {file = "pathspec-0.11.2.tar.gz", hash = "sha256:e0d8d0ac2f12da61956eb2306b69f9469b42f4deb0f3cb6ed47b9cce9996ced3"}, ] [[package]] diff --git a/pyproject.toml b/pyproject.toml index 54c80e6a1..f7a840dd1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -61,13 +61,13 @@ pandas = { version = "^1.1" } # needed for simulation optional = true [tool.poetry.group.style.dependencies] poethepoet = "^0.20.0" -mypy = "^0.971" -black = {extras = ["jupyter"], version = "^22.12.0"} +mypy = "^1.4.1" +black = {extras = ["jupyter"], version = "^23.7.0"} isort = "^5.12.0" -flake8 = "^5.0.4" +flake8 = "^6.0.0" bandit = "^1.7.4" pydocstyle = "^6.1.1" -codespell = "^2.1.0" +codespell = "^2.2.5" toml = "^0.10.2" [tool.poetry.group.docs] diff --git a/pywhy_graphs/algorithms/generic.py b/pywhy_graphs/algorithms/generic.py index 6df210097..a7fc87060 100644 --- a/pywhy_graphs/algorithms/generic.py +++ b/pywhy_graphs/algorithms/generic.py @@ -1,4 +1,4 @@ -from typing import List, Set, Union +from typing import List, Optional, Set, Union import networkx as nx @@ -16,7 +16,9 @@ ] -def is_node_common_cause(G: nx.DiGraph, node: Node, exclude_nodes: List[Node] = None) -> bool: +def is_node_common_cause( + G: nx.DiGraph, node: Node, exclude_nodes: Optional[List[Node]] = None +) -> bool: """Check if a node is a common cause within the graph. Parameters @@ -515,7 +517,7 @@ def _shortest_valid_path( return (path_exists, path) -def inducing_path(G, node_x: Node, node_y: Node, L: Set = None, S: Set = None): +def inducing_path(G, node_x: Node, node_y: Node, L: Optional[Set] = None, S: Optional[Set] = None): """Checks if an inducing path exists between two nodes. An inducing path is defined in :footcite:`Zhang2008`. @@ -592,7 +594,6 @@ def inducing_path(G, node_x: Node, node_y: Node, L: Set = None, S: Set = None): path_exists = False for elem in x_neighbors: - visited = {node_x} if elem not in visited: path_exists, temp_path = _shortest_valid_path( diff --git a/pywhy_graphs/algorithms/multidomain.py b/pywhy_graphs/algorithms/multidomain.py index 730f963bf..7af262d4f 100644 --- a/pywhy_graphs/algorithms/multidomain.py +++ b/pywhy_graphs/algorithms/multidomain.py @@ -69,7 +69,6 @@ def add_all_snode_combinations(G, n_domains: int, on_error="raise"): # add all the S-nodes representing differences across pairs of domains # to every single node with S-nodes for (source_domain, target_domain), s_node in s_node_domains.items(): - # now modify the function of the edge, S-nodes are pointing to s_node_domains[(source_domain, target_domain)] = s_node if s_node in G.s_nodes: diff --git a/pywhy_graphs/algorithms/pag.py b/pywhy_graphs/algorithms/pag.py index 4f8f2a28b..98969736a 100644 --- a/pywhy_graphs/algorithms/pag.py +++ b/pywhy_graphs/algorithms/pag.py @@ -510,7 +510,7 @@ def uncovered_pd_path( def pds( - graph: PAG, node_x: Node, node_y: Node = None, max_path_length: Optional[int] = None + graph: PAG, node_x: Node, node_y: Optional[Node] = None, max_path_length: Optional[int] = None ) -> Set[Node]: """Find all PDS sets between node_x and node_y. @@ -711,7 +711,7 @@ def pds_path( for comp in biconn_comp: if (node_x, node_y) in comp or (node_y, node_x) in comp: # add all unique nodes in the biconnected component - for (x, y) in comp: + for x, y in comp: found_component.add(x) found_component.add(y) break diff --git a/pywhy_graphs/algorithms/tests/test_cyclic.py b/pywhy_graphs/algorithms/tests/test_cyclic.py index ec5eaa436..f47d2969f 100644 --- a/pywhy_graphs/algorithms/tests/test_cyclic.py +++ b/pywhy_graphs/algorithms/tests/test_cyclic.py @@ -83,7 +83,7 @@ def test_sigma_separated(): cyclic_G = pywhy_nx.MixedEdgeGraph(graphs=[cyclic_G], edge_types=["directed"]) cyclic_G.add_edge_type(nx.Graph(), edge_type="bidirected") - for (u, v) in combinations(cyclic_G.nodes, 2): + for u, v in combinations(cyclic_G.nodes, 2): other_nodes = set(cyclic_G.nodes) other_nodes.remove(u) other_nodes.remove(v) diff --git a/pywhy_graphs/algorithms/tests/test_generic.py b/pywhy_graphs/algorithms/tests/test_generic.py index d0a52981f..9ed8ee1ab 100644 --- a/pywhy_graphs/algorithms/tests/test_generic.py +++ b/pywhy_graphs/algorithms/tests/test_generic.py @@ -43,7 +43,6 @@ def test_convert_to_latent_confounder(graph_func): def test_inducing_path(): - admg = ADMG() admg.add_edge("X", "Y", admg.directed_edge_name) @@ -93,7 +92,6 @@ def test_inducing_path(): def test_inducing_path_wihtout_LandS(): - admg = ADMG() admg.add_edge("X", "Y", admg.directed_edge_name) @@ -113,7 +111,6 @@ def test_inducing_path_wihtout_LandS(): def test_inducing_path_one_direction(): - admg = ADMG() admg.add_edge("A", "B", admg.directed_edge_name) diff --git a/pywhy_graphs/algorithms/tests/test_pag.py b/pywhy_graphs/algorithms/tests/test_pag.py index c09607a85..740283af8 100644 --- a/pywhy_graphs/algorithms/tests/test_pag.py +++ b/pywhy_graphs/algorithms/tests/test_pag.py @@ -179,7 +179,7 @@ def test_discriminating_path(): ) for u in pag.nodes: - for (a, c) in permutations(pag.neighbors(u), 2): + for a, c in permutations(pag.neighbors(u), 2): found_discriminating_path, disc_path, _ = discriminating_path( pag, u, a, c, max_path_length=100 ) @@ -193,7 +193,7 @@ def test_discriminating_path(): pag.remove_edge("x2", "x5", pag.directed_edge_name) pag.add_edge("x5", "x2", pag.bidirected_edge_name) for u in pag.nodes: - for (a, c) in permutations(pag.neighbors(u), 2): + for a, c in permutations(pag.neighbors(u), 2): found_discriminating_path, disc_path, _ = discriminating_path( pag, u, a, c, max_path_length=100 ) diff --git a/pywhy_graphs/array/api.py b/pywhy_graphs/array/api.py index 37f9548b0..e209ead77 100644 --- a/pywhy_graphs/array/api.py +++ b/pywhy_graphs/array/api.py @@ -1,4 +1,4 @@ -from typing import Dict, List, Set +from typing import Dict, List, Optional, Set import numpy as np from numpy.typing import NDArray @@ -80,7 +80,7 @@ def get_summary_graph(arr: NDArray, arr_enum: str = "clearn"): def array_to_lagged_links( - arr: NDArray, arr_idx: List[Node] = None, include_weights: bool = True + arr: NDArray, arr_idx: Optional[List[Node]] = None, include_weights: bool = True ) -> Dict[Node, List[Set]]: """Convert a time-series 3D array to a dictionary of lagged links. diff --git a/pywhy_graphs/classes/__init__.py b/pywhy_graphs/classes/__init__.py index 69701d4ac..33cc2fa84 100644 --- a/pywhy_graphs/classes/__init__.py +++ b/pywhy_graphs/classes/__init__.py @@ -1,6 +1,6 @@ from . import timeseries from .admg import ADMG -from .augmented import AugmentedGraph, AugmentedPAG +from .augmented import AugmentedGraph, AugmentedPAG, compute_augmented_nodes from .cpdag import CPDAG from .pag import PAG from .timeseries import ( diff --git a/pywhy_graphs/classes/augmented.py b/pywhy_graphs/classes/augmented.py index 8164f4917..c9418cf48 100644 --- a/pywhy_graphs/classes/augmented.py +++ b/pywhy_graphs/classes/augmented.py @@ -10,12 +10,11 @@ from .pag import PAG -def create_augmented_diagram( - G, +def compute_augmented_nodes( intervention_targets: List[Set[Node]], - domain_ids: List[int] = None, + domain_ids: Optional[List[int]] = None, ): - """Create an augmented causal diagram. + """Compute augmented nodes that would be added to a graph G. Each additional F-node created is mapped back to a symmetric difference set of intervention targets and a pair of domain ids. @@ -25,8 +24,6 @@ def create_augmented_diagram( Parameters ---------- - G : Causal Graph - The causal graph before augmenting. intervention_targets : List[Set[Node]] Sets of intervention targets. All intervention targets must be nodes in the graph G. domain_ids : List[int], optional @@ -35,8 +32,15 @@ def create_augmented_diagram( Returns ------- - G : Causal Graph - The augmented causal graph with additional nodes. + augmented_nodes : Set + Set of augmented nodes (i.e. F and S nodes). + symmetric_diff_map : Dict[Any, FrozenSet] + Mapping of augmented nodes to intervention targets, or distribution indices represented + by the node. + sigma_map : Dict[Any, FrozenSet] + Mapping of augmented nodes to distribution indices represented by the node. + node_domain_map : Dict[Any, FrozenSet] + Mapping of augmented nodes to domains. Examples -------- @@ -58,7 +62,7 @@ def create_augmented_diagram( reverse_sigma_map = dict() symmetric_diff_map = dict() sigma_map = dict() - f_nodes = [] + f_nodes = set() # create F-nodes, which is now all combinations of distributions choose 2 k = 0 @@ -83,7 +87,6 @@ def create_augmented_diagram( if ( intervention_targets[dataset_idx] is not None and intervention_targets[dataset_jdx] is not None - and source == target ): symm_diff = set(intervention_targets[dataset_idx]).symmetric_difference( set(intervention_targets[dataset_jdx]) @@ -94,7 +97,7 @@ def create_augmented_diagram( # create the F-node f_node = ("F", k) - f_nodes.append(f_node) + f_nodes.add(f_node) # map each F-node to a set of domain(s) node_domain_map[f_node] = [source, target] @@ -104,15 +107,15 @@ def create_augmented_diagram( k += 1 - # get non-augmented nodes - non_aug_nodes = set(G.non_augmented_nodes) - for aug_node in f_nodes: - G.add_f_node( - aug_node, targets=symmetric_diff_map[aug_node], domain=node_domain_map[aug_node] - ) - for node in non_aug_nodes: - G.add_edge(aug_node, node, G.directed_edge_name) - return G, sigma_map + # # get non-augmented nodes + # non_aug_nodes = set(G.non_augmented_nodes) + # for aug_node in f_nodes: + # G.add_f_node( + # aug_node, targets=symmetric_diff_map[aug_node], domain=node_domain_map[aug_node] + # ) + # for node in non_aug_nodes: + # G.add_edge(aug_node, node, G.directed_edge_name) + return f_nodes, symmetric_diff_map, sigma_map, node_domain_map class AugmentedNodeMixin: @@ -185,7 +188,7 @@ def add_f_node(self, targets: Set[Node], require_unique=True, domain=None): ) # add a new F-node into the graph - f_node_name = ("F", len(self.f_nodes)) + f_node_name = ("F", len(self.augmented_nodes)) self.add_node(f_node_name) # add edge between the F-node and its intervention set @@ -267,7 +270,7 @@ def domain_ids_to_snodes(self) -> Dict: """Return a mapping of domain ids to their ocrresponding S-nodes.""" return {v: k for k, v in self.graph["S-nodes"].items()} - def add_s_node(self, domain_ids: Tuple, node_changes: Set[Node] = None): + def add_s_node(self, domain_ids: Tuple, node_changes: Optional[Set[Node]] = None): if isinstance(node_changes, str) or not isinstance(node_changes, Iterable): raise RuntimeError("The intervention set nodes must be an iterable set of node(s).") @@ -285,7 +288,8 @@ def add_s_node(self, domain_ids: Tuple, node_changes: Set[Node] = None): ) # add a new S-node into the graph - s_node_name = ("S", len(self.s_nodes)) + # Note: that we represent S-nodes as F-nodes + s_node_name = ("F", len(self.augmented_nodes)) self.add_node(s_node_name, domain_ids=domain_ids) # add edge between the F-node and its intervention set diff --git a/pywhy_graphs/classes/tests/test_augmented.py b/pywhy_graphs/classes/tests/test_augmented.py new file mode 100644 index 000000000..48194aee0 --- /dev/null +++ b/pywhy_graphs/classes/tests/test_augmented.py @@ -0,0 +1,56 @@ +import math + +from pywhy_graphs.classes import compute_augmented_nodes + + +def test_compute_augmented_nodes(): + domain_indices = [1, 2, 2] + intervention_targets = [{}, {}, {"x"}] + + # test augmented nodes + ( + augmented_nodes, + symmetric_diff_map, + sigma_map, + node_domain_map, + ) = compute_augmented_nodes(intervention_targets, domain_indices) + assert len(augmented_nodes) == math.comb(len(domain_indices), 2) + assert symmetric_diff_map == { + ("F", 2): frozenset({"x"}), + ("F", 1): frozenset({"x"}), + ("F", 0): frozenset(), + } + assert sigma_map == {("F", 2): [1, 2], ("F", 1): [0, 2], ("F", 0): [0, 1]} + assert node_domain_map == {("F", 2): [2, 2], ("F", 1): [1, 2], ("F", 0): [1, 2]} + + domain_indices = [1, 3, 5, 2, 2] + intervention_targets = [{}, {}, {3}, {2}, {3}] + + # test augmented nodes + ( + augmented_nodes, + symmetric_diff_map, + sigma_map, + node_domain_map, + ) = compute_augmented_nodes(intervention_targets, domain_indices) + assert len(augmented_nodes) == math.comb(len(domain_indices), 2) + for node, domains in node_domain_map.items(): + assert all(domain in domain_indices for domain in domains) + assert node in sigma_map + assert node in symmetric_diff_map + + domain_indices = [1, 10, 10] + intervention_targets = [{}, {}, {"x"}] + # test augmented nodes + ( + augmented_nodes, + symmetric_diff_map, + sigma_map, + node_domain_map, + ) = compute_augmented_nodes(intervention_targets, domain_indices) + assert len(augmented_nodes) == math.comb(len(domain_indices), 2) + for node, domains in node_domain_map.items(): + # the domain indices should always be part of the domain indices passed + assert all(domain in domain_indices for domain in domains) + assert node in sigma_map + assert node in symmetric_diff_map diff --git a/pywhy_graphs/classes/timeseries/conversion.py b/pywhy_graphs/classes/timeseries/conversion.py index 94ec50472..b84ea4c99 100644 --- a/pywhy_graphs/classes/timeseries/conversion.py +++ b/pywhy_graphs/classes/timeseries/conversion.py @@ -1,4 +1,4 @@ -from typing import List +from typing import List, Optional import numpy as np @@ -7,7 +7,7 @@ from .graph import StationaryTimeSeriesGraph -def tsgraph_to_numpy(G, var_order: List[Node] = None): +def tsgraph_to_numpy(G, var_order: Optional[List[Node]] = None): """Convert stationary timeseries graph to numpy array. Parameters @@ -44,7 +44,9 @@ def tsgraph_to_numpy(G, var_order: List[Node] = None): return ts_graph_arr -def numpy_to_tsgraph(arr, var_order: List[Node] = None, create_using=StationaryTimeSeriesGraph): +def numpy_to_tsgraph( + arr, var_order: Optional[List[Node]] = None, create_using=StationaryTimeSeriesGraph +): """Convert 3D numpy array into a stationary time-series graph. Parameters diff --git a/pywhy_graphs/classes/timeseries/mixededge.py b/pywhy_graphs/classes/timeseries/mixededge.py index b873b901e..7f72922e5 100644 --- a/pywhy_graphs/classes/timeseries/mixededge.py +++ b/pywhy_graphs/classes/timeseries/mixededge.py @@ -1,3 +1,5 @@ +from typing import Optional + import numpy as np import pywhy_graphs.networkx as pywhy_nx @@ -178,7 +180,7 @@ class StationaryTimeSeriesMixedEdgeGraph(TimeSeriesMixedEdgeGraph): # supported graph types graph_types = (StationaryTimeSeriesGraph, StationaryTimeSeriesDiGraph) - def __init__(self, graphs=None, edge_types=None, max_lag: int = None, **attr): + def __init__(self, graphs=None, edge_types=None, max_lag: Optional[int] = None, **attr): super().__init__(graphs, edge_types, max_lag=max_lag, **attr) def set_stationarity(self, stationary: bool): diff --git a/pywhy_graphs/export/pcalg.py b/pywhy_graphs/export/pcalg.py index f5e81b198..9800af389 100644 --- a/pywhy_graphs/export/pcalg.py +++ b/pywhy_graphs/export/pcalg.py @@ -166,7 +166,7 @@ def graph_to_pcalg(causal_graph): # now map all values to their respective pcalg values seen_idx = dict() - for (idx, jdx) in np.argwhere(clearn_arr != 0): + for idx, jdx in np.argwhere(clearn_arr != 0): if (idx, jdx) in seen_idx or (jdx, idx) in seen_idx: continue diff --git a/pywhy_graphs/export/tests/test_ananke.py b/pywhy_graphs/export/tests/test_ananke.py index 66dc84318..ccffa5730 100644 --- a/pywhy_graphs/export/tests/test_ananke.py +++ b/pywhy_graphs/export/tests/test_ananke.py @@ -9,7 +9,6 @@ def dag(): - vertices = ["A", "B", "C", "D"] di_edges = [("A", "B"), ("B", "C"), ("C", "D")] graph = DAG(vertices=vertices, di_edges=di_edges) @@ -19,7 +18,6 @@ def dag(): def admg(): - vertices = ["A", "B", "C", "D"] di_edges = [("A", "B"), ("B", "C"), ("C", "D")] bi_edges = [("A", "C"), ("B", "D")] diff --git a/pywhy_graphs/functional/__init__.py b/pywhy_graphs/functional/__init__.py index 825023601..59117a28a 100644 --- a/pywhy_graphs/functional/__init__.py +++ b/pywhy_graphs/functional/__init__.py @@ -1,5 +1,5 @@ from .base import sample_from_graph -from .linear import apply_linear_soft_intervention, make_graph_linear_gaussian +from .linear import apply_linear_soft_intervention, make_random_linear_gaussian_graph from .multidomain import ( generate_multidomain_noise_for_node, make_graph_multidomain, diff --git a/pywhy_graphs/functional/additive.py b/pywhy_graphs/functional/additive.py index 111796511..a83f030df 100644 --- a/pywhy_graphs/functional/additive.py +++ b/pywhy_graphs/functional/additive.py @@ -42,6 +42,14 @@ def generate_edge_functions_for_node( directed_G = G.get_graphs("directed") else: directed_G = G + if edge_weight_lims is None: + edge_weight_lims_ = [1.0, 1.0] + else: + edge_weight_lims_ = edge_weight_lims + if edge_functions is None: + edge_functions_ = [lambda x: x] + else: + edge_functions_ = edge_functions rng = np.random.default_rng(random_state) # get all parents @@ -56,8 +64,8 @@ def generate_edge_functions_for_node( if parent == node: continue - weight = rng.uniform(low=edge_weight_lims[0], high=edge_weight_lims[1]) - func = rng.choice(edge_functions) + weight = rng.uniform(low=edge_weight_lims_[0], high=edge_weight_lims_[1]) + func = rng.choice(edge_functions_) node_function.append({"weight": weight, "func": func}) def parent_func(*args): diff --git a/pywhy_graphs/functional/base.py b/pywhy_graphs/functional/base.py index 1a36cf52c..32c762094 100644 --- a/pywhy_graphs/functional/base.py +++ b/pywhy_graphs/functional/base.py @@ -40,7 +40,7 @@ def add_parent_function(G: nx.DiGraph, node: Node, func: Callable) -> nx.DiGraph def add_noise_function( - G: nx.DiGraph, node: Node, distr_func: Callable, func: Callable = None + G: nx.DiGraph, node: Node, distr_func: Callable, func: Optional[Callable] = None ) -> nx.DiGraph: """Add function and distribution for a node's exogenous variable into the graph. @@ -120,7 +120,11 @@ def add_soft_intervention_function( def add_domain_shift_function( - G: AugmentedGraph, node: Node, s_node: Node, func: Callable = None, distr_func: Callable = None + G: AugmentedGraph, + node: Node, + s_node: Node, + func: Optional[Callable] = None, + distr_func: Optional[Callable] = None, ): """Add domain shift function for a node into the graph assuming invariant graph structure. @@ -368,7 +372,7 @@ def _check_input_func(func: Callable, parents=None): def _check_input_graph(G: nx.DiGraph): if not nx.is_directed_acyclic_graph(G): raise ValueError("The input graph must be a DAG.") - if not G.graph.get("functional", True): + if not G.graph.get("functional"): raise ValueError( "The input graph must be a functional graph. Please initialize " "the graph with functions." diff --git a/pywhy_graphs/functional/discrete.py b/pywhy_graphs/functional/discrete.py index cc21a2466..1c8f30716 100644 --- a/pywhy_graphs/functional/discrete.py +++ b/pywhy_graphs/functional/discrete.py @@ -184,9 +184,9 @@ def parent_func(*args): def make_random_discrete_graph( G: nx.DiGraph, - cardinality_lims: Dict[Any, List[int]] = None, - weight_lims: Dict[Any, List[int]] = None, - noise_ratio_lims: List[float] = None, + cardinality_lims: Optional[Dict[Any, List[int]]] = None, + weight_lims: Optional[Dict[Any, List[int]]] = None, + noise_ratio_lims: Optional[List[float]] = None, overwrite: bool = False, random_state=None, ) -> nx.DiGraph: diff --git a/pywhy_graphs/functional/linear.py b/pywhy_graphs/functional/linear.py index b5eb8b9c4..bb83c68d6 100644 --- a/pywhy_graphs/functional/linear.py +++ b/pywhy_graphs/functional/linear.py @@ -1,4 +1,4 @@ -from typing import Callable, List, Set +from typing import Callable, List, Optional, Set import networkx as nx import numpy as np @@ -9,12 +9,12 @@ from .utils import _preprocess_parameter_inputs -def make_graph_linear_gaussian( +def make_random_linear_gaussian_graph( G: nx.DiGraph, - node_mean_lims: List[float] = None, - node_std_lims: List[float] = None, - edge_functions: List[Callable[[float], float]] = None, - edge_weight_lims: List[float] = None, + node_mean_lims: Optional[List[float]] = None, + node_std_lims: Optional[List[float]] = None, + edge_functions: Optional[List[Callable[[float], float]]] = None, + edge_weight_lims: Optional[List[float]] = None, random_state=None, ) -> nx.DiGraph: r"""Convert an existing DAG to a linear Gaussian graphical model. @@ -62,7 +62,7 @@ def make_graph_linear_gaussian( ------- G : NetworkX DiGraph NetworkX graph with the edge weights and functions set with node attributes - set with ``'parent_function'``, and ``'gaussian_noise_function'``. Moreover + set with ``'parent_function'``, and ``'exogenous_distribution'``. Moreover the graph attribute ``'linear_gaussian'`` is set to ``True``. """ G = G.copy() @@ -121,7 +121,7 @@ def generate_noise_for_node(G, node, node_mean_lims, node_std_lims, random_state def apply_linear_soft_intervention( - G, targets: Set[Node], type: str = "additive", random_state=None + G, targets: Set[Node], type: str = "additive", intervention_value=None, random_state=None ): """Applies a soft intervention to a linear Gaussian graph. @@ -133,6 +133,10 @@ def apply_linear_soft_intervention( The set of nodes to intervene on simultanenously. type : str, optional Type of intervention, by default "additive". + intervention_value : float, optional + The value of the intervention, by default None, which will add + a random value sampled from the standard normal distribution + to the exogenous noise of the target nodes. random_state : RandomState, optional Random seed, by default None. @@ -140,7 +144,7 @@ def apply_linear_soft_intervention( ------- G : Graph The functional linear causal graph with the intervention applied on the - target nodes. The perturbation occurs on the ``gaussian_noise_function`` + target nodes. The perturbation occurs on the ``exogenous_distribution`` of the target nodes. That is, the soft intervention, perturbs the exogenous noise of the target nodes. """ @@ -153,6 +157,12 @@ def apply_linear_soft_intervention( for target in targets: if type == "additive": - G.nodes[target]["gaussian_noise_function"]["mean"] += rng.uniform(low=-1, high=1) + orig_func = G.nodes[target]["exogenous_distribution"] + if intervention_value is None: + G.nodes[target]["exogenous_distribution"] = ( + lambda: orig_func() + rng.standard_normal() + ) + else: + G.nodes[target]["exogenous_distribution"] = lambda: orig_func() + intervention_value return G diff --git a/pywhy_graphs/functional/multidomain.py b/pywhy_graphs/functional/multidomain.py index d2b971efe..d3311e1f0 100644 --- a/pywhy_graphs/functional/multidomain.py +++ b/pywhy_graphs/functional/multidomain.py @@ -23,7 +23,7 @@ def make_graph_multidomain( n_invariances_to_try: int = 1, node_mean_lims: Optional[List[float]] = None, node_std_lims: Optional[List[float]] = None, - edge_functions: List[Callable[[float], float]] = None, + edge_functions: Optional[List[Callable[[float], float]]] = None, edge_weight_lims: Optional[List[float]] = None, random_state=None, ) -> nx.DiGraph: @@ -262,7 +262,7 @@ def sample_multidomain_lin_functions( G: AugmentedGraph, node_mean_lims: Optional[List[float]] = None, node_std_lims: Optional[List[float]] = None, - edge_functions: List[Callable[[float], float]] = None, + edge_functions: Optional[List[Callable[[float], float]]] = None, edge_weight_lims: Optional[List[float]] = None, random_state=None, ): diff --git a/pywhy_graphs/functional/tests/test_linear.py b/pywhy_graphs/functional/tests/test_linear.py index 0f8605b1e..b9d94e4cf 100644 --- a/pywhy_graphs/functional/tests/test_linear.py +++ b/pywhy_graphs/functional/tests/test_linear.py @@ -1,14 +1,21 @@ import networkx as nx import pytest +from scipy.stats import ttest_ind -from pywhy_graphs.functional import make_graph_linear_gaussian, sample_from_graph +from pywhy_graphs.functional import sample_from_graph +from pywhy_graphs.functional.additive import generate_edge_functions_for_node +from pywhy_graphs.functional.linear import ( + apply_linear_soft_intervention, + generate_noise_for_node, + make_random_linear_gaussian_graph, +) from pywhy_graphs.simulate import simulate_random_er_dag def test_make_linear_gaussian_graph(): G = simulate_random_er_dag(n_nodes=5, seed=12345, ensure_acyclic=True) - G = make_graph_linear_gaussian(G, random_state=12345) + G = make_random_linear_gaussian_graph(G, random_state=12345) assert all( key in nx.get_node_attributes(G, "parent_function") @@ -26,13 +33,117 @@ def test_make_linear_gaussian_graph_errors(): G = simulate_random_er_dag(n_nodes=2, seed=12345, ensure_acyclic=True) with pytest.raises(ValueError, match="must be a list of length 2."): - make_graph_linear_gaussian(G, node_mean_lims=[0], random_state=12345) + make_random_linear_gaussian_graph(G, node_mean_lims=[0], random_state=12345) with pytest.raises(ValueError, match="must be a list of length 2."): - make_graph_linear_gaussian(G, node_std_lims=[0], random_state=12345) + make_random_linear_gaussian_graph(G, node_std_lims=[0], random_state=12345) with pytest.raises(ValueError, match="must be a list of length 2."): - make_graph_linear_gaussian(G, edge_weight_lims=[0], random_state=12345) + make_random_linear_gaussian_graph(G, edge_weight_lims=[0], random_state=12345) with pytest.raises(ValueError, match="The input graph must be a DAG."): - make_graph_linear_gaussian(nx.cycle_graph(4, create_using=nx.DiGraph), random_state=12345) + make_random_linear_gaussian_graph( + nx.cycle_graph(4, create_using=nx.DiGraph), random_state=12345 + ) + + +def test_generate_noise_for_node_works(): + G = nx.DiGraph() + G.add_node("A") + + node_mean_lims = (0, 1) + node_std_lims = (0, 1) + + # before adding any functionals, we cannot sample from the graph + with pytest.raises(ValueError, match="The input graph must be a functional graph"): + sample_from_graph(G, n_samples=1, random_state=12345) + G.graph["functional"] = "linear_gaussian" + with pytest.raises(ValueError, match="does not have an exogenous function"): + sample_from_graph(G, n_samples=1, random_state=12345) + + # now when we add the exogenous function, we can sample from the graph + G = generate_noise_for_node(G, "A", node_mean_lims, node_std_lims) + + # Check if node attributes are set properly + assert "exogenous_distribution" in G.nodes["A"] + assert "exogenous_function" in G.nodes["A"] + + # Check if exogenous_distribution is a callable function + assert callable(G.nodes["A"]["exogenous_distribution"]) + + # Check if exogenous_function is a callable function + assert callable(G.nodes["A"]["exogenous_function"]) + + # sample from the graph should work, since there is only one node + sample_from_graph(G, n_samples=1, random_state=12345) + + +def test_apply_linear_soft_intervention(): + G = nx.DiGraph() + G.add_node("A") + G.add_node("B") + G.add_edge("A", "B") + + node_mean_lims = (0, 1) + node_std_lims = (0, 1) + + G = generate_noise_for_node(G, "A", node_mean_lims, node_std_lims) + G = generate_noise_for_node(G, "B", node_mean_lims, node_std_lims) + G = generate_edge_functions_for_node(G, "B", random_state=1234) + G.graph["functional"] = "linear_gaussian" + targets = {"B"} + + # Before intervening, the functions are the same + for target in targets: + assert ( + G.nodes[target]["exogenous_distribution"] + == G.copy().nodes[target]["exogenous_distribution"] + ) + + # Test additive intervention type + G_intervened = apply_linear_soft_intervention( + G.copy(), targets, intervention_value=5.0, type="additive" + ) + + # Check if the target nodes have modified exogenous_distribution functions + for target in targets: + assert ( + G.nodes[target]["exogenous_distribution"] + != G_intervened.nodes[target]["exogenous_distribution"] + ) + + # Check if the exogenous_distribution functions of non-target nodes remain unchanged + non_target_nodes = set(G.nodes) - targets + for node in non_target_nodes: + assert ( + G.nodes[node]["exogenous_distribution"] + == G_intervened.nodes[node]["exogenous_distribution"] + ) + + # now ensure that the two distributions are different and the same where they should be + # the node A is not intervened on, so the distributions should be the same + # while the node B is intervened on, so the distributions should be different + df_original = sample_from_graph(G, n_samples=1000, random_state=12345) + df_intervened = sample_from_graph(G_intervened, n_samples=1000, random_state=12345) + _, pvalue = ttest_ind(df_original["A"], df_intervened["A"]) + assert pvalue > 0.05 + _, pvalue = ttest_ind(df_original["B"], df_intervened["B"]) + assert pvalue < 0.05 + + +def test_apply_linear_soft_intervention_errors(): + targets = {"A", "B"} + + # Test intervention on a non-linear Gaussian graph + H = nx.DiGraph() + H.add_node("X") + H.add_node("Y") + H.add_edge("X", "Y") + + with pytest.raises(ValueError, match="The input graph must be a linear Gaussian graph."): + H.graph["linear_gaussian"] = False + apply_linear_soft_intervention(H, targets, type="additive") + + with pytest.raises(ValueError, match="All targets"): + H.graph["linear_gaussian"] = True + apply_linear_soft_intervention(H, {1, 2}, type="additive") diff --git a/pywhy_graphs/functional/tests/test_multidomain.py b/pywhy_graphs/functional/tests/test_multidomain.py index 0910b90ae..77306b00a 100644 --- a/pywhy_graphs/functional/tests/test_multidomain.py +++ b/pywhy_graphs/functional/tests/test_multidomain.py @@ -1,6 +1,6 @@ import pytest -from pywhy_graphs.functional import make_graph_linear_gaussian +from pywhy_graphs.functional import make_random_linear_gaussian_graph from pywhy_graphs.simulate import simulate_random_er_dag @@ -9,7 +9,7 @@ def test_make_linear_gaussian_graph(n_domains, n_invariances_to_try): G = simulate_random_er_dag(n_nodes=5, seed=12345, ensure_acyclic=True) # make linear graph SCM - G = make_graph_linear_gaussian(G, random_state=12345) + G = make_random_linear_gaussian_graph(G, random_state=12345) # make multidomain SCM # G = make_graph_multidomain( diff --git a/pywhy_graphs/functional/utils.py b/pywhy_graphs/functional/utils.py index 1bd424ff6..7c52ec957 100644 --- a/pywhy_graphs/functional/utils.py +++ b/pywhy_graphs/functional/utils.py @@ -1,8 +1,11 @@ import itertools +from typing import Optional import networkx as nx import numpy as np +from pywhy_graphs.typing import Node + def to_pgmpy_bayesian_network(G): """Convert a discrete graph to a pgmpy Bayesian network. @@ -113,7 +116,7 @@ def get_cpd(G, node): return cpd -def get_cardinality(G, node): +def get_cardinality(G: nx.DiGraph, node: Node): from pgmpy.factors.discrete import TabularCPD cpd: TabularCPD = get_cpd(G, node) @@ -217,7 +220,7 @@ def _preprocess_parameter_inputs( edge_functions, edge_weight_lims, multi_domain: bool = False, - n_domains: int = None, + n_domains: Optional[int] = None, ): """Helper function to preprocess common parameter inputs for sampling functional graphs. diff --git a/pywhy_graphs/networkx/algorithms/causal/m_separation.py b/pywhy_graphs/networkx/algorithms/causal/m_separation.py index 84510e505..48308cdd1 100644 --- a/pywhy_graphs/networkx/algorithms/causal/m_separation.py +++ b/pywhy_graphs/networkx/algorithms/causal/m_separation.py @@ -110,7 +110,6 @@ def m_separated( G_bidirected = G.get_graphs(edge_type=bidirected_edge_name) while forward_deque or backward_deque: - if backward_deque: node = backward_deque.popleft() backward_visited.add(node) @@ -151,7 +150,6 @@ def m_separated( # Consider if *-> node <-* is opened due to conditioning on collider, # or descendant of collider if node in an_z: - if has_directed: # add <- edges to backward deque for x, _ in G_directed.in_edges(nbunch=node): diff --git a/pywhy_graphs/networkx/algorithms/causal/mixed_edge_moral.py b/pywhy_graphs/networkx/algorithms/causal/mixed_edge_moral.py index 6fee0634e..eeec16306 100644 --- a/pywhy_graphs/networkx/algorithms/causal/mixed_edge_moral.py +++ b/pywhy_graphs/networkx/algorithms/causal/mixed_edge_moral.py @@ -61,7 +61,6 @@ def mixed_edge_moral_graph( G_a = nx.compose(G_a, G_bidirected) for component in nx.connected_components(G_bidirected): - for u, v in itertools.combinations(component, 2): G_a.add_edge(u, v) all_parents = {parent for node in component for parent in G_directed.predecessors(node)} diff --git a/pywhy_graphs/networkx/algorithms/causal/tests/test_convert.py b/pywhy_graphs/networkx/algorithms/causal/tests/test_convert.py index b289372a4..7e3ff570e 100644 --- a/pywhy_graphs/networkx/algorithms/causal/tests/test_convert.py +++ b/pywhy_graphs/networkx/algorithms/causal/tests/test_convert.py @@ -4,7 +4,6 @@ def test_m_separation(): - # 0 -> 1 -> 2 -> 3 -> 4; 2 -> 4; 2 <-> 3 digraph = nx.path_graph(4, create_using=nx.DiGraph) digraph.add_edge(2, 4) diff --git a/pywhy_graphs/networkx/algorithms/causal/tests/test_m_separation.py b/pywhy_graphs/networkx/algorithms/causal/tests/test_m_separation.py index bc6aa56da..3aa19d26b 100644 --- a/pywhy_graphs/networkx/algorithms/causal/tests/test_m_separation.py +++ b/pywhy_graphs/networkx/algorithms/causal/tests/test_m_separation.py @@ -9,7 +9,6 @@ @pytest.fixture def fig5_vanderzander(): - nodes = ["V_1", "X", "V_2", "Y", "Z_1", "Z_2"] digraph = nx.DiGraph() @@ -35,7 +34,6 @@ def fig5_vanderzander(): @pytest.fixture def modified_fig5_vanderzander(): - nodes = ["V_1", "X", "V_2", "Y", "Z_1", "Z_2"] digraph = nx.DiGraph() @@ -239,7 +237,6 @@ def test_anterior(): def test_is_minimal_m_separator(fig5_vanderzander): - assert pywhy_nx.is_minimal_m_separator(fig5_vanderzander, "X", "Y", {"Z_1"}) assert pywhy_nx.is_minimal_m_separator(fig5_vanderzander, "X", "Y", {"Z_2"}) assert pywhy_nx.is_minimal_m_separator(fig5_vanderzander, "X", "Y", {"Z_2"}, r={"Z_1", "Z_2"}) diff --git a/pywhy_graphs/simulate.py b/pywhy_graphs/simulate.py index 1b1ebeb70..1ecc98837 100644 --- a/pywhy_graphs/simulate.py +++ b/pywhy_graphs/simulate.py @@ -11,7 +11,7 @@ def simulate_random_er_dag( - n_nodes: int, p: float = 0.5, seed: int = None, ensure_acyclic: bool = False + n_nodes: int, p: float = 0.5, seed: Optional[int] = None, ensure_acyclic: bool = False ): """Simulate a random Erdos-Renyi graph. @@ -109,7 +109,7 @@ def simulate_data_from_var( n_times: int = 1000, n_realizations: int = 1, var_names: Optional[List[Node]] = None, - random_state: int = None, + random_state: Optional[int] = None, ): """Simulate data from an already set VAR process. @@ -199,7 +199,7 @@ def simulate_linear_var_process( n_times: int = 1000, n_realizations: int = 1, weight_dist: Callable = scipy.stats.norm, - random_state: int = None, + random_state: Optional[int] = None, ): """Simulate a linear VAR process of a "stationary" causal graph. @@ -286,7 +286,7 @@ def simulate_linear_var_process( def simulate_var_process_from_summary_graph( - G: pywhy_nx.MixedEdgeGraph, max_lag=1, n_times=1000, random_state: int = None + G: pywhy_nx.MixedEdgeGraph, max_lag=1, n_times=1000, random_state: Optional[int] = None ): """Simulate a VAR(max_lag) process starting from a summary graph. diff --git a/pywhy_graphs/viz/draw.py b/pywhy_graphs/viz/draw.py index 7be5d561b..a4e4cce62 100644 --- a/pywhy_graphs/viz/draw.py +++ b/pywhy_graphs/viz/draw.py @@ -5,10 +5,10 @@ def _draw_circle_edges( dot, - directed_edges: List[Tuple] = None, - circle_edges: List[Tuple] = None, - undirected_edges: List[Tuple] = None, - bidirected_edges: List[Tuple] = None, + directed_edges: Optional[List[Tuple]] = None, + circle_edges: Optional[List[Tuple]] = None, + undirected_edges: Optional[List[Tuple]] = None, + bidirected_edges: Optional[List[Tuple]] = None, **attrs, ): """Draw the PAG edges. @@ -52,7 +52,7 @@ def _draw_circle_edges( def _draw_un_edges( dot, - undirected_edges: List[Tuple] = None, + undirected_edges: Optional[List[Tuple]] = None, **attrs, ): """Draw undirected edges.""" @@ -65,7 +65,7 @@ def _draw_un_edges( def _draw_bi_edges( dot, - bidirected_edges: List[Tuple] = None, + bidirected_edges: Optional[List[Tuple]] = None, **attrs, ): """Draw bidirected edges.""" @@ -81,7 +81,7 @@ def draw( direction: Optional[str] = None, pos: Optional[dict] = None, name: Optional[str] = None, - shape="square", + shape: str = "square", **attrs, ): """Visualize the graph. From 0ef79ce341052b124f13c00419b64cc526f1c432 Mon Sep 17 00:00:00 2001 From: Adam Li Date: Fri, 28 Jul 2023 20:19:11 -0700 Subject: [PATCH 09/14] Update lock file Signed-off-by: Adam Li --- poetry.lock | 133 +++++++++++++++++++++++++++---------------------- pyproject.toml | 2 +- 2 files changed, 75 insertions(+), 60 deletions(-) diff --git a/poetry.lock b/poetry.lock index fee750b5f..7952c8f50 100644 --- a/poetry.lock +++ b/poetry.lock @@ -129,33 +129,44 @@ lxml = ["lxml"] [[package]] name = "black" -version = "22.12.0" +version = "23.7.0" description = "The uncompromising code formatter." optional = false -python-versions = ">=3.7" +python-versions = ">=3.8" files = [ - {file = "black-22.12.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9eedd20838bd5d75b80c9f5487dbcb06836a43833a37846cf1d8c1cc01cef59d"}, - {file = "black-22.12.0-cp310-cp310-win_amd64.whl", hash = "sha256:159a46a4947f73387b4d83e87ea006dbb2337eab6c879620a3ba52699b1f4351"}, - {file = "black-22.12.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d30b212bffeb1e252b31dd269dfae69dd17e06d92b87ad26e23890f3efea366f"}, - {file = "black-22.12.0-cp311-cp311-win_amd64.whl", hash = "sha256:7412e75863aa5c5411886804678b7d083c7c28421210180d67dfd8cf1221e1f4"}, - {file = "black-22.12.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c116eed0efb9ff870ded8b62fe9f28dd61ef6e9ddd28d83d7d264a38417dcee2"}, - {file = "black-22.12.0-cp37-cp37m-win_amd64.whl", hash = "sha256:1f58cbe16dfe8c12b7434e50ff889fa479072096d79f0a7f25e4ab8e94cd8350"}, - {file = "black-22.12.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:77d86c9f3db9b1bf6761244bc0b3572a546f5fe37917a044e02f3166d5aafa7d"}, - {file = "black-22.12.0-cp38-cp38-win_amd64.whl", hash = "sha256:82d9fe8fee3401e02e79767016b4907820a7dc28d70d137eb397b92ef3cc5bfc"}, - {file = "black-22.12.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:101c69b23df9b44247bd88e1d7e90154336ac4992502d4197bdac35dd7ee3320"}, - {file = "black-22.12.0-cp39-cp39-win_amd64.whl", hash = "sha256:559c7a1ba9a006226f09e4916060982fd27334ae1998e7a38b3f33a37f7a2148"}, - {file = "black-22.12.0-py3-none-any.whl", hash = "sha256:436cc9167dd28040ad90d3b404aec22cedf24a6e4d7de221bec2730ec0c97bcf"}, - {file = "black-22.12.0.tar.gz", hash = "sha256:229351e5a18ca30f447bf724d007f890f97e13af070bb6ad4c0a441cd7596a2f"}, + {file = "black-23.7.0-cp310-cp310-macosx_10_16_arm64.whl", hash = "sha256:5c4bc552ab52f6c1c506ccae05681fab58c3f72d59ae6e6639e8885e94fe2587"}, + {file = "black-23.7.0-cp310-cp310-macosx_10_16_universal2.whl", hash = "sha256:552513d5cd5694590d7ef6f46e1767a4df9af168d449ff767b13b084c020e63f"}, + {file = "black-23.7.0-cp310-cp310-macosx_10_16_x86_64.whl", hash = "sha256:86cee259349b4448adb4ef9b204bb4467aae74a386bce85d56ba4f5dc0da27be"}, + {file = "black-23.7.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:501387a9edcb75d7ae8a4412bb8749900386eaef258f1aefab18adddea1936bc"}, + {file = "black-23.7.0-cp310-cp310-win_amd64.whl", hash = "sha256:fb074d8b213749fa1d077d630db0d5f8cc3b2ae63587ad4116e8a436e9bbe995"}, + {file = "black-23.7.0-cp311-cp311-macosx_10_16_arm64.whl", hash = "sha256:b5b0ee6d96b345a8b420100b7d71ebfdd19fab5e8301aff48ec270042cd40ac2"}, + {file = "black-23.7.0-cp311-cp311-macosx_10_16_universal2.whl", hash = "sha256:893695a76b140881531062d48476ebe4a48f5d1e9388177e175d76234ca247cd"}, + {file = "black-23.7.0-cp311-cp311-macosx_10_16_x86_64.whl", hash = "sha256:c333286dc3ddca6fdff74670b911cccedacb4ef0a60b34e491b8a67c833b343a"}, + {file = "black-23.7.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:831d8f54c3a8c8cf55f64d0422ee875eecac26f5f649fb6c1df65316b67c8926"}, + {file = "black-23.7.0-cp311-cp311-win_amd64.whl", hash = "sha256:7f3bf2dec7d541b4619b8ce526bda74a6b0bffc480a163fed32eb8b3c9aed8ad"}, + {file = "black-23.7.0-cp38-cp38-macosx_10_16_arm64.whl", hash = "sha256:f9062af71c59c004cd519e2fb8f5d25d39e46d3af011b41ab43b9c74e27e236f"}, + {file = "black-23.7.0-cp38-cp38-macosx_10_16_universal2.whl", hash = "sha256:01ede61aac8c154b55f35301fac3e730baf0c9cf8120f65a9cd61a81cfb4a0c3"}, + {file = "black-23.7.0-cp38-cp38-macosx_10_16_x86_64.whl", hash = "sha256:327a8c2550ddc573b51e2c352adb88143464bb9d92c10416feb86b0f5aee5ff6"}, + {file = "black-23.7.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6d1c6022b86f83b632d06f2b02774134def5d4d4f1dac8bef16d90cda18ba28a"}, + {file = "black-23.7.0-cp38-cp38-win_amd64.whl", hash = "sha256:27eb7a0c71604d5de083757fbdb245b1a4fae60e9596514c6ec497eb63f95320"}, + {file = "black-23.7.0-cp39-cp39-macosx_10_16_arm64.whl", hash = "sha256:8417dbd2f57b5701492cd46edcecc4f9208dc75529bcf76c514864e48da867d9"}, + {file = "black-23.7.0-cp39-cp39-macosx_10_16_universal2.whl", hash = "sha256:47e56d83aad53ca140da0af87678fb38e44fd6bc0af71eebab2d1f59b1acf1d3"}, + {file = "black-23.7.0-cp39-cp39-macosx_10_16_x86_64.whl", hash = "sha256:25cc308838fe71f7065df53aedd20327969d05671bac95b38fdf37ebe70ac087"}, + {file = "black-23.7.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:642496b675095d423f9b8448243336f8ec71c9d4d57ec17bf795b67f08132a91"}, + {file = "black-23.7.0-cp39-cp39-win_amd64.whl", hash = "sha256:ad0014efc7acf0bd745792bd0d8857413652979200ab924fbf239062adc12491"}, + {file = "black-23.7.0-py3-none-any.whl", hash = "sha256:9fd59d418c60c0348505f2ddf9609c1e1de8e7493eab96198fc89d9f865e7a96"}, + {file = "black-23.7.0.tar.gz", hash = "sha256:022a582720b0d9480ed82576c920a8c1dde97cc38ff11d8d8859b3bd6ca9eedb"}, ] [package.dependencies] click = ">=8.0.0" ipython = {version = ">=7.8.0", optional = true, markers = "extra == \"jupyter\""} mypy-extensions = ">=0.4.3" +packaging = ">=22.0" pathspec = ">=0.9.0" platformdirs = ">=2" tokenize-rt = {version = ">=3.2.0", optional = true, markers = "extra == \"jupyter\""} -tomli = {version = ">=1.1.0", markers = "python_full_version < \"3.11.0a7\""} +tomli = {version = ">=1.1.0", markers = "python_version < \"3.11\""} typing-extensions = {version = ">=3.10.0.0", markers = "python_version < \"3.10\""} [package.extras] @@ -720,19 +731,19 @@ testing = ["covdefaults (>=2.3)", "coverage (>=7.2.7)", "diff-cover (>=7.5)", "p [[package]] name = "flake8" -version = "5.0.4" +version = "6.0.0" description = "the modular source code checker: pep8 pyflakes and co" optional = false -python-versions = ">=3.6.1" +python-versions = ">=3.8.1" files = [ - {file = "flake8-5.0.4-py2.py3-none-any.whl", hash = "sha256:7a1cf6b73744f5806ab95e526f6f0d8c01c66d7bbe349562d22dfca20610b248"}, - {file = "flake8-5.0.4.tar.gz", hash = "sha256:6fbe320aad8d6b95cec8b8e47bc933004678dc63095be98528b7bdd2a9f510db"}, + {file = "flake8-6.0.0-py2.py3-none-any.whl", hash = "sha256:3833794e27ff64ea4e9cf5d410082a8b97ff1a06c16aa3d2027339cd0f1195c7"}, + {file = "flake8-6.0.0.tar.gz", hash = "sha256:c61007e76655af75e6785a931f452915b371dc48f56efd765247c8fe68f2b181"}, ] [package.dependencies] mccabe = ">=0.7.0,<0.8.0" -pycodestyle = ">=2.9.0,<2.10.0" -pyflakes = ">=2.5.0,<2.6.0" +pycodestyle = ">=2.10.0,<2.11.0" +pyflakes = ">=3.0.0,<3.1.0" [[package]] name = "fonttools" @@ -1577,43 +1588,47 @@ tests = ["pytest (>=4.6)"] [[package]] name = "mypy" -version = "0.971" +version = "1.4.1" description = "Optional static typing for Python" optional = false -python-versions = ">=3.6" +python-versions = ">=3.7" files = [ - {file = "mypy-0.971-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:f2899a3cbd394da157194f913a931edfd4be5f274a88041c9dc2d9cdcb1c315c"}, - {file = "mypy-0.971-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:98e02d56ebe93981c41211c05adb630d1d26c14195d04d95e49cd97dbc046dc5"}, - {file = "mypy-0.971-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:19830b7dba7d5356d3e26e2427a2ec91c994cd92d983142cbd025ebe81d69cf3"}, - {file = "mypy-0.971-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:02ef476f6dcb86e6f502ae39a16b93285fef97e7f1ff22932b657d1ef1f28655"}, - {file = "mypy-0.971-cp310-cp310-win_amd64.whl", hash = "sha256:25c5750ba5609a0c7550b73a33deb314ecfb559c350bb050b655505e8aed4103"}, - {file = "mypy-0.971-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:d3348e7eb2eea2472db611486846742d5d52d1290576de99d59edeb7cd4a42ca"}, - {file = "mypy-0.971-cp36-cp36m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:3fa7a477b9900be9b7dd4bab30a12759e5abe9586574ceb944bc29cddf8f0417"}, - {file = "mypy-0.971-cp36-cp36m-win_amd64.whl", hash = "sha256:2ad53cf9c3adc43cf3bea0a7d01a2f2e86db9fe7596dfecb4496a5dda63cbb09"}, - {file = "mypy-0.971-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:855048b6feb6dfe09d3353466004490b1872887150c5bb5caad7838b57328cc8"}, - {file = "mypy-0.971-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:23488a14a83bca6e54402c2e6435467a4138785df93ec85aeff64c6170077fb0"}, - {file = "mypy-0.971-cp37-cp37m-win_amd64.whl", hash = "sha256:4b21e5b1a70dfb972490035128f305c39bc4bc253f34e96a4adf9127cf943eb2"}, - {file = "mypy-0.971-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:9796a2ba7b4b538649caa5cecd398d873f4022ed2333ffde58eaf604c4d2cb27"}, - {file = "mypy-0.971-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:5a361d92635ad4ada1b1b2d3630fc2f53f2127d51cf2def9db83cba32e47c856"}, - {file = "mypy-0.971-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:b793b899f7cf563b1e7044a5c97361196b938e92f0a4343a5d27966a53d2ec71"}, - {file = "mypy-0.971-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:d1ea5d12c8e2d266b5fb8c7a5d2e9c0219fedfeb493b7ed60cd350322384ac27"}, - {file = "mypy-0.971-cp38-cp38-win_amd64.whl", hash = "sha256:23c7ff43fff4b0df93a186581885c8512bc50fc4d4910e0f838e35d6bb6b5e58"}, - {file = "mypy-0.971-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:1f7656b69974a6933e987ee8ffb951d836272d6c0f81d727f1d0e2696074d9e6"}, - {file = "mypy-0.971-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:d2022bfadb7a5c2ef410d6a7c9763188afdb7f3533f22a0a32be10d571ee4bbe"}, - {file = "mypy-0.971-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:ef943c72a786b0f8d90fd76e9b39ce81fb7171172daf84bf43eaf937e9f220a9"}, - {file = "mypy-0.971-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:d744f72eb39f69312bc6c2abf8ff6656973120e2eb3f3ec4f758ed47e414a4bf"}, - {file = "mypy-0.971-cp39-cp39-win_amd64.whl", hash = "sha256:77a514ea15d3007d33a9e2157b0ba9c267496acf12a7f2b9b9f8446337aac5b0"}, - {file = "mypy-0.971-py3-none-any.whl", hash = "sha256:0d054ef16b071149917085f51f89555a576e2618d5d9dd70bd6eea6410af3ac9"}, - {file = "mypy-0.971.tar.gz", hash = "sha256:40b0f21484238269ae6a57200c807d80debc6459d444c0489a102d7c6a75fa56"}, -] - -[package.dependencies] -mypy-extensions = ">=0.4.3" + {file = "mypy-1.4.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:566e72b0cd6598503e48ea610e0052d1b8168e60a46e0bfd34b3acf2d57f96a8"}, + {file = "mypy-1.4.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:ca637024ca67ab24a7fd6f65d280572c3794665eaf5edcc7e90a866544076878"}, + {file = "mypy-1.4.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0dde1d180cd84f0624c5dcaaa89c89775550a675aff96b5848de78fb11adabcd"}, + {file = "mypy-1.4.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:8c4d8e89aa7de683e2056a581ce63c46a0c41e31bd2b6d34144e2c80f5ea53dc"}, + {file = "mypy-1.4.1-cp310-cp310-win_amd64.whl", hash = "sha256:bfdca17c36ae01a21274a3c387a63aa1aafe72bff976522886869ef131b937f1"}, + {file = "mypy-1.4.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:7549fbf655e5825d787bbc9ecf6028731973f78088fbca3a1f4145c39ef09462"}, + {file = "mypy-1.4.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:98324ec3ecf12296e6422939e54763faedbfcc502ea4a4c38502082711867258"}, + {file = "mypy-1.4.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:141dedfdbfe8a04142881ff30ce6e6653c9685b354876b12e4fe6c78598b45e2"}, + {file = "mypy-1.4.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:8207b7105829eca6f3d774f64a904190bb2231de91b8b186d21ffd98005f14a7"}, + {file = "mypy-1.4.1-cp311-cp311-win_amd64.whl", hash = "sha256:16f0db5b641ba159eff72cff08edc3875f2b62b2fa2bc24f68c1e7a4e8232d01"}, + {file = "mypy-1.4.1-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:470c969bb3f9a9efcedbadcd19a74ffb34a25f8e6b0e02dae7c0e71f8372f97b"}, + {file = "mypy-1.4.1-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e5952d2d18b79f7dc25e62e014fe5a23eb1a3d2bc66318df8988a01b1a037c5b"}, + {file = "mypy-1.4.1-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:190b6bab0302cec4e9e6767d3eb66085aef2a1cc98fe04936d8a42ed2ba77bb7"}, + {file = "mypy-1.4.1-cp37-cp37m-win_amd64.whl", hash = "sha256:9d40652cc4fe33871ad3338581dca3297ff5f2213d0df345bcfbde5162abf0c9"}, + {file = "mypy-1.4.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:01fd2e9f85622d981fd9063bfaef1aed6e336eaacca00892cd2d82801ab7c042"}, + {file = "mypy-1.4.1-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:2460a58faeea905aeb1b9b36f5065f2dc9a9c6e4c992a6499a2360c6c74ceca3"}, + {file = "mypy-1.4.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a2746d69a8196698146a3dbe29104f9eb6a2a4d8a27878d92169a6c0b74435b6"}, + {file = "mypy-1.4.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:ae704dcfaa180ff7c4cfbad23e74321a2b774f92ca77fd94ce1049175a21c97f"}, + {file = "mypy-1.4.1-cp38-cp38-win_amd64.whl", hash = "sha256:43d24f6437925ce50139a310a64b2ab048cb2d3694c84c71c3f2a1626d8101dc"}, + {file = "mypy-1.4.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:c482e1246726616088532b5e964e39765b6d1520791348e6c9dc3af25b233828"}, + {file = "mypy-1.4.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:43b592511672017f5b1a483527fd2684347fdffc041c9ef53428c8dc530f79a3"}, + {file = "mypy-1.4.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:34a9239d5b3502c17f07fd7c0b2ae6b7dd7d7f6af35fbb5072c6208e76295816"}, + {file = "mypy-1.4.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:5703097c4936bbb9e9bce41478c8d08edd2865e177dc4c52be759f81ee4dd26c"}, + {file = "mypy-1.4.1-cp39-cp39-win_amd64.whl", hash = "sha256:e02d700ec8d9b1859790c0475df4e4092c7bf3272a4fd2c9f33d87fac4427b8f"}, + {file = "mypy-1.4.1-py3-none-any.whl", hash = "sha256:45d32cec14e7b97af848bddd97d85ea4f0db4d5a149ed9676caa4eb2f7402bb4"}, + {file = "mypy-1.4.1.tar.gz", hash = "sha256:9bbcd9ab8ea1f2e1c8031c21445b511442cc45c89951e49bbf852cbb70755b1b"}, +] + +[package.dependencies] +mypy-extensions = ">=1.0.0" tomli = {version = ">=1.1.0", markers = "python_version < \"3.11\""} -typing-extensions = ">=3.10" +typing-extensions = ">=4.1.0" [package.extras] dmypy = ["psutil (>=4.0)"] +install-types = ["pip"] python2 = ["typed-ast (>=1.4.0,<2)"] reports = ["lxml"] @@ -2304,13 +2319,13 @@ pybtex = ">=0.16" [[package]] name = "pycodestyle" -version = "2.9.1" +version = "2.10.0" description = "Python style guide checker" optional = false python-versions = ">=3.6" files = [ - {file = "pycodestyle-2.9.1-py2.py3-none-any.whl", hash = "sha256:d1735fc58b418fd7c5f658d28d943854f8a849b01a5d0a1e6f3f3fdd0166804b"}, - {file = "pycodestyle-2.9.1.tar.gz", hash = "sha256:2c9607871d58c76354b697b42f5d57e1ada7d261c261efac224b664affdc5785"}, + {file = "pycodestyle-2.10.0-py2.py3-none-any.whl", hash = "sha256:8a4eaf0d0495c7395bdab3589ac2db602797d76207242c17d470186815706610"}, + {file = "pycodestyle-2.10.0.tar.gz", hash = "sha256:347187bdb476329d98f695c213d7295a846d1152ff4fe9bacb8a9590b8ee7053"}, ] [[package]] @@ -2380,13 +2395,13 @@ pyparsing = ">=2.1.4" [[package]] name = "pyflakes" -version = "2.5.0" +version = "3.0.1" description = "passive checker of Python programs" optional = false python-versions = ">=3.6" files = [ - {file = "pyflakes-2.5.0-py2.py3-none-any.whl", hash = "sha256:4579f67d887f804e67edb544428f264b7b24f435b263c4614f384135cea553d2"}, - {file = "pyflakes-2.5.0.tar.gz", hash = "sha256:491feb020dca48ccc562a8c0cbe8df07ee13078df59813b83959cbdada312ea3"}, + {file = "pyflakes-3.0.1-py2.py3-none-any.whl", hash = "sha256:ec55bf7fe21fff7f1ad2f7da62363d749e2a470500eab1b555334b67aa1ef8cf"}, + {file = "pyflakes-3.0.1.tar.gz", hash = "sha256:ec8b276a6b60bd80defed25add7e439881c19e64850afd9b346283d4165fd0fd"}, ] [[package]] @@ -3657,5 +3672,5 @@ viz = ["pygraphviz"] [metadata] lock-version = "2.0" -python-versions = ">=3.8,<3.12" -content-hash = "36b8dbd6d9ec663dc72fad91c02356d139180933cc0647f1b7b387c0956c092d" +python-versions = ">=3.8.1,<3.12" +content-hash = "0e915c9796c05cd28e374ba952c47017bb6a73ce1fa04589eeb3b1db0d64af40" diff --git a/pyproject.toml b/pyproject.toml index f7a840dd1..b27d20fd8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -36,7 +36,7 @@ include = ['pywhy_graphs*'] exclude = ['*tests'] [tool.poetry.dependencies] -python = ">=3.8,<3.12" +python = ">=3.8.1,<3.12" numpy = "^1.21.0" scipy = "^1.8.0" networkx = "^3.1" From 386705424b6b41d88a26475ec3f4efa2dad9b000 Mon Sep 17 00:00:00 2001 From: Adam Li Date: Fri, 28 Jul 2023 20:20:55 -0700 Subject: [PATCH 10/14] Fix ing unit tests Signed-off-by: Adam Li --- pywhy_graphs/classes/tests/test_graph.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/pywhy_graphs/classes/tests/test_graph.py b/pywhy_graphs/classes/tests/test_graph.py index 4c1875a4f..2af5872ee 100644 --- a/pywhy_graphs/classes/tests/test_graph.py +++ b/pywhy_graphs/classes/tests/test_graph.py @@ -459,4 +459,6 @@ def test_add_s_nodes(self): assert G.s_nodes == [] G.add_s_node(domain_ids=(0, 1), node_changes={0, 1}) - assert G.s_nodes == [("S", 0)] + + # S-nodes are encoded with a similar pattern as F-nodes + assert G.s_nodes == [("F", 0)] From 084e740db1ac0212e6f04e7d991aedba82e3ccbe Mon Sep 17 00:00:00 2001 From: Adam Li Date: Fri, 28 Jul 2023 20:57:43 -0700 Subject: [PATCH 11/14] Fix CI Signed-off-by: Adam Li --- pywhy_graphs/classes/augmented.py | 3 ++- pywhy_graphs/classes/tests/test_graph.py | 2 +- pywhy_graphs/functional/tests/conftest.py | 3 +++ pywhy_graphs/functional/tests/test_base.py | 1 + 4 files changed, 7 insertions(+), 2 deletions(-) diff --git a/pywhy_graphs/classes/augmented.py b/pywhy_graphs/classes/augmented.py index c9418cf48..b530a58ab 100644 --- a/pywhy_graphs/classes/augmented.py +++ b/pywhy_graphs/classes/augmented.py @@ -289,7 +289,8 @@ def add_s_node(self, domain_ids: Tuple, node_changes: Optional[Set[Node]] = None # add a new S-node into the graph # Note: that we represent S-nodes as F-nodes - s_node_name = ("F", len(self.augmented_nodes)) + # s_node_name = ("F", len(self.augmented_nodes)) + s_node_name = ("S", len(self.s_nodes)) self.add_node(s_node_name, domain_ids=domain_ids) # add edge between the F-node and its intervention set diff --git a/pywhy_graphs/classes/tests/test_graph.py b/pywhy_graphs/classes/tests/test_graph.py index 2af5872ee..f1ce385c5 100644 --- a/pywhy_graphs/classes/tests/test_graph.py +++ b/pywhy_graphs/classes/tests/test_graph.py @@ -461,4 +461,4 @@ def test_add_s_nodes(self): G.add_s_node(domain_ids=(0, 1), node_changes={0, 1}) # S-nodes are encoded with a similar pattern as F-nodes - assert G.s_nodes == [("F", 0)] + assert G.s_nodes == [("S", 0)] diff --git a/pywhy_graphs/functional/tests/conftest.py b/pywhy_graphs/functional/tests/conftest.py index 9ae917bde..2afd1b339 100644 --- a/pywhy_graphs/functional/tests/conftest.py +++ b/pywhy_graphs/functional/tests/conftest.py @@ -8,3 +8,6 @@ def suppress_pgmpy_warnings(): warnings.filterwarnings("ignore", "DeprecationWarning:pgmpy.*") warnings.filterwarnings("ignore", "DeprecationWarning:numpy.*") warnings.filterwarnings("ignore", "DeprecationWarning:pkg_resources.*") + warnings.filterwarnings( + "ignore", category=UserWarning, message="Probability values don't exactly sum to 1." + ) diff --git a/pywhy_graphs/functional/tests/test_base.py b/pywhy_graphs/functional/tests/test_base.py index b95793e89..7c0f689ed 100644 --- a/pywhy_graphs/functional/tests/test_base.py +++ b/pywhy_graphs/functional/tests/test_base.py @@ -20,6 +20,7 @@ def test_check_input_graph(): G.add_node(3, exogenous_function="func3", exogenous_distribution="dist3") G.add_edge(1, 2) G.add_edge(3, 2) + G.graph["functional"] = "linear_gaussian" G_copy = G.copy() # Test a valid input graph From 14a5d673ddac3958422041c5fac30db1eeaaa2cf Mon Sep 17 00:00:00 2001 From: Adam Li Date: Sat, 29 Jul 2023 08:49:46 -0700 Subject: [PATCH 12/14] Fix ci Signed-off-by: Adam Li --- doc/api.rst | 4 ++-- doc/reference/functional/index.rst | 4 ++-- pywhy_graphs/functional/__init__.py | 2 +- pywhy_graphs/functional/multidomain.py | 4 ++-- 4 files changed, 7 insertions(+), 7 deletions(-) diff --git a/doc/api.rst b/doc/api.rst index 37e12f9f2..d1047d592 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -161,10 +161,10 @@ linear structural equation Gaussian models (SEMs). .. autosummary:: :toctree: generated/ - make_graph_linear_gaussian + make_random_linear_gaussian_graph apply_linear_soft_intervention set_node_attributes_with_G - make_graph_multidomain + make_random_multidomain_graph Visualization of causal graphs ============================== diff --git a/doc/reference/functional/index.rst b/doc/reference/functional/index.rst index 23b3f1f1a..ea016bcef 100644 --- a/doc/reference/functional/index.rst +++ b/doc/reference/functional/index.rst @@ -266,7 +266,7 @@ Linear functional graphs .. autosummary:: :toctree: ../../generated/ - make_graph_linear_gaussian + make_random_linear_gaussian_graph apply_linear_soft_intervention Multidomain @@ -306,4 +306,4 @@ Linear functional selection diagrams .. autosummary:: :toctree: ../../generated/ - make_graph_multidomain + make_random_multidomain_graph diff --git a/pywhy_graphs/functional/__init__.py b/pywhy_graphs/functional/__init__.py index 59117a28a..b668ba017 100644 --- a/pywhy_graphs/functional/__init__.py +++ b/pywhy_graphs/functional/__init__.py @@ -2,7 +2,7 @@ from .linear import apply_linear_soft_intervention, make_random_linear_gaussian_graph from .multidomain import ( generate_multidomain_noise_for_node, - make_graph_multidomain, + make_random_multidomain_graph, sample_multidomain_lin_functions, ) from .utils import set_node_attributes_with_G diff --git a/pywhy_graphs/functional/multidomain.py b/pywhy_graphs/functional/multidomain.py index d3311e1f0..7d0f70f16 100644 --- a/pywhy_graphs/functional/multidomain.py +++ b/pywhy_graphs/functional/multidomain.py @@ -16,7 +16,7 @@ from .linear import generate_noise_for_node -def make_graph_multidomain( +def make_random_multidomain_graph( G: nx.DiGraph, n_domains: int = 2, n_nodes_with_s_nodes: Union[int, Tuple[int]] = 1, @@ -181,7 +181,7 @@ def make_graph_multidomain( random_state=random_state, ) - G.graph["linear_gaussian"] = True + G.graph["functional"] = "linear_gaussian" G.graph["S-nodes"] = s_nodes G.graph["n_domains"] = n_domains return G From ec5af70403499b1c5d6bd2bba8cc31134bc9fe91 Mon Sep 17 00:00:00 2001 From: Adam Li Date: Tue, 8 Aug 2023 12:33:55 -0400 Subject: [PATCH 13/14] WIP Signed-off-by: Adam Li --- doc/conf.py | 2 + doc/whats_new/v0.1.rst | 4 +- examples/simulations/README.txt | 5 + .../plot_discrete_causal_bayesian_network.py | 104 ++++++++++++++++++ .../plot_graphs_with_interventions.py | 0 .../plot_linear_gaussian_causal_graph.py | 22 ++++ pywhy_graphs/classes/augmented.py | 11 ++ pywhy_graphs/functional/base.py | 9 +- pywhy_graphs/functional/discrete.py | 3 + pywhy_graphs/functional/linear.py | 2 +- pywhy_graphs/functional/multidomain.py | 34 ++++++ pywhy_graphs/viz/draw.py | 4 +- 12 files changed, 194 insertions(+), 6 deletions(-) create mode 100644 examples/simulations/README.txt create mode 100644 examples/simulations/plot_discrete_causal_bayesian_network.py create mode 100644 examples/simulations/plot_graphs_with_interventions.py create mode 100644 examples/simulations/plot_linear_gaussian_causal_graph.py diff --git a/doc/conf.py b/doc/conf.py index 79135aecc..7606c3ea1 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -236,6 +236,8 @@ def setup(app): "graphviz": ("https://graphviz.readthedocs.io/en/stable/", None), "sphinx-gallery": ("https://sphinx-gallery.github.io/stable/", None), "pgmpy": ("https://pgmpy.org/", None), + "dodiscover": ("https://pywhy.org/dodiscover/dev/", None), + "dowhy": ("https://pywhy.org/dowhy/dev/", None), } intersphinx_timeout = 5 diff --git a/doc/whats_new/v0.1.rst b/doc/whats_new/v0.1.rst index 9361f0b87..434d6cd01 100644 --- a/doc/whats_new/v0.1.rst +++ b/doc/whats_new/v0.1.rst @@ -40,8 +40,8 @@ Changelog - |Feature| Implement export/import functions to go to/from pywhy-graphs to pcalg and tetrad, by `Adam Li`_ (:pr:`60`) - |Feature| Implement export/import functions to go to/from pywhy-graphs to ananke-causal, by `Jaron Lee`_ (:pr:`63`) - |Feature| Implement pre-commit hooks for development, by `Jaron Lee`_ (:pr:`68`) -- |Feature| Implement a new submodule for converting graphs to a functional model, with :func:`pywhy_graphs.functional.make_graph_linear_gaussian`, by `Adam Li`_ (:pr:`75`) -- |Feature| Implement a multidomain linear functional graph, with :func:`pywhy_graphs.functional.make_graph_multidomain`, by `Adam Li`_ (:pr:`77`) +- |Feature| Implement a new submodule for converting graphs to a functional model, with :func:`pywhy_graphs.functional.make_random_linear_gaussian_graph`, by `Adam Li`_ (:pr:`75`) +- |Feature| Implement a multidomain linear functional graph, with :func:`pywhy_graphs.functional.make_random_multidomain_graph`, by `Adam Li`_ (:pr:`77`) - |Feature| Implement and test functions to find inducing paths between two nodes, `Aryan Roy`_ (:pr:`78`) - |Feature| Implement general functional API for sampling and generating a functional causal graph, by `Adam Li`_ (:pr:`82`) diff --git a/examples/simulations/README.txt b/examples/simulations/README.txt new file mode 100644 index 000000000..4c0f33ea5 --- /dev/null +++ b/examples/simulations/README.txt @@ -0,0 +1,5 @@ +Examples Simulating Data From Causal Diagrams +--------------------------------------------- + +Examples demonstrating how to simulate data stemming from causal diagrams in a variety of different +settings. diff --git a/examples/simulations/plot_discrete_causal_bayesian_network.py b/examples/simulations/plot_discrete_causal_bayesian_network.py new file mode 100644 index 000000000..cdf77039f --- /dev/null +++ b/examples/simulations/plot_discrete_causal_bayesian_network.py @@ -0,0 +1,104 @@ +""" +.. _ex-discrete-cbn: + +============================================================= +Discrete Causal Bayesian Networks and Simulated Discrete Data +============================================================= + +Discrete data arises commonly in many applications. For example, data is typically stored +in a table, which categorizes values into discrete values representing certain categories. +In survey data, answers are typically multiple choice. In medical settings, many symptoms +are rated on a scale of for example 1-5. Perhaps a data feature indicates whether or not a +certain disease is present or not, resulting in a binary variable. + +Even if these are discrete, the data is still generated by some unknown structural causal model, +which induces a causal diagram. Causal algorithms will typically need to generate different +causal models and then datasets with varying sample sizes to evaluate the algorithms. + +In this example, we illustrate how to generate discrete data from a random causal graph. + +For information on generating continuous data from a causal graph, one can see +:ref:`ex-linear-gaussian-graph`. +""" + +# %% +# Import the required libraries +# ----------------------------- +import networkx as nx +from pywhy_graphs.functional.discrete import make_random_discrete_graph +from pywhy_graphs.functional import sample_from_graph +from pywhy_graphs.viz import draw +from pgmpy.factors.discrete.CPD import TabularCPD + +# define a helper function to print the full CPD +def print_full(cpd): + backup = TabularCPD._truncate_strtable + TabularCPD._truncate_strtable = lambda self, x: x + print(cpd) + TabularCPD._truncate_strtable = backup + +# %% +# Construct the causal graph +# -------------------------- +# In order to generate the data, we start from a causal graph that informs us how +# data is generated. That is, each variable is a function of its exogenous noise distribution +# and its parent values. +edge_list = [ + ("A", "B"), + ("B", "C"), + ("C", "D"), + ("B", "D"), + ("X", "A"), + ("X", "C"), + ("C", "W"), +] +G = nx.DiGraph() + +G.add_edges_from(edge_list) + +draw(G) + +# %% +# Define functional relationships on the graph +# -------------------------------------------- +# In order to generate data, we need to define the full functional relationship +# between every node and its parents and also how to generate parent-less nodes. +# We leverage the :class:`pgmpy.factors.discrete.CPD.TabularCPD` abstraction to +# represent conditional probability distributions as tables. + +cardinality_lims = {node: [2, 4] for node in G.nodes} +weight_lims = {node: [1, 100] for node in G.nodes} +noise_ratio_lims = {node: [0.1, 0.1] for node in G.nodes} +seed = 1234 + +G = make_random_discrete_graph( + G, + cardinality_lims=cardinality_lims, + weight_lims=weight_lims, + noise_ratio_lims=noise_ratio_lims, + random_state=seed, + overwrite=True, +) + +print(G) + +# we can extract the conditional probability table for each node, which is a function of its parents +node_dict = G.nodes["C"] + +# We see that each node is fully defined given a conditional probability table, stored as a node +# attribute under the keyword 'cpd'. For more information on the CPD object, see pgmpy's documentation +# on :class:`pgmpy.factors.discrete.CPD.TabularCPD`. Note this is in contrast with what node attributes +# are required in general for simulating data from a causal graph in pywhy-graphs. +print_full(node_dict["cpd"]) + +# %% +# Sample data from the graph +# -------------------------- +# Now, we can sample data from the graph that is generated according to the causal diagram. +# This data can be used for instance to evaluate causal discovery algorithms from +# :mod:`dodiscover`, or causal estimation algorithms from :mod:`dowhy`. + +# now we sample from the graph the discrete dataset +df = sample_from_graph(G, n_samples=2000, n_jobs=1, random_state=seed) + +print(df.head()) diff --git a/examples/simulations/plot_graphs_with_interventions.py b/examples/simulations/plot_graphs_with_interventions.py new file mode 100644 index 000000000..e69de29bb diff --git a/examples/simulations/plot_linear_gaussian_causal_graph.py b/examples/simulations/plot_linear_gaussian_causal_graph.py new file mode 100644 index 000000000..02a50d01b --- /dev/null +++ b/examples/simulations/plot_linear_gaussian_causal_graph.py @@ -0,0 +1,22 @@ +""" +.. _ex-linear-gaussian-graph: + +===================================================== +Linear Gaussian Graphs and Generating Continuous Data +===================================================== + +Linear gaussian graphs are an important model. These are joint distributions +that follow the structure of a causal graph, where exogenous noise distributions are Gaussian +and nodes are linear combinations of their parents perturbed by the exogenous variable. + +Thus, each edge is associated with a weight of how the parent node is added to the current +node. + +In this example, we illustrate how to generate continuous data from a linear gaussian +causal graph. + +For information on generating discrete data from a causal graph, one can see +:ref:`ex-discrete-cbn`. Consider reading the user-guide, :ref:`functional-causal-graphical-models` +to understand how an arbitrary functional relationships are encoded in a causal graph. +""" + diff --git a/pywhy_graphs/classes/augmented.py b/pywhy_graphs/classes/augmented.py index b530a58ab..296d3ce62 100644 --- a/pywhy_graphs/classes/augmented.py +++ b/pywhy_graphs/classes/augmented.py @@ -135,6 +135,17 @@ def add_node(self, u, **attrs): def directed_edge_name(self) -> str: pass + @property + def n_domains(self): + domains = set() + for node_dict in self.nodes(data=True): + domain_ids = node_dict.get('domain_ids', None) + if domain_ids is not None: + domains.add(domain for domain in domain_ids) + + return len(domains) + + def _verify_augmentednode_dict(self): # verify validity of F nodes if "F-nodes" not in self.graph: diff --git a/pywhy_graphs/functional/base.py b/pywhy_graphs/functional/base.py index 32c762094..34565a882 100644 --- a/pywhy_graphs/functional/base.py +++ b/pywhy_graphs/functional/base.py @@ -230,10 +230,16 @@ def sample_from_graph( rng: np.random.Generator = np.random.default_rng(random_state) if hasattr(G, "get_graphs"): - directed_G = G.get_graphs("directed") + directed_G = nx.DiGraph() + directed_G.add_nodes_from(G.nodes(data=True)) + for edge in directed_G.edges: + G.add_edge(*edge, **directed_G.edges[edge]) + + directed_G.graph = G.graph.copy() else: directed_G = G + print('inside: ', directed_G.nodes(data=True)) # check input _check_input_graph(directed_G) @@ -378,6 +384,7 @@ def _check_input_graph(G: nx.DiGraph): "the graph with functions." ) for node in G.nodes: + print(G.nodes[node]) if G.nodes[node].get("exogenous_function", None) is None: raise ValueError(f"Node {node} does not have an exogenous function.") if G.nodes[node].get("exogenous_distribution", None) is None: diff --git a/pywhy_graphs/functional/discrete.py b/pywhy_graphs/functional/discrete.py index 1c8f30716..6a2d75fb8 100644 --- a/pywhy_graphs/functional/discrete.py +++ b/pywhy_graphs/functional/discrete.py @@ -18,6 +18,9 @@ def apply_discrete_soft_intervention( Linear functional causal graph. targets : Set[Node] The set of nodes to intervene on simultanenously. + weight_ranges : Optional[List], optional + The range of weights to sample from for each target node. If None, then + the range is [1, 5], by default None. random_state : RandomState, optional Random seed, by default None. diff --git a/pywhy_graphs/functional/linear.py b/pywhy_graphs/functional/linear.py index bb83c68d6..383a0fc3e 100644 --- a/pywhy_graphs/functional/linear.py +++ b/pywhy_graphs/functional/linear.py @@ -148,7 +148,7 @@ def apply_linear_soft_intervention( of the target nodes. That is, the soft intervention, perturbs the exogenous noise of the target nodes. """ - if not G.graph.get("linear_gaussian", True): + if not G.graph.get("functional", 'linear_gaussian'): raise ValueError("The input graph must be a linear Gaussian graph.") if not all(target in G.nodes for target in targets): raise ValueError(f"All targets {targets} must be in the graph: {G.nodes}.") diff --git a/pywhy_graphs/functional/multidomain.py b/pywhy_graphs/functional/multidomain.py index 7d0f70f16..8f264a200 100644 --- a/pywhy_graphs/functional/multidomain.py +++ b/pywhy_graphs/functional/multidomain.py @@ -16,6 +16,40 @@ from .linear import generate_noise_for_node +def apply_domain_shift(G, node, domain_ids, exogenous_distribution=None, random_state=None): + """Applies a domain shift to a node in a multi-domain selection diagram. + + Parameters + ---------- + G : AugmentedGraph + The graph to apply the domain shift to. + node : Node + The node to apply the domain shift to. + domain_ids : tuple of int + The domain pair to apply the domain shift to. The first element is the base domain + and must already exist. The second element is the new domain to add with respect + to the differences. + exogenous_distribution : Optional[Callable], optional + The new exogenous distribution to apply to the node. If None, then will use the + existing exogenous distribution. By default None. + """ + if exogenous_distribution is None: + rng = np.random.default_rng(random_state) + exogenous_distribution = lambda: rng.standard_normal() + + # determine which S-node the domain IDs corresond to + snode = G.domain_ids_to_snode[domain_ids] + + if not G.has_edge(snode, node): + raise RuntimeError(f'Node {node} does not have an S-node {snode} pointing to it for domain' + f'pairs {domain_ids}.') + + # now add a new exogenous distribution for the node + domain_id = domain_ids[1] + G.nodes[node]['domain'][domain_id]["exogenous_distribution"] = lambda: exogenous_distribution() + return G + + def make_random_multidomain_graph( G: nx.DiGraph, n_domains: int = 2, diff --git a/pywhy_graphs/viz/draw.py b/pywhy_graphs/viz/draw.py index a4e4cce62..c21aff556 100644 --- a/pywhy_graphs/viz/draw.py +++ b/pywhy_graphs/viz/draw.py @@ -135,8 +135,8 @@ def draw( # an edge case of drawing graphs is the undirected Markov network if hasattr(G, "undirected_edges"): undirected_edges = G.undirected_edges - elif isinstance(G, nx.Graph): - undirected_edges = G.edges() + elif isinstance(G, nx.Graph) and not G.is_directed(): + undirected_edges = G.edges() if hasattr(G, "bidirected_edges"): bidirected_edges = G.bidirected_edges From c8de7a30b525510bba69cec591880970366d3626 Mon Sep 17 00:00:00 2001 From: Adam Li Date: Tue, 8 Aug 2023 13:33:29 -0400 Subject: [PATCH 14/14] Fix style Signed-off-by: Adam Li --- .../plot_discrete_causal_bayesian_network.py | 14 ++++++++----- .../plot_linear_gaussian_causal_graph.py | 1 - pywhy_graphs/classes/augmented.py | 3 +-- pywhy_graphs/functional/base.py | 2 +- pywhy_graphs/functional/linear.py | 2 +- pywhy_graphs/functional/multidomain.py | 12 ++++++----- pywhy_graphs/viz/draw.py | 21 ++++++++++++------- pywhy_graphs/viz/tests/test_draw.py | 13 ++++++++++++ 8 files changed, 45 insertions(+), 23 deletions(-) diff --git a/examples/simulations/plot_discrete_causal_bayesian_network.py b/examples/simulations/plot_discrete_causal_bayesian_network.py index cdf77039f..519190953 100644 --- a/examples/simulations/plot_discrete_causal_bayesian_network.py +++ b/examples/simulations/plot_discrete_causal_bayesian_network.py @@ -25,10 +25,12 @@ # Import the required libraries # ----------------------------- import networkx as nx -from pywhy_graphs.functional.discrete import make_random_discrete_graph +from pgmpy.factors.discrete.CPD import TabularCPD + from pywhy_graphs.functional import sample_from_graph +from pywhy_graphs.functional.discrete import make_random_discrete_graph from pywhy_graphs.viz import draw -from pgmpy.factors.discrete.CPD import TabularCPD + # define a helper function to print the full CPD def print_full(cpd): @@ -37,6 +39,7 @@ def print_full(cpd): print(cpd) TabularCPD._truncate_strtable = backup + # %% # Construct the causal graph # -------------------------- @@ -86,9 +89,10 @@ def print_full(cpd): node_dict = G.nodes["C"] # We see that each node is fully defined given a conditional probability table, stored as a node -# attribute under the keyword 'cpd'. For more information on the CPD object, see pgmpy's documentation -# on :class:`pgmpy.factors.discrete.CPD.TabularCPD`. Note this is in contrast with what node attributes -# are required in general for simulating data from a causal graph in pywhy-graphs. +# attribute under the keyword 'cpd'. For more information on the CPD object, see +# pgmpy's documentation on :class:`pgmpy.factors.discrete.CPD.TabularCPD`. Note this +# is in contrast with what node attributes are required in general for simulating data +# from a causal graph in pywhy-graphs. print_full(node_dict["cpd"]) # %% diff --git a/examples/simulations/plot_linear_gaussian_causal_graph.py b/examples/simulations/plot_linear_gaussian_causal_graph.py index 02a50d01b..7cdaac48f 100644 --- a/examples/simulations/plot_linear_gaussian_causal_graph.py +++ b/examples/simulations/plot_linear_gaussian_causal_graph.py @@ -19,4 +19,3 @@ :ref:`ex-discrete-cbn`. Consider reading the user-guide, :ref:`functional-causal-graphical-models` to understand how an arbitrary functional relationships are encoded in a causal graph. """ - diff --git a/pywhy_graphs/classes/augmented.py b/pywhy_graphs/classes/augmented.py index 296d3ce62..ed979b0f7 100644 --- a/pywhy_graphs/classes/augmented.py +++ b/pywhy_graphs/classes/augmented.py @@ -139,12 +139,11 @@ def directed_edge_name(self) -> str: def n_domains(self): domains = set() for node_dict in self.nodes(data=True): - domain_ids = node_dict.get('domain_ids', None) + domain_ids = node_dict.get("domain_ids", None) if domain_ids is not None: domains.add(domain for domain in domain_ids) return len(domains) - def _verify_augmentednode_dict(self): # verify validity of F nodes diff --git a/pywhy_graphs/functional/base.py b/pywhy_graphs/functional/base.py index 34565a882..52182d4ee 100644 --- a/pywhy_graphs/functional/base.py +++ b/pywhy_graphs/functional/base.py @@ -239,7 +239,7 @@ def sample_from_graph( else: directed_G = G - print('inside: ', directed_G.nodes(data=True)) + print("inside: ", directed_G.nodes(data=True)) # check input _check_input_graph(directed_G) diff --git a/pywhy_graphs/functional/linear.py b/pywhy_graphs/functional/linear.py index 383a0fc3e..2d4bd1d77 100644 --- a/pywhy_graphs/functional/linear.py +++ b/pywhy_graphs/functional/linear.py @@ -148,7 +148,7 @@ def apply_linear_soft_intervention( of the target nodes. That is, the soft intervention, perturbs the exogenous noise of the target nodes. """ - if not G.graph.get("functional", 'linear_gaussian'): + if not G.graph.get("functional", "linear_gaussian"): raise ValueError("The input graph must be a linear Gaussian graph.") if not all(target in G.nodes for target in targets): raise ValueError(f"All targets {targets} must be in the graph: {G.nodes}.") diff --git a/pywhy_graphs/functional/multidomain.py b/pywhy_graphs/functional/multidomain.py index 8f264a200..177046f6a 100644 --- a/pywhy_graphs/functional/multidomain.py +++ b/pywhy_graphs/functional/multidomain.py @@ -37,16 +37,18 @@ def apply_domain_shift(G, node, domain_ids, exogenous_distribution=None, random_ rng = np.random.default_rng(random_state) exogenous_distribution = lambda: rng.standard_normal() - # determine which S-node the domain IDs corresond to + # determine which S-node the domain IDs correspond to snode = G.domain_ids_to_snode[domain_ids] if not G.has_edge(snode, node): - raise RuntimeError(f'Node {node} does not have an S-node {snode} pointing to it for domain' - f'pairs {domain_ids}.') - + raise RuntimeError( + f"Node {node} does not have an S-node {snode} pointing to it for domain" + f"pairs {domain_ids}." + ) + # now add a new exogenous distribution for the node domain_id = domain_ids[1] - G.nodes[node]['domain'][domain_id]["exogenous_distribution"] = lambda: exogenous_distribution() + G.nodes[node]["domain"][domain_id]["exogenous_distribution"] = lambda: exogenous_distribution() return G diff --git a/pywhy_graphs/viz/draw.py b/pywhy_graphs/viz/draw.py index c21aff556..4ffd7b44e 100644 --- a/pywhy_graphs/viz/draw.py +++ b/pywhy_graphs/viz/draw.py @@ -81,6 +81,7 @@ def draw( direction: Optional[str] = None, pos: Optional[dict] = None, name: Optional[str] = None, + node_order: Optional[List] = None, shape: str = "square", **attrs, ): @@ -123,6 +124,9 @@ def draw( if direction == "LR": dot.graph_attr["rankdir"] = direction + if node_order is None: + node_order = G.nodes + circle_edges = None directed_edges = None undirected_edges = None @@ -136,10 +140,17 @@ def draw( if hasattr(G, "undirected_edges"): undirected_edges = G.undirected_edges elif isinstance(G, nx.Graph) and not G.is_directed(): - undirected_edges = G.edges() + undirected_edges = G.edges() if hasattr(G, "bidirected_edges"): bidirected_edges = G.bidirected_edges + for v in node_order: + child = str(v) + if pos and pos.get(v) is not None: + dot.node(child, shape=shape, height=".5", width=".5", pos=f"{pos[v][0]},{pos[v][1]}!") + else: + dot.node(child, shape=shape, height=".5", width=".5") + # draw PAG edges and keep track of the circular endpoints found dot, found_circle_sibs = _draw_circle_edges( dot, @@ -157,14 +168,8 @@ def draw( # only need to draw directed edges now, but directed_G can be a nx.Graph if hasattr(directed_G, "predecessors"): - for v in G.nodes: + for v in node_order: child = str(v) - if pos and pos.get(v) is not None: - dot.node( - child, shape=shape, height=".5", width=".5", pos=f"{pos[v][0]},{pos[v][1]}!" - ) - else: - dot.node(child, shape=shape, height=".5", width=".5") for parent in directed_G.predecessors(v): if parent == v or not directed_G.has_edge(parent, v): diff --git a/pywhy_graphs/viz/tests/test_draw.py b/pywhy_graphs/viz/tests/test_draw.py index 1dfa19e1a..253aca135 100644 --- a/pywhy_graphs/viz/tests/test_draw.py +++ b/pywhy_graphs/viz/tests/test_draw.py @@ -78,6 +78,19 @@ def test_draw_pos_contains_more_nodes(): assert "pos=" not in re.search(r"\tz \[(.*)\]", dot_body_text).groups()[0] +def test_draw_does_not_show_undirected(): + graph = nx.DiGraph() + + graph.add_edge("x", "y") + graph.add_edge("y", "z") + + dot = draw(graph) + dot_body_text = "".join(dot.body) + + # there should not be a drawn undirected edges + assert "dir=none" not in dot_body_text + + def test_draw_pos_with_pag(): """ Ensure the Graphviz pos="x,y!" attribute is generated by the draw function