File: treesum.py

package info (click to toggle)
python-dendropy 4.2.0%2Bdfsg-1
  • links: PTS, VCS
  • area: main
  • in suites: stretch
  • size: 68,392 kB
  • ctags: 3,947
  • sloc: python: 41,840; xml: 1,400; makefile: 15
file content (455 lines) | stat: -rw-r--r-- 19,882 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
#! /usr/bin/env python

##############################################################################
##  DendroPy Phylogenetic Computing Library.
##
##  Copyright 2010-2015 Jeet Sukumaran and Mark T. Holder.
##  All rights reserved.
##
##  See "LICENSE.rst" for terms and conditions of usage.
##
##  If you use this work or any portion thereof in published work,
##  please cite it as:
##
##     Sukumaran, J. and M. T. Holder. 2010. DendroPy: a Python library
##     for phylogenetic computing. Bioinformatics 26: 1569-1571.
##
##############################################################################

"""
Tree summarization and consensus tree building.
"""

import math
import collections
import dendropy
from dendropy.datamodel import taxonmodel
from dendropy.calculate.statistics import mean_and_sample_variance

##############################################################################
## TreeSummarizer

class TreeSummarizer(object):
    "Summarizes a distribution of trees."

    def __init__(self, **kwargs):
        """
        __init__ kwargs:

            - ``support_as_labels`` (boolean)
            - ``support_as_edge_lengths`` (boolean)
            - ``support_as_percentages`` (boolean)
            - ``support_label_decimals`` (integer)
        """
        self.support_as_labels = kwargs.get("support_as_labels", True)
        self.support_as_edge_lengths = kwargs.get("support_as_edge_lengths", False)
        self.support_as_percentages = kwargs.get("support_as_percentages", False)
        self.add_node_metadata = kwargs.get("add_node_metadata", True)
        self.default_support_label_decimals = 4
        self.support_label_decimals = kwargs.get("support_label_decimals", self.default_support_label_decimals)
        self.weighted_splits = False

    def tree_from_splits(self,
            split_distribution,
            min_freq=0.5,
            rooted=None,
            include_edge_lengths=True):
        """Returns a consensus tree from splits in ``split_distribution``.

        If include_edge_length_var is True, then the sample variance of the
            edge length will also be calculated and will be stored as
            a length_var attribute.
        """
        taxon_namespace = split_distribution.taxon_namespace
        taxa_mask = taxon_namespace.all_taxa_bitmask()
        if self.weighted_splits:
            split_freqs = split_distribution.weighted_split_frequencies
        else:
            split_freqs = split_distribution.split_frequencies
        if rooted is None:
            if split_distribution.is_all_counted_trees_rooted():
                rooted = True
            elif split_distribution.is_all_counted_trees_strictly_unrooted:
                rooted = False
        #include_edge_lengths = self.support_as_labels and include_edge_lengths
        if self.support_as_edge_lengths and include_edge_lengths:
            raise Exception("Cannot map support as edge lengths if edge lengths are to be set on consensus tree")
        to_try_to_add = []
        _almost_one = lambda x: abs(x - 1.0) <= 0.0000001
        for s in split_freqs:
            freq = split_freqs[s]
            if (min_freq is None) or (freq >= min_freq) or (_almost_one(min_freq) and _almost_one(freq)):
                to_try_to_add.append((freq, s))
        to_try_to_add.sort(reverse=True)
        splits_for_tree = [i[1] for i in to_try_to_add]
        con_tree = dendropy.Tree.from_split_bitmasks(
                split_bitmasks=splits_for_tree,
                taxon_namespace=taxon_namespace,
                is_rooted=rooted)
        con_tree.encode_bipartitions()

        if include_edge_lengths:
            split_edge_lengths = {}
            for split, edges in split_distribution.split_edge_lengths.items():
                if len(edges) > 0:
                    mean, var = mean_and_sample_variance(edges)
                    elen = mean
                else:
                    elen = None
                split_edge_lengths[split] = elen
        else:
            split_edge_lengths = None

        for node in con_tree.postorder_node_iter():
            split = node.edge.bipartition.split_bitmask
            if split in split_freqs:
                self.map_split_support_to_node(node=node, split_support=split_freqs[split])
            if include_edge_lengths and split in split_distribution.split_edge_lengths:
                edges = split_distribution.split_edge_lengths[split]
                if len(edges) > 0:
                    mean, var = mean_and_sample_variance(edges)
                    elen = mean
                else:
                    elen = None
                node.edge.length = elen

        return con_tree

    def compose_support_label(self, split_support_freq):
        "Returns an appropriately composed and formatted support label."
        if self.support_as_percentages:
            if self.support_label_decimals <= 0:
                support_label = "%d" % round(split_support_freq * 100, 0)
            else:
                support_label_template = "%%0.%df" % self.support_label_decimals
                support_label = support_label_template % round(split_support_freq * 100,
                    self.support_label_decimals)
        else:
            if self.support_label_decimals <= 0:
                support_label_decimals = self.default_support_label_decimals
            else:
                support_label_decimals = self.support_label_decimals
            support_label_template = "%%0.%df" % support_label_decimals
            support_label = support_label_template % round(split_support_freq,
                support_label_decimals)
        return support_label

    def map_split_support_to_node(self,
            node,
            split_support,
            attr_name="support"):
        "Appropriately sets up a node."
        if self.support_as_percentages:
            support_value = split_support * 100
        else:
            support_value = split_support
        if self.support_as_labels:
            node.label = self.compose_support_label(split_support)
        if self.support_as_edge_lengths:
            node.edge.length = support_value
        if self.add_node_metadata and attr_name:
            setattr(node, attr_name, support_value)
            node.annotations.drop(name=attr_name)
            node.annotations.add_bound_attribute(attr_name,
                    real_value_format_specifier=".{}f".format(self.support_label_decimals))
        return node

    def map_split_support_to_tree(self, tree, split_distribution):
        "Maps splits support to the given tree."
        if self.weighted_splits:
            split_freqs = split_distribution.weighted_split_frequencies
        else:
            split_freqs = split_distribution.split_frequencies
        tree.reindex_taxa(taxon_namespace=split_distribution.taxon_namespace)
        assert tree.taxon_namespace is split_distribution.taxon_namespace
        tree.encode_bipartitions()
        for edge in tree.postorder_edge_iter():
            split = edge.bipartition.split_bitmask
            if split in split_freqs:
                split_support = split_freqs[split]
            else:
                split_support = 0.0
            self.map_split_support_to_node(edge.head_node, split_support)
        return tree

    def annotate_nodes_and_edges(self,
            tree,
            split_distribution,
            is_bipartitions_updated=False,):
        """
        Summarizes edge length and age information in ``split_distribution`` for
        each node on target tree ``tree``.
        This will result in each node in ``tree`` being decorated with the following attributes:
            ``age_mean``,
            ``age_median``,
            ``age_sd``,
            ``age_hpd95``,
            ``age_range``,
        And each edge in ``tree`` being decorated with the following attributes:
            ``length_mean``,
            ``length_median``,
            ``length_sd``,
            ``length_hpd95``,
            ``length_range``,
        These attributes will be added to the annotations dictionary to be persisted.
        """
        assert tree.taxon_namespace is split_distribution.taxon_namespace
        if not is_bipartitions_updated:
            tree.encode_bipartitions()
        split_edge_length_summaries = split_distribution.split_edge_length_summaries
        split_node_age_summaries = split_distribution.split_node_age_summaries
        fields = ['mean', 'median', 'sd', 'hpd95', 'quant_5_95', 'range']

        for edge in tree.preorder_edge_iter():
            split = edge.split_bitmask
            nd = edge.head_node
            for summary_name, summary_target, summary_src in [ ('length', edge, split_edge_length_summaries),
                                               ('age', nd, split_node_age_summaries) ]:
                if split in summary_src:
                    summary = summary_src[split]
                    for field in fields:
                        attr_name = summary_name + "_" + field
                        setattr(summary_target, attr_name, summary[field])
                        # clear annotations, which might be associated with either nodes
                        # or edges due to the way NEXUS/NEWICK node comments are parsed
                        nd.annotations.drop(name=attr_name)
                        edge.annotations.drop(name=attr_name)
                        summary_target.annotations.add_bound_attribute(attr_name)
                else:
                    for field in fields:
                        attr_name = summary_name + "_" + field
                        setattr(summary_target, attr_name, None)

    def summarize_node_ages_on_tree(self,
            tree,
            split_distribution,
            set_edge_lengths=True,
            collapse_negative_edges=False,
            allow_negative_edges=False,
            summarization_fn=None,
            is_bipartitions_updated=False):
        """
        Sets the ``age`` attribute of nodes on ``tree`` (a |Tree| object) to the
        result of ``summarization_fn`` applied to the vector of ages of the
        same node on the input trees (in ``split_distribution``, a
        `SplitDistribution` object) being summarized.
        ``summarization_fn`` should take an iterable of floats, and return a float. If |None|, it
        defaults to calculating the mean (``lambda x: float(sum(x))/len(x)``).
        If ``set_edge_lengths`` is |True|, then edge lengths will be set to so that the actual node ages
        correspond to the ``age`` attribute value.
        If ``collapse_negative_edges`` is True, then edge lengths with negative values will be set to 0.
        If ``allow_negative_edges`` is True, then no error will be raised if edges have negative lengths.
        """
        if summarization_fn is None:
            summarization_fn = lambda x: float(sum(x))/len(x)
        if is_bipartitions_updated:
            tree.encode_splits()
        #'height',
        #'height_median',
        #'height_95hpd',
        #'height_range',
        #'length',
        #'length_median',
        #'length_95hpd',
        #'length_range',
        for edge in tree.preorder_edge_iter():
            split = edge.bipartition.split_bitmask
            nd = edge.head_node
            if split in split_distribution.split_node_ages:
                ages = split_distribution.split_node_ages[split]
                nd.age = summarization_fn(ages)
            else:
                # default to age of parent if split not found
                nd.age = nd.parent_node.age
            ## force parent nodes to be at least as old as their oldest child
            if collapse_negative_edges:
                for child in nd.child_nodes():
                    if child.age > nd.age:
                        nd.age = child.age
        if set_edge_lengths:
            tree.set_edge_lengths_from_node_ages(allow_negative_edges=allow_negative_edges)
        return tree

    def summarize_edge_lengths_on_tree(self,
            tree,
            split_distribution,
            summarization_fn=None,
            is_bipartitions_updated=False):
        """
        Sets the lengths of edges on ``tree`` (a |Tree| object) to the mean
        lengths of the corresponding edges on the input trees (in
        ``split_distribution``, a `SplitDistribution` object) being
        summarized.
        ``summarization_fn`` should take an iterable of floats, and return a float. If |None|, it
        defaults to calculating the mean (``lambda x: float(sum(x))/len(x)``).
        """
        if summarization_fn is None:
            summarization_fn = lambda x: float(sum(x))/len(x)
        if not is_bipartitions_updated:
            tree.encode_bipartitions()
        for edge in tree.postorder_edge_iter():
            split = edge.bipartition.split_bitmask
            if (split in split_distribution.split_edge_lengths
                    and split_distribution.split_edge_lengths[split]):
                lengths = split_distribution.split_edge_lengths[split]
                edge.length = summarization_fn(lengths)
            elif (split in split_distribution.split_edge_lengths
                    and not split_distribution.split_edge_lengths[split]):
                # no input trees had any edge lengths for this split
                edge.length = None
            else:
                # split on target tree that was not found in any of the input
                # trees
                edge.length = 0.0
        return tree

        ## here we add the support values and/or edge lengths for the terminal taxa ##
        for node in leaves:
            if not is_rooted:
                split = node.edge.split_bitmask
            else:
                split = node.edge.leafset_bitmask
            self.map_split_support_to_node(node, 1.0)
            if include_edge_lengths:
                elen = split_distribution.split_edge_lengths.get(split, [0.0])
                if len(elen) > 0:
                    mean, var = mean_and_sample_variance(elen)
                    node.edge.length = mean
                    if include_edge_length_var:
                        node.edge.length_var = var
                else:
                    node.edge.length = None
                    if include_edge_length_var:
                        node.edge.length_var = None
        #if include_edge_lengths:
            #self.map_edge_lengths_to_tree(tree=con_tree,
            #        split_distribution=split_distribution,
            #        summarization_fn=summarization_fn,
            #        include_edge_length_var=False)
        return con_tree

    def count_splits_on_trees(self, tree_iterator, split_distribution=None, is_bipartitions_updated=False):
        """
        Given a list of trees file, a SplitsDistribution object (a new one, or,
        if passed as an argument) is returned collating the split data in the files.
        """
        if split_distribution is None:
            split_distribution = dendropy.SplitDistribution()
        taxon_namespace = split_distribution.taxon_namespace
        for tree_idx, tree in enumerate(tree_iterator):
            if taxon_namespace is None:
                assert(split_distribution.taxon_namespace is None)
                split_distribution.taxon_namespace = tree.taxon_namespace
                taxon_namespace = tree.taxon_namespace
            else:
                assert(taxon_namespace is tree.taxon_namespace)
            split_distribution.count_splits_on_tree(tree,
                    is_bipartitions_updated=is_bipartitions_updated)
        return split_distribution

    def consensus_tree(self, trees, min_freq=0.5, is_bipartitions_updated=False):
        """
        Returns a consensus tree of all trees in ``trees``, with minumum frequency
        of split to be added to the consensus tree given by ``min_freq``.
        """
        taxon_namespace = trees[0].taxon_namespace
        split_distribution = dendropy.SplitDistribution(taxon_namespace=taxon_namespace)
        self.count_splits_on_trees(trees,
                split_distribution=split_distribution,
                is_bipartitions_updated=is_bipartitions_updated)
        tree = self.tree_from_splits(split_distribution, min_freq=min_freq)
        return tree

##############################################################################
## Convenience Wrappers

def consensus_tree(trees, min_freq=0.5, is_bipartitions_updated=False, **kwargs):
    """
    Returns a consensus tree of all trees in ``trees``, with minumum frequency
    of split to be added to the consensus tree given by ``min_freq``.
    """
    tsum = TreeSummarizer(**kwargs)
    return tsum.consensus_tree(trees,
            min_freq=min_freq,
            is_bipartitions_updated=is_bipartitions_updated)

##############################################################################
## TreeCounter

class TopologyCounter(object):
    """
    Tracks frequency of occurrences of topologies.
    """

    def hash_topology(tree):
        """
        Set of all splits on tree: default topology hash.
        """
        return frozenset(tree.bipartition_encoding)
    hash_topology = staticmethod(hash_topology)

    def __init__(self):
        self.topology_hash_map = {}
        self.total_trees_counted = 0

    def update_topology_hash_map(self,
            src_map):
        """
        Imports data from another counter.
        """
        trees_counted = 0
        for topology_hash in src_map:
            if topology_hash not in self.topology_hash_map:
                self.topology_hash_map[topology_hash] = src_map[topology_hash]
            else:
                self.topology_hash_map[topology_hash] = self.topology_hash_map[topology_hash] + src_map[topology_hash]
            self.total_trees_counted += src_map[topology_hash]

    def count(self,
            tree,
            is_bipartitions_updated=False):
        """
        Logs/registers a tree.
        """
        if not is_bipartitions_updated:
            tree.encode_bipartitions()
        topology = self.hash_topology(tree)
        if topology not in self.topology_hash_map:
            self.topology_hash_map[topology] = 1
        else:
            self.topology_hash_map[topology] = self.topology_hash_map[topology] + 1
        self.total_trees_counted += 1

    def calc_hash_freqs(self):
        """
        Returns an ordered dictionary (collections.OrderedDict) of topology hashes mapped
        to a tuple, (raw numbers of occurrences, proportion of total trees
        counted) in (descending) order of the proportion of occurrence.
        """
        t_freqs = collections.OrderedDict()
        count_topology = [(v, k) for k, v in self.topology_hash_map.items()]
        count_topology.sort(reverse=True)
        for count, topology_hash in count_topology:
            freq = float(count) / self.total_trees_counted
            t_freqs[topology_hash] = (count, freq)
        return t_freqs

    def calc_tree_freqs(self, taxon_namespace, is_rooted=False):
        """
        Returns an ordered dictionary (collections.OrderedDict) of DendroPy trees mapped
        to a tuple, (raw numbers of occurrences, proportion of total trees
        counted) in (descending) order of the proportion of occurrence.
        """
        hash_freqs = self.calc_hash_freqs()
        tree_freqs = collections.OrderedDict()
        for topology_hash, (count, freq) in hash_freqs.items():
            tree = dendropy.Tree.from_bipartition_encoding(
                bipartition_encoding=topology_hash,
                taxon_namespace=taxon_namespace,
                is_rooted=is_rooted)
            tree_freqs[tree] = (count, freq)
        return tree_freqs

## TreeCounter
##############################################################################