1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166
|
import argparse
import json
import matplotlib as mpl
mpl.use("Agg")
from matplotlib import gridspec
import matplotlib.pyplot as plt
from matplotlib.collections import LineCollection
import numpy as np
import sys
sys.path.append('..')
from base.io_util import json_to_tree
def plot_tree(tree, figure_name, color_by_trait, initial_branch_width, tip_size):
"""Plot a BioPython Phylo tree in the BALTIC-style.
"""
# Plot H3N2 tree in BALTIC style from Bio.Phylo tree.
mpl.rcParams['savefig.dpi'] = 120
mpl.rcParams['figure.dpi'] = 100
mpl.rcParams['font.weight']=300
mpl.rcParams['axes.labelweight']=300
mpl.rcParams['font.size']=14
yvalues = [node.yvalue for node in tree.find_clades()]
y_span = max(yvalues)
y_unit = y_span / float(len(yvalues))
# Setup colors.
trait_name = color_by_trait
traits = [k.attr[trait_name] for k in tree.find_clades()]
norm = mpl.colors.Normalize(min(traits), max(traits))
cmap = mpl.cm.viridis
#
# Setup the figure grid.
#
fig = plt.figure(figsize=(8, 6), facecolor='w')
gs = gridspec.GridSpec(2, 1, height_ratios=[14, 1], width_ratios=[1], hspace=0.1, wspace=0.1)
ax = fig.add_subplot(gs[0])
colorbar_ax = fig.add_subplot(gs[1])
L=len([k for k in tree.find_clades() if k.is_terminal()])
# Setup arrays for tip and internal node coordinates.
tip_circles_x = []
tip_circles_y = []
tip_circles_color = []
tip_circle_sizes = []
node_circles_x = []
node_circles_y = []
node_circles_color = []
node_line_widths = []
node_line_segments = []
node_line_colors = []
branch_line_segments = []
branch_line_widths = []
branch_line_colors = []
branch_line_labels = []
for k in tree.find_clades(): ## iterate over objects in tree
x=k.attr["num_date"] ## or from x position determined earlier
y=k.yvalue ## get y position from .drawTree that was run earlier, but could be anything else
if k.up is None:
xp = None
else:
xp=k.up.attr["num_date"] ## get x position of current object's parent
if x==None: ## matplotlib won't plot Nones, like root
x=0.0
if xp==None:
xp=x
c = 'k'
if k.attr.has_key(trait_name):
c = cmap(norm(k.attr[trait_name]))
branchWidth=2
if k.is_terminal(): ## if leaf...
s = tip_size ## tip size can be fixed
tip_circle_sizes.append(s)
tip_circles_x.append(x)
tip_circles_y.append(y)
tip_circles_color.append(c)
else: ## if node...
k_leaves = [child
for child in k.find_clades()
if child.is_terminal()]
# Scale branch widths by the number of tips.
branchWidth += initial_branch_width * len(k_leaves) / float(L)
if len(k.clades)==1:
node_circles_x.append(x)
node_circles_y.append(y)
node_circles_color.append(c)
ax.plot([x,x],[k.clades[-1].yvalue, k.clades[0].yvalue], lw=branchWidth, color=c, ls='-', zorder=9, solid_capstyle='round')
branch_line_segments.append([(xp, y), (x, y)])
branch_line_widths.append(branchWidth)
branch_line_colors.append(c)
branch_lc = LineCollection(branch_line_segments, zorder=9)
branch_lc.set_color(branch_line_colors)
branch_lc.set_linewidth(branch_line_widths)
branch_lc.set_label(branch_line_labels)
branch_lc.set_linestyle("-")
ax.add_collection(branch_lc)
# Add circles for tips and internal nodes.
tip_circle_sizes = np.array(tip_circle_sizes)
ax.scatter(tip_circles_x, tip_circles_y, s=tip_circle_sizes, facecolor=tip_circles_color, edgecolor='none',zorder=11) ## plot circle for every tip
ax.scatter(tip_circles_x, tip_circles_y, s=tip_circle_sizes*2, facecolor='k', edgecolor='none', zorder=10) ## plot black circle underneath
ax.scatter(node_circles_x, node_circles_y, facecolor=node_circles_color, s=50, edgecolor='none', zorder=10, lw=2, marker='|') ## mark every node in the tree to highlight that it's a multitype tree
#ax.set_ylim(-10, y_span - 300)
ax.spines['top'].set_visible(False) ## no axes
ax.spines['right'].set_visible(False)
ax.spines['left'].set_visible(False)
ax.grid(axis='x',ls='-',color='grey')
ax.tick_params(axis='y',size=0)
ax.set_yticklabels([])
cb1 = mpl.colorbar.ColorbarBase(
colorbar_ax,
cmap=cmap,
norm=norm,
orientation='horizontal'
)
cb1.set_label(color_by_trait)
gs.tight_layout(fig)
plt.savefig(figure_name)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("tree", help="auspice tree JSON")
parser.add_argument("output", help="plotted tree figure")
parser.add_argument("--colorby", help="trait in tree to color by", default="num_date")
parser.add_argument("--branch_width", help="initial branch width", type=int, default=10)
parser.add_argument("--tip_size", help="tip size", type=int, default=10)
args = parser.parse_args()
with open(args.tree, "r") as json_fh:
json_tree = json.load(json_fh)
# Convert JSON tree layout to a Biopython Clade instance.
tree = json_to_tree(json_tree)
# Plot the tree.
plot_tree(
tree,
args.output,
args.colorby,
args.branch_width,
args.tip_size
)
|