source: branches/ali/GDE/SATIVA/sativa/epac/classify_util.py

Last change on this file was 14565, checked in by akozlov, 9 years ago
File size: 14.3 KB
Line 
1#! /usr/bin/env python
2from taxonomy_util import Taxonomy
3from erlang import erlang
4import math
5
6class TaxTreeHelper:
7    def __init__(self, cfg, tax_map, tax_tree=None):
8        self.origin_taxonomy = tax_map
9        self.cfg = cfg
10        self.outgroup = None
11        self.mf_rooted_tree = None
12        self.bf_rooted_tree = None
13        self.tax_tree = tax_tree
14        self.bid_taxonomy_map = None
15        self.ranks_set = set()
16        if tax_tree:
17            self.init_taxnode_map()
18        else:
19            self.name2taxnode = {}
20   
21    def set_mf_rooted_tree(self, rt):
22        self.mf_rooted_tree = rt
23        self.save_outgroup()
24        self.bf_rooted_tree = None
25        self.tax_tree = None
26        self.bid_taxonomy_map = None
27   
28    def set_bf_unrooted_tree(self, ut):
29        self.restore_rooting(ut)
30       
31    def get_outgroup(self):
32        return self.outgroup
33
34    def set_outgroup(self, outgr):
35        self.outgroup = outgr
36
37    def get_tax_tree(self):
38        if not self.tax_tree:
39            self.label_bf_tree_with_ranks()
40        return self.tax_tree
41   
42    def get_bid_taxonomy_map(self, rebuild=False):
43        self.get_tax_tree()
44        if not self.bid_taxonomy_map or rebuild:
45            self.build_bid_taxonomy_map()
46        return self.bid_taxonomy_map
47
48    def init_taxnode_map(self):
49        self.name2taxnode = {}
50        for leaf in self.tax_tree.iter_leaves():
51            self.name2taxnode[leaf.name] = leaf
52       
53    def save_outgroup(self):
54        rt = self.mf_rooted_tree
55       
56        # remove unifurcation at the root
57        if len(rt.children) == 1:
58            rt = rt.children[0]
59
60        if len(rt.children) > 1:
61            outgr = rt.children[0]   
62            outgr_size = len(outgr.get_leaves())
63            for child in rt.children:
64                if child != outgr:
65                    child_size = len(child.get_leaves())
66                    if child_size < outgr_size:
67                        outgr = child
68                        outgr_size = child_size
69        else:
70            raise AssertionError("Invalid tree: unifurcation at the root node!")
71       
72        self.outgroup = outgr
73   
74    def restore_rooting(self, utree):
75        outgr_leaves = self.outgroup.get_leaf_names()
76        # check if outgroup consists of a single node - ETE considers it to be root, not leaf
77        if not outgr_leaves:
78            outgr_root = utree&outgr.name
79        elif len(outgr_leaves) == 1:
80            outgr_root = utree&outgr_leaves[0]
81        else:
82            # Even unrooted tree is "implicitely" rooted in ETE representation.
83            # If this pseudo-rooting happens to be within the outgroup, it cause problems
84            # in the get_common_ancestor() step (since common_ancestor = "root")
85            # Workaround: explicitely root the tree outside from outgroup subtree
86            for node in utree.iter_leaves():
87                if not node.name in outgr_leaves:
88                    tmp_root = node.up
89                    if not utree == tmp_root:
90                        utree.set_outgroup(tmp_root)
91                        break
92           
93            outgr_root = utree.get_common_ancestor(outgr_leaves)
94
95        # we could be so lucky that the RAxML tree is already correctly rooted :)
96        if outgr_root != utree:
97            utree.set_outgroup(outgr_root)
98
99        self.bf_rooted_tree = utree
100
101    def label_bf_tree_with_ranks(self):
102        """labeling inner tree nodes with taxonomic ranks"""
103        if not self.bf_rooted_tree:
104            raise AssertionError("self.bf_rooted_tree is not set: TaxTreeHelper.set_bf_unrooted_tree() must be called before!")
105           
106        for node in self.bf_rooted_tree.traverse("postorder"):
107            if node.is_leaf():
108                seq_ranks = self.origin_taxonomy[node.name]
109                rank_level = Taxonomy.lowest_assigned_rank_level(seq_ranks)
110                node.add_feature("rank_level", rank_level)
111                node.add_feature("ranks", seq_ranks)
112                node.name += "__" + seq_ranks[rank_level]
113            else:
114                if len(node.children) != 2:
115                    raise AssertionError("FATAL ERROR: tree is not bifurcating!")
116                lchild = node.children[0]
117                rchild = node.children[1]
118                rank_level = min(lchild.rank_level, rchild.rank_level)
119                while rank_level >= 0 and lchild.ranks[rank_level] != rchild.ranks[rank_level]:
120                    rank_level -= 1
121                node.add_feature("rank_level", rank_level)
122                node_ranks = [Taxonomy.EMPTY_RANK] * max(len(lchild.ranks),len(rchild.ranks)) 
123                if rank_level >= 0:
124                    node_ranks[0:rank_level+1] = lchild.ranks[0:rank_level+1]
125                    node.name = lchild.ranks[rank_level]
126                else:
127                    node.name = "Undefined"
128                    if hasattr(node, "B"):
129                        self.cfg.log.debug("INFO: empty taxonomic annotation for branch %s (child nodes have no common ranks)", node.B)
130               
131                node.add_feature("ranks", node_ranks)
132
133        self.tax_tree = self.bf_rooted_tree
134        self.init_taxnode_map()
135
136    def build_bid_taxonomy_map(self):
137        self.bid_taxonomy_map = {}
138        self.ranks_set = set([])
139        for node in self.tax_tree.traverse("postorder"):
140            if not node.is_root() and hasattr(node, "B"):
141                parent = node.up
142                branch_rdiff = Taxonomy.lowest_assigned_rank_level(node.ranks) - Taxonomy.lowest_assigned_rank_level(parent.ranks)
143                branch_rank_id = Taxonomy.get_rank_uid(node.ranks)
144                branch_len = node.dist
145                self.bid_taxonomy_map[node.B] = (branch_rank_id, branch_rdiff, branch_len)
146                self.ranks_set.add(branch_rank_id)
147#                if self.cfg.debug:
148#                  print node.ranks, parent.ranks, branch_diff
149
150    def get_seq_ranks_from_tree(self, seq_name):
151        if seq_name not in self.name2taxnode:
152            errmsg = "FATAL ERROR: Sequence %s is not found in the taxonomic tree!" % seq_name
153            self.cfg.exit_fatal_error(errmsg)
154
155        seq_node = self.name2taxnode[seq_name]
156        ranks = Taxonomy.split_rank_uid(seq_node.up.name)
157        return ranks
158
159    def strip_missing_ranks(self, ranks):
160        rank_level = len(ranks)
161        while not Taxonomy.get_rank_uid(ranks[0:rank_level]) in self.ranks_set and rank_level > 0:
162            rank_level -= 1
163       
164        return ranks[0:rank_level]   
165   
166class TaxClassifyHelper:
167    def __init__(self, cfg, bid_taxonomy_map, sp_rate = 0., node_height = []):
168        self.cfg = cfg
169        self.bid_taxonomy_map = bid_taxonomy_map
170        self.sp_rate = sp_rate
171        self.node_height = node_height
172        self.erlang = erlang()
173        # hardcoded for now
174        self.parent_lhw_coeff = 0.49
175
176    def classify_seq(self, edges, minlw = None):
177        if not minlw:
178            minlw = self.cfg.min_lhw
179
180        edges = self.erlang_filter(edges)
181        if len(edges) > 0:
182            if self.cfg.taxassign_method == "1":
183                ranks, lws = self.assign_taxonomy_maxsum(edges, minlw)
184            else:
185                ranks, lws = self.assign_taxonomy_maxlh(edges)
186            return ranks, lws
187        else:
188            return [], []     
189           
190    def erlang_filter(self, edges):
191        if self.cfg.brlen_pv == 0.:
192            return edges
193           
194        newedges = []
195        for edge in edges:
196            edge_nr = str(edge[0])
197            pendant_length = edge[4]
198            pv = self.erlang.one_tail_test(rate = self.sp_rate, k = int(self.node_height[edge_nr]), x = pendant_length)
199            if pv >= self.cfg.brlen_pv:
200                newedges.append(edge)
201#            else:
202#                self.cfg.log.debug("Edge ignored: [%s, %f], p = %.12f", edge_nr, pendant_length, pv)
203       
204        if len(newedges) == 0:
205            return newedges
206       
207        # adjust likelihood weights -> is there a better way ???       
208        sum_lh = 0
209        max_lh = float(newedges[0][1])
210        for edge in newedges:
211            lh = float(edge[1])
212            sum_lh += math.exp(lh - max_lh)
213                       
214        for edge in newedges:
215            lh = float(edge[1])
216            edge[2] = math.exp(lh - max_lh) / sum_lh
217
218        return newedges
219
220    # "all or none" filter: return empty set iff *all* brlens are below the threshold
221    def erlang_filter2(self, edges):
222        if self.cfg.brlen_pv == 0.:
223            return edges
224           
225        for edge in edges:
226            edge_nr = str(edge[0])
227            pendant_length = edge[4]
228            pv = self.erlang.one_tail_test(rate = self.sp_rate, k = int(self.node_height[edge_nr]), x = pendant_length)
229            if pv >= self.cfg.brlen_pv:
230                return edges
231               
232        return []
233     
234    def get_branch_ranks(self, br_id):
235        br_rec = self.bid_taxonomy_map[br_id]
236        br_rank_id = br_rec[0]
237        ranks = Taxonomy.split_rank_uid(br_rank_id)           
238        return ranks
239   
240    def assign_taxonomy_maxlh(self, edges):
241        #Calculate the sum of likelihood weight for each rank
242        taxonmy_sumlw_map = {}
243        for edge in edges:
244            edge_nr = str(edge[0])
245            lw = edge[2]
246            taxranks = self.get_branch_ranks(edge_nr)           
247            for rank in taxranks:
248                if rank == "-":
249                    taxonmy_sumlw_map[rank] = -1
250                elif rank in taxonmy_sumlw_map:
251                    oldlw = taxonmy_sumlw_map[rank]
252                    taxonmy_sumlw_map[rank] = oldlw + lw
253                else:
254                    taxonmy_sumlw_map[rank] = lw
255       
256        #Assignment using the max likelihood placement
257        ml_edge = edges[0]
258        edge_nr = str(ml_edge[0])
259        maxlw = ml_edge[2]
260        ml_ranks = self.get_branch_ranks(edge_nr)
261        ml_ranks_copy = []
262        for rk in ml_ranks:
263            ml_ranks_copy.append(rk)
264        lws = []
265        cnt = 0
266        for rank in ml_ranks:
267            lw = taxonmy_sumlw_map[rank]
268            if lw > 1.0:
269                lw = 1.0
270            lws.append(lw)
271            if rank == "-" and cnt > 0 :               
272                for edge in edges[1:]:
273                    edge_nr = str(edge[0])
274                    taxonomy = self.get_branch_ranks(edge_nr)
275                    newrank = taxonomy[cnt]
276                    newlw = taxonmy_sumlw_map[newrank]
277                    higherrank_old = ml_ranks[cnt -1]
278                    higherrank_new = taxonomy[cnt -1]
279                    if higherrank_old == higherrank_new and newrank!="-":
280                        ml_ranks_copy[cnt] = newrank
281                        lws[cnt] = newlw
282            cnt = cnt + 1
283           
284        return ml_ranks_copy, lws
285
286    def assign_taxonomy_maxsum(self, edges, minlw):
287        """this function sums up all LH-weights for each rank and takes the rank with the max. sum """
288        # in EPA result, each placement(=branch) has a "weight"
289        # since we are interested in taxonomic placement, we do not care about branch vs. branch comparisons,
290        # but only consider rank vs. rank (e. g. G1 S1 vs. G1 S2 vs. G1)
291        # Thus we accumulate weights for each rank, there are to measures:
292        # "own" weight  = sum of weight of all placements EXACTLY to this rank (e.g. for G1: G1 only)
293        # "total" rank  = own rank + own rank of all children (for G1: G1 or G1 S1 or G1 S2)
294        rw_own = {}
295        rw_total = {}
296       
297        ranks = [Taxonomy.EMPTY_RANK]
298       
299        for edge in edges:
300            br_id = str(edge[0])
301            lweight = edge[2]
302            lowest_rank = None
303            lowest_rank_lvl = None
304
305            if lweight == 0.:
306                continue
307
308            # accumulate weight for the current sequence               
309            br_rank_id, rdiff, brlen = self.bid_taxonomy_map[br_id]
310            ranks = Taxonomy.split_rank_uid(br_rank_id)
311            for i in range(len(ranks)):
312                rank = ranks[i]
313                rank_id = Taxonomy.get_rank_uid(ranks, i)
314                if rank != Taxonomy.EMPTY_RANK:
315                    rw_total[rank_id] = rw_total.get(rank_id, 0) + lweight
316                    lowest_rank_lvl = i
317                    lowest_rank = rank_id
318                else:
319                    break
320
321            if lowest_rank:
322                if rdiff > 0:
323                  # if ranks of 'upper' and 'lower' adjacent nodes of a branch are non-equal, split LHW among them
324                  parent_rank = Taxonomy.get_rank_uid(ranks, lowest_rank_lvl - rdiff)
325                  rw_own[lowest_rank] = rw_own.get(lowest_rank, 0) + lweight * (1 - self.parent_lhw_coeff)
326                  rw_own[parent_rank] = rw_own.get(parent_rank, 0) + lweight * self.parent_lhw_coeff
327                  # correct total lhw for all levels between "parent" and "lowest"
328                  # NOTE: all those intermediate ranks are in fact indistinguishable, e.g. a family which contains a single genus
329                  for r in range(rdiff):
330                    interim_rank = Taxonomy.get_rank_uid(ranks, lowest_rank_lvl - r)
331                    rw_total[interim_rank] = rw_total.get(interim_rank, 0) - lweight * self.parent_lhw_coeff
332                else:
333                  rw_own[lowest_rank] = rw_own.get(lowest_rank, 0) + lweight
334#            else:
335#                self.cfg.log.debug("WARNING: no annotation for branch %s", br_id)
336           
337        # if all branches have empty ranks only, just return this placement
338        if len(rw_total) == 0:
339            return ranks, [1.] * len(ranks)
340       
341        # we assign the sequence to a rank, which has the max "own" weight AND
342        # whose "total" weight is greater than a confidence threshold
343        max_rw = 0.
344        ass_rank_id = None
345        for r in rw_own.iterkeys():
346            if rw_own[r] > max_rw and rw_total[r] >= minlw:
347                ass_rank_id = r
348                max_rw = rw_own[r] 
349        if not ass_rank_id:
350            ass_rank_id = max(rw_total.iterkeys(), key=(lambda key: rw_total[key]))
351
352        a_ranks = Taxonomy.split_rank_uid(ass_rank_id)
353       
354        # "total" weight is considered as confidence value for now
355        a_conf = [0.] * len(a_ranks)
356        for i in range(len(a_conf)):
357            rank = a_ranks[i]
358            if rank != Taxonomy.EMPTY_RANK:
359                rank_id = Taxonomy.get_rank_uid(a_ranks, i)
360                a_conf[i] = rw_total[rank_id]
361
362        return a_ranks, a_conf
363   
Note: See TracBrowser for help on using the repository browser.