| 1 | #! /usr/bin/env python |
|---|
| 2 | from taxonomy_util import Taxonomy |
|---|
| 3 | from erlang import erlang |
|---|
| 4 | import math |
|---|
| 5 | |
|---|
| 6 | class TaxTreeHelper: |
|---|
| 7 | def __init__(self, cfg, tax_map, tax_tree=None): |
|---|
| 8 | self.origin_taxonomy = tax_map |
|---|
| 9 | self.cfg = cfg |
|---|
| 10 | self.outgroup = None |
|---|
| 11 | self.mf_rooted_tree = None |
|---|
| 12 | self.bf_rooted_tree = None |
|---|
| 13 | self.tax_tree = tax_tree |
|---|
| 14 | self.bid_taxonomy_map = None |
|---|
| 15 | self.ranks_set = set() |
|---|
| 16 | if tax_tree: |
|---|
| 17 | self.init_taxnode_map() |
|---|
| 18 | else: |
|---|
| 19 | self.name2taxnode = {} |
|---|
| 20 | |
|---|
| 21 | def set_mf_rooted_tree(self, rt): |
|---|
| 22 | self.mf_rooted_tree = rt |
|---|
| 23 | self.save_outgroup() |
|---|
| 24 | self.bf_rooted_tree = None |
|---|
| 25 | self.tax_tree = None |
|---|
| 26 | self.bid_taxonomy_map = None |
|---|
| 27 | |
|---|
| 28 | def set_bf_unrooted_tree(self, ut): |
|---|
| 29 | self.restore_rooting(ut) |
|---|
| 30 | |
|---|
| 31 | def get_outgroup(self): |
|---|
| 32 | return self.outgroup |
|---|
| 33 | |
|---|
| 34 | def set_outgroup(self, outgr): |
|---|
| 35 | self.outgroup = outgr |
|---|
| 36 | |
|---|
| 37 | def get_tax_tree(self): |
|---|
| 38 | if not self.tax_tree: |
|---|
| 39 | self.label_bf_tree_with_ranks() |
|---|
| 40 | return self.tax_tree |
|---|
| 41 | |
|---|
| 42 | def get_bid_taxonomy_map(self, rebuild=False): |
|---|
| 43 | self.get_tax_tree() |
|---|
| 44 | if not self.bid_taxonomy_map or rebuild: |
|---|
| 45 | self.build_bid_taxonomy_map() |
|---|
| 46 | return self.bid_taxonomy_map |
|---|
| 47 | |
|---|
| 48 | def init_taxnode_map(self): |
|---|
| 49 | self.name2taxnode = {} |
|---|
| 50 | for leaf in self.tax_tree.iter_leaves(): |
|---|
| 51 | self.name2taxnode[leaf.name] = leaf |
|---|
| 52 | |
|---|
| 53 | def save_outgroup(self): |
|---|
| 54 | rt = self.mf_rooted_tree |
|---|
| 55 | |
|---|
| 56 | # remove unifurcation at the root |
|---|
| 57 | if len(rt.children) == 1: |
|---|
| 58 | rt = rt.children[0] |
|---|
| 59 | |
|---|
| 60 | if len(rt.children) > 1: |
|---|
| 61 | outgr = rt.children[0] |
|---|
| 62 | outgr_size = len(outgr.get_leaves()) |
|---|
| 63 | for child in rt.children: |
|---|
| 64 | if child != outgr: |
|---|
| 65 | child_size = len(child.get_leaves()) |
|---|
| 66 | if child_size < outgr_size: |
|---|
| 67 | outgr = child |
|---|
| 68 | outgr_size = child_size |
|---|
| 69 | else: |
|---|
| 70 | raise AssertionError("Invalid tree: unifurcation at the root node!") |
|---|
| 71 | |
|---|
| 72 | self.outgroup = outgr |
|---|
| 73 | |
|---|
| 74 | def restore_rooting(self, utree): |
|---|
| 75 | outgr_leaves = self.outgroup.get_leaf_names() |
|---|
| 76 | # check if outgroup consists of a single node - ETE considers it to be root, not leaf |
|---|
| 77 | if not outgr_leaves: |
|---|
| 78 | outgr_root = utree&outgr.name |
|---|
| 79 | elif len(outgr_leaves) == 1: |
|---|
| 80 | outgr_root = utree&outgr_leaves[0] |
|---|
| 81 | else: |
|---|
| 82 | # Even unrooted tree is "implicitely" rooted in ETE representation. |
|---|
| 83 | # If this pseudo-rooting happens to be within the outgroup, it cause problems |
|---|
| 84 | # in the get_common_ancestor() step (since common_ancestor = "root") |
|---|
| 85 | # Workaround: explicitely root the tree outside from outgroup subtree |
|---|
| 86 | for node in utree.iter_leaves(): |
|---|
| 87 | if not node.name in outgr_leaves: |
|---|
| 88 | tmp_root = node.up |
|---|
| 89 | if not utree == tmp_root: |
|---|
| 90 | utree.set_outgroup(tmp_root) |
|---|
| 91 | break |
|---|
| 92 | |
|---|
| 93 | outgr_root = utree.get_common_ancestor(outgr_leaves) |
|---|
| 94 | |
|---|
| 95 | # we could be so lucky that the RAxML tree is already correctly rooted :) |
|---|
| 96 | if outgr_root != utree: |
|---|
| 97 | utree.set_outgroup(outgr_root) |
|---|
| 98 | |
|---|
| 99 | self.bf_rooted_tree = utree |
|---|
| 100 | |
|---|
| 101 | def label_bf_tree_with_ranks(self): |
|---|
| 102 | """labeling inner tree nodes with taxonomic ranks""" |
|---|
| 103 | if not self.bf_rooted_tree: |
|---|
| 104 | raise AssertionError("self.bf_rooted_tree is not set: TaxTreeHelper.set_bf_unrooted_tree() must be called before!") |
|---|
| 105 | |
|---|
| 106 | for node in self.bf_rooted_tree.traverse("postorder"): |
|---|
| 107 | if node.is_leaf(): |
|---|
| 108 | seq_ranks = self.origin_taxonomy[node.name] |
|---|
| 109 | rank_level = Taxonomy.lowest_assigned_rank_level(seq_ranks) |
|---|
| 110 | node.add_feature("rank_level", rank_level) |
|---|
| 111 | node.add_feature("ranks", seq_ranks) |
|---|
| 112 | node.name += "__" + seq_ranks[rank_level] |
|---|
| 113 | else: |
|---|
| 114 | if len(node.children) != 2: |
|---|
| 115 | raise AssertionError("FATAL ERROR: tree is not bifurcating!") |
|---|
| 116 | lchild = node.children[0] |
|---|
| 117 | rchild = node.children[1] |
|---|
| 118 | rank_level = min(lchild.rank_level, rchild.rank_level) |
|---|
| 119 | while rank_level >= 0 and lchild.ranks[rank_level] != rchild.ranks[rank_level]: |
|---|
| 120 | rank_level -= 1 |
|---|
| 121 | node.add_feature("rank_level", rank_level) |
|---|
| 122 | node_ranks = [Taxonomy.EMPTY_RANK] * max(len(lchild.ranks),len(rchild.ranks)) |
|---|
| 123 | if rank_level >= 0: |
|---|
| 124 | node_ranks[0:rank_level+1] = lchild.ranks[0:rank_level+1] |
|---|
| 125 | node.name = lchild.ranks[rank_level] |
|---|
| 126 | else: |
|---|
| 127 | node.name = "Undefined" |
|---|
| 128 | if hasattr(node, "B"): |
|---|
| 129 | self.cfg.log.debug("INFO: empty taxonomic annotation for branch %s (child nodes have no common ranks)", node.B) |
|---|
| 130 | |
|---|
| 131 | node.add_feature("ranks", node_ranks) |
|---|
| 132 | |
|---|
| 133 | self.tax_tree = self.bf_rooted_tree |
|---|
| 134 | self.init_taxnode_map() |
|---|
| 135 | |
|---|
| 136 | def build_bid_taxonomy_map(self): |
|---|
| 137 | self.bid_taxonomy_map = {} |
|---|
| 138 | self.ranks_set = set([]) |
|---|
| 139 | for node in self.tax_tree.traverse("postorder"): |
|---|
| 140 | if not node.is_root() and hasattr(node, "B"): |
|---|
| 141 | parent = node.up |
|---|
| 142 | branch_rdiff = Taxonomy.lowest_assigned_rank_level(node.ranks) - Taxonomy.lowest_assigned_rank_level(parent.ranks) |
|---|
| 143 | branch_rank_id = Taxonomy.get_rank_uid(node.ranks) |
|---|
| 144 | branch_len = node.dist |
|---|
| 145 | self.bid_taxonomy_map[node.B] = (branch_rank_id, branch_rdiff, branch_len) |
|---|
| 146 | self.ranks_set.add(branch_rank_id) |
|---|
| 147 | # if self.cfg.debug: |
|---|
| 148 | # print node.ranks, parent.ranks, branch_diff |
|---|
| 149 | |
|---|
| 150 | def get_seq_ranks_from_tree(self, seq_name): |
|---|
| 151 | if seq_name not in self.name2taxnode: |
|---|
| 152 | errmsg = "FATAL ERROR: Sequence %s is not found in the taxonomic tree!" % seq_name |
|---|
| 153 | self.cfg.exit_fatal_error(errmsg) |
|---|
| 154 | |
|---|
| 155 | seq_node = self.name2taxnode[seq_name] |
|---|
| 156 | ranks = Taxonomy.split_rank_uid(seq_node.up.name) |
|---|
| 157 | return ranks |
|---|
| 158 | |
|---|
| 159 | def strip_missing_ranks(self, ranks): |
|---|
| 160 | rank_level = len(ranks) |
|---|
| 161 | while not Taxonomy.get_rank_uid(ranks[0:rank_level]) in self.ranks_set and rank_level > 0: |
|---|
| 162 | rank_level -= 1 |
|---|
| 163 | |
|---|
| 164 | return ranks[0:rank_level] |
|---|
| 165 | |
|---|
| 166 | class TaxClassifyHelper: |
|---|
| 167 | def __init__(self, cfg, bid_taxonomy_map, sp_rate = 0., node_height = []): |
|---|
| 168 | self.cfg = cfg |
|---|
| 169 | self.bid_taxonomy_map = bid_taxonomy_map |
|---|
| 170 | self.sp_rate = sp_rate |
|---|
| 171 | self.node_height = node_height |
|---|
| 172 | self.erlang = erlang() |
|---|
| 173 | # hardcoded for now |
|---|
| 174 | self.parent_lhw_coeff = 0.49 |
|---|
| 175 | |
|---|
| 176 | def classify_seq(self, edges, minlw = None): |
|---|
| 177 | if not minlw: |
|---|
| 178 | minlw = self.cfg.min_lhw |
|---|
| 179 | |
|---|
| 180 | edges = self.erlang_filter(edges) |
|---|
| 181 | if len(edges) > 0: |
|---|
| 182 | if self.cfg.taxassign_method == "1": |
|---|
| 183 | ranks, lws = self.assign_taxonomy_maxsum(edges, minlw) |
|---|
| 184 | else: |
|---|
| 185 | ranks, lws = self.assign_taxonomy_maxlh(edges) |
|---|
| 186 | return ranks, lws |
|---|
| 187 | else: |
|---|
| 188 | return [], [] |
|---|
| 189 | |
|---|
| 190 | def erlang_filter(self, edges): |
|---|
| 191 | if self.cfg.brlen_pv == 0.: |
|---|
| 192 | return edges |
|---|
| 193 | |
|---|
| 194 | newedges = [] |
|---|
| 195 | for edge in edges: |
|---|
| 196 | edge_nr = str(edge[0]) |
|---|
| 197 | pendant_length = edge[4] |
|---|
| 198 | pv = self.erlang.one_tail_test(rate = self.sp_rate, k = int(self.node_height[edge_nr]), x = pendant_length) |
|---|
| 199 | if pv >= self.cfg.brlen_pv: |
|---|
| 200 | newedges.append(edge) |
|---|
| 201 | # else: |
|---|
| 202 | # self.cfg.log.debug("Edge ignored: [%s, %f], p = %.12f", edge_nr, pendant_length, pv) |
|---|
| 203 | |
|---|
| 204 | if len(newedges) == 0: |
|---|
| 205 | return newedges |
|---|
| 206 | |
|---|
| 207 | # adjust likelihood weights -> is there a better way ??? |
|---|
| 208 | sum_lh = 0 |
|---|
| 209 | max_lh = float(newedges[0][1]) |
|---|
| 210 | for edge in newedges: |
|---|
| 211 | lh = float(edge[1]) |
|---|
| 212 | sum_lh += math.exp(lh - max_lh) |
|---|
| 213 | |
|---|
| 214 | for edge in newedges: |
|---|
| 215 | lh = float(edge[1]) |
|---|
| 216 | edge[2] = math.exp(lh - max_lh) / sum_lh |
|---|
| 217 | |
|---|
| 218 | return newedges |
|---|
| 219 | |
|---|
| 220 | # "all or none" filter: return empty set iff *all* brlens are below the threshold |
|---|
| 221 | def erlang_filter2(self, edges): |
|---|
| 222 | if self.cfg.brlen_pv == 0.: |
|---|
| 223 | return edges |
|---|
| 224 | |
|---|
| 225 | for edge in edges: |
|---|
| 226 | edge_nr = str(edge[0]) |
|---|
| 227 | pendant_length = edge[4] |
|---|
| 228 | pv = self.erlang.one_tail_test(rate = self.sp_rate, k = int(self.node_height[edge_nr]), x = pendant_length) |
|---|
| 229 | if pv >= self.cfg.brlen_pv: |
|---|
| 230 | return edges |
|---|
| 231 | |
|---|
| 232 | return [] |
|---|
| 233 | |
|---|
| 234 | def get_branch_ranks(self, br_id): |
|---|
| 235 | br_rec = self.bid_taxonomy_map[br_id] |
|---|
| 236 | br_rank_id = br_rec[0] |
|---|
| 237 | ranks = Taxonomy.split_rank_uid(br_rank_id) |
|---|
| 238 | return ranks |
|---|
| 239 | |
|---|
| 240 | def assign_taxonomy_maxlh(self, edges): |
|---|
| 241 | #Calculate the sum of likelihood weight for each rank |
|---|
| 242 | taxonmy_sumlw_map = {} |
|---|
| 243 | for edge in edges: |
|---|
| 244 | edge_nr = str(edge[0]) |
|---|
| 245 | lw = edge[2] |
|---|
| 246 | taxranks = self.get_branch_ranks(edge_nr) |
|---|
| 247 | for rank in taxranks: |
|---|
| 248 | if rank == "-": |
|---|
| 249 | taxonmy_sumlw_map[rank] = -1 |
|---|
| 250 | elif rank in taxonmy_sumlw_map: |
|---|
| 251 | oldlw = taxonmy_sumlw_map[rank] |
|---|
| 252 | taxonmy_sumlw_map[rank] = oldlw + lw |
|---|
| 253 | else: |
|---|
| 254 | taxonmy_sumlw_map[rank] = lw |
|---|
| 255 | |
|---|
| 256 | #Assignment using the max likelihood placement |
|---|
| 257 | ml_edge = edges[0] |
|---|
| 258 | edge_nr = str(ml_edge[0]) |
|---|
| 259 | maxlw = ml_edge[2] |
|---|
| 260 | ml_ranks = self.get_branch_ranks(edge_nr) |
|---|
| 261 | ml_ranks_copy = [] |
|---|
| 262 | for rk in ml_ranks: |
|---|
| 263 | ml_ranks_copy.append(rk) |
|---|
| 264 | lws = [] |
|---|
| 265 | cnt = 0 |
|---|
| 266 | for rank in ml_ranks: |
|---|
| 267 | lw = taxonmy_sumlw_map[rank] |
|---|
| 268 | if lw > 1.0: |
|---|
| 269 | lw = 1.0 |
|---|
| 270 | lws.append(lw) |
|---|
| 271 | if rank == "-" and cnt > 0 : |
|---|
| 272 | for edge in edges[1:]: |
|---|
| 273 | edge_nr = str(edge[0]) |
|---|
| 274 | taxonomy = self.get_branch_ranks(edge_nr) |
|---|
| 275 | newrank = taxonomy[cnt] |
|---|
| 276 | newlw = taxonmy_sumlw_map[newrank] |
|---|
| 277 | higherrank_old = ml_ranks[cnt -1] |
|---|
| 278 | higherrank_new = taxonomy[cnt -1] |
|---|
| 279 | if higherrank_old == higherrank_new and newrank!="-": |
|---|
| 280 | ml_ranks_copy[cnt] = newrank |
|---|
| 281 | lws[cnt] = newlw |
|---|
| 282 | cnt = cnt + 1 |
|---|
| 283 | |
|---|
| 284 | return ml_ranks_copy, lws |
|---|
| 285 | |
|---|
| 286 | def assign_taxonomy_maxsum(self, edges, minlw): |
|---|
| 287 | """this function sums up all LH-weights for each rank and takes the rank with the max. sum """ |
|---|
| 288 | # in EPA result, each placement(=branch) has a "weight" |
|---|
| 289 | # since we are interested in taxonomic placement, we do not care about branch vs. branch comparisons, |
|---|
| 290 | # but only consider rank vs. rank (e. g. G1 S1 vs. G1 S2 vs. G1) |
|---|
| 291 | # Thus we accumulate weights for each rank, there are to measures: |
|---|
| 292 | # "own" weight = sum of weight of all placements EXACTLY to this rank (e.g. for G1: G1 only) |
|---|
| 293 | # "total" rank = own rank + own rank of all children (for G1: G1 or G1 S1 or G1 S2) |
|---|
| 294 | rw_own = {} |
|---|
| 295 | rw_total = {} |
|---|
| 296 | |
|---|
| 297 | ranks = [Taxonomy.EMPTY_RANK] |
|---|
| 298 | |
|---|
| 299 | for edge in edges: |
|---|
| 300 | br_id = str(edge[0]) |
|---|
| 301 | lweight = edge[2] |
|---|
| 302 | lowest_rank = None |
|---|
| 303 | lowest_rank_lvl = None |
|---|
| 304 | |
|---|
| 305 | if lweight == 0.: |
|---|
| 306 | continue |
|---|
| 307 | |
|---|
| 308 | # accumulate weight for the current sequence |
|---|
| 309 | br_rank_id, rdiff, brlen = self.bid_taxonomy_map[br_id] |
|---|
| 310 | ranks = Taxonomy.split_rank_uid(br_rank_id) |
|---|
| 311 | for i in range(len(ranks)): |
|---|
| 312 | rank = ranks[i] |
|---|
| 313 | rank_id = Taxonomy.get_rank_uid(ranks, i) |
|---|
| 314 | if rank != Taxonomy.EMPTY_RANK: |
|---|
| 315 | rw_total[rank_id] = rw_total.get(rank_id, 0) + lweight |
|---|
| 316 | lowest_rank_lvl = i |
|---|
| 317 | lowest_rank = rank_id |
|---|
| 318 | else: |
|---|
| 319 | break |
|---|
| 320 | |
|---|
| 321 | if lowest_rank: |
|---|
| 322 | if rdiff > 0: |
|---|
| 323 | # if ranks of 'upper' and 'lower' adjacent nodes of a branch are non-equal, split LHW among them |
|---|
| 324 | parent_rank = Taxonomy.get_rank_uid(ranks, lowest_rank_lvl - rdiff) |
|---|
| 325 | rw_own[lowest_rank] = rw_own.get(lowest_rank, 0) + lweight * (1 - self.parent_lhw_coeff) |
|---|
| 326 | rw_own[parent_rank] = rw_own.get(parent_rank, 0) + lweight * self.parent_lhw_coeff |
|---|
| 327 | # correct total lhw for all levels between "parent" and "lowest" |
|---|
| 328 | # NOTE: all those intermediate ranks are in fact indistinguishable, e.g. a family which contains a single genus |
|---|
| 329 | for r in range(rdiff): |
|---|
| 330 | interim_rank = Taxonomy.get_rank_uid(ranks, lowest_rank_lvl - r) |
|---|
| 331 | rw_total[interim_rank] = rw_total.get(interim_rank, 0) - lweight * self.parent_lhw_coeff |
|---|
| 332 | else: |
|---|
| 333 | rw_own[lowest_rank] = rw_own.get(lowest_rank, 0) + lweight |
|---|
| 334 | # else: |
|---|
| 335 | # self.cfg.log.debug("WARNING: no annotation for branch %s", br_id) |
|---|
| 336 | |
|---|
| 337 | # if all branches have empty ranks only, just return this placement |
|---|
| 338 | if len(rw_total) == 0: |
|---|
| 339 | return ranks, [1.] * len(ranks) |
|---|
| 340 | |
|---|
| 341 | # we assign the sequence to a rank, which has the max "own" weight AND |
|---|
| 342 | # whose "total" weight is greater than a confidence threshold |
|---|
| 343 | max_rw = 0. |
|---|
| 344 | ass_rank_id = None |
|---|
| 345 | for r in rw_own.iterkeys(): |
|---|
| 346 | if rw_own[r] > max_rw and rw_total[r] >= minlw: |
|---|
| 347 | ass_rank_id = r |
|---|
| 348 | max_rw = rw_own[r] |
|---|
| 349 | if not ass_rank_id: |
|---|
| 350 | ass_rank_id = max(rw_total.iterkeys(), key=(lambda key: rw_total[key])) |
|---|
| 351 | |
|---|
| 352 | a_ranks = Taxonomy.split_rank_uid(ass_rank_id) |
|---|
| 353 | |
|---|
| 354 | # "total" weight is considered as confidence value for now |
|---|
| 355 | a_conf = [0.] * len(a_ranks) |
|---|
| 356 | for i in range(len(a_conf)): |
|---|
| 357 | rank = a_ranks[i] |
|---|
| 358 | if rank != Taxonomy.EMPTY_RANK: |
|---|
| 359 | rank_id = Taxonomy.get_rank_uid(a_ranks, i) |
|---|
| 360 | a_conf[i] = rw_total[rank_id] |
|---|
| 361 | |
|---|
| 362 | return a_ranks, a_conf |
|---|
| 363 | |
|---|