"""Methods for working with Tree objects.
"""
import os
import re
import h5py
from scipy import interpolate
import requests
import networkx as nx
import numpy as np
import pandas as pd
import anytree
from anytree import search, PreOrderIter
from anytree.walker import Walker
from genepy3d.util import plot as pl
from genepy3d.io.base import CatmaidApiTokenAuth
[docs]
class Tree:
"""Tree in 3D. The tree structure is constructed based on ``anytree`` package.
We assume there are no duplicated nodes on the Tree.
Attributes:
nodes (anytree): structure of tree nodes.
id (int): tree ID (optional).
name (str): tree name (optional).
"""
def __init__(self,_nodes,_id=0,_name="GeNePy3D"):
self.nodes = _nodes
self.id = _id
self.name = _name
def _build_treenodes(nodes_tbl, connectors_tbl):
"""Building tree object using ``anytree`` package.
This is the support function for classmethod.
Args:
nodes_tbl (pandas dataframe): tree nodes table.
connectors_tbl (pandas dataframe): connectors table.
"""
nodes = {}
# get connector nodes
conlst = []
if connectors_tbl is not None:
conlst = connectors_tbl.index.values.astype('int') # it may happen similar (ske,treenode) for 2 different connector id
# build treenodes
for i in range(len(nodes_tbl)):
nodeid = int(nodes_tbl.index[i])
_x, _y, _z, _r, _structure_id = tuple(nodes_tbl.iloc[i][["x","y","z","r","structure_id"]])
nodes[nodeid] = anytree.Node(nodeid,x=_x,y=_y,z=_z,r=_r,structure_id=_structure_id)
# assign others features
for i in range(len(nodes_tbl)):
nodeid = int(nodes_tbl.index[i])
# assign their parents
parentnodeid = int(nodes_tbl.iloc[i]['parent_treenode_id'])
if parentnodeid != -1:
nodes[nodeid].parent = nodes[parentnodeid]
# assign connectors
if nodeid in conlst:
_connrel = connectors_tbl.loc[nodeid]['relation_id'] # could be "pre or post synaptic"
_connid = connectors_tbl.loc[nodeid]['connector_id']
else:
_connrel = "None"
_connid = -1
nodes[nodeid].connector_relation = _connrel
nodes[nodeid].connector_id = _connid
return nodes
[docs]
@classmethod
def from_table(cls,nodes_tbl,connectors_tbl=None,tree_id=0,tree_name='genepy3d'):
"""Build tree from tables.
Args:
nodes_tbl (pandas dataframe): tree nodes table.
connectors_tbl (pandas dataframe): connector table (optional).
tree_id (int): tree ID (optional).
tree_name (str): tree name (optional).
Returns:
A Tree object.
Notes:
- It needs at least ``nodes_tbl`` to create the tree object.
The ``nodes_tbl`` stores information of nodes and their linked parent-nodes.
It is the same as in the SWC file specification. Please see this [1] for more detail.
The column names must include:
- treenode_id: ID of node
- structure_id: ID of node structure (see the column `Structure Identifier` in [1] for the meaning).
- x, y, z: coordinates of node
- r: radius of node
- parent_treenode_id: ID of parent node
- The table ``connectors_tbl`` is optional.
It is used to describe the synaptic connection between neuronal trees. It is the same as connector table in CATMAID [2].
It must consist of columns:
- connector_id: ID of a connector
- treenode_id: ID of a treenode from a neuronal tree object
- relation_id: connectomic relation, must be 'presynaptic_to' or 'postsynaptic_to' (see [3] to understand about connectomic relation in CATMAID)
References:
.. [1] http://www.neuronland.org/NLMorphologyConverter/MorphologyFormats/SWC/Spec.html
.. [2] https://catmaid.readthedocs.io/en/stable/index.html
.. [3] https://catmaid.readthedocs.io/en/stable/tracing_neurons.html
Examples:
.. code-block:: python
import pandas as pd
from genepy3d.obj import trees
import matplotlib.pyplot as plt
# Generate a dummy nodes data
data = [[0,0,0.,0.,0.,1.,-1],
[1,0,1.,2.,3.,1.,0],
[2,0,4.,5.,6.,1.,1]]
# Dataframe from nodes data
node_tbl = pd.DataFrame(data,columns=['treenode_id','structure_id','x','y','z','r','parent_treenode_id'])
# Create a dummy neuron from node_tbl
neuron = trees.Tree.from_table(node_tbl)
# Plot it
fig = plt.figure()
ax = fig.add_subplot(111,projection='3d')
neuron.plot(ax)
"""
# check nodes_tbl
if isinstance(nodes_tbl,pd.DataFrame):
lbls = nodes_tbl.columns.values
if all(lbl in lbls for lbl in ['treenode_id','structure_id','x','y','z','r','parent_treenode_id']):
nodes_tbl_refined = nodes_tbl.dropna()
else:
raise Exception("can not find 'treenode_id','structure_id','x','y','z','r','parent_treenode_id' within the dataframe.")
else:
raise Exception("nodes_tbl must be pandas dataframe.")
# constraint some columns of nodes to int
nodes_tbl_refined['structure_id'] = nodes_tbl_refined['structure_id'].astype('int')
nodes_tbl_refined['treenode_id'] = nodes_tbl_refined['treenode_id'].astype('int')
nodes_tbl_refined['parent_treenode_id'] = nodes_tbl_refined['parent_treenode_id'].astype('int')
# reset index
nodes_tbl_refined.drop_duplicates(subset=["treenode_id"],inplace=True) # make sure unique index
nodes_tbl_refined.set_index('treenode_id',inplace=True) # reset dataframe index
# check connectors_tbl
connectors_tbl_refined = None
if connectors_tbl is not None:
# check dfcon columns
if isinstance(connectors_tbl,pd.DataFrame):
lbls = connectors_tbl.columns.values
if all(lbl in lbls for lbl in ['treenode_id','connector_id','relation_id']):
connectors_tbl_refined = connectors_tbl.dropna()
else:
raise Exception("can not find 'treenode_id','connector_id','relation_id' within the dataframe.")
else:
raise Exception("connectors_tbl must be pandas dataframe.")
# constraint columns to int
connectors_tbl_refined['connector_id'] = connectors_tbl_refined['connector_id'].astype('int')
connectors_tbl_refined['treenode_id'] = connectors_tbl_refined['treenode_id'].astype('int')
# reset index
connectors_tbl_refined.drop_duplicates(subset=["treenode_id"],inplace=True) # make sure unique index
connectors_tbl_refined.set_index("treenode_id",inplace=True) # reset dataframe index
# build tree node structure
nodes = cls._build_treenodes(nodes_tbl_refined,connectors_tbl_refined)
return cls(nodes,tree_id,tree_name)
[docs]
@classmethod
def from_swc(cls,filepath):
"""Build tree from a SWC file.
See this [1] for the description of SWC format.
Args:
filepath (str): path to swc file.
Returns:
A Tree object.
References:
.. [1] http://www.neuronland.org/NLMorphologyConverter/MorphologyFormats/SWC/Spec.html
Examples:
.. code-block:: python
from genepy3d.obj import trees
import matplotlib.pyplot as plt
# An example of swc file can be downloaded from:
# https://neuromorpho.org/neuron_info.jsp?neuron_name=AD1202-PMC-L-4201-fiber129
swc_path = "file/to/swc/file"
neu = trees.Tree.from_swc(swc_path)
fig = plt.figure()
ax = fig.add_subplot(111,projection='3d')
neu.plot(ax)
"""
data = []
# extract info from file
f = open(filepath,'r')
for line in f:
if line[0]!='#':
tmp = []
elemens = re.split(r'\r|\n| |\s', line)
for ile in elemens:
if ile!='':
tmp.append(float(ile))
data.append(tmp)
f.close()
# build dataframe and cast columns types
if len(data)!=0:
tree_name = os.path.basename(filepath).split(".swc")[0]
tree_id = 0
dfneu = pd.DataFrame(data,columns=['treenode_id', 'structure_id', 'x', 'y', 'z', 'r', 'parent_treenode_id'])
dfneu.dropna(inplace=True)
# cast type
dfneu['treenode_id'] = dfneu['treenode_id'].astype('int')
dfneu['structure_id'] = dfneu['structure_id'].astype('int')
dfneu['parent_treenode_id'] = dfneu['parent_treenode_id'].astype('int')
# reset index
dfneu.drop_duplicates(subset=["treenode_id"],inplace=True) # make sure unique index
dfneu.set_index('treenode_id',inplace=True) # reset dataframe index
# build tree node structure
nodes = cls._build_treenodes(dfneu,None)
return cls(nodes,tree_id,tree_name)
else:
raise ValueError("Errors when reading file.")
[docs]
@classmethod
def from_eswc(cls,filepath):
"""Build tree from ESWC file. The ESWC is used e.g. in Vaa3D-Neuron software [1].
ESWC is an extension of SWC with additional columns 'segment_id', 'level', 'mode', 'timestamp', 'featureval'.
Args:
filepath (str): path to eswc file.
Returns:
A Tree object.
Notes:
ESWC has extended columns compared to default SWC file, but we don't pass these extended columns into the Tree object.
We only pass the similar columns to the Tree object as in the SWC file.
References:
.. [1] https://github.com/Vaa3D/Vaa3D_Wiki/wiki/Vaa3DNeuron1.wiki
Examples:
Reading ESWC should be the same as reading SWC. Please see example from ``from_swc()``.
"""
data = []
# extract info from file
f = open(filepath,'r')
for line in f:
if line[0]!='#':
tmp = []
elemens = re.split(r'\r|\n| |\s', line)
for ile in elemens:
if ile!='':
tmp.append(float(ile))
data.append(tmp)
f.close()
# build dataframe and cast columns types
if len(data)!=0:
tree_name = os.path.basename(filepath).split(".eswc")[0]
tree_id = 0
if len(data[0])==12:
dfneu = pd.DataFrame(data,columns=['treenode_id', 'structure_id', 'x', 'y', 'z', 'r', 'parent_treenode_id','segment_id','level','mode','timestamp','featureval'])
elif len(data[0])==11:
dfneu = pd.DataFrame(data,columns=['treenode_id', 'structure_id', 'x', 'y', 'z', 'r', 'parent_treenode_id','segment_id','level','mode','timestamp'])
else:
raise Exception('The number of columns must be 11 or 12.')
dfneu = dfneu[['treenode_id', 'structure_id', 'x', 'y', 'z', 'r', 'parent_treenode_id']]
dfneu.dropna(inplace=True)
# cast type
dfneu['treenode_id'] = dfneu['treenode_id'].astype('int')
dfneu['structure_id'] = dfneu['structure_id'].astype('int')
dfneu['parent_treenode_id'] = dfneu['parent_treenode_id'].astype('int')
# reset index
dfneu.drop_duplicates(subset=["treenode_id"],inplace=True) # make sure unique index
dfneu.set_index('treenode_id',inplace=True) # reset dataframe index
# build tree node structure
nodes = cls._build_treenodes(dfneu,None)
return cls(nodes,tree_id,tree_name)
else:
raise ValueError("Errors when reading file.")
[docs]
@classmethod
def from_ims(cls,filepath):
"""Build tree from Imaris file.
Args:
filepath (str): path to ims file.
Returns:
A Tree object.
"""
# read file
f = h5py.File(filepath,"r")
# import vertex
item = f["Scene8"]["Content"]["Filaments0"]["Vertex"]
item_arr = np.array([item[i].tolist() for i in range(len(item))])
vertex = pd.DataFrame(item_arr,columns=["x","y","z","r"])
# import edge
item = f["Scene8"]["Content"]["Filaments0"]["Edge"]
item_arr = np.array([item[i].tolist() for i in range(len(item))])
edge = pd.DataFrame(item_arr,columns=["vertex_begin","vertex_end"])
# build graph from vertices and edges
G=nx.Graph()
G.add_nodes_from(vertex.index.tolist())
G.add_edges_from(edge.values.tolist())
# import dendrite table
item = f["Scene8"]["Content"]["Filaments0"]["DendriteSegment"]
item_arr = np.array([item[i].tolist() for i in range(len(item))])
dendrite = pd.DataFrame(item_arr,columns=["id_time","id","vertex_begin","vertex_end","edge_begin","edge_end"])
# import dendrite edge index table
item = f["Scene8"]["Content"]["Filaments0"]["DendriteSegmentEdge"]
item_arr = np.array([item[i].tolist() for i in range(len(item))])
dendrite_edge = pd.DataFrame(item_arr,columns=["index"])
# import dendrite vertex index table
item = f["Scene8"]["Content"]["Filaments0"]["DendriteSegmentVertex"]
item_arr = np.array([item[i].tolist() for i in range(len(item))])
dendrite_vertex = pd.DataFrame(item_arr,columns=["index"])
# compute dendrite fragments
parent = {}
big_vertex_lst = []
for ix in range(len(dendrite)):
subdf = dendrite.iloc[ix]
# get IDs of edge begin and pair of vertices making the edge begin
vertex_begin = dendrite_vertex.loc[subdf["vertex_begin"]][0] # get vertex begin
edge_begin = dendrite_edge.loc[subdf["edge_begin"]].values[0] # get edge begin
vertices_begin = edge.loc[edge_begin].values # get pair of vertices making the edge begin
vertex_begin_partner = np.setdiff1d(vertices_begin,vertex_begin)[0] # get the second vertex from the pair
# get IDs of edge end and pair of vertices making the edge end
vertex_end = dendrite_vertex.loc[subdf["vertex_end"]-1][0]
edge_end = dendrite_edge.loc[subdf["edge_end"]-1].values[0]
vertices_end = edge.loc[edge_end].values
vertex_end_partner = np.setdiff1d(vertices_end,vertex_end)[0]
# check all path starting from vertex begin to vertex end
# then get only path containing both edge begin and edge end
vertex_lst = []
for path in nx.all_simple_paths(G, source=vertex_begin, target=vertex_end):
if (vertex_begin_partner == path[1]) & (vertex_end_partner == path[-2]):
vertex_lst = path
# create parent-children link
if len(vertex_lst)!=0:
for i in range(1,len(vertex_lst)):
child_vertex = vertex_lst[i]
parent_vertex = vertex_lst[i-1]
parent[child_vertex] = parent_vertex
big_vertex_lst += vertex_lst
# build dendritic dataframe
big_vertex_lst = np.unique(big_vertex_lst)
parent_vertex_lst = []
for key in big_vertex_lst:
if key in parent.keys():
parent_vertex_lst.append(parent[key])
else:
parent_vertex_lst.append(-1)
df = vertex.loc[big_vertex_lst].copy()
df["treenode_id"] = big_vertex_lst
df["parent_treenode_id"] = parent_vertex_lst
df["structure_id"] = 0
return cls.from_table(df)
[docs]
@classmethod
def from_catmaid_server(cls,catmaid_host,token,project_id,neuron_id):
"""Build tree from CATMAID [1].
This function queries from the CATMAID server a neuronal trace given by a ``project_id`` and ``neuron_id``.
The neuronal trace data is fetched using CATMAID API [2].
Args:
catmaid_host (str): address of CATMAID server
token (str): authenticated string
project_id (int): project ID
neuron_id (int): neuron ID
Returns:
A Tree object.
References:
.. [1] https://catmaid.readthedocs.io/en/stable/index.html
.. [2] https://catmaid.readthedocs.io/en/stable/api.html
Examples:
.. code-block:: python
from genepy3d.obj import trees
# Replace these fake values with your own
catmaid_host = 'https://your-catmaid-host.com'
token = "9944b09199c62bcf9418ad846dd0e4bbdfc6ee4b"
project_id = 1
neuron_id = 1
# Retreive neuronal tree from from CATMAID host
neuron = trees.Tree.from_catmaid_server(catmaid_host,token,project_id,neuron_id)
"""
subneu, subcon = [], []
linkrequest = catmaid_host+'{}/skeleton/{}/json'.format(project_id,neuron_id)
res = requests.get(linkrequest,auth=CatmaidApiTokenAuth(token))
if res.status_code!=200:
raise ValueError('something wrong: check again your host, token, project id or neuron id.')
else:
if isinstance(res.json(),dict):
raise ValueError('something wrong: check again your project id or neuron id.')
else: # should be a list
res = np.array(res.json())
# get neuron name
try:
neuron_name = res[0]
except:
neuron_name = 'genepy3d'
# get neuron info
try:
for iske in res[1]:
if iske[1] is not None:
iske_sub = [iske[0], 0, iske[3], iske[4], iske[5], iske[6], iske[1]]
else:
iske_sub = [iske[0], 0, iske[3], iske[4], iske[5], iske[6], -1]
subneu.append(iske_sub)
except:
subneu = []
# get connector info
try:
for icon in res[3]:
relation_type = 'presynaptic_to'
if icon[2]==1:
relation_type = 'postsynaptic_to'
icon_sub = [icon[1], icon[0], relation_type]
subcon.append(icon_sub)
except:
subcon = []
# create dfneu dataframe
if len(subneu)==0:
raise ValueError('neuron is empty!')
else:
dfneu = pd.DataFrame(subneu,columns=['treenode_id','structure_id','x','y','z','r','parent_treenode_id'])
# remove nan
dfneu.dropna(inplace=True)
# cast type
dfneu['treenode_id'] = dfneu['treenode_id'].astype('int')
dfneu['structure_id'] = dfneu['structure_id'].astype('int')
dfneu['parent_treenode_id'] = dfneu['parent_treenode_id'].astype('int')
# reset index
dfneu.drop_duplicates(subset=["treenode_id"],inplace=True) # make sure unique index
dfneu.set_index('treenode_id',inplace=True) # reset dataframe index
# create dfcon dataframe
if len(subcon)==0:
dfcon = None
else:
dfcon = pd.DataFrame(subcon,columns=['connector_id','treenode_id','relation_id'])
# remove nan
dfcon.dropna(inplace=True)
# constraint columns to int
dfcon['connector_id'] = dfcon['connector_id'].astype('int')
dfcon['treenode_id'] = dfcon['treenode_id'].astype('int')
# reset index
dfcon.drop_duplicates(subset=["treenode_id"],inplace=True) # make sure unique index
dfcon.set_index("treenode_id",inplace=True) # reset dataframe index
# build tree node structure
nodes = cls._build_treenodes(dfneu,dfcon)
return cls(nodes,neuron_id,neuron_name)
[docs]
def get_root(self):
"""Return root node ID.
Returns:
List of root ID (list[int]).
"""
# may contain many roots
return list(filter(lambda nodeid: self.nodes[nodeid].parent is None, self.nodes.keys()))
[docs]
def get_parent(self,nodeid):
"""Return parent id of node id.
Args:
nodeid (int): a node ID.
Returns:
Node ID (int).
"""
return self.nodes[nodeid].parent.name
[docs]
def get_children(self,nodeid):
"""Return children ids of nodeid.
Args:
nodeid (int): a node ID.
Returns:
List of node ID (list[int]).
"""
return [node.name for node in self.nodes[nodeid].children]
[docs]
def get_preorder_nodes(self,rootid=None):
"""Return list of nodes in tree. The order of nodes in the list follows the preorder traversal [1].
Args:
rootid (int): root node ID. If None, then the first rootid will be taken.
Returns:
List of node ID (list[int]).
References:
.. [1] https://en.wikipedia.org/wiki/Tree_traversal
"""
if rootid is None:
_rootid = self.get_root()[0]
else:
_rootid = rootid
return [node.name for node in PreOrderIter(self.nodes[_rootid])]
[docs]
def get_number_nodes(self,rootid=None):
"""Return the total number of nodes in tree.
Args:
rootid (int): root node ID. If None, then the first rootid will be taken.
Returns:
Total number of nodes (int).
"""
if rootid is None:
_rootid = self.get_root()[0]
else:
_rootid = rootid
return len(self.get_preorder_nodes(_rootid))
[docs]
def get_leaves(self,rootid=None):
"""Return list of leaf node IDs.
Args:
rootid (int): root node ID. If None, then the first rootid will be taken.
Returns:
list of int.
"""
if rootid is None:
_rootid = self.get_root()[0]
else:
_rootid = rootid
return [leaf.name for leaf in self.nodes[_rootid].leaves]
[docs]
def get_branchingnodes(self,rootid=None):
"""Return list of branching node IDs.
Args:
rootid (int): root node ID. If None, then the first rootid will be taken.
Returns:
list of int.
"""
if rootid is None:
_rootid = self.get_root()[0]
else:
_rootid = rootid
return [n.name for n in search.findall(self.nodes[_rootid],filter_=lambda node: len(node.children)>1)]
[docs]
def get_connectors(self,rootid=None):
"""Return list of connector node IDs.
Args:
rootid (int): root node ID. If None, then the first rootid will be taken.
Returns:
Pandas dataframe whose indices are node IDs and values are corresponding connector types, connector ids.
"""
if rootid is None:
_rootid = self.get_root()[0]
else:
_rootid = rootid
connector_nodes = np.array([[node.name,node.connector_relation,node.connector_id] for node in PreOrderIter(self.nodes[_rootid])])
df = pd.DataFrame({"relation":connector_nodes[:,1],"id":connector_nodes[:,2].astype(np.int64)},
index=connector_nodes[:,0].astype(np.int64))
subdf = df[(df['relation']!='None')&(df['id']!=-1)]
return subdf
[docs]
def get_coordinates(self, nodeid=None, rootid=None):
"""Return coordinates of given nodes.
Args:
nodes (int | list of int): list of node IDs. If None, then return the coordinates of all nodes from the given root ID.
rootid (int): root node ID. If None, then the first rootid will be taken.
Returns:
Pandas dataframe whose indices are node IDs and values are corresponding coordinates.
"""
if rootid is None:
_rootid = self.get_root()[0]
else:
_rootid = rootid
if nodeid is None:
nodelst = [node.name for node in PreOrderIter(self.nodes[_rootid])]
elif nodeid is not None:
if isinstance(nodeid,(int,np.integer)): # only one item.
nodelst = [nodeid]
elif isinstance(nodeid,(list,np.ndarray)):
nodelst = nodeid
else:
raise Exception("node_id must be array-like.")
coors = np.array([[self.nodes[i].x,self.nodes[i].y,self.nodes[i].z] for i in nodelst])
df = pd.DataFrame({'nodeid':nodelst,'x':coors[:,0],'y':coors[:,1],'z':coors[:,2]})
df.set_index('nodeid',inplace=True)
return df
[docs]
def get_features(self,features=["x","y","z"],nodeid=None,rootid=None):
"""Return list of features from given nodes.
Args:
features (list of str): list of node-wise features, e.g. "x", "y", "z", "r", "structure_id", "connector_relation". If None, then return all features.
nodes (int | list of int): list of node IDs. If None, then return the coordinates of all nodes from the given root ID.
rootid (int): root node ID. If None, then the first rootid will be taken.
Returns:
Pandas dataframe whose indices are node IDs and values are corresponding features.
"""
if rootid is None:
_rootid = self.get_root()[0]
else:
_rootid = rootid
if nodeid is None:
nodelst = [node.name for node in PreOrderIter(self.nodes[_rootid])]
elif nodeid is not None:
if isinstance(nodeid,(int,np.integer)): # only one item.
nodelst = [nodeid]
elif isinstance(nodeid,(list,np.ndarray)):
nodelst = nodeid
else:
raise Exception("node_id must be array-like.")
feats = []
for i in nodelst:
node = self.nodes[i]
feats.append([i]+[getattr(node,f) for f in features])
feats = np.array(feats)
column_names = ["nodeid"] + features
df = pd.DataFrame(feats,columns=column_names)
df.set_index('nodeid',inplace=True)
return df
[docs]
def get_bbox(self,nodeid=None,rootid=None):
"""Return bounding box coordinates covering a list of nodes from the tree.
Args:
nodeid (list[int]): list of node IDs. If None, then return bounding box of the entire tree.
rootid (int): root node ID. If None, then the first rootid will be taken.
Returns:
List of [(Xmin,Xmax),(Ymin,Ymax),(Zmin,Zmax)]
"""
coors = self.get_coordinates(nodeid,rootid)
xmin, xmax = coors['x'].min(), coors['x'].max()
ymin, ymax = coors['y'].min(), coors['y'].max()
zmin, zmax = coors['z'].min(), coors['z'].max()
return [(xmin,xmax),(ymin,ymax),(zmin,zmax)]
[docs]
def compute_length(self,nodeid=None,rootid=None):
"""Return the total length of the tree or subpart of the tree.
Args:
nodeid (list[int]): list of node IDs. If None, then return the total length of the entire tree.
rootid (int): root node ID. If None, then the first rootid will be taken.
Returns:
int.
Notes:
If you provide a list of nodes. We consider it as a curve, then we return the curve length.
Examples:
.. code-block:: python
import pandas as pd
from genepy3d.obj import trees
import matplotlib.pyplot as plt
# Generate a dummy nodes data
data = [[0,0,0.,0.,0.,1.,-1],
[1,0,1.,2.,3.,1.,0],
[2,0,4.,5.,6.,1.,1],
[3,0,5.,5.,5.,1.,1]]
# Dataframe from nodes data
node_tbl = pd.DataFrame(data,columns=['treenode_id','structure_id','x','y','z','r','parent_treenode_id'])
# Create a dummy neuron from node_tbl
neuron = trees.Tree.from_table(node_tbl)
# Plot it
fig = plt.figure()
ax = fig.add_subplot(111,projection='3d')
neuron.plot(ax)
plt.tight_layout();
# Total length of the entire neuron
total_length = neuron.compute_length()
print(total_length)
# Get a segment from root to the first branching point
segment_lst = neuron.decompose_segments()
print(segment_lst)
segment = segment_lst['0_1']
# Length of that segment
segment_length = neuron.compute_length(segment)
print(segment_length)
"""
from genepy3d.obj.curves import Curve
if rootid is None:
_rootid = self.get_root()[0]
else:
_rootid = rootid
if nodeid is not None: # a list of nodes
return Curve(self.get_coordinates(nodeid).values).compute_length()
else:
segments = self.decompose_segments(_rootid)
total_length = 0
for seg in segments.values():
total_length += Curve(self.get_coordinates(seg).values).compute_length()
return total_length
[docs]
def compute_spine(self,rootid=None):
"""Return spine node IDs (longest branch from the root to a leaf).
Args:
rootid (int): root node ID. If None, then the first rootid will be taken.
Returns:
list of int.
"""
if rootid is None:
_rootid = self.get_root()[0]
else:
_rootid = rootid
segments = self.decompose_leaves(_rootid)
segment_lengths = [self.compute_length(seg) for seg in segments.values()]
return list(segments.values())[np.argmax(segment_lengths)]
[docs]
def compute_strahler_order(self, nodeid=None, rootid=None):
"""Return Strahler order for given node IDs [1].
Args:
nodeid (int|list[int]): list of node IDs. If None, return all the Strahler order for all nodes from the given root ID.
rootid (int): root node ID. If None, then the first rootid will be taken.
Returns:
Pandas serie whose indices are node IDs and values are corresponding strahler orders.
References:
.. [1] https://en.wikipedia.org/wiki/Strahler_number
"""
if rootid is None:
_rootid = self.get_root()[0]
else:
_rootid = rootid
try:
# check if strahler orders already computed
_ = self.nodes[_rootid].strahler
except:
preorder_lst = np.array([node.name for node in PreOrderIter(self.nodes[_rootid])])
for i in preorder_lst[-1::-1]:
n = self.nodes[i]
if n.is_leaf:
n.strahler = 1
else:
children_strahler = np.array([[k.name, k.strahler] for k in n.children])
if len(children_strahler)==1:
n.strahler = children_strahler[0,1]
else:
nbmax = len(np.argwhere(children_strahler[:,1]==np.max(children_strahler[:,1])).flatten())
if nbmax>=2:
n.strahler = np.max(children_strahler[:,1]) + 1
else:
n.strahler = np.max(children_strahler[:,1])
if nodeid is None:
nodelst = [node.name for node in PreOrderIter(self.nodes[_rootid])]
elif nodeid is not None:
if isinstance(nodeid,int): # only one item.
nodelst = [nodeid]
elif isinstance(nodeid,(list,np.ndarray)):
nodelst = nodeid
else:
raise Exception("nodeid must be array-like.")
strahler_orders = [self.nodes[i].strahler for i in nodelst]
return pd.Series(strahler_orders,name='strahler_order',index=nodelst)
# def compute_orientation(self,nodeid=None,nb_avg=1,rootid=None):
# """Compute orientation vectors at given nodes.
# The orientation vector depends on node type (e.g. root, branching node, leaf, normal node).
# Args:
# nodeid (int|array of int): list of node IDs.
# nb_avg (int): number of neighboring nodes used to average the orientation.
# """
# # internal funcs
# def walk_to_parent(_nid,nb_steps,_rootid):
# target = _nid
# target_lst = []
# it = 0
# while(it < nb_steps):
# it = it + 1
# target = self.get_parent(target)
# if target is not None:
# target_lst.append(target)
# else:
# break
# if target in self.get_branchingnodes(_rootid):
# break
# return target_lst
# def walk_to_children(_nid,nb_steps):
# target = [_nid]
# target_lst = []
# it = 0
# while(it < nb_steps):
# it = it + 1
# target = self.get_children(target[0])
# if len(target)==1:
# target_lst.append(target[0])
# else:
# break
# return target_lst
# if rootid is None:
# _rootid = self.get_root()[0]
# else:
# _rootid = rootid
# # get list of nodes
# if nodeid is None:
# nodeidlst = self.get_preorder_nodes(_rootid)
# elif isinstance(nodeid,(int,np.integer)):
# nodeidlst = [nodeid]
# else:
# nodeidlst = nodeid
# dic = {}
# for nid in nodeidlst:
# item = {}
# source_coors = self.get_coordinates(nid,_rootid).values.flatten()
# if nid==_rootid: # root
# # vectors toward children
# children = self.get_children(nid)
# for child in children:
# target_lst = [child]
# target_lst = target_lst + walk_to_children(child,nb_avg-1)
# target_coors = self.get_coordinates(target_lst,_rootid).values
# mean_target_coors = np.mean(target_coors,axis=0)
# item[str(nid)+"-"+str(target_lst[0])] = geo.vector2points(source_coors,mean_target_coors)
# elif nid in self.get_branchingnodes(_rootid): # branchingnodes
# # vector toward parent
# target_lst = walk_to_parent(nid,nb_avg,_rootid)
# target_coors = self.get_coordinates(target_lst,_rootid).values
# mean_target_coors = np.mean(target_coors,axis=0)
# item[str(nid)+"-"+str(target_lst[0])] = geo.vector2points(source_coors,mean_target_coors)
# # vectors toward children
# children = self.get_children(nid)
# for child in children:
# target_lst = [child]
# target_lst = target_lst + walk_to_children(child,nb_avg-1)
# target_coors = self.get_coordinates(target_lst,_rootid).values
# mean_target_coors = np.mean(target_coors,axis=0)
# item[str(nid)+"-"+str(target_lst[0])] = geo.vector2points(source_coors,mean_target_coors)
# else: # leaf or normal node
# # vector toward parent
# target_lst = walk_to_parent(nid,nb_avg,_rootid)
# target_coors = self.get_coordinates(target_lst,_rootid).values
# mean_target_coors = np.mean(target_coors,axis=0)
# item[str(nid)+"-"+str(target_lst[0])] = geo.vector2points(source_coors,mean_target_coors)
# dic[nid] = item
# return dic
[docs]
def path(self,target,source=None,rootid=None):
"""Find list of nodes going from source node to target node.
Args:
target (int): target node ID.
source (int): source node ID. If None, then we take root as the source.
rootid (int): root node ID. If None, then the first rootid will be taken.
Returns:
list of traveled node IDs included source and target node IDs.
Examples:
.. code-block:: python
import pandas as pd
from genepy3d.obj import trees
import matplotlib.pyplot as plt
# Generate a dummy nodes data
data = [[0,0,0.,0.,0.,1.,-1],
[1,0,1.,2.,3.,1.,0],
[2,0,4.,5.,6.,1.,1],
[3,0,5.,5.,5.,1.,1]]
# Dataframe from nodes data
node_tbl = pd.DataFrame(data,columns=['treenode_id','structure_id','x','y','z','r','parent_treenode_id'])
# Create a dummy neuron from node_tbl
neuron = trees.Tree.from_table(node_tbl)
# Get ID of a leaf
leaf_node = neuron.get_leaves()[0]
print(leaf_node)
# Path from root to the leaf
print(neuron.path(leaf_node))
"""
if rootid is None:
_rootid = self.get_root()[0]
else:
_rootid = rootid
if source==None:
source = _rootid
w = Walker()
res = w.walk(self.nodes[source],self.nodes[target])
return [n.name for n in res[0]] + [res[1].name] + [n.name for n in res[2]]
[docs]
def copy(self,rootid=None):
"""Make a copy of the tree.
Args:
rootid (int): root node ID. If None, then the first rootid will be taken.
Returns:
A Tree object.
"""
if rootid is None:
_rootid = self.get_root()[0]
else:
_rootid = rootid
return self.extract_subtrees(nodeid=_rootid,rootid=_rootid)
[docs]
def prune_leaves(self,nodeid=None,length=None,rootid=None):
"""Prune leaf branches based on its length.
Args:
nodeid (list [int]): list of leaves IDs used to examine the pruning. If None, then take all leaves into account.
length (float): length threshold for pruning the leaf branches from the leaves given by nodeid. If None, then we eliminate all the leaf branches.
rootid (int): root node ID. If None, then the first rootid will be taken.
Returns:
A new Tree object after pruning.
Examples:
.. code-block:: python
import pandas as pd
from genepy3d.obj import trees
import matplotlib.pyplot as plt
# Generate a dummy nodes data
data = [[0,0,0.,0.,0.,1.,-1],
[1,0,1.,2.,3.,1.,0],
[2,0,4.,5.,6.,1.,1],
[3,0,7.,2.,2.,1.,1],
[4,0,6.,6.,6.,1.,2],
[5,0,3.,7.,6.,1.,2]]
# Dataframe from nodes data
node_tbl = pd.DataFrame(data,columns=['treenode_id','structure_id','x','y','z','r','parent_treenode_id'])
# Create a dummy neuron from node_tbl
neuron = trees.Tree.from_table(node_tbl)
# Plot it
fig = plt.figure()
ax = fig.add_subplot(111,projection='3d')
neuron.plot(ax)
plt.tight_layout();
# Prune all the leaf branches
neuron_pruned = neuron.prune_leaves()
# Plot pruned tree
fig = plt.figure()
ax = fig.add_subplot(111,projection='3d')
neuron_pruned.plot(ax)
plt.tight_layout();
# Prune only leaf branches whose length < 5 microns
neuron_pruned = neuron.prune_leaves(length=5)
# Plot pruned tree
fig = plt.figure()
ax = fig.add_subplot(111,projection='3d')
neuron_pruned.plot(ax)
plt.tight_layout();
"""
if rootid is None:
_rootid = self.get_root()[0]
else:
_rootid = rootid
if length is None:
_length = np.inf
else:
_length = length
segments = self.decompose_segments(_rootid)
if nodeid is None:
leafnodes = self.get_leaves(_rootid)
else:
leafnodes = nodeid
subnodes = {}
nodelst = []
for seg in segments.values():
contain_leaf = (np.sum([leaf in seg for leaf in leafnodes])>=1)
seg_len = self.compute_length(seg)
if (~contain_leaf) | (contain_leaf & (seg_len >= _length)):
# copy node properties
for nid in seg:
if nid not in nodelst:
ref_node = self.nodes[nid]
new_node = anytree.Node(nid, connector_id=ref_node.connector_id,
connector_relation=ref_node.connector_relation,
r=ref_node.r, x=ref_node.x, y=ref_node.y, z=ref_node.z,
structure_id=ref_node.structure_id)
subnodes[nid] = new_node
nodelst.append(nid)
# assign node relationship
for nid in nodelst:
try:
parent_id = self.nodes[nid].parent.name
subnodes[nid].parent = subnodes[parent_id]
except:
subnodes[nid].parent = None # root node
return Tree(subnodes,self.id,self.name)
[docs]
def decompose_segments(self,rootid=None):
"""Decompose tree in segments separating by the branching nodes.
Args:
rootid (int): root node ID. If None, then the first rootid will be taken.
Returns:
A dictionary whose each item is a decomposed segment, where
- key is a string represents the first node and the last node of the decomposed segment.
- value is a list of nodes IDs of the decomposed segment.
Examples:
.. code-block:: python
import pandas as pd
from genepy3d.obj import trees
import matplotlib.pyplot as plt
# Generate a dummy nodes data
data = [[0,0,0.,0.,0.,1.,-1],
[1,0,1.,2.,3.,1.,0],
[2,0,4.,5.,6.,1.,1],
[3,0,7.,2.,2.,1.,1],
[4,0,6.,6.,6.,1.,2],
[5,0,3.,7.,6.,1.,2],
[6,0,5.,5.,9.,1.,5]]
# Dataframe from nodes data
node_tbl = pd.DataFrame(data,columns=['treenode_id','structure_id','x','y','z','r','parent_treenode_id'])
# Create a dummy neuron from node_tbl
neuron = trees.Tree.from_table(node_tbl)
# Plot neuron
fig = plt.figure()
ax = fig.add_subplot(111,projection='3d')
neuron.plot(ax)
# Plot node ID
coors = neuron.get_coordinates()
for nodeid in coors.index:
tmp = coors.loc[nodeid]
ax.text(tmp.x-0.5,tmp.y+0.1,tmp.z+0.1,nodeid,fontsize=12)
ax.axis('off');
plt.tight_layout();
# Decompose into segments separating by branching nodes
print(neuron.decompose_segments())
# The output should be as follows:
# {'2_4': [2, 4], '2_6': [2, 5, 6], '1_3': [1, 3], '0_1': [0, 1], '1_2': [1, 2]}
"""
if rootid is None:
_rootid = self.get_root()[0]
else:
_rootid = rootid
segments = {}
controlnodes = self.get_leaves(_rootid) + self.get_branchingnodes(_rootid) + [_rootid]
for it in range(len(controlnodes)):
nodeid = controlnodes[it]
branch = [n.name for n in self.nodes[nodeid].path]
idx = len(branch)-2
while(idx != -1):
if branch[idx] in controlnodes:
subbranch = branch[idx:]
segments[str(branch[idx])+"_"+str(nodeid)] = subbranch
break
idx = idx - 1
return segments
[docs]
def decompose_leaves(self,rootid=None):
"""Decompose tree into segments starting from root to every leaf.
Args:
rootid (int): root node ID. If None, then the first rootid will be taken.
Returns:
A dictionary whose each item is a decomposed segment, where
- key is a string represents the leaf node of the decomposed segment.
- value is a list of nodes IDs of the decomposed segment.
Examples:
.. code-block:: python
import pandas as pd
from genepy3d.obj import trees
import matplotlib.pyplot as plt
# Generate a dummy nodes data
data = [[0,0,0.,0.,0.,1.,-1],
[1,0,1.,2.,3.,1.,0],
[2,0,4.,5.,6.,1.,1],
[3,0,7.,2.,2.,1.,1],
[4,0,6.,6.,6.,1.,2],
[5,0,3.,7.,6.,1.,2],
[6,0,5.,5.,9.,1.,5]]
# Dataframe from nodes data
node_tbl = pd.DataFrame(data,columns=['treenode_id','structure_id','x','y','z','r','parent_treenode_id'])
# Create a dummy neuron from node_tbl
neuron = trees.Tree.from_table(node_tbl)
# Plot neuron
fig = plt.figure()
ax = fig.add_subplot(111,projection='3d')
neuron.plot(ax)
# Plot node ID
coors = neuron.get_coordinates()
for nodeid in coors.index:
tmp = coors.loc[nodeid]
ax.text(tmp.x-0.5,tmp.y+0.1,tmp.z+0.1,nodeid,fontsize=12)
ax.axis('off');
plt.tight_layout();
# Decompose into segments from root to every leaves
print(neuron.decompose_leaves())
# The output should be as follows:
# {4: [0, 1, 2, 4], 6: [0, 1, 2, 5, 6], 3: [0, 1, 3]}
"""
if rootid is None:
_rootid = self.get_root()[0]
else:
_rootid = rootid
dic = {}
for leaf in self.get_leaves(_rootid):
dic[leaf] = self.path(target=leaf,source=_rootid)
return dic
[docs]
def decompose_spines(self,rootid=None):
"""Decompose tree into list of spines (longest branches) starting from the root node.
We first extract the longest branch from the root to a leaf.
Then for every branching node on that longest branch, we extract the subtree from that, and repeat the process by extracting the longest branch of that subtree.
The process is repeated until no branch can be extracted.
Args:
rootid (int): root node ID. If None, then the first rootid will be taken.
Returns:
A dictionary whose each item is a decomposed segment, where
- key is a string represents the first node and the last node of the decomposed segment.
- value is a list of nodes IDs of the decomposed segment.
Examples:
.. code-block:: python
import pandas as pd
from genepy3d.obj import trees
import matplotlib.pyplot as plt
# Generate a dummy nodes data
data = [[0,0,0.,0.,0.,1.,-1],
[1,0,1.,2.,3.,1.,0],
[2,0,4.,5.,6.,1.,1],
[3,0,7.,2.,2.,1.,1],
[4,0,6.,6.,6.,1.,2],
[5,0,3.,7.,6.,1.,2],
[6,0,5.,5.,9.,1.,5]]
# Dataframe from nodes data
node_tbl = pd.DataFrame(data,columns=['treenode_id','structure_id','x','y','z','r','parent_treenode_id'])
# Create a dummy neuron from node_tbl
neuron = trees.Tree.from_table(node_tbl)
# Plot neuron
fig = plt.figure()
ax = fig.add_subplot(111,projection='3d')
neuron.plot(ax)
# Plot node ID
coors = neuron.get_coordinates()
for nodeid in coors.index:
tmp = coors.loc[nodeid]
ax.text(tmp.x-0.5,tmp.y+0.1,tmp.z+0.1,nodeid,fontsize=12)
ax.axis('off');
plt.tight_layout();
# Decompose into segments of longest branches
print(neuron.decompose_spines())
# The output should be as follows:
# {'0_6': [0, 1, 2, 5, 6], '1_3': [1, 3], '2_4': [2, 4]}
"""
# maximal number of recursions using in deepcopy()
# sys.setrecursionlimit(recursion_limit)
if rootid is None:
_rootid = self.get_root()[0]
else:
_rootid = rootid
data = {}
# first check if there're many branches at root
mychildren = [item.name for item in self.nodes[_rootid].children]
if len(mychildren)>1: # more than one children
subtrees = self.extract_subtrees(nodeid=_rootid,separate_children=True,rootid=_rootid)
for subtree in subtrees:
subdata = subtree.decompose_spines()
data = {**data, **subdata}
else:
spinenodes = self.compute_spine(_rootid)
spinename = str(spinenodes[0])+'_'+str(spinenodes[-1])
data[spinename] = spinenodes
branchingnodes = self.get_branchingnodes(_rootid)
if len(branchingnodes)!=0:
spinebranchingnodes = list(filter(lambda node : node in spinenodes, branchingnodes))
for node in spinebranchingnodes:
subtrees = self.extract_subtrees(nodeid=node,separate_children=True,rootid=_rootid)
for subtree in subtrees:
firstchild = subtree.nodes[node].children[0].name
if firstchild not in spinenodes:
subdata = subtree.decompose_spines()
data = {**data, **subdata}
return data
[docs]
def resample(self,sampling_length=None,rootid=None,spline_order=1,spline_smoothness=0,decompose_method="branching"):
"""Resample the tree by a given sampling length. Spline interpolation is used for resampling.
The tree is first decomposed into segments either by "branching" (``decompose_segments()``) or by "spine" (``decompose_spines()``).
Then, each segment is resampled by the given sampling length.
Args:
sampling_length (float): sampling length. If None, then we resample the tree by the same number of nodes.
rootid (int): root node ID. If None, then the first rootid will be taken.
spline_order (uint): Spline order. value of 1 means linear interpolation.
spline_smoothness (float or None): Parameter to control the smoothness of spline interpolation. The parameter is similar to the ``s`` parameter from this [1].
decompose_method (str): choose the mode of decomposition: "branching" or "spine".
Returns:
A new Tree object with resampled nodes.
References:
.. [1] https://docs.scipy.org/doc/scipy/reference/generated/scipy.interpolate.splprep.html
Examples:
.. code-block:: python
import pandas as pd
from genepy3d.obj import trees
import matplotlib.pyplot as plt
# Generate a dummy nodes data
data = [[0,0,0.,0.,0.,1.,-1],
[1,0,1.,2.,3.,1.,0],
[2,0,4.,5.,6.,1.,1],
[3,0,7.,2.,2.,1.,1],
[4,0,6.,6.,6.,1.,2],
[5,0,3.,7.,6.,1.,2],
[6,0,5.,5.,9.,1.,5]]
# Dataframe from nodes data
node_tbl = pd.DataFrame(data,columns=['treenode_id','structure_id','x','y','z','r','parent_treenode_id'])
# Create a dummy neuron from node_tbl
neuron = trees.Tree.from_table(node_tbl)
# Plot neuron, only showing the node points
fig = plt.figure()
ax = fig.add_subplot(111,projection='3d')
neuron.plot(ax,show_nodes=True,show_leaves=False,show_root=False,show_branchingnodes=False)
print("Nb. of nodes before resampling:",neuron.get_number_nodes())
# Resampling by sampling length = 1 micron
neuron_resampled = neuron.resample(sampling_length=1)
# Plot resampled neuron, only showing the node points
fig = plt.figure()
ax = fig.add_subplot(111,projection='3d')
neuron_resampled.plot(ax,show_nodes=True,show_leaves=False,show_root=False,show_branchingnodes=False)
print("Nb. of nodes after resampling:",neuron_resampled.get_number_nodes())
"""
if rootid is None:
_rootid = self.get_root()[0]
else:
_rootid = rootid
branching_nodes = self.get_branchingnodes(_rootid)
newnodes = {} # new resampled nodes
# TODO: smatter handling of ID (e.g. remove ancient IDs, new_ID_generator reuse free ancient IDs)
nodeid_max = max(self.nodes.keys()) # for now new node ID is counted from old maximal ID
newnodeid = nodeid_max + 1
if decompose_method=="branching":
segments = self.decompose_segments(rootid=_rootid)
elif decompose_method=="spine":
segments = self.decompose_spines(rootid=_rootid)
brnode_link = [] # contain link between branching node and new closest resampled node.
else:
raise Exception("accept only 'branching' or 'spine' decomposition")
for seg in segments.values():
# print(seg)
try:
# get segment coordinates
coors = self.get_coordinates(seg,rootid).values
# try to remove duplicates before interpolation
_, uix = np.unique(coors,axis=0,return_index=True)
unique_coors = coors[np.sort(uix)]
unique_seg = np.array(seg)[np.sort(uix)]
# # interpolate coordinates
# if spline_smoothness==True:
# s = None
# else:
# s = 0
if unique_coors.shape[0]<(spline_order+1): # linear interpolation
coefcoors, to = interpolate.splprep(unique_coors.T.tolist(),k=1,s=spline_smoothness)
else:
coefcoors, to = interpolate.splprep(unique_coors.T.tolist(),k=spline_order,s=spline_smoothness)
# interpolate radius
r = [self.nodes[nodeid].r for nodeid in unique_seg]
coefr = interpolate.splrep(to,r,k=1,s=0)
except:
return None
# # print("ok");
# # simply copy nodes to new nodes
# for i in range(len(seg)):
# if seg[i] not in newnodes.keys():
# node = self.nodes[seg[i]]
# newnodes[seg[i]] = anytree.Node(seg[i],connector_id=node.connector_id,
# connector_relation=node.connector_relation, r=node.r,
# x=node.x, y=node.y, z=node.z, structure_id=node.structure_id)
# if i != 0:
# newnodes[seg[i]].parent = newnodes[seg[i-1]]
# continue # go to next segment
# in case of success interpolation
if sampling_length == None:
n = len(seg)
else:
n = max(int(np.round(self.compute_length(unique_seg)/sampling_length)),2) + 1 # new sampled points, must have at least 2 points
tn = np.linspace(0,1,n)
xn, yn, zn = interpolate.splev(tn,coefcoors)
rn = interpolate.splev(tn,coefr)
# in case of spine decomposition
if decompose_method=="spine":
newix_flag = np.ones(len(tn))*-1
# newix_supp = []
for node in branching_nodes:
# index start from "1" not "0" since we don't consider branching node at begin.
idx = np.argwhere(node==np.array(unique_seg)[1:]).flatten()
if len(idx)!=0:
oldix = idx[0]+1
newix = np.argwhere(tn<to[oldix]).flatten()[-1]
newix_flag[newix] = node
# newix_supp.append([node,oldix,to[oldix],newix,tn[newix]])
# print(newix_flag)
# print(newix_supp)
# add two end nodes to newnodes
for iseg in [unique_seg[-1],unique_seg[0]]:
if iseg not in newnodes.keys():
node = self.nodes[iseg]
newnodes[iseg] = anytree.Node(iseg,connector_id=node.connector_id,
connector_relation=node.connector_relation,
r=node.r, x=node.x, y=node.y, z=node.z, structure_id=node.structure_id)
# add resampled nodes and link to parent
check_node = newnodes[unique_seg[-1]]
i = n - 2
while i != -1:
if i == 0:
if decompose_method=="spine":
if newix_flag[i]!=-1:
brnode_link.append([int(newix_flag[i]),unique_seg[0]])
new_node = newnodes[unique_seg[0]]
else:
if decompose_method=="spine":
if newix_flag[i]!=-1:
brnode_link.append([int(newix_flag[i]),newnodeid])
new_node = anytree.Node(newnodeid, connector_id=-1, connector_relation='None',
r=rn[i], x=xn[i], y=yn[i], z=zn[i], structure_id=0)
newnodes[newnodeid] = new_node
# if (newnodeid==696):
# print(unique_seg)
check_node.parent = new_node
check_node = new_node
newnodeid = newnodeid + 1
i = i - 1
if decompose_method=="spine":
for link in brnode_link:
mychildren = [item.name for item in newnodes[link[0]].children]
try:
for mychild in mychildren:
newnodes[mychild].parent = newnodes[link[1]]
except:
print(link)
print(mychildren)
print(newnodes[link[1]])
raise Exception("failed, segment {}".format(unique_seg))
for brnode in branching_nodes:
if((brnode != _rootid) & (newnodes[brnode].parent == None)):
del newnodes[brnode]
# new resampled neuron
return Tree(newnodes)
[docs]
def compute_angles(self,decomposed_method="branching",sigma=0,rootid=None):
"""Return angle of each tree node.
The tree is decomposed into segments. Then, the angle of node N is the angle between two vectors (N => N-1) and (N => N+1).
One node can have multiple angle values depending on the decomposition mode.
Args:
decomposed_method (str): "leaf", "spine" or "branching", method for decomposing the neuronal trace.
sigma (float): sigma used in Gaussian function for smoothing the tree branches.
rootid (int): root node ID. If None, then the first rootid will be taken.
Returns:
Pandas dataframe containing angles at every nodes.
Examples:
.. code-block:: python
import pandas as pd
from genepy3d.obj import trees
import matplotlib.pyplot as plt
# Generate a dummy nodes data
data = [[0,0,0.,0.,0.,1.,-1],
[1,0,1.,2.,3.,1.,0],
[2,0,4.,5.,6.,1.,1],
[3,0,7.,2.,2.,1.,1],
[4,0,6.,6.,6.,1.,2],
[5,0,3.,7.,6.,1.,2],
[6,0,5.,5.,9.,1.,5]]
# Dataframe from nodes data
node_tbl = pd.DataFrame(data,columns=['treenode_id','structure_id','x','y','z','r','parent_treenode_id'])
# Create a dummy neuron from node_tbl
neuron = trees.Tree.from_table(node_tbl)
# Plot neuron, only showing the node points
fig = plt.figure()
ax = fig.add_subplot(111,projection='3d')
neuron.plot(ax,show_nodes=True,show_leaves=False,show_root=False,show_branchingnodes=False)
# Plot node ID
coors = neuron.get_coordinates()
for nodeid in coors.index:
tmp = coors.loc[nodeid]
ax.text(tmp.x-0.5,tmp.y+0.1,tmp.z+0.1,nodeid,fontsize=12)
plt.tight_layout();
# Compute the node angles with x, y and z axes
# The default decomposition is "branching"
neuron.compute_angle()
"""
from genepy3d.obj.curves import Curve
if rootid is None:
_rootid = self.get_root()[0]
else:
_rootid = rootid
nodeid_lst, segkey_lst = [],[]
# vecx_lst, vecy_lst, vecz_lst = [], [], []
# thetax_lst, thetay_lst, thetaz_lst = [], [], []
thetas_lst = []
# decompose tree into segments
if decomposed_method=="spine":
segments = self.decompose_spines(_rootid)
elif decomposed_method=="leaf":
segments = self.decompose_leaves(_rootid)
else:
segments = self.decompose_segments(_rootid)
for key,seg in segments.items():
# save node IDs
nodeid_lst = nodeid_lst + list(seg)
# save segment keys
segkey_lst += [key for _ in range(len(seg))]
# create curve
coors = self.get_coordinates(seg).values
crv = Curve(coors)
# smooth if need
if sigma > 0:
crv = crv.convolve_gaussian(sigma)
thetas = crv.compute_angles()
thetas_lst = thetas_lst + thetas.tolist()
# vecx_lst = vecx_lst + vecs[:,0].tolist()
# vecy_lst = vecy_lst + vecs[:,1].tolist()
# vecz_lst = vecz_lst + vecs[:,2].tolist()
# thetax_lst = thetax_lst + thetas[:,0].tolist()
# thetay_lst = thetay_lst + thetas[:,1].tolist()
# thetaz_lst = thetaz_lst + thetas[:,2].tolist()
# making dataframe
df = pd.DataFrame({"seg_key":segkey_lst,
"nodeid":nodeid_lst,
"theta":thetas_lst})
df.set_index(["seg_key","nodeid"],inplace=True)
return df
[docs]
def compute_angles_vector(self,decomposed_method="branching",sigma=0,rootid=None):
"""Return angles of each tree node with x, y and z axes.
Angle at at a node N is the angle between vector (N => N + 1) and three unit vectors of x, y and z axes.
The tree is decomposed into segments. Then, the angles of nodes on each segment are computed.
One node can have multiple angle values depending on the decomposition mode.
Args:
decomposed_method (str): "leaf", "spine" or "branching", method for decomposing the neuronal trace.
sigma (float): sigma used in Gaussian function for smoothing the tree branches.
rootid (int): root node ID. If None, then the first rootid will be taken.
Returns:
Pandas dataframe containing angles at every nodes.
Examples:
.. code-block:: python
import pandas as pd
from genepy3d.obj import trees
import matplotlib.pyplot as plt
# Generate a dummy nodes data
data = [[0,0,0.,0.,0.,1.,-1],
[1,0,1.,2.,3.,1.,0],
[2,0,4.,5.,6.,1.,1],
[3,0,7.,2.,2.,1.,1],
[4,0,6.,6.,6.,1.,2],
[5,0,3.,7.,6.,1.,2],
[6,0,5.,5.,9.,1.,5]]
# Dataframe from nodes data
node_tbl = pd.DataFrame(data,columns=['treenode_id','structure_id','x','y','z','r','parent_treenode_id'])
# Create a dummy neuron from node_tbl
neuron = trees.Tree.from_table(node_tbl)
# Plot neuron, only showing the node points
fig = plt.figure()
ax = fig.add_subplot(111,projection='3d')
neuron.plot(ax,show_nodes=True,show_leaves=False,show_root=False,show_branchingnodes=False)
# Plot node ID
coors = neuron.get_coordinates()
for nodeid in coors.index:
tmp = coors.loc[nodeid]
ax.text(tmp.x-0.5,tmp.y+0.1,tmp.z+0.1,nodeid,fontsize=12)
plt.tight_layout();
# Compute the node angles with x, y and z axes
# The default decomposition is "branching"
neuron.compute_angle_axes()
"""
from genepy3d.obj.curves import Curve
if rootid is None:
_rootid = self.get_root()[0]
else:
_rootid = rootid
nodeid_lst, segkey_lst = [],[]
vecx_lst, vecy_lst, vecz_lst = [], [], []
thetax_lst, thetay_lst, thetaz_lst = [], [], []
# decompose tree into segments
if decomposed_method=="spine":
segments = self.decompose_spines(_rootid)
elif decomposed_method=="leaf":
segments = self.decompose_leaves(_rootid)
else:
segments = self.decompose_segments(_rootid)
for key,seg in segments.items():
# save node IDs
nodeid_lst = nodeid_lst + list(seg)
# save segment keys
segkey_lst += [key for _ in range(len(seg))]
# create curve
coors = self.get_coordinates(seg).values
crv = Curve(coors)
# smooth if need
if sigma > 0:
crv = crv.convolve_gaussian(sigma)
thetas = crv.compute_angles_vector()
thetax_lst = thetax_lst + thetas[:,0].tolist()
thetay_lst = thetay_lst + thetas[:,1].tolist()
thetaz_lst = thetaz_lst + thetas[:,2].tolist()
# making dataframe
df = pd.DataFrame({"seg_key":segkey_lst,
"nodeid":nodeid_lst,
"thetax":thetax_lst,
"thetay":thetay_lst,
"thetaz":thetaz_lst})
df.set_index(["seg_key","nodeid"],inplace=True)
return df
[docs]
def compute_local_3d_scale_sigma(self,sig_lst,dim_param=None,decomposed_method="leaf",rootid=None,postprocess='mean'):
"""Compute local 3d scale of neuron from a list of sigma of Gaussian.
The tree is decomposed into segments. Then, local 3d scale is computed for each segment.
Each segment is considered as a Curve object. Please see ``compute_local_3d_scale_sigma()`` from ``Curve`` class for more detail.
One node can have multiple local 3d scales due to the decomposition. You can specify the ``postprocess`` to return e.g. the mean local 3d scale of the node.
Args:
sig_lst (list of float): list of sigma values of Gaussian.
dim_param (dic): parameters for dimension decomposition. If None, default parameters are used. See parameters of ``compute_local_3d_scale_sigma()`` from ``Curve`` class for the meaning.
decomposed_method (str): "leaf", "spine" or "branching", method for decomposing the neuronal tree.
rootid (int): root node ID. If None, then the first rootid will be taken.
postprocess (str): support "mean", "std", "max", "min" or None.
Returns:
Pandas dataframe containing local 3D scale and additional information: 1D, 2D, 3D flags, curvature and torsion.
Notes:
If ``postprocess = None``, then a Pandas dataframe containing local 3d scale and additional features (curvature, torsion, etc) of each decomposed segment is returned.
Otherwise, a Pandas serie containing only the local 3d scale of each node is returned.
Examples:
.. code-block:: python
import pandas as pd
from genepy3d.obj import trees
import matplotlib.pyplot as plt
# Generate a dummy nodes data
data = [[0,0,0.,0.,0.,1.,-1],
[1,0,1.,2.,3.,1.,0],
[2,0,4.,5.,6.,1.,1],
[3,0,7.,2.,2.,1.,1],
[4,0,6.,6.,6.,1.,2],
[5,0,3.,7.,6.,1.,2],
[6,0,5.,5.,9.,1.,5]]
# Dataframe from nodes data
node_tbl = pd.DataFrame(data,columns=['treenode_id','structure_id','x','y','z','r','parent_treenode_id'])
# Create a dummy neuron from node_tbl
neuron = trees.Tree.from_table(node_tbl)
# Resample the neuron with sampling length = 1 simply to get more points
neuron = neuron.resample(sampling_length=1)
# Plot neuron, only showing the node points
fig = plt.figure()
ax = fig.add_subplot(111,projection='3d')
neuron.plot(ax,show_nodes=True,show_leaves=False,show_root=False,show_branchingnodes=False)
plt.tight_layout();
# Averaged local 3D scale of each node for sigma from 0 => 100
l3ds_tbl = neuron.compute_local_3d_scale_sigma(sig_lst=np.arange(100),postprocess='mean');
# Plot neuron whose nodes are colored by the local 3d scale
fig = plt.figure()
ax = fig.add_subplot(111,projection='3d')
neuron.plot(ax,weights=l3ds_tbl,
show_cbar=True,point_args={"cmap":"rainbow"},
show_nodes=False,show_leaves=False,show_root=False,show_branchingnodes=False)
plt.tight_layout();
"""
from genepy3d.obj.curves import Curve
if rootid is None:
_rootid = self.get_root()[0]
else:
_rootid = rootid
nodeid_lst, ls_lst, plane_line_lst, line_lst, threed_lst = [], [], [], [], []
curvature_lst, torsion_lst = [], []
segkey_lst = []
# decompose tree into segments
if decomposed_method=="spine":
segments = self.decompose_spines(_rootid)
elif decomposed_method=="leaf":
segments = self.decompose_leaves(_rootid)
else:
segments = self.decompose_segments(_rootid)
for key,seg in segments.items():
# create curve
coors = self.get_coordinates(seg).values
crv = Curve(coors)
# intrinsic dimension decomposition
try:
if dim_param is not None:
dim_param["return_dim_results"] = True
ls_res, intrinsic_res = crv.compute_local_3d_scale_sigma(sig_lst,**dim_param)
else: # use default params
ls_res, intrinsic_res = crv.compute_local_3d_scale_sigma(sig_lst,return_dim_results=True)
except:
raise Exception(seg)
# save segment keys
segkey_lst += [key for _ in range(len(seg))]
# curvature and torsion
curvature_lst += crv.compute_curvature().tolist()
torsion_lst += crv.compute_torsion().tolist()
# save node IDs
nodeid_lst = nodeid_lst + list(seg)
# local 3D scale
ls_lst = ls_lst + list(ls_res)
# if(len(np.argwhere(np.array(ls_res)>=70).flatten())==len(ls_res)):
# print(seg)
# 1D, 2D, 3D flags
pl_flag = np.zeros((len(seg),len(sig_lst)))
l_flag = np.zeros((len(seg),len(sig_lst)))
for i in range(len(sig_lst)):
res = intrinsic_res[i]
for plids in res["planeline_pred"]:
pl_flag[plids[0]:plids[1]+1,i]=1.
for lids in res["line_pred"]:
l_flag[lids[0]:lids[1]+1,i]=1.
threed_flag = (pl_flag + np.ones(pl_flag.shape))%2 # xor bit (3D is the opposite of planeline)
plane_line_lst = plane_line_lst + pl_flag.tolist()
line_lst = line_lst + l_flag.tolist()
threed_lst = threed_lst + threed_flag.tolist()
# making dataframe
df = pd.DataFrame({"nodeid":nodeid_lst,
"seg_key":segkey_lst,
"local_scale":ls_lst,
"plane_line_flag":plane_line_lst,
"line_flag":line_lst,
"threed_flag":threed_lst,
"curvature":curvature_lst,"torsion":torsion_lst})
# df.drop_duplicates("nodeid",keep=False) # drop branching nodes
df.set_index("nodeid",inplace=True)
if postprocess == "mean":
return df.groupby("nodeid")['local_scale'].mean()
elif postprocess == "std":
return df.groupby("nodeid")['local_scale'].std()
elif postprocess == "min":
return df.groupby("nodeid")['local_scale'].min()
elif postprocess == "max":
return df.groupby("nodeid")['local_scale'].max()
else:
return df
[docs]
def compute_local_3d_scale_radius(self,r_lst,dim_param=None,decomposed_method="leaf",rootid=None,postprocess='mean'):
"""Compute local 3d scale of neuron from a list of radius of curvatures.
This is the advanced approach to compute the local 3d scale. Please see this [1] for more detail.
The tree is decomposed into segments. Then, local 3d scale is computed for each segment.
Each segment is considered as a Curve object. Please see ``compute_local_3d_scale_radius()`` from ``Curve`` class for more detail.
One node can have multiple local 3d scales due to the decomposition. You can specify the ``postprocess`` to return e.g. the mean local 3d scale of the node.
Args:
r_lst (list of float): list of radius of curvatures. The radius must be > 0.
dim_param (dic): parameters for dimension decomposition. If None, default parameters are used. See parameters of ``compute_local_3d_scale_radius()`` from ``Curve`` class for the meaning.
decomposed_method (str): "leaf", "spine" or "branching", method for decomposing the neuronal tree.
rootid (int): root node ID. If None, then the first rootid will be taken.
postprocess (str): support "mean", "std", "max", "min" or None.
Returns:
Pandas dataframe containing local 3D scale and additional information: 1D, 2D, 3D flags, curvature and torsion.
Notes:
- If ``postprocess = None``, then a Pandas dataframe containing local 3d scale and additional features (curvature, torsion, etc) of each decomposed segment is returned. Otherwise, a Pandas serie containing only the local 3d scale of each node is returned.
- Running this method is much longer than ``compute_local_3d_scale_sigma()``.
References:
.. [1] Phan MS, Matho K, Beaurepaire E, Livet J, Chessel A.
nAdder: A scale-space approach for the 3D analysis of neuronal traces. 2022.
PLOS Computational Biology 18(7): e1010211. DOI: 10.1371/journal.pcbi.1010211
"""
from genepy3d.obj.curves import Curve
if rootid is None:
_rootid = self.get_root()[0]
else:
_rootid = rootid
nodeid_lst, ls_lst, plane_line_lst, line_lst, threed_lst = [], [], [], [], []
curvature_lst, torsion_lst = [], []
segkey_lst = []
# decompose tree into segments
if decomposed_method=="spine":
segments = self.decompose_spines(_rootid)
elif decomposed_method=="leaf":
segments = self.decompose_leaves(_rootid)
else:
segments = self.decompose_segments(_rootid)
for key,seg in segments.items():
# create curve
coors = self.get_coordinates(seg).values
crv = Curve(coors)
# intrinsic dimension decomposition
try:
if dim_param is not None:
dim_param["return_dim_results"] = True
ls_res, intrinsic_res = crv.compute_local_3d_scale_radius(r_lst,**dim_param)
else: # use default params
ls_res, intrinsic_res = crv.compute_local_3d_scale_radius(r_lst,return_dim_results=True)
except:
raise Exception(seg)
# save segment keys
segkey_lst += [key for _ in range(len(seg))]
# curvature and torsion
curvature_lst += crv.compute_curvature().tolist()
torsion_lst += crv.compute_torsion().tolist()
# save node IDs
nodeid_lst = nodeid_lst + list(seg)
# local 3D scale
ls_lst = ls_lst + list(ls_res)
# if(len(np.argwhere(np.array(ls_res)>=70).flatten())==len(ls_res)):
# print(seg)
# 1D, 2D, 3D flags
pl_flag = np.zeros((len(seg),len(r_lst)))
l_flag = np.zeros((len(seg),len(r_lst)))
for i in range(len(r_lst)):
res = intrinsic_res[i]
for plids in res["planeline_pred"]:
pl_flag[plids[0]:plids[1]+1,i]=1.
for lids in res["line_pred"]:
l_flag[lids[0]:lids[1]+1,i]=1.
threed_flag = (pl_flag + np.ones(pl_flag.shape))%2 # xor bit (3D is the opposite of planeline)
plane_line_lst = plane_line_lst + pl_flag.tolist()
line_lst = line_lst + l_flag.tolist()
threed_lst = threed_lst + threed_flag.tolist()
# making dataframe
df = pd.DataFrame({"nodeid":nodeid_lst,
"seg_key":segkey_lst,
"local_scale":ls_lst,
"plane_line_flag":plane_line_lst,
"line_flag":line_lst,
"threed_flag":threed_lst,
"curvature":curvature_lst,"torsion":torsion_lst})
# df.drop_duplicates("nodeid",keep=False) # drop branching nodes
df.set_index("nodeid",inplace=True)
if postprocess == "mean":
return df.groupby("nodeid")['local_scale'].mean()
elif postprocess == "std":
return df.groupby("nodeid")['local_scale'].std()
elif postprocess == "min":
return df.groupby("nodeid")['local_scale'].min()
elif postprocess == "max":
return df.groupby("nodeid")['local_scale'].max()
else:
return df
# def compute_local_3d_scale(self,r_lst,dim_param=None,decomposed_method="leaf",rootid=None):
# """Compute 3d local scale of neuron.
# This function is **DEPRECATED**. Use compute_local_3d_scale_radius().
# Args:
# r_lst (list): list of scales (radius of curvature)
# dim_param (dic): parameters for dimension decomposition
# decomposed_method (str): "leaf", "spine" or "branching", method for cutting neuronal trace
# rootid (int): ID of root
# Returns:
# Pandas dataframe containing local 3D scale, 1D, 2D, 3D flags, curvature and torsion
# """
# from genepy3d.obj.curves import Curve
# if rootid is None:
# _rootid = self.get_root()[0]
# else:
# _rootid = rootid
# nodeid_lst, ls_lst, plane_line_lst, line_lst, threed_lst = [], [], [], [], []
# curvature_lst, torsion_lst = [], []
# segkey_lst = []
# # decompose tree into segments
# if decomposed_method=="spine":
# segments = self.decompose_spines(_rootid)
# elif decomposed_method=="leaf":
# segments = self.decompose_leaves(_rootid)
# else:
# segments = self.decompose_segments(_rootid)
# for key,seg in segments.items():
# # create curve
# coors = self.get_coordinates(seg).values
# crv = Curve(coors)
# # intrinsic dimension decomposition
# try:
# if dim_param is not None:
# ls_res, intrinsic_res = crv.compute_local_3d_scale(r_lst,
# dim_param["eps_seg_len"],
# dim_param["eps_crv_len"],
# dim_param["sig_step"],
# dim_param["eps_kappa"],
# dim_param["eps_tau"],
# return_dim_results=True)
# else: # use default params
# ls_res, intrinsic_res = crv.compute_local_3d_scale(r_lst,return_dim_results=True)
# except:
# raise Exception(seg)
# # save segment keys
# segkey_lst += [key for _ in range(len(seg))]
# # curvature and torsion
# curvature_lst += crv.compute_curvature().tolist()
# torsion_lst += crv.compute_torsion().tolist()
# # save node IDs
# nodeid_lst = nodeid_lst + list(seg)
# # local 3D scale
# ls_lst = ls_lst + list(ls_res)
# # if(len(np.argwhere(np.array(ls_res)>=70).flatten())==len(ls_res)):
# # print(seg)
# # 1D, 2D, 3D flags
# pl_flag = np.zeros((len(seg),len(r_lst)))
# l_flag = np.zeros((len(seg),len(r_lst)))
# for i in range(len(r_lst)):
# res = intrinsic_res[i]
# for plids in res["planeline_pred"]:
# pl_flag[plids[0]:plids[1]+1,i]=1.
# for lids in res["line_pred"]:
# l_flag[lids[0]:lids[1]+1,i]=1.
# threed_flag = (pl_flag + np.ones(pl_flag.shape))%2 # xor bit (3D is the opposite of planeline)
# plane_line_lst = plane_line_lst + pl_flag.tolist()
# line_lst = line_lst + l_flag.tolist()
# threed_lst = threed_lst + threed_flag.tolist()
# # making dataframe
# df = pd.DataFrame({"nodeid":nodeid_lst,
# "seg_key":segkey_lst,
# "local_scale":ls_lst,
# "plane_line_flag":plane_line_lst,
# "line_flag":line_lst,
# "threed_flag":threed_lst,
# "curvature":curvature_lst,"torsion":torsion_lst})
# # df.drop_duplicates("nodeid",keep=False) # drop branching nodes
# df.set_index("nodeid",inplace=True)
# return df
[docs]
def summary(self):
"""Return a brief summary of the tree.
Returns:
A Pandas Serie contains
- number of nodes
- number of leaves
- number of branching nodes
- number of connectors
"""
index = ['id',
'name',
'root',
'nb_nodes',
'nb_leaves',
'nb_branchingnodes',
'nb_connectors']
data = [self.id,
self.name,
self.get_root(),
[len(self.get_preorder_nodes(_rootid)) for _rootid in self.get_root()],
# len(list(PreOrderIter(self.nodes[self.get_root()]))),
[len(self.get_leaves(_rootid)) for _rootid in self.get_root()],
[len(self.get_branchingnodes(_rootid)) for _rootid in self.get_root()],
[len(self.get_connectors(_rootid)) for _rootid in self.get_root()]]
return pd.Series(data,index=index)
[docs]
def to_curve(self,nodeid=None):
"""Convert a segment of tree given by the list of node ID to a Curve object.
Args:
nodeid (list of int): List of node ID. If None, then consider all nodes of the tree.
Returns:
A Curve object.
"""
from genepy3d.obj.curves import Curve
return Curve(self.get_coordinates(nodeid).values)
[docs]
def to_points(self,nodeid=None):
"""Convert a segment of tree given by the list of node ID to a Points object.
Args:
nodeid (list of int): List of node ID. If None, then consider all nodes of the tree.
Returns:
A Points object.
"""
from genepy3d.obj.points import Points
return Points(self.get_coordinates(nodeid).values)
[docs]
def plot(self,ax,projection='3d',rootid=None,spine_only=False,
show_root=True,show_leaves=True,show_branchingnodes=True,show_connectors=True,show_nodes=False,
root_args={},leaves_args={},branchingnodes_args={},connectors_args={},
weights=None,weights_display_type="c",point_args = {},show_cbar=False,cbar_args={},
line_args={},scales=(1.,1.,1.),equal_axis=True):
"""Plot tree using matplotlib.
Args:
ax: plot axis.
projection (str): we support *3d, xy, xz, yz* modes.
spine_only (bool): if True, then only plot tree spine.
show_root (bool): if True, then display root node.
show_leaves (bool): if True, then display leaves nodes.
show_branchingnodes (bool): if True, then display branching nodes.
show_connectors (bool): if True, then display connector nodes.
show_nodes (bool): if True, then display all nodes.
root_args (dic): plot params for root node.
leaves_args (dic): plot params for leaves nodes.
branchingnodes_args (dic): plot params for branching nodes.
connectors_args (dic): plot params for connectors nodes.
weights (pandas serie): serie indexed by node ID and contains corresponding weights
weights_display_type (str): "s" for point size or "c" for point color.
point_args (dic): point plot params.
line_args (dic): line plot params.
scales (tuple of float): use to set x, y and z scales.
equal_axis (bool): fix equal axes.
"""
_root_args={'s':50,'c':'red'}
for key,val in root_args.items():
_root_args[key] = val
_leaves_args={'s':8,'c':'blue'}
for key,val in leaves_args.items():
_leaves_args[key] = val
_branchingnodes_args={'s':20,'c':'magenta'}
for key,val in branchingnodes_args.items():
_branchingnodes_args[key] = val
_connectors_args={'s':70,'c':'red','alpha':0.7}
for key,val in connectors_args.items():
_connectors_args[key] = val
_point_args = {"s":10,"alpha":0.8,"c":"k","cmap":"viridis"}
for key,val in point_args.items():
_point_args[key] = val
_line_args={'alpha':0.8,'c':'k'}
for key,val in line_args.items():
_line_args[key] = val
if rootid is None:
_rootid = self.get_root()[0]
else:
_rootid = rootid
if show_root==True:
coors = self.get_coordinates(_rootid).values
x, y, z = coors[:,0], coors[:,1], coors[:,2]
pl.plot_point(ax,projection,x,y,z,scales,_root_args)
if show_nodes==True:
coors = self.get_coordinates().values
x, y, z = coors[:,0], coors[:,1], coors[:,2]
pl.plot_point(ax,projection,x,y,z,scales,_point_args)
if spine_only==True:
spine_nodes = self.compute_spine(_rootid)
coors = self.get_coordinates(spine_nodes).values
x, y, z = coors[:,0], coors[:,1], coors[:,2]
pl.plot_line(ax,projection,x,y,z,scales,_line_args)
if weights is not None:
weight_args = _point_args.copy()
weight_args[weights_display_type] = weights.loc[spine_nodes]
if "vmin" not in weight_args.keys():
weight_args["vmin"] = weights.min()
if "vmax" not in weight_args.keys():
weight_args["vmax"] = weights.max()
plo = pl.plot_point(ax,projection,x,y,z,scales,weight_args)
if (show_cbar == True):
if "shrink" in cbar_args:
shrink = cbar_args["shrink"]
else:
shrink = 0.7
cbar = ax.figure.colorbar(plo,shrink=shrink)
if "ticks" in cbar_args:
cbar.set_ticks(cbar_args["ticks"])
if "ticklabels" in cbar_args:
cbar.set_ticklabels(cbar_args["ticklabels"])
if show_leaves==True:
coors = self.get_coordinates(spine_nodes[-1]).values
x, y, z = coors[:,0], coors[:,1], coors[:,2]
pl.plot_point(ax,projection,x,y,z,scales,_leaves_args)
else:
segments = self.decompose_segments(_rootid)
cbar_flag = True
for seg_nodes in segments.values():
coors = self.get_coordinates(seg_nodes).values
x, y, z = coors[:,0], coors[:,1], coors[:,2]
segment_args = _line_args.copy()
pl.plot_line(ax,projection,x,y,z,scales,segment_args)
if weights is not None:
weight_args = _point_args.copy()
weight_args[weights_display_type] = weights.loc[seg_nodes]
if "vmin" not in weight_args.keys():
weight_args["vmin"] = weights.min()
if "vmax" not in weight_args.keys():
weight_args["vmax"] = weights.max()
plo = pl.plot_point(ax,projection,x,y,z,scales,weight_args)
if (show_cbar == True) & (cbar_flag == True):
if "shrink" in cbar_args:
shrink = cbar_args["shrink"]
else:
shrink = 0.7
cbar = ax.figure.colorbar(plo,shrink=shrink)
if "ticks" in cbar_args:
cbar.set_ticks(cbar_args["ticks"])
if "ticklabels" in cbar_args:
cbar.set_ticklabels(cbar_args["ticklabels"]);
cbar_flag = False
if show_leaves==True:
leaves_nodes = self.get_leaves(_rootid)
coors = self.get_coordinates(leaves_nodes).values
x, y, z = coors[:,0], coors[:,1], coors[:,2]
pl.plot_point(ax,projection,x,y,z,scales,_leaves_args)
if show_branchingnodes==True:
inter_nodes = self.get_branchingnodes(_rootid)
if len(inter_nodes)!=0:
coors = self.get_coordinates(inter_nodes).values
x, y, z = coors[:,0], coors[:,1], coors[:,2]
pl.plot_point(ax,projection,x,y,z,scales,_branchingnodes_args)
if show_connectors==True:
connectors_nodes = self.get_connectors(_rootid).index.values
if len(connectors_nodes)!=0:
coors = self.get_coordinates(connectors_nodes).values
x, y, z = coors[:,0], coors[:,1], coors[:,2]
pl.plot_point(ax,projection,x,y,z,scales,_connectors_args)
if equal_axis==True:
if projection != '3d':
ax.axis('equal')
else:
param = pl.fix_equal_axis(self.get_coordinates(rootid=_rootid).values / np.array(scales))
ax.set_xlim(param['xmin'],param['xmax'])
ax.set_ylim(param['ymin'],param['ymax'])
ax.set_zlim(param['zmin'],param['zmax'])
if projection != '3d':
if projection=='xy':
xlbl, ylbl = 'X', 'Y'
elif projection=='xz':
xlbl, ylbl = 'X', 'Z'
else:
xlbl, ylbl = 'Y', 'Z'
ax.set_xlabel(xlbl)
ax.set_ylabel(ylbl)
else:
ax.set_xlabel('X')
ax.set_ylabel('Y')
ax.set_zlabel('Z')