Tree Inference
pyggdrasil.tree_inference.CellAttachmentStrategy
Bases: Enum
Enum representing valid strategies for attaching cells to the mutation tree.
Allowed values
- UNIFORM_INCLUDE_ROOT: each node in the tree has equal probability of being attached a cell
- UNIFORM_EXCLUDE_ROOT: each non-root node in the tree has equal probability of being attached a cell
pyggdrasil.tree_inference.CellSimulationData
dataclass
Bases: TypedDict
Data class for Cell Simulation Data.
pyggdrasil.tree_inference.CellSimulationId
Bases: MutationDataId
Class representing a cell simulation id.
Note: that the Tree_id contains the number of mutations i.e. nodes-1
__init__(seed, tree_id, n_cells, fpr, fnr, na_rate, observe_homozygous, strategy)
Initializes a cell simulation id.
from_str(str_id)
classmethod
Creates a CellSimulation id from a string representation of the id. Args: str_id: str
pyggdrasil.tree_inference.CellSimulationModel
Bases: BaseModel
Model for Cell Simulation Parameters.
Note: used in the simulation of cells and mutations gen_sim_data()
realistic_cell_number(v)
Validate that the number of cells is realistic.
realistic_fnr(v)
Validate that the false negative rate is realistic.
realistic_fpr(v)
Validate that the false positive rate is realistic.
realistic_mutation_number(v)
Validate that the number of mutations is realistic.
realistic_na_rate(v)
Validate that the NA rate is realistic.
pyggdrasil.tree_inference.ErrorCombinations
Bases: Enum
Error Combinations for Cell Simulation and Tree Inference.
Ideal: fpr=1e-6, fnr=1e-6 Typical: fpr=1e-6, fnr=0.1 Large: fpr=0.1, fnr=0.1 Extreme: fpr=0.3, fnr=0.3
pyggdrasil.tree_inference.ErrorRates = tuple[float, float]
module-attribute
pyggdrasil.tree_inference.evolve_tree_mcmc(init_tree, n_moves, rng, move_probs=MoveProbConfigOptions.DEFAULT.value)
Evolves a tree using the SCITE MCMC moves, assumes default move weights.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
init_tree |
TreeNode
|
TreeNode tree to evolve |
required |
n_moves |
int
|
int number of moves to perform |
required |
rng |
JAXRandomKey
|
JAXRandomKey random number generator |
required |
move_probs |
MoveProbabilities
|
MoveProbabilities move probabilities to use |
value
|
Returns:
Name | Type | Description |
---|---|---|
tree_ev |
TreeNode
|
TreeNode evolved tree |
pyggdrasil.tree_inference.evolve_tree_mcmc_all(init_tree, n_moves, rng, move_probs=MoveProbConfigOptions.DEFAULT.value)
Evolves a tree using the SCITE MCMC moves, assumes default move weights.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
init_tree |
TreeNode
|
TreeNode tree to evolve |
required |
n_moves |
int
|
int number of moves to perform |
required |
rng |
JAXRandomKey
|
JAXRandomKey random number generator |
required |
move_probs |
MoveProbabilities
|
MoveProbabilities move probabilities to use |
value
|
Returns:
Name | Type | Description |
---|---|---|
trees |
list[TreeNode]
|
list[TreeNode] evolved trees in order of evolution |
pyggdrasil.tree_inference.gen_sim_data(params, rng, tree_tn)
Generates cell mutation matrix for one tree.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
params |
CellSimulationModel
|
TypedDict from parser of cell_simulation.py input parameters from parser for simulation |
required |
rng |
JAXRandomKey
|
JAX random number generator |
required |
tree_tn |
TreeNode
|
TreeNode tree to generate data for |
required |
Returns: data: dict data dictionary containing - serialised data for the tree: adjacency_matrix - adjacency matrix of the tree perfect_mutation_mat - perfect mutation matrix noisy_mutation_mat - noisy mutation matrix (only if fpr > 0 | fnr > 0 | na_rate > 0) root - root of the tree (TreeNode)
pyggdrasil.tree_inference.get_descendants(adj_matrix, labels, parent, include_parent=False)
Returns a list of labels representing the descendants of node parent. Used boolean matrix exponentiation to find descendants.
Complexity
Naive: O(n^3 * (n-1)) where n is the number of nodes in the tree including root. TODO: - Consider implementing 'Exponentiation by Squaring Algorithm' for O(n^3 * log(m) - fix conditional exponentiation for exponent < n-1
Args: - adj_matrix: a JAX array of shape (n, n) representing the adjacency matrix - labels: a JAX array of shape (n,) representing the labels of the nodes - parent: an integer representing the label of the node whose descendants we want to find
- a JAX array of integers representing the labels of the descendants of node parent in order of nodes in the adjacency matrix, i.e. the order of the labels if includeParent is True, the parent is included in the list of descendants
pyggdrasil.tree_inference.get_simulation_data(data)
Load the mutation matrix from json object of the simulation data output of cell_simulation.py
Parameters:
Name | Type | Description | Default |
---|---|---|---|
data |
dict
|
dict data dictionary containing - serialised data |
required |
Returns:
Type | Description |
---|---|
CellSimulationData
|
tuple of: adjacency_matrix: TreeAdjacencyMatrix Adjacency matrix of the tree. perfect_mutation_mat: PerfectMutationMatrix Perfect mutation matrix. noisy_mutation_mat: MutationMatrix Noisy mutation matrix. May be none if cell simulation was errorless. root: TreeNode Root of the tree. |
pyggdrasil.tree_inference.huntress_tree_inference(mutation_mat, false_positive_rate, false_negative_rate, n_threads=1)
Runs the HUNTRESS algorithm.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
mutation_mat |
MutationMatrix
|
binary array with entries 0 or 1, depending on whether the mutation is present or not. Shape (n_sites, n_cells) |
required |
false_positive_rate |
float
|
false positive rate, in [0, 1) |
required |
false_negative_rate |
float
|
false negative rate, in [0, 1) |
required |
n_threads |
int
|
number of threads to be used, default 1 |
1
|
Returns:
Type | Description |
---|---|
Node
|
inferred tree. The root node (wildtype) has name |
Example
For a matrix of shape (n_cells, 4) an example tree can be 4 ├── 0 │ ├── 1 │ └── 2 └── 3
pyggdrasil.tree_inference.make_tree(n_nodes, tree_type, seed=None)
Generated basic Trees by parameters.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
n_nodes |
int
|
int number of nodes in the tree |
required |
tree_type |
TreeType
|
TreeType type of the tree (STAR, RANDOM, DEEP) |
required |
seed |
Optional[int]
|
int seed for the random number generator |
None
|
Returns: tree: TreeNode
pyggdrasil.tree_inference.McmcConfig
Bases: BaseModel
Config for MCMC sampler.
Attributes:
move_probs: MoveProbConfig
move probabilities for MCMC sampler
fpr: float
false positive rate
fnr: float
false negative rate
n_samples: int
number of samples to draw
burn_in: int
number of samples to discard as burn-in
thinning: int
thinning factor for samples
id()
String representation of MCMC config.
pyggdrasil.tree_inference.McmcConfigOptions
Bases: Enum
MCMC run configurations.
Implements configurations are DEFAULT and TEST.
DEFAULT
move_probs=MoveProbConfigOptions.DEFAULT fpr=1.24e-06, fnr=0.097, n_samples=12000, burn_in=0, thinning=1
TEST
move_probs=MoveProbConfigOptions.DEFAULT fpr=1.24e-06, fnr=0.097, n_samples=100, burn_in=0, thinning=1
pyggdrasil.tree_inference.McmcRunId
Class representing an MCMC run id.
__init__(seed, data, init_tree_id, mcmc_config)
Initializes an MCMC run id.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
seed |
int
|
int |
required |
data |
Union[CellSimulationId, MutationDataId]
|
Union[CellSimulationId, MutationDataId] |
required |
init_tree_id |
TreeId
|
TreeId |
required |
mcmc_config |
McmcConfig
|
McmcConfig |
required |
pyggdrasil.tree_inference.mcmc_sampler(rng_key, init_tree, error_rates, move_probs, data, num_samples, out_fp, num_burn_in=0, thinning=0, iteration=0)
Sample mutation trees according to the SCITE model.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
rng_key |
JAXRandomKey
|
random key for the MCMC sampler |
required |
init_tree |
Tree
|
initial tree to start the MCMC sampler from |
required |
error_rates |
ErrorRates
|
heta = (pr, nr) error rates |
required |
move_probs |
MoveProbabilities
|
probabilities for each move |
required |
data |
MutationMatrix
|
observed mutation matrix to calculate the log-probability of, given current tree |
required |
num_samples |
int
|
number of samples to return |
required |
out_fp |
Path
|
fullpath to output file (excluding file extension) |
required |
num_burn_in |
int
|
number of samples to discard before returning samples |
0
|
thinning |
int
|
number of samples to discard between samples |
0
|
iteration |
int
|
sample numer in chain, for restarting |
0
|
Returns:
Type | Description |
---|---|
None
|
None |
pyggdrasil.tree_inference.MoveProbabilities
dataclass
Move probabilities. The default values were taken from the paragraph Combining the three MCMC moves of page 14 of the SCITE paper supplement.
pyggdrasil.tree_inference.MoveProbConfig
Bases: BaseModel
Move probabilities for MCMC sampler.
id()
String representation of move probabilities.
move_prob_validator(field_values)
classmethod
Probabilities sum to 1.
pyggdrasil.tree_inference.MoveProbConfigOptions
Bases: Enum
Move probability configurations.
Implements configurations are DEFAULT and OPTIMAL from SCITE paper, supplement p.15.
Default values
prune_and_reattach=0.1, swap_node_labels=0.65, swap_subtrees=0.25
Optimal values
prune_and_reattach=0.55, swap_node_labels=0.4, swap_subtrees=0.05
(Optimal values find ML tree up to 2 or 3 times faster
)
pyggdrasil.tree_inference.MutationDataId
Class representing a mutation data id.
In case we want to infer a tree from real data, we need to provide a mutation data id.
__init__(id)
Initializes a mutation data id.
pyggdrasil.tree_inference.MutationMatrix = Array
module-attribute
pyggdrasil.tree_inference.Tree
dataclass
For N
mutations we use a tree with N+1
nodes,
where the nodes at positions 0, ..., N-1
are "blank"
and can be mapped to any of the mutations.
The node N
is the root node and should always be mapped
to the wild type.
Attrs
tree_topology: the topology of the tree
encoded in the adjacency matrix.
No self-loops, i.e. diagonal is all zeros.
Shape (N+1, N+1)
labels: maps nodes in the tree topology
to the actual mutations.
Note: the last position always maps to itself,
as it's the root, and we use the convention
that root has the largest index.
Shape (N+1,)
print_topo()
Prints the tree in a human-readable format.
to_TreeNode()
Converts this Tree to a TreeNode. Returns the root node of the tree.
tree_from_tree_node(tree_node)
staticmethod
Converts a tree node to a tree
pyggdrasil.tree_inference.TreeId
Class representing a tree id.
A tree id is a unique identifier for a tree.
tree_type: TreeType - type of tree n_nodes: int - number of nodes in the tree seed: int - seed used to generate the tree, not required for star tree cell_simulation_id: str - if the tree was generated from a cell simulation, i.e. Huntress
__init__(tree_type, n_nodes, seed=None, cell_simulation_id=None)
Initializes a tree id.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
tree_type |
TreeType
|
TreeType |
required |
n_nodes |
int
|
int |
required |
seed |
Optional[int]
|
int |
None
|
from_str(str_id)
classmethod
Creates a tree id from a string representation of the id.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
str_id |
str
|
str |
required |
pyggdrasil.tree_inference.TreeType
Bases: Enum
Enum representing valid tree types implemented in pyggdrasil.
Allowed values
- RANDOM (random tree)
- STAR (star tree)
- DEEP (deep tree)
- HUNTRESS (Huntress tree) - inferred from real / cell simulation data
- MCMC - generated tree evolve by MCMC moves