source: branches/ali/GDE/SATIVA/sativa/sativa.py

Last change on this file was 14565, checked in by akozlov, 9 years ago
  • Property svn:executable set to *
File size: 32.3 KB
Line 
1#! /usr/bin/env python
2
3import sys
4import os
5import time
6import glob
7import multiprocessing
8from operator import itemgetter
9from subprocess import call
10
11from epac.ete2 import Tree, SeqGroup
12from epac.argparse import ArgumentParser,RawDescriptionHelpFormatter
13from epac.config import SativaConfig,EpacConfig
14from epac.raxml_util import RaxmlWrapper, FileUtils
15from epac.json_util import RefJsonParser, RefJsonChecker, EpaJsonParser
16from epac.taxonomy_util import TaxCode, Taxonomy
17from epac.classify_util import TaxTreeHelper,TaxClassifyHelper
18import epa_trainer
19
20class LeaveOneTest:
21    def __init__(self, config):
22        self.cfg = config
23       
24        self.mis_fname = self.cfg.out_fname("%NAME%.mis")
25        self.premis_fname = self.cfg.out_fname("%NAME%.premis")
26        self.misrank_fname = self.cfg.out_fname("%NAME%.misrank")
27        self.stats_fname = self.cfg.out_fname("%NAME%.stats")
28       
29        if os.path.isfile(self.mis_fname):
30            print "\nERROR: Output file already exists: %s" % self.mis_fname
31            print "Please specify a different job name using -n or remove old output files."
32            self.cfg.exit_user_error()
33
34        self.tmp_refaln = config.tmp_fname("%NAME%.refaln")
35        self.reftree_lbl_fname = config.tmp_fname("%NAME%_lbl.tre")
36        self.reftree_tax_fname = config.tmp_fname("%NAME%_tax.tre")
37        self.optmod_fname = self.cfg.tmp_fname("%NAME%.opt")
38        self.reftree_fname = self.cfg.tmp_fname("ref_%NAME%.tre")
39
40        self.mislabels = []
41        self.mislabels_cnt = []
42        self.rank_mislabels = []
43        self.rank_mislabels_cnt = []
44        self.misrank_conf_map = {}
45       
46    def write_bid_tax_map(self, bid_tax_map, final):
47        if self.cfg.debug:
48            fname_suffix = "final" if final else "l1out"
49            bid_fname = self.cfg.tmp_fname("%NAME%_" + "bid_tax_map_%s.txt" % fname_suffix)
50            with open(bid_fname, "w") as outf:
51              for bid, bid_rec in bid_tax_map.iteritems():
52                outf.write("%s\t%s\t%d\t%f\n" % (bid, bid_rec[0], bid_rec[1], bid_rec[2]));   
53
54    def write_assignments(self, assign_map, final):
55        if self.cfg.debug:
56            fname_suffix = "final" if final else "l1out"
57            assign_fname = self.cfg.tmp_fname("%NAME%_" + "taxassign_%s.txt" % fname_suffix)
58            with open(assign_fname, "w") as outf:
59                for seq_name in assign_map.iterkeys():
60                    ranks, lws = assign_map[seq_name]
61                    outf.write("%s\t%s\t%s\n" % (seq_name, ";".join(ranks), ";".join(["%.3f" % l for l in lws])))
62
63    def load_refjson(self, refjson_fname):
64        try:
65            self.refjson = RefJsonParser(refjson_fname)
66        except ValueError:
67            self.cfg.exit_user_error("ERROR: Invalid json file format!")
68           
69        #validate input json format
70        (valid, err) = self.refjson.validate()
71        if not valid:
72            self.cfg.log.error("ERROR: Parsing reference JSON file failed:\n%s", err)
73            self.cfg.exit_user_error()
74       
75        self.rate = self.refjson.get_rate()
76        self.node_height = self.refjson.get_node_height()
77        self.origin_taxonomy = self.refjson.get_origin_taxonomy()
78        self.tax_tree = self.refjson.get_tax_tree()
79        self.cfg.compress_patterns = self.refjson.get_pattern_compression()
80
81        self.bid_taxonomy_map = self.refjson.get_branch_tax_map()
82        if not self.bid_taxonomy_map:
83            # old file format (before 1.6), need to rebuild this map from scratch
84            th = TaxTreeHelper(self.cfg, self.origin_taxonomy)
85            th.set_mf_rooted_tree(self.tax_tree)
86            th.set_bf_unrooted_tree(self.refjson.get_reftree())
87            self.bid_taxonomy_map = th.get_bid_taxonomy_map()
88           
89        self.write_bid_tax_map(self.bid_taxonomy_map, final=False)
90
91        reftree_str = self.refjson.get_raxml_readable_tree()
92        self.reftree = Tree(reftree_str)
93        self.reftree_size = len(self.reftree.get_leaves())
94
95        # IMPORTANT: set EPA heuristic rate based on tree size!               
96        self.cfg.resolve_auto_settings(self.reftree_size)
97        # If we're loading the pre-optimized model, we MUST set the same rate het. mode as in the ref file       
98        if self.cfg.epa_load_optmod:
99            self.cfg.raxml_model = self.refjson.get_ratehet_model()
100
101        self.classify_helper = TaxClassifyHelper(self.cfg, self.bid_taxonomy_map, self.rate, self.node_height)
102        self.taxtree_helper = TaxTreeHelper(self.cfg, self.origin_taxonomy, self.tax_tree)
103       
104        tax_code_name = self.refjson.get_taxcode()
105        self.tax_code = TaxCode(tax_code_name)
106       
107        self.taxonomy = Taxonomy(prefix=EpacConfig.REF_SEQ_PREFIX, tax_map=self.origin_taxonomy)
108        self.tax_common_ranks = self.taxonomy.get_common_ranks()
109#        print "Common ranks: ", self.tax_common_ranks
110
111        self.mislabels_cnt = [0] * TaxCode.UNI_TAX_LEVELS
112        self.rank_mislabels_cnt = [0] * TaxCode.UNI_TAX_LEVELS
113       
114    def run_epa_trainer(self):
115        epa_trainer.run_trainer(self.cfg)
116
117        if not os.path.isfile(self.cfg.refjson_fname):
118            self.cfg.log.error("\nBuilding reference tree failed, see error messages above.")
119            self.cfg.exit_fatal_error()
120       
121    def classify_seq(self, placement):
122        edges = placement["p"]
123        if len(edges) > 0:
124            return self.classify_helper.classify_seq(edges)
125        else:
126            print "ERROR: no placements! something is definitely wrong!"
127
128    def check_seq_tax_labels(self, seq_name, orig_ranks, ranks, lws):
129        mis_rec = None
130       
131        num_common_ranks = len(self.tax_common_ranks)
132        orig_rank_level = Taxonomy.lowest_assigned_rank_level(orig_ranks)
133        new_rank_level = Taxonomy.lowest_assigned_rank_level(ranks)
134        #if new_rank_level < 0 or (new_rank_level < num_common_ranks and orig_rank_level >= num_common_ranks):
135#        if new_rank_level < 0:
136        if len(ranks) == 0:
137            mis_rec = {}
138            mis_rec['name'] = seq_name
139            mis_rec['orig_level'] = -1
140            mis_rec['real_level'] = 0
141            mis_rec['level_name'] = "[NotIngroup]"
142            mis_rec['inv_level'] = -1 * mis_rec['real_level']  # just for sorting
143            mis_rec['orig_ranks'] = orig_ranks
144            mis_rec['ranks'] = []
145            mis_rec['lws'] = [1.0]
146            mis_rec['conf'] = mis_rec['lws'][0]
147        else:
148            mislabel_lvl = -1
149            min_len = min(len(orig_ranks),len(ranks))
150            for rank_lvl in range(min_len):
151                if ranks[rank_lvl] != Taxonomy.EMPTY_RANK and ranks[rank_lvl] != orig_ranks[rank_lvl]:
152                    mislabel_lvl = rank_lvl
153                    break
154
155            if mislabel_lvl >= 0:
156                real_lvl = self.tax_code.guess_rank_level(orig_ranks, mislabel_lvl)
157                mis_rec = {}
158                mis_rec['name'] = seq_name
159                mis_rec['orig_level'] = mislabel_lvl
160                mis_rec['real_level'] = real_lvl
161                mis_rec['level_name'] = self.tax_code.rank_level_name(real_lvl)[0]
162                mis_rec['inv_level'] = -1 * mis_rec['real_level']  # just for sorting
163                mis_rec['orig_ranks'] = orig_ranks
164                mis_rec['ranks'] = ranks
165                mis_rec['lws'] = lws
166                mis_rec['conf'] = lws[mislabel_lvl]
167   
168        if mis_rec:
169            self.mislabels.append(mis_rec)
170           
171        return mis_rec
172       
173    def filter_mislabels(self):
174        filtered_mis = []
175        for i in range(len(self.mislabels)):
176            if self.mislabels[i]['conf'] >= self.cfg.conf_cutoff:
177                filtered_mis.append(self.mislabels[i])
178       
179        self.mislabels = filtered_mis
180
181    def check_rank_tax_labels(self, rank_name, orig_ranks, ranks, lws):
182        mislabel_lvl = -1
183        min_len = min(len(orig_ranks),len(ranks))
184        for rank_lvl in range(min_len):
185            if ranks[rank_lvl] != Taxonomy.EMPTY_RANK and ranks[rank_lvl] != orig_ranks[rank_lvl]:
186                mislabel_lvl = rank_lvl
187                break
188
189        if mislabel_lvl >= 0:
190            real_lvl = self.tax_code.guess_rank_level(orig_ranks, mislabel_lvl)
191            mis_rec = {}
192            mis_rec['name'] = rank_name
193            mis_rec['orig_level'] = mislabel_lvl
194            mis_rec['real_level'] = real_lvl
195            mis_rec['level_name'] = self.tax_code.rank_level_name(real_lvl)[0]
196            mis_rec['inv_level'] = -1 * real_lvl  # just for sorting
197            mis_rec['orig_ranks'] = orig_ranks
198            mis_rec['ranks'] = ranks
199            mis_rec['lws'] = lws
200            mis_rec['conf'] = lws[mislabel_lvl]
201            self.rank_mislabels.append(mis_rec)
202               
203            return mis_rec
204        else:
205            return None               
206
207    def mis_rec_to_string_old(self, mis_rec):
208        lvl = mis_rec['orig_level']
209        output = mis_rec['name'] + "\t"
210        output += "%s\t%s\t%s\t%.3f\n" % (mis_rec['level_name'], 
211            mis_rec['orig_ranks'][lvl], mis_rec['ranks'][lvl], mis_rec['lws'][lvl])
212        output += ";".join(mis_rec['orig_ranks']) + "\n"
213        output += ";".join(mis_rec['ranks']) + "\n"
214        output += "\t".join(["%.3f" % conf for conf in mis_rec['lws']]) + "\n"
215        return output
216
217    def mis_rec_to_string(self, mis_rec):
218        lvl = mis_rec['orig_level']
219        uncorr_name = EpacConfig.strip_ref_prefix(self.refjson.get_uncorr_seqid(mis_rec['name']))
220        uncorr_orig_ranks = self.refjson.get_uncorr_ranks(mis_rec['orig_ranks'])
221        uncorr_ranks = self.refjson.get_uncorr_ranks(mis_rec['ranks'])
222        output = uncorr_name + "\t"
223     
224        if lvl >= 0:
225            output += "%s\t%s\t%s\t%.3f\t" % (mis_rec['level_name'], 
226                uncorr_orig_ranks[lvl], uncorr_ranks[lvl], mis_rec['lws'][lvl])
227        else:
228            output += "%s\t%s\t%s\t%.3f\t" % (mis_rec['level_name'], 
229                "NA", "NA", mis_rec['lws'][0])
230       
231        output += Taxonomy.lineage_str(uncorr_orig_ranks) + "\t"
232        output += Taxonomy.lineage_str(uncorr_ranks) + "\t"
233        output += ";".join(["%.3f" % conf for conf in mis_rec['lws']])
234        if 'rank_conf' in mis_rec:
235            output += "\t%.3f" % mis_rec['rank_conf']
236        return output
237
238    def sort_mislabels(self):
239        self.mislabels = sorted(self.mislabels, key=itemgetter('inv_level', 'conf', 'name'), reverse=True)
240        for mis_rec in self.mislabels:
241            real_lvl = mis_rec["real_level"]
242            self.mislabels_cnt[real_lvl] += 1
243       
244        if self.cfg.ranktest:
245            self.rank_mislabels = sorted(self.rank_mislabels, key=itemgetter('inv_level', 'conf', 'name'), reverse=True)
246            for mis_rec in self.rank_mislabels:
247                real_lvl = mis_rec["real_level"]
248                self.rank_mislabels_cnt[real_lvl] += 1
249   
250    def write_stats(self, toFile=False):
251        self.cfg.log.info("Mislabeled sequences by rank:")
252        seq_sum = 0
253        rank_sum = 0
254        stats = []
255        for i in range(len(self.mislabels_cnt)):
256            if i > 0:
257                rname = self.tax_code.rank_level_name(i)[0].ljust(12)
258            else:
259                rname = "[NotIngroup]"
260            if self.mislabels_cnt[i] > 0:
261                seq_sum += self.mislabels_cnt[i]
262#                    output = "%s:\t%d" % (rname, seq_sum)
263                output = "%s:\t%d" % (rname, self.mislabels_cnt[i])
264                if self.cfg.ranktest:
265                    rank_sum += self.rank_mislabels_cnt[i]
266                    output += "\t%d" % rank_sum
267                self.cfg.log.info(output) 
268                stats.append(output)
269
270        if toFile:
271            with open(self.stats_fname, "w") as fo_stat:
272                for line in stats:
273                    fo_stat.write(line + "\n")
274   
275    def write_mislabels(self, final=True):
276        if final:
277            out_fname = self.mis_fname
278        else:
279            out_fname = self.premis_fname
280       
281        with open(out_fname, "w") as fo_all:
282            fields = ["SeqID", "MislabeledLevel", "OriginalLabel", "ProposedLabel", "Confidence", "OriginalTaxonomyPath", "ProposedTaxonomyPath", "PerRankConfidence"]
283            if self.cfg.ranktest:
284                fields += ["HigherRankMisplacedConfidence"]
285            header = ";" + "\t".join(fields) + "\n"
286            fo_all.write(header)
287            if self.cfg.verbose and len(self.mislabels) > 0 and final:
288                print "Mislabeled sequences:\n"
289                print header
290            for mis_rec in self.mislabels:
291                output = self.mis_rec_to_string(mis_rec)  + "\n"
292                fo_all.write(output)
293                if self.cfg.verbose and final:
294                    print(output) 
295                   
296        if not final:
297            return
298
299        if self.cfg.ranktest:
300            with open(self.misrank_fname, "w") as fo_all:
301                fields = ["RankID", "MislabeledLevel", "OriginalLabel", "ProposedLabel", "Confidence", "OriginalTaxonomyPath", "ProposedTaxonomyPath", "PerRankConfidence"]
302                header = ";" + "\t".join(fields)  + "\n"
303                fo_all.write(header)
304                if self.cfg.verbose  and len(self.rank_mislabels) > 0:
305                    print "\nMislabeled higher ranks:\n"
306                    print header
307                for mis_rec in self.rank_mislabels:
308                    output = self.mis_rec_to_string(mis_rec) + "\n"
309                    fo_all.write(output)
310                    if self.cfg.verbose:
311                        print(output) 
312                       
313        self.write_stats()
314   
315    def run_leave_subtree_out_test(self):
316        job_name = self.cfg.subst_name("l1out_rank_%NAME%")
317#        if self.jplace_fname:
318#            jp = EpaJsonParser(self.jplace_fname)
319#        else:       
320
321        #create file with subtrees
322        rank_tips = {}
323        rank_parent = {}
324        for node in self.tax_tree.traverse("postorder"):
325            if node.is_leaf() or node.is_root():
326                continue
327            tax_path = node.name
328            ranks = Taxonomy.split_rank_uid(tax_path)
329            rank_lvl = Taxonomy.lowest_assigned_rank_level(ranks)
330            if rank_lvl < 2:
331                continue
332               
333            parent_ranks = Taxonomy.split_rank_uid(node.up.name)
334            parent_lvl = Taxonomy.lowest_assigned_rank_level(parent_ranks)
335            if parent_lvl < 1:
336                continue
337           
338            rank_seqs = node.get_leaf_names()
339            rank_size = len(rank_seqs)
340            if rank_size < 2 or rank_size > self.reftree_size-4:
341                continue
342
343#            print rank_lvl, "\t", tax_path, "\t", rank_seqs, "\n"
344            rank_tips[tax_path] = node.get_leaf_names()
345            rank_parent[tax_path] = parent_ranks
346               
347        subtree_list = rank_tips.items()
348       
349        if len(subtree_list) == 0:
350            return 0
351           
352        subtree_list_file = self.cfg.tmp_fname("treelist_%NAME%.txt")
353        with open(subtree_list_file, "w") as fout:
354            for rank_name, tips in subtree_list:
355                fout.write("%s\n" % " ".join(tips))
356       
357        jp_list = self.raxml.run_epa(job_name, self.refalign_fname, self.reftree_fname, self.optmod_fname, 
358            mode="l1o_subtree", subtree_fname=subtree_list_file)
359
360        subtree_count = 0
361        for jp in jp_list:
362            placements = jp.get_placement()
363            for place in placements:
364                ranks, lws = self.classify_seq(place)
365                tax_path = subtree_list[subtree_count][0]
366                orig_ranks = Taxonomy.split_rank_uid(tax_path)
367                rank_level = Taxonomy.lowest_assigned_rank_level(orig_ranks)
368                rank_prefix = self.guess_rank_level_name(orig_ranks, rank_level)[0]
369                rank_name = orig_ranks[rank_level]
370                if not rank_name.startswith(rank_prefix):
371                    rank_name = rank_prefix + rank_name
372                parent_ranks = rank_parent[tax_path]
373#                print orig_ranks, "\n", parent_ranks, "\n", ranks, "\n"
374                mis_rec = self.check_rank_tax_labels(rank_name, parent_ranks, ranks, lws)
375                if mis_rec:
376                    self.misrank_conf_map[tax_path] = mis_rec['conf']
377                subtree_count += 1
378
379        return subtree_count   
380       
381    def run_leave_seq_out_test(self):
382        job_name = self.cfg.subst_name("l1out_seq_%NAME%")
383        placements = []
384        if self.cfg.jplace_fname:
385            if os.path.isdir(self.cfg.jplace_fname):
386                jplace_fmask = os.path.join(self.cfg.jplace_fname, '*.jplace')
387            else:
388                jplace_fmask = self.cfg.jplace_fname
389
390            jplace_fname_list = glob.glob(jplace_fmask)
391            for jplace_fname in jplace_fname_list:
392                jp = EpaJsonParser(jplace_fname)
393                placements += jp.get_placement()
394               
395            config.log.debug("Loaded %d placements from %s\n", len(placements), jplace_fmask)
396        else:       
397            jp = self.raxml.run_epa(job_name, self.refalign_fname, self.reftree_fname, self.optmod_fname, mode="l1o_seq")
398            placements = jp.get_placement()
399            if self.cfg.output_interim_files:
400                out_jplace_fname = self.cfg.out_fname("%NAME%.l1out_seq.jplace")
401                self.raxml.copy_epa_jplace(job_name, out_jplace_fname, move=True, mode="l1o_seq")
402       
403        seq_count = 0
404        l1out_ass = {}
405        for place in placements:
406            seq_name = place["n"][0]
407           
408            # get original taxonomic label
409#            orig_ranks = self.get_orig_ranks(seq_name)
410            orig_ranks =  self.taxtree_helper.get_seq_ranks_from_tree(seq_name)
411
412            # get EPA tax label
413            ranks, lws = self.classify_seq(place)
414            l1out_ass[seq_name] = (ranks, lws)
415           
416            # check if they match
417            mis_rec = self.check_seq_tax_labels(seq_name, orig_ranks, ranks, lws)
418            # cross-check with higher rank mislabels
419            if self.cfg.ranktest and mis_rec:
420                rank_conf = 0
421                for lvl in range(2,len(orig_ranks)):
422                    tax_path = Taxonomy.get_rank_uid(orig_ranks, lvl)
423                    if tax_path in self.misrank_conf_map:
424                        rank_conf = max(rank_conf, self.misrank_conf_map[tax_path])
425                mis_rec['rank_conf'] = rank_conf
426            seq_count += 1
427
428        self.write_assignments(l1out_ass, final=False)
429           
430        return seq_count   
431       
432    def run_final_epa_test(self):
433        self.reftree_outgroup = self.refjson.get_outgroup()
434
435        tmp_reftree = self.reftree.copy(method="newick") 
436        name2refnode = {}
437        for leaf in tmp_reftree.iter_leaves():
438            name2refnode[leaf.name] = leaf       
439
440        tmp_taxtree = self.tax_tree.copy(method="newick") 
441        name2taxnode = {}
442        for leaf in tmp_taxtree.iter_leaves():
443            name2taxnode[leaf.name] = leaf       
444
445        for mis_rec in self.mislabels:
446            rname = mis_rec['name']
447#            rname = EpacConfig.REF_SEQ_PREFIX + name
448
449            if rname in name2refnode:
450                name2refnode[rname].delete()
451            else:
452                print "Node not found in the reference tree: %s" % rname
453
454            if rname in name2taxnode:
455                name2taxnode[rname].delete()
456            else:
457                print "Node not found in the taxonomic tree: %s" % rname
458
459        # remove unifurcation at the root
460        if len(tmp_reftree.children) == 1:
461            tmp_reftree = tmp_reftree.children[0]
462           
463        self.mislabels = []
464
465        th = TaxTreeHelper(self.cfg, self.origin_taxonomy)
466        th.set_mf_rooted_tree(tmp_taxtree)
467           
468        epa_result = self.run_epa_once(tmp_reftree)
469       
470        reftree_epalbl_str = epa_result.get_std_newick_tree()       
471        placements = epa_result.get_placement()
472       
473        # update branchid-taxonomy mapping to account for possible changes in branch numbering
474        reftree_tax = Tree(reftree_epalbl_str)
475        th.set_bf_unrooted_tree(reftree_tax)
476        bid_tax_map = th.get_bid_taxonomy_map()
477       
478        self.write_bid_tax_map(bid_tax_map, final=True)
479
480        cl = TaxClassifyHelper(self.cfg, bid_tax_map, self.rate, self.node_height)
481       
482#        newtax_fname = self.cfg.subst_name("newtax_%NAME%.tre")
483#        th.get_tax_tree().write(outfile=newtax_fname, format=3)
484
485        final_ass = {}
486        for place in placements:
487            seq_name = place["n"][0]
488
489            # get original taxonomic label
490            orig_ranks = self.taxtree_helper.get_seq_ranks_from_tree(seq_name)
491
492            # EXPERIMENTAL FEATURE - disabled for now!
493            # It could happen that certain ranks were present in the "original" reference tree, but
494            # are completely missing in the pruned tree (e.g., all seqs of a species were considered "suspicious"
495            # after the leave-one-out test and thus pruned)
496            # In this case, EPA has no chance to infer full original taxonomic annotation (=species) since the corresponding clade
497            # is now missing. To account for this fact, we amend the original taxonomic annotation and set ranks missing from 
498            # pruned tree to "Undefined".
499#            orig_ranks = th.strip_missing_ranks(orig_ranks)
500#            print orig_ranks
501
502            # get EPA tax label
503            ranks, lws = cl.classify_seq(place["p"])
504            final_ass[seq_name] = (ranks, lws)
505
506            #print seq_name, ": ", orig_ranks, "--->", ranks
507
508            # check if they match
509            mis_rec = self.check_seq_tax_labels(seq_name, orig_ranks, ranks, lws)
510
511        self.write_assignments(final_ass, final=True)
512
513    def run_epa_once(self, reftree):
514        reftree_fname = self.cfg.tmp_fname("final_ref_%NAME%.tre")
515        job_name = self.cfg.subst_name("final_epa_%NAME%")
516
517        reftree.write(outfile=reftree_fname)
518
519        # IMPORTANT: don't load the model, since it's invalid for the pruned true !!!
520        optmod_fname=""
521        epa_result = self.raxml.run_epa(job_name, self.refalign_fname, reftree_fname, optmod_fname)
522
523        if self.cfg.output_interim_files:
524            out_jplace_fname = self.cfg.out_fname("%NAME%.final_epa.jplace")
525            self.raxml.copy_epa_jplace(job_name, out_jplace_fname, move=True)
526
527        return epa_result
528
529    def run_test(self):
530        self.raxml = RaxmlWrapper(self.cfg)
531
532#        config.log.info("Number of sequences in the reference: %d\n", self.reftree_size)
533
534        self.refjson.get_raxml_readable_tree(self.reftree_fname)
535        self.refalign_fname = self.refjson.get_alignment(self.tmp_refaln)       
536        self.refjson.get_binary_model(self.optmod_fname)
537       
538        if self.cfg.ranktest:
539            config.log.info("Running the leave-one-rank-out test...\n")
540            subtree_count = self.run_leave_subtree_out_test()
541           
542        config.log.info("Running the leave-one-sequence-out test...\n")
543        self.run_leave_seq_out_test()
544
545        if len(self.mislabels) > 0:
546            config.log.info("Leave-one-out test identified %d suspicious sequences; running final EPA test to check them...\n", len(self.mislabels))
547            if self.cfg.debug:
548                self.write_mislabels(final=False)
549            self.run_final_epa_test()
550
551        self.filter_mislabels()
552        self.sort_mislabels()
553        self.write_mislabels()
554        config.log.info("\nTotal mislabels: %d / %.2f %%", len(self.mislabels), (float(len(self.mislabels)) / self.reftree_size * 100))
555
556def parse_args():
557    parser = ArgumentParser(usage="%(prog)s -s ALIGNMENT -t TAXONOMY -x {BAC,BOT,ZOO,VIR} [options]",
558    description=EpacConfig.SATIVA_INFO % "SATIVA",
559    epilog="Example: sativa.py -s example/test.phy -t example/test.tax -x BAC",
560    formatter_class=RawDescriptionHelpFormatter)
561    parser.add_argument("-s", dest="align_fname",
562            help="""Reference alignment file (PHYLIP or FASTA). Sequences must be aligned,
563            their IDs must correspond to those in taxonomy file.""")
564    parser.add_argument("-t", dest="taxonomy_fname",
565            help="""Reference taxonomy file.""")
566    parser.add_argument("-x", dest="taxcode_name", choices=["bac", "bot", "zoo", "vir"], type = str.lower,
567            help="""Taxonomic code: BAC(teriological), BOT(anical), ZOO(logical), VIR(ological)""")
568    parser.add_argument("-n", dest="output_name", default=None,
569            help="""Job name, will be used as a prefix for output file names (default: taxonomy file name without extension)""")
570    parser.add_argument("-o", dest="output_dir", default=".",
571            help="""Output directory (default: current).""")
572    parser.add_argument("-T", dest="num_threads", type=int, default=multiprocessing.cpu_count(),
573            help="""Specify the number of CPUs (default: %d)""" % multiprocessing.cpu_count())
574    parser.add_argument("-N", dest="rep_num", type=int, default=1, 
575            help="""Number of RAxML tree searches (with distinct random seeds) to resolve multifurcation. Default: 1""")
576    parser.add_argument("-v", dest="verbose", action="store_true",
577            help="""Print additional info messages to the console.""")
578    parser.add_argument("-R", dest="restart", action="store_true",
579            help="""Resume execution after a premature termination (e.g., due to expired job time limit).
580Run name of the previous (terminated) job must be specified via -n option.""")
581    parser.add_argument("-c", dest="config_fname", default=None,
582            help="Config file name.")
583    parser.add_argument("-r", dest="ref_fname",
584            help="""Specify the reference alignment and taxonomy in refjson format.""")
585    parser.add_argument("-j", dest="jplace_fname", default=None,
586            help="""Do not call RAxML EPA, use existing .jplace file as input instead.
587            This could be also a directory with *.jplace files.""")
588    parser.add_argument("-p", dest="rand_seed", type=int, default=12345,
589            help="""Random seed to be used with RAxML. Default: 12345""")
590    parser.add_argument("-C", dest="conf_cutoff", type=float, default=0.,
591            help="""Confidence cut-off between 0 and 1. Default: 0\n""")
592    parser.add_argument("-P", dest="brlen_pv", type=float, default=0.,
593            help="""P-value for branch length Erlang test. Default: 0=off\n""")
594    parser.add_argument("-l", dest="min_lhw", type=float, default=0.,
595            help="""A value between 0 and 1, the minimal sum of likelihood weight of
596                    an assignment to a specific rank. This value represents a confidence
597                    measure of the assignment, assignments below this value will be discarded.
598                    Default: 0 to output all possbile assignments.""")
599    parser.add_argument("-m", dest="mfresolv_method", choices=["thorough", "fast", "ultrafast"],
600            default="thorough", help="""Method of multifurcation resolution:
601            thorough    use stardard constrainted RAxML tree search (default)
602            fast        use RF distance as search convergence criterion (RAxML -D option)
603            ultrafast   optimize model+branch lengths only (RAxML -f e option)""")
604    parser.add_argument("-debug", dest="debug", action="store_true",
605            help="""Debug mode, intermediate files will not be cleaned up.""")
606    parser.add_argument("-ranktest", dest="ranktest", action="store_true",
607            help="""Test for misplaced higher ranks.""")
608    parser.add_argument("-tmpdir", dest="temp_dir", default=None,
609            help="""Directory for temporary files.""")
610
611    args = parser.parse_args()
612    if len(sys.argv) == 1: 
613        parser.print_help()
614        sys.exit()
615    check_args(args, parser)
616    return args
617
618
619def check_args(args, parser):   
620    if args.ref_fname:
621        if args.align_fname:
622            print("WARNING: -r and -s options are mutually exclusive! Your alignment file will be ignored.\n")
623        if args.taxonomy_fname:
624            print("WARNING: -r and -t options are mutually exclusive! Your taxonomy file will be ignored.\n")
625        if args.taxcode_name:
626            print("WARNING: -r and -x options are mutually exclusive! The taxonomic code from reference file will be used.\n")
627    elif not args.align_fname or not args.taxonomy_fname or not args.taxcode_name:
628        print("ERROR: either reference in JSON format or taxonomy, alignment and taxonomic code name must be provided:\n")
629        parser.print_help()
630        sys.exit()
631   
632    if not os.path.exists(args.output_dir):
633        print("Output directory does not exists: %s" % args.output_dir)
634        sys.exit()
635
636    #check if taxonomy file exists
637    if args.taxonomy_fname and not os.path.isfile(args.taxonomy_fname):
638        print "ERROR: Taxonomy file not found: %s" % args.taxonomy_fname
639        sys.exit()
640
641    #check if alignment file exists
642    if args.align_fname and not os.path.isfile(args.align_fname):
643        print "ERROR: Alignment file not found: %s" % args.align_fname
644        sys.exit()
645
646    if args.ref_fname and not os.path.isfile(args.ref_fname):
647        print("Input reference json file does not exists: %s" % args.ref_fname)
648        sys.exit()
649   
650    if args.jplace_fname and not (os.path.isfile(args.jplace_fname) or os.path.isdir(args.jplace_fname)):
651        print("EPA placement file does not exists: %s" % args.jplace_fname)
652        sys.exit()
653
654    if args.min_lhw < 0 or args.min_lhw > 1.0:
655         args.min_lhw = 0.0
656   
657    if args.conf_cutoff < 0 or args.conf_cutoff > 1.0:
658         args.min_lhw = 0.0
659
660    sativa_home = os.path.dirname(os.path.abspath(__file__))
661    if not args.config_fname:
662        args.config_fname = os.path.join(sativa_home, "sativa.cfg")
663    if not args.temp_dir:
664        args.temp_dir = os.path.join(sativa_home, "tmp")
665    if not args.output_name:
666        if args.taxonomy_fname:
667            base_fname = args.taxonomy_fname
668        else:
669            base_fname = args.ref_fname
670        args.output_name = os.path.splitext(base_fname)[0]
671       
672def print_run_info(config):
673    print ""
674    config.print_version("SATIVA")
675   
676    call_str = " ".join(sys.argv)
677    config.log.info("SATIVA was called as follows:\n\n%s\n" % call_str)
678   
679    if config.verbose:
680        config.log.info("Mislabels search is running with the following parameters:")
681        if config.align_fname:
682            config.log.info(" Alignment:                        %s", config.align_fname)
683            config.log.info(" Taxonomy:                         %s", config.taxonomy_fname)
684        if config.load_refjson:
685            config.log.info(" Reference:                        %s", config.refjson_fname)
686        if config.jplace_fname:
687            config.log.info(" EPA jplace file:                  %s", config.jplace_fname)
688        #config.log.info(" Min likelihood weight:            %f", args.min_lhw)
689#        config.log.info(" Assignment method:                %s", args.method)
690        config.log.info(" Output directory:                 %s", os.path.abspath(config.output_dir))
691        config.log.info(" Job name / output files prefix:   %s", config.name)
692        config.log.info(" Model of rate heterogeneity:      %s", config.raxml_model)
693        config.log.info(" Confidence cut-off:               %f", config.conf_cutoff)
694#        config.log.info(" P-value for branch length test:   %g", config.brlen_pv)
695        config.log.info(" Number of threads:                %d", config.num_threads)
696        config.log.info("")
697
698    if config.debug:
699        config.log.debug("Running in DEBUG mode, temp files will be saved to: %s\n", os.path.abspath(config.temp_dir))
700
701if __name__ == "__main__":
702    args = parse_args()
703    config = SativaConfig(args)
704   
705    start_time = time.time()
706    trainer_time = 0
707   
708    t = LeaveOneTest(config)
709    print_run_info(config)
710
711    if config.load_refjson:
712        t.load_refjson(config.refjson_fname)
713    else:
714        config.log.info("*** STEP 1: Building the reference tree using provided alignment and taxonomic annotations ***\n")
715        tr_start_time = time.time() 
716        t.run_epa_trainer()
717        trainer_time = time.time() - tr_start_time
718        t.load_refjson(config.refjson_fname)
719        config.log.info("*** STEP 2: Searching for mislabels ***\n")
720   
721    l1out_start_time = time.time()
722   
723    t.run_test()
724   
725    config.clean_tempdir()
726       
727    l1out_time = time.time() - l1out_start_time
728
729    config.log.info("\nResults were saved to: %s", os.path.abspath(t.mis_fname))
730    config.log.info("Execution log was saved to: %s\n", os.path.abspath(config.log_fname))
731
732    elapsed_time = time.time() - start_time
733    config.log.info("Analysis completed successfully, elapsed time: %.0f seconds (%.0fs reftree, %.0fs leave-one-out)\n", elapsed_time, trainer_time, l1out_time)
Note: See TracBrowser for help on using the repository browser.