1 | #! /usr/bin/env python |
---|
2 | from taxonomy_util import Taxonomy |
---|
3 | from erlang import erlang |
---|
4 | import math |
---|
5 | |
---|
6 | class TaxTreeHelper: |
---|
7 | def __init__(self, cfg, tax_map, tax_tree=None): |
---|
8 | self.origin_taxonomy = tax_map |
---|
9 | self.cfg = cfg |
---|
10 | self.outgroup = None |
---|
11 | self.mf_rooted_tree = None |
---|
12 | self.bf_rooted_tree = None |
---|
13 | self.tax_tree = tax_tree |
---|
14 | self.bid_taxonomy_map = None |
---|
15 | self.ranks_set = set() |
---|
16 | if tax_tree: |
---|
17 | self.init_taxnode_map() |
---|
18 | else: |
---|
19 | self.name2taxnode = {} |
---|
20 | |
---|
21 | def set_mf_rooted_tree(self, rt): |
---|
22 | self.mf_rooted_tree = rt |
---|
23 | self.save_outgroup() |
---|
24 | self.bf_rooted_tree = None |
---|
25 | self.tax_tree = None |
---|
26 | self.bid_taxonomy_map = None |
---|
27 | |
---|
28 | def set_bf_unrooted_tree(self, ut): |
---|
29 | self.restore_rooting(ut) |
---|
30 | |
---|
31 | def get_outgroup(self): |
---|
32 | return self.outgroup |
---|
33 | |
---|
34 | def set_outgroup(self, outgr): |
---|
35 | self.outgroup = outgr |
---|
36 | |
---|
37 | def get_tax_tree(self): |
---|
38 | if not self.tax_tree: |
---|
39 | self.label_bf_tree_with_ranks() |
---|
40 | return self.tax_tree |
---|
41 | |
---|
42 | def get_bid_taxonomy_map(self, rebuild=False): |
---|
43 | self.get_tax_tree() |
---|
44 | if not self.bid_taxonomy_map or rebuild: |
---|
45 | self.build_bid_taxonomy_map() |
---|
46 | return self.bid_taxonomy_map |
---|
47 | |
---|
48 | def init_taxnode_map(self): |
---|
49 | self.name2taxnode = {} |
---|
50 | for leaf in self.tax_tree.iter_leaves(): |
---|
51 | self.name2taxnode[leaf.name] = leaf |
---|
52 | |
---|
53 | def save_outgroup(self): |
---|
54 | rt = self.mf_rooted_tree |
---|
55 | |
---|
56 | # remove unifurcation at the root |
---|
57 | if len(rt.children) == 1: |
---|
58 | rt = rt.children[0] |
---|
59 | |
---|
60 | if len(rt.children) > 1: |
---|
61 | outgr = rt.children[0] |
---|
62 | outgr_size = len(outgr.get_leaves()) |
---|
63 | for child in rt.children: |
---|
64 | if child != outgr: |
---|
65 | child_size = len(child.get_leaves()) |
---|
66 | if child_size < outgr_size: |
---|
67 | outgr = child |
---|
68 | outgr_size = child_size |
---|
69 | else: |
---|
70 | raise AssertionError("Invalid tree: unifurcation at the root node!") |
---|
71 | |
---|
72 | self.outgroup = outgr |
---|
73 | |
---|
74 | def restore_rooting(self, utree): |
---|
75 | outgr_leaves = self.outgroup.get_leaf_names() |
---|
76 | # check if outgroup consists of a single node - ETE considers it to be root, not leaf |
---|
77 | if not outgr_leaves: |
---|
78 | outgr_root = utree&outgr.name |
---|
79 | elif len(outgr_leaves) == 1: |
---|
80 | outgr_root = utree&outgr_leaves[0] |
---|
81 | else: |
---|
82 | # Even unrooted tree is "implicitely" rooted in ETE representation. |
---|
83 | # If this pseudo-rooting happens to be within the outgroup, it cause problems |
---|
84 | # in the get_common_ancestor() step (since common_ancestor = "root") |
---|
85 | # Workaround: explicitely root the tree outside from outgroup subtree |
---|
86 | for node in utree.iter_leaves(): |
---|
87 | if not node.name in outgr_leaves: |
---|
88 | tmp_root = node.up |
---|
89 | if not utree == tmp_root: |
---|
90 | utree.set_outgroup(tmp_root) |
---|
91 | break |
---|
92 | |
---|
93 | outgr_root = utree.get_common_ancestor(outgr_leaves) |
---|
94 | |
---|
95 | # we could be so lucky that the RAxML tree is already correctly rooted :) |
---|
96 | if outgr_root != utree: |
---|
97 | utree.set_outgroup(outgr_root) |
---|
98 | |
---|
99 | self.bf_rooted_tree = utree |
---|
100 | |
---|
101 | def label_bf_tree_with_ranks(self): |
---|
102 | """labeling inner tree nodes with taxonomic ranks""" |
---|
103 | if not self.bf_rooted_tree: |
---|
104 | raise AssertionError("self.bf_rooted_tree is not set: TaxTreeHelper.set_bf_unrooted_tree() must be called before!") |
---|
105 | |
---|
106 | for node in self.bf_rooted_tree.traverse("postorder"): |
---|
107 | if node.is_leaf(): |
---|
108 | seq_ranks = self.origin_taxonomy[node.name] |
---|
109 | rank_level = Taxonomy.lowest_assigned_rank_level(seq_ranks) |
---|
110 | node.add_feature("rank_level", rank_level) |
---|
111 | node.add_feature("ranks", seq_ranks) |
---|
112 | node.name += "__" + seq_ranks[rank_level] |
---|
113 | else: |
---|
114 | if len(node.children) != 2: |
---|
115 | raise AssertionError("FATAL ERROR: tree is not bifurcating!") |
---|
116 | lchild = node.children[0] |
---|
117 | rchild = node.children[1] |
---|
118 | rank_level = min(lchild.rank_level, rchild.rank_level) |
---|
119 | while rank_level >= 0 and lchild.ranks[rank_level] != rchild.ranks[rank_level]: |
---|
120 | rank_level -= 1 |
---|
121 | node.add_feature("rank_level", rank_level) |
---|
122 | node_ranks = [Taxonomy.EMPTY_RANK] * max(len(lchild.ranks),len(rchild.ranks)) |
---|
123 | if rank_level >= 0: |
---|
124 | node_ranks[0:rank_level+1] = lchild.ranks[0:rank_level+1] |
---|
125 | node.name = lchild.ranks[rank_level] |
---|
126 | else: |
---|
127 | node.name = "Undefined" |
---|
128 | if hasattr(node, "B"): |
---|
129 | self.cfg.log.debug("INFO: empty taxonomic annotation for branch %s (child nodes have no common ranks)", node.B) |
---|
130 | |
---|
131 | node.add_feature("ranks", node_ranks) |
---|
132 | |
---|
133 | self.tax_tree = self.bf_rooted_tree |
---|
134 | self.init_taxnode_map() |
---|
135 | |
---|
136 | def build_bid_taxonomy_map(self): |
---|
137 | self.bid_taxonomy_map = {} |
---|
138 | self.ranks_set = set([]) |
---|
139 | for node in self.tax_tree.traverse("postorder"): |
---|
140 | if not node.is_root() and hasattr(node, "B"): |
---|
141 | parent = node.up |
---|
142 | branch_rdiff = Taxonomy.lowest_assigned_rank_level(node.ranks) - Taxonomy.lowest_assigned_rank_level(parent.ranks) |
---|
143 | branch_rank_id = Taxonomy.get_rank_uid(node.ranks) |
---|
144 | branch_len = node.dist |
---|
145 | self.bid_taxonomy_map[node.B] = (branch_rank_id, branch_rdiff, branch_len) |
---|
146 | self.ranks_set.add(branch_rank_id) |
---|
147 | # if self.cfg.debug: |
---|
148 | # print node.ranks, parent.ranks, branch_diff |
---|
149 | |
---|
150 | def get_seq_ranks_from_tree(self, seq_name): |
---|
151 | if seq_name not in self.name2taxnode: |
---|
152 | errmsg = "FATAL ERROR: Sequence %s is not found in the taxonomic tree!" % seq_name |
---|
153 | self.cfg.exit_fatal_error(errmsg) |
---|
154 | |
---|
155 | seq_node = self.name2taxnode[seq_name] |
---|
156 | ranks = Taxonomy.split_rank_uid(seq_node.up.name) |
---|
157 | return ranks |
---|
158 | |
---|
159 | def strip_missing_ranks(self, ranks): |
---|
160 | rank_level = len(ranks) |
---|
161 | while not Taxonomy.get_rank_uid(ranks[0:rank_level]) in self.ranks_set and rank_level > 0: |
---|
162 | rank_level -= 1 |
---|
163 | |
---|
164 | return ranks[0:rank_level] |
---|
165 | |
---|
166 | class TaxClassifyHelper: |
---|
167 | def __init__(self, cfg, bid_taxonomy_map, sp_rate = 0., node_height = []): |
---|
168 | self.cfg = cfg |
---|
169 | self.bid_taxonomy_map = bid_taxonomy_map |
---|
170 | self.sp_rate = sp_rate |
---|
171 | self.node_height = node_height |
---|
172 | self.erlang = erlang() |
---|
173 | # hardcoded for now |
---|
174 | self.parent_lhw_coeff = 0.49 |
---|
175 | |
---|
176 | def classify_seq(self, edges, minlw = None): |
---|
177 | if not minlw: |
---|
178 | minlw = self.cfg.min_lhw |
---|
179 | |
---|
180 | edges = self.erlang_filter(edges) |
---|
181 | if len(edges) > 0: |
---|
182 | if self.cfg.taxassign_method == "1": |
---|
183 | ranks, lws = self.assign_taxonomy_maxsum(edges, minlw) |
---|
184 | else: |
---|
185 | ranks, lws = self.assign_taxonomy_maxlh(edges) |
---|
186 | return ranks, lws |
---|
187 | else: |
---|
188 | return [], [] |
---|
189 | |
---|
190 | def erlang_filter(self, edges): |
---|
191 | if self.cfg.brlen_pv == 0.: |
---|
192 | return edges |
---|
193 | |
---|
194 | newedges = [] |
---|
195 | for edge in edges: |
---|
196 | edge_nr = str(edge[0]) |
---|
197 | pendant_length = edge[4] |
---|
198 | pv = self.erlang.one_tail_test(rate = self.sp_rate, k = int(self.node_height[edge_nr]), x = pendant_length) |
---|
199 | if pv >= self.cfg.brlen_pv: |
---|
200 | newedges.append(edge) |
---|
201 | # else: |
---|
202 | # self.cfg.log.debug("Edge ignored: [%s, %f], p = %.12f", edge_nr, pendant_length, pv) |
---|
203 | |
---|
204 | if len(newedges) == 0: |
---|
205 | return newedges |
---|
206 | |
---|
207 | # adjust likelihood weights -> is there a better way ??? |
---|
208 | sum_lh = 0 |
---|
209 | max_lh = float(newedges[0][1]) |
---|
210 | for edge in newedges: |
---|
211 | lh = float(edge[1]) |
---|
212 | sum_lh += math.exp(lh - max_lh) |
---|
213 | |
---|
214 | for edge in newedges: |
---|
215 | lh = float(edge[1]) |
---|
216 | edge[2] = math.exp(lh - max_lh) / sum_lh |
---|
217 | |
---|
218 | return newedges |
---|
219 | |
---|
220 | # "all or none" filter: return empty set iff *all* brlens are below the threshold |
---|
221 | def erlang_filter2(self, edges): |
---|
222 | if self.cfg.brlen_pv == 0.: |
---|
223 | return edges |
---|
224 | |
---|
225 | for edge in edges: |
---|
226 | edge_nr = str(edge[0]) |
---|
227 | pendant_length = edge[4] |
---|
228 | pv = self.erlang.one_tail_test(rate = self.sp_rate, k = int(self.node_height[edge_nr]), x = pendant_length) |
---|
229 | if pv >= self.cfg.brlen_pv: |
---|
230 | return edges |
---|
231 | |
---|
232 | return [] |
---|
233 | |
---|
234 | def get_branch_ranks(self, br_id): |
---|
235 | br_rec = self.bid_taxonomy_map[br_id] |
---|
236 | br_rank_id = br_rec[0] |
---|
237 | ranks = Taxonomy.split_rank_uid(br_rank_id) |
---|
238 | return ranks |
---|
239 | |
---|
240 | def assign_taxonomy_maxlh(self, edges): |
---|
241 | #Calculate the sum of likelihood weight for each rank |
---|
242 | taxonmy_sumlw_map = {} |
---|
243 | for edge in edges: |
---|
244 | edge_nr = str(edge[0]) |
---|
245 | lw = edge[2] |
---|
246 | taxranks = self.get_branch_ranks(edge_nr) |
---|
247 | for rank in taxranks: |
---|
248 | if rank == "-": |
---|
249 | taxonmy_sumlw_map[rank] = -1 |
---|
250 | elif rank in taxonmy_sumlw_map: |
---|
251 | oldlw = taxonmy_sumlw_map[rank] |
---|
252 | taxonmy_sumlw_map[rank] = oldlw + lw |
---|
253 | else: |
---|
254 | taxonmy_sumlw_map[rank] = lw |
---|
255 | |
---|
256 | #Assignment using the max likelihood placement |
---|
257 | ml_edge = edges[0] |
---|
258 | edge_nr = str(ml_edge[0]) |
---|
259 | maxlw = ml_edge[2] |
---|
260 | ml_ranks = self.get_branch_ranks(edge_nr) |
---|
261 | ml_ranks_copy = [] |
---|
262 | for rk in ml_ranks: |
---|
263 | ml_ranks_copy.append(rk) |
---|
264 | lws = [] |
---|
265 | cnt = 0 |
---|
266 | for rank in ml_ranks: |
---|
267 | lw = taxonmy_sumlw_map[rank] |
---|
268 | if lw > 1.0: |
---|
269 | lw = 1.0 |
---|
270 | lws.append(lw) |
---|
271 | if rank == "-" and cnt > 0 : |
---|
272 | for edge in edges[1:]: |
---|
273 | edge_nr = str(edge[0]) |
---|
274 | taxonomy = self.get_branch_ranks(edge_nr) |
---|
275 | newrank = taxonomy[cnt] |
---|
276 | newlw = taxonmy_sumlw_map[newrank] |
---|
277 | higherrank_old = ml_ranks[cnt -1] |
---|
278 | higherrank_new = taxonomy[cnt -1] |
---|
279 | if higherrank_old == higherrank_new and newrank!="-": |
---|
280 | ml_ranks_copy[cnt] = newrank |
---|
281 | lws[cnt] = newlw |
---|
282 | cnt = cnt + 1 |
---|
283 | |
---|
284 | return ml_ranks_copy, lws |
---|
285 | |
---|
286 | def assign_taxonomy_maxsum(self, edges, minlw): |
---|
287 | """this function sums up all LH-weights for each rank and takes the rank with the max. sum """ |
---|
288 | # in EPA result, each placement(=branch) has a "weight" |
---|
289 | # since we are interested in taxonomic placement, we do not care about branch vs. branch comparisons, |
---|
290 | # but only consider rank vs. rank (e. g. G1 S1 vs. G1 S2 vs. G1) |
---|
291 | # Thus we accumulate weights for each rank, there are to measures: |
---|
292 | # "own" weight = sum of weight of all placements EXACTLY to this rank (e.g. for G1: G1 only) |
---|
293 | # "total" rank = own rank + own rank of all children (for G1: G1 or G1 S1 or G1 S2) |
---|
294 | rw_own = {} |
---|
295 | rw_total = {} |
---|
296 | |
---|
297 | ranks = [Taxonomy.EMPTY_RANK] |
---|
298 | |
---|
299 | for edge in edges: |
---|
300 | br_id = str(edge[0]) |
---|
301 | lweight = edge[2] |
---|
302 | lowest_rank = None |
---|
303 | lowest_rank_lvl = None |
---|
304 | |
---|
305 | if lweight == 0.: |
---|
306 | continue |
---|
307 | |
---|
308 | # accumulate weight for the current sequence |
---|
309 | br_rank_id, rdiff, brlen = self.bid_taxonomy_map[br_id] |
---|
310 | ranks = Taxonomy.split_rank_uid(br_rank_id) |
---|
311 | for i in range(len(ranks)): |
---|
312 | rank = ranks[i] |
---|
313 | rank_id = Taxonomy.get_rank_uid(ranks, i) |
---|
314 | if rank != Taxonomy.EMPTY_RANK: |
---|
315 | rw_total[rank_id] = rw_total.get(rank_id, 0) + lweight |
---|
316 | lowest_rank_lvl = i |
---|
317 | lowest_rank = rank_id |
---|
318 | else: |
---|
319 | break |
---|
320 | |
---|
321 | if lowest_rank: |
---|
322 | if rdiff > 0: |
---|
323 | # if ranks of 'upper' and 'lower' adjacent nodes of a branch are non-equal, split LHW among them |
---|
324 | parent_rank = Taxonomy.get_rank_uid(ranks, lowest_rank_lvl - rdiff) |
---|
325 | rw_own[lowest_rank] = rw_own.get(lowest_rank, 0) + lweight * (1 - self.parent_lhw_coeff) |
---|
326 | rw_own[parent_rank] = rw_own.get(parent_rank, 0) + lweight * self.parent_lhw_coeff |
---|
327 | # correct total lhw for all levels between "parent" and "lowest" |
---|
328 | # NOTE: all those intermediate ranks are in fact indistinguishable, e.g. a family which contains a single genus |
---|
329 | for r in range(rdiff): |
---|
330 | interim_rank = Taxonomy.get_rank_uid(ranks, lowest_rank_lvl - r) |
---|
331 | rw_total[interim_rank] = rw_total.get(interim_rank, 0) - lweight * self.parent_lhw_coeff |
---|
332 | else: |
---|
333 | rw_own[lowest_rank] = rw_own.get(lowest_rank, 0) + lweight |
---|
334 | # else: |
---|
335 | # self.cfg.log.debug("WARNING: no annotation for branch %s", br_id) |
---|
336 | |
---|
337 | # if all branches have empty ranks only, just return this placement |
---|
338 | if len(rw_total) == 0: |
---|
339 | return ranks, [1.] * len(ranks) |
---|
340 | |
---|
341 | # we assign the sequence to a rank, which has the max "own" weight AND |
---|
342 | # whose "total" weight is greater than a confidence threshold |
---|
343 | max_rw = 0. |
---|
344 | ass_rank_id = None |
---|
345 | for r in rw_own.iterkeys(): |
---|
346 | if rw_own[r] > max_rw and rw_total[r] >= minlw: |
---|
347 | ass_rank_id = r |
---|
348 | max_rw = rw_own[r] |
---|
349 | if not ass_rank_id: |
---|
350 | ass_rank_id = max(rw_total.iterkeys(), key=(lambda key: rw_total[key])) |
---|
351 | |
---|
352 | a_ranks = Taxonomy.split_rank_uid(ass_rank_id) |
---|
353 | |
---|
354 | # "total" weight is considered as confidence value for now |
---|
355 | a_conf = [0.] * len(a_ranks) |
---|
356 | for i in range(len(a_conf)): |
---|
357 | rank = a_ranks[i] |
---|
358 | if rank != Taxonomy.EMPTY_RANK: |
---|
359 | rank_id = Taxonomy.get_rank_uid(a_ranks, i) |
---|
360 | a_conf[i] = rw_total[rank_id] |
---|
361 | |
---|
362 | return a_ranks, a_conf |
---|
363 | |
---|