File: TGraph.py

package info (click to toggle)
trinityrnaseq 2.11.0%2Bdfsg-6
  • links: PTS, VCS
  • area: main
  • in suites: bullseye
  • size: 417,528 kB
  • sloc: perl: 48,420; cpp: 17,749; java: 12,695; python: 3,124; sh: 1,030; ansic: 983; makefile: 688; xml: 62
file content (158 lines) | stat: -rwxr-xr-x 4,487 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
#!/usr/bin/python3
# encoding: utf-8

from __future__ import (absolute_import, division,
                        print_function, unicode_literals)

import os, sys, re
import logging
import argparse
import collections
import numpy
import time

import TNode

logger = logging.getLogger(__name__)


class TGraph:

 
    def __init__(self, gene_id):

        self.node_cache = dict()
        self.gene_id = gene_id

    
    def get_node(self, transcript_id, loc_node_id, node_seq):
        
        """
        Instantiates Node objects, and stores them in a graph.

        *** use this method for instantiating all Node objects ***

        use clear_node_cache() to clear the graph
                
        """
                
        logger.debug("{}\t{}".format(loc_node_id, node_seq))
        
        if len(node_seq) == 0:
            raise RuntimeError("Error, non-zero length node_seq required for parameter")


        if loc_node_id in self.node_cache:
            node_obj = self.node_cache[ loc_node_id ]
            node_obj.add_transcripts(transcript_id)
            if node_obj.seq != node_seq:
                errmsg = "ERROR: have conflicting node sequences for {} node_id: {}\n".format(self.get_gene_id(),
                                                   loc_node_id) + "{}\n vs. \n{}".format(node_obj.seq, node_seq)
                logger.critical(errmsg)
                
                raise RuntimeError(errmsg)
            else:
                return node_obj
            
        else:
            # instantiate a new one
            node_obj = TNode.TNode(self, transcript_id, loc_node_id, node_seq)
            self.node_cache[ loc_node_id ] = node_obj
            return node_obj



    def get_all_nodes(self):
        return list(self.node_cache.values())
    
    def clear_node_cache(self):
        """
        clears the graph
        """
        self.node_cache.clear()
    
    def clear_touch_settings(self):
        """
        clear the touch settings for each of the nodes
        """

        for node in self.get_all_nodes():
            node.clear_touch()
    


    def add_edges(self, from_nodes_list, to_nodes_list):

        for from_node in from_nodes_list:
            for to_node in to_nodes_list:
                from_node.add_next_node(to_node)
                to_node.add_prev_node(from_node)

    def prune_edges(self, from_nodes_list, to_nodes_list):

        for from_node in from_nodes_list:
            for to_node in to_nodes_list:
                from_node.remove_next_node(to_node)
                to_node.remove_prev_node(from_node)
    

    def prune_node(self, node):
        logger.debug("pruning node: {}".format(node))
        self.prune_edges(node.get_prev_nodes(), [node])
        self.prune_edges([node], node.get_next_nodes())
        node.dead = True
        self.node_cache.pop(node.get_loc_id())
    

    def retrieve_node(self, node_id):
        """
        does not instantiate, only retrieves.
        If loc_node_id is not in the graph, returns None
        """

        if node_id in self.node_cache:
            return self.node_cache[node_id]
        else:
            return None
        
        
    def get_gene_id(self):
        return self.gene_id


    def draw_graph(self, filename):

        logger.debug("drawing graph: {}".format(filename))
        ofh = open(filename, 'w')

        ofh.write("digraph G {\n")

        for node_id in self.node_cache:
            node = self.node_cache[node_id]
            node_seq = node.get_seq()
            gene_node_id = node.get_gene_node_id()
            next_nodes = node.get_next_nodes()
            node_seq_len = len(node_seq)

            max_len_show = 50
            max_len_show_half = int(max_len_show/2)

            if node_seq_len > max_len_show:
                node_seq = node_seq[0:max_len_show_half] + "..." + node_seq[(node_seq_len-max_len_show_half):node_seq_len]
            
            ofh.write("{} [label=\"{}:Len{}:T{}:{}\"]\n".format(node.get_id(), gene_node_id, node_seq_len,
                                                                node.get_topological_order(), node_seq))
            
            for next_node in next_nodes:
                ofh.write("{}->{}\n".format(node.get_id(), next_node.get_id()))

        ofh.write("}\n") # close it

        ofh.close()
    
    def __repr__(self):
        txt = ""
        for node in self.get_all_nodes():
            txt += node.toString() + "\n"

        return txt