source: tags/ms_r16q2/GDE/SATIVA/sativa/epa_trainer.py

Last change on this file was 14661, checked in by akozlov, 8 years ago
  • Property svn:executable set to *
File size: 35.5 KB
Line 
1#!/usr/bin/env python
2
3import sys
4import os
5import shutil
6import datetime
7import time
8import logging
9import multiprocessing
10from string import maketrans
11
12from epac.ete2 import Tree, SeqGroup
13from epac.argparse import ArgumentParser,RawTextHelpFormatter
14from epac.config import EpacConfig,EpacTrainerConfig
15from epac.raxml_util import RaxmlWrapper, FileUtils
16from epac.taxonomy_util import Taxonomy, TaxTreeBuilder
17from epac.json_util import RefJsonBuilder
18from epac.erlang import tree_param
19from epac.msa import hmmer
20from epac.classify_util import TaxTreeHelper
21
22class InputValidator:
23    def __init__(self, config, input_tax, input_seqs, verbose=True): 
24        self.cfg = config
25        self.taxonomy = input_tax
26        self.alignment = input_seqs
27        self.verbose = verbose
28        self.dupseq_sets = None
29        self.merged_ranks = None
30        self.corr_seqid = {}
31        self.corr_ranks = {}
32        self.gaps_trantab = maketrans("?N", "--")
33
34    def validate(self):
35        # following two checks are obsolete and disabled by default
36        self.check_tax_disbalance()
37        self.check_tax_duplicates()
38       
39        self.check_seq_ids()
40        self.check_invalid_chars()
41       
42        self.check_identical_seqs()
43        self.check_identical_ranks()
44       
45        self.taxonomy.close_taxonomy_gaps()
46       
47        return self.corr_ranks, self.corr_seqid, self.merged_ranks
48
49    def normalize_gaps(self, seq):
50        return seq.translate(self.gaps_trantab)
51       
52    def check_seq_ids(self):
53        # check that seq IDs in taxonomy and alignment correspond
54        self.mis_ids = []
55        for sid in self.taxonomy.seq_ranks_map.iterkeys():
56            unprefixed_sid = EpacConfig.strip_ref_prefix(sid)
57            if not self.alignment.has_seq(unprefixed_sid):
58                self.mis_ids.append(unprefixed_sid)
59               
60        if len(self.mis_ids) > 0 and self.verbose:
61            errmsg = "ERROR: Following %d sequence(s) are missing in your alignment file:\n%s\n\n" % (len(self.mis_ids), "\n".join(self.mis_ids))
62            errmsg += "Please make sure sequence IDs in taxonomic annotation file and in alignment are identical!\n"
63            self.cfg.exit_user_error(errmsg)
64           
65        return self.mis_ids
66
67    def check_invalid_chars(self):
68        # check for invalid characters in rank names
69        self.corr_ranks = self.taxonomy.normalize_rank_names()
70       
71        # check for invalid characters in sequence IDs
72        self.corr_seqid = self.taxonomy.normalize_seq_ids()
73
74        if self.verbose:
75            for old_rank in sorted(self.corr_ranks.keys()):
76                self.cfg.log.debug("NOTE: Following rank name contains illegal symbols and was renamed: %s --> %s", old_rank, self.corr_ranks[old_rank])
77            if len(self.corr_ranks) > 0:
78                self.cfg.log.debug("")
79            for old_sid in sorted(self.corr_seqid.keys()):
80                self.cfg.log.debug("NOTE: Following sequence ID contains illegal symbols and was renamed: %s --> %s" , old_sid, self.corr_seqid[old_sid])
81            if len(self.corr_seqid) > 0:
82                self.cfg.log.debug("")
83           
84        return self.corr_ranks, self.corr_seqid
85       
86    def check_identical_seqs(self):
87        seq_hash_map = {}
88        for name, seq, comment, sid in self.alignment.iter_entries():
89            ref_seq_name = EpacConfig.REF_SEQ_PREFIX + name
90            ref_seq_name = self.corr_seqid.get(ref_seq_name, ref_seq_name)
91            if ref_seq_name in self.taxonomy.seq_ranks_map:
92                seq_hash = hash(self.normalize_gaps(seq))
93                if seq_hash in seq_hash_map:
94                    seq_hash_map[seq_hash] += [name]
95                else:
96                    seq_hash_map[seq_hash] = [name]
97
98        self.dupseq_count = 0
99        self.dupseq_sets = []
100        for seq_hash, seq_ids in seq_hash_map.iteritems():
101            check_ids = seq_ids[:]
102            while len(check_ids) > 1:
103                # compare actual sequence strings, to account for a possible hash collision
104                seq1 = self.normalize_gaps(self.alignment.get_seq(check_ids[0]))
105                coll_ids = []
106                dup_ids = [check_ids[0]]
107                for i in range(1, len(check_ids)):
108                    seq2 = self.normalize_gaps(self.alignment.get_seq(check_ids[i]))
109                    if seq1 == seq2:
110                        dup_ids += [check_ids[i]]
111                    else:
112                        # collision found, add put seq id on a list to be checked in the next iteration
113                        coll_ids += [check_ids[i]]
114 
115                if len(dup_ids) > 1:
116                    self.dupseq_sets += [dup_ids]
117                    self.dupseq_count += len(dup_ids) - 1
118
119                check_ids = coll_ids
120               
121        if self.verbose:
122            for dup_ids in self.dupseq_sets:
123                self.cfg.log.debug("NOTE: Following sequences are identical: %s", ", ".join(dup_ids))
124            if self.dupseq_count > 0:
125                self.cfg.log.debug("\nNOTE: Found %d sequence duplicates", self.dupseq_count)
126               
127        return self.dupseq_count, self.dupseq_sets
128
129    def check_identical_ranks(self):
130        if not self.dupseq_sets:
131            self.check_identical_seqs()
132        self.merged_ranks = {}
133        for dup_ids in self.dupseq_sets:
134            if len(dup_ids) > 1:
135                duprank_map = {}
136                for seq_name in dup_ids:
137                    rank_id = self.taxonomy.seq_rank_id(seq_name)
138                    duprank_map[rank_id] = duprank_map.get(rank_id, 0) + 1
139                if len(duprank_map) > 1 and self.cfg.debug:
140                    self.cfg.log.debug("Ranks sharing duplicates: %s\n", str(duprank_map))
141                dup_ranks = []
142                for rank_id, count in duprank_map.iteritems():
143                    if count > self.cfg.taxa_ident_thres * self.taxonomy.get_rank_seq_count(rank_id):
144                      dup_ranks += [rank_id]
145                if len(dup_ranks) > 1:
146                    prefix = "__TAXCLUSTER%d__" % (len(self.merged_ranks) + 1)
147                    merged_rank_id = self.taxonomy.merge_ranks(dup_ranks, prefix)
148                    self.merged_ranks[merged_rank_id] = dup_ranks
149
150        if self.verbose:
151            merged_count = 0
152            for merged_rank_id, dup_ranks in self.merged_ranks.iteritems():
153                dup_ranks_str = "\n".join([Taxonomy.rank_uid_to_lineage_str(rank_id) for rank_id in dup_ranks])
154                self.cfg.log.warning("\nWARNING: Following taxa share >%.0f%% indentical sequences und thus considered indistinguishable:\n%s", self.cfg.taxa_ident_thres * 100, dup_ranks_str)
155                merged_rank_str = Taxonomy.rank_uid_to_lineage_str(merged_rank_id)
156                self.cfg.log.warning("For the purpose of mislabels identification, they were merged into one taxon:\n%s\n", merged_rank_str)
157                merged_count += len(dup_ranks)
158           
159            if merged_count > 0:
160                self.cfg.log.warning("WARNING: %d indistinguishable taxa have been merged into %d clusters.\n", merged_count, len(self.merged_ranks))
161
162        return self.merged_ranks
163           
164    def check_tax_disbalance(self):
165        # make sure we don't taxonomy "irregularities" (more than 7 ranks or missing ranks in the middle)
166        action = self.cfg.wrong_rank_count
167        if action != "ignore":
168            autofix = action == "autofix"
169            errs = self.taxonomy.check_for_disbalance(autofix)
170            if len(errs) > 0:
171                if action == "autofix":
172                    print "WARNING: %d sequences with invalid annotation (missing/redundant ranks) found and were fixed as follows:\n" % len(errs)
173                    for err in errs:
174                        print "Original:   %s\t%s"   % (err[0], err[1])
175                        print "Fixed as:   %s\t%s\n" % (err[0], err[2])
176                elif action == "skip":
177                    print "WARNING: Following %d sequences with invalid annotation (missing/redundant ranks) were skipped:\n" % len(errs)
178                    for err in errs:
179                        self.taxonomy.remove_seq(err[0])
180                        print "%s\t%s" % err
181                else:  # abort
182                    print "ERROR: %d sequences with invalid annotation (missing/redundant ranks) found:\n" % len(errs)
183                    for err in errs:
184                        print "%s\t%s" % err
185                    print "\nPlease fix them manually (add/remove ranks) and run the pipeline again (or use -wrong-rank-count autofix option)"
186                    print "NOTE: Only standard 7-level taxonomies are supported at the moment. Although missing trailing ranks (e.g. species) are allowed,"
187                    print "missing intermediate ranks (e.g. family) or sublevels (e.g. suborder) are not!\n"
188                    self.cfg.exit_user_error()
189       
190    def check_tax_duplicates(self):
191        # check for duplicate rank names
192        action = self.cfg.dup_rank_names
193        if action != "ignore":
194            autofix = action == "autofix"
195            dups = self.taxonomy.check_for_duplicates(autofix)
196            if len(dups) > 0:
197                if action == "autofix":
198                    print "WARNING: %d sequences with duplicate rank names found and were renamed as follows:\n" % len(dups)
199                    for dup in dups:
200                        print "Original:    %s\t%s"   %  (dup[0], dup[1])
201                        print "Duplicate:   %s\t%s"   %  (dup[2], dup[3])
202                        print "Renamed to:  %s\t%s\n" %  (dup[2], dup[4])
203                elif action == "skip":
204                    print "WARNING: Following %d sequences with duplicate rank names were skipped:\n" % len(dups)
205                    for dup in dups:
206                        self.taxonomy.remove_seq(dup[2])
207                        print "%s\t%s\n" % (dup[2], dup[3])
208                else:  # abort
209                    print "ERROR: %d sequences with duplicate rank names found:\n" % len(dups)
210                    for dup in dups:
211                        print "%s\t%s\n%s\t%s\n" % dup
212                    print "Please fix (rename) them and run the pipeline again (or use -dup-rank-names autofix option)" 
213                    self.cfg.exit_user_error()
214
215class RefTreeBuilder:
216    def __init__(self, config): 
217        self.cfg = config
218        self.mfresolv_job_name = self.cfg.subst_name("mfresolv_%NAME%")
219        self.epalbl_job_name = self.cfg.subst_name("epalbl_%NAME%")
220        self.optmod_job_name = self.cfg.subst_name("optmod_%NAME%")
221        self.raxml_wrapper = RaxmlWrapper(config)
222       
223        self.outgr_fname = self.cfg.tmp_fname("%NAME%_outgr.tre")
224        self.reftree_mfu_fname = self.cfg.tmp_fname("%NAME%_mfu.tre")
225        self.reftree_bfu_fname = self.cfg.tmp_fname("%NAME%_bfu.tre")
226        self.optmod_fname = self.cfg.tmp_fname("%NAME%.opt")
227        self.lblalign_fname = self.cfg.tmp_fname("%NAME%_lblq.fa")
228        self.reftree_lbl_fname = self.cfg.tmp_fname("%NAME%_lbl.tre")
229        self.reftree_tax_fname = self.cfg.tmp_fname("%NAME%_tax.tre")
230        self.brmap_fname = self.cfg.tmp_fname("%NAME%_map.txt")
231
232    def load_alignment(self):
233        in_file = self.cfg.align_fname
234        self.input_seqs = None
235        formats = ["fasta", "phylip_relaxed", "iphylip_relaxed", "phylip", "iphylip"]
236        for fmt in formats:
237            try:
238                self.input_seqs = SeqGroup(sequences=in_file, format = fmt)
239                break
240            except:
241                self.cfg.log.debug("Guessing input format: not " + fmt)
242        if self.input_seqs == None:
243            self.cfg.exit_user_error("Invalid input file format: %s\nThe supported input formats are fasta and phylip" % in_file)
244           
245    def validate_taxonomy(self):
246        self.input_validator = InputValidator(self.cfg, self.taxonomy, self.input_seqs)
247        self.input_validator.validate()
248       
249    def build_multif_tree(self):
250        c = self.cfg
251       
252        tb = TaxTreeBuilder(c, self.taxonomy)
253        (t, ids) = tb.build(c.reftree_min_rank, c.reftree_max_seqs_per_leaf, c.reftree_clades_to_include, c.reftree_clades_to_ignore)
254        self.reftree_ids = frozenset(ids)
255        self.reftree_size = len(ids)
256        self.reftree_multif = t
257
258        # IMPORTANT: select GAMMA or CAT model based on tree size!               
259        self.cfg.resolve_auto_settings(self.reftree_size)
260
261        if self.cfg.debug:
262            refseq_fname = self.cfg.tmp_fname("%NAME%_seq_ids.txt")
263            # list of sequence ids which comprise the reference tree
264            with open(refseq_fname, "w") as f:
265                for sid in ids:
266                    f.write("%s\n" % sid)
267
268            # original tree with taxonomic ranks as internal node labels
269            reftax_fname = self.cfg.tmp_fname("%NAME%_mfu_tax.tre")
270            t.write(outfile=reftax_fname, format=8)
271        #    t.show()
272
273    def export_ref_alignment(self):
274        """This function transforms the input alignment in the following way:
275           1. Filter out sequences which are not part of the reference tree
276           2. Add sequence name prefix (r_)"""
277       
278        self.refalign_fname = self.cfg.tmp_fname("%NAME%_matrix.afa")
279        with open(self.refalign_fname, "w") as fout:
280            for name, seq, comment, sid in self.input_seqs.iter_entries():
281                seq_name = EpacConfig.REF_SEQ_PREFIX + name
282                if seq_name in self.input_validator.corr_seqid:
283                  seq_name = self.input_validator.corr_seqid[seq_name]
284                if seq_name in self.reftree_ids:
285                    fout.write(">" + seq_name + "\n" + seq + "\n")
286
287        # we do not need the original alignment anymore, so free its memory
288        self.input_seqs = None
289
290    def export_ref_taxonomy(self):
291        self.taxonomy_map = {}
292       
293        for sid, ranks in self.taxonomy.iteritems():
294            if sid in self.reftree_ids:
295                self.taxonomy_map[sid] = ranks
296           
297        if self.cfg.debug:
298            tax_fname = self.cfg.tmp_fname("%NAME%_tax.txt")
299            with open(tax_fname, "w") as fout:
300                for sid, ranks in self.taxonomy_map.iteritems():
301                    ranks_str = self.taxonomy.seq_lineage_str(sid) 
302                    fout.write(sid + "\t" + ranks_str + "\n")   
303
304    def save_rooting(self):
305        rt = self.reftree_multif
306
307        tax_map = self.taxonomy.get_map()
308        self.taxtree_helper = TaxTreeHelper(self.cfg, tax_map)
309        self.taxtree_helper.set_mf_rooted_tree(rt)
310        outgr = self.taxtree_helper.get_outgroup()
311        outgr_size = len(outgr.get_leaves())
312        outgr.write(outfile=self.outgr_fname, format=9)
313        self.reftree_outgroup = outgr
314        self.cfg.log.debug("Outgroup for rooting was saved to: %s, outgroup size: %d", self.outgr_fname, outgr_size)
315           
316        # remove unifurcation at the root
317        if len(rt.children) == 1:
318            rt = rt.children[0]
319       
320        # now we can safely unroot the tree and remove internal node labels to make it suitable for raxml
321        rt.write(outfile=self.reftree_mfu_fname, format=9)
322
323    # RAxML call to convert multifurcating tree to the strictly bifurcating one
324    def resolve_multif(self):
325        self.cfg.log.debug("\nReducing the alignment: \n")
326        self.reduced_refalign_fname = self.raxml_wrapper.reduce_alignment(self.refalign_fname)
327       
328        self.cfg.log.debug("\nConstrained ML inference: \n")
329        raxml_params = ["-s", self.reduced_refalign_fname, "-g", self.reftree_mfu_fname, "--no-seq-check", "-N", str(self.cfg.rep_num)] 
330        if self.cfg.mfresolv_method  == "fast":
331            raxml_params += ["-D"]
332        elif self.cfg.mfresolv_method  == "ultrafast":
333            raxml_params += ["-f", "e"]
334        if self.cfg.restart and self.raxml_wrapper.result_exists(self.mfresolv_job_name):
335            self.invocation_raxml_multif = self.raxml_wrapper.get_invocation_str(self.mfresolv_job_name)
336            self.cfg.log.debug("\nUsing existing ML tree found in: %s\n", self.raxml_wrapper.result_fname(self.mfresolv_job_name))
337        else:
338            self.invocation_raxml_multif = self.raxml_wrapper.run(self.mfresolv_job_name, raxml_params)
339#            self.invocation_raxml_multif = self.raxml_wrapper.run_multiple(self.mfresolv_job_name, raxml_params, self.cfg.rep_num)
340            if self.cfg.mfresolv_method  == "ultrafast":
341              self.raxml_wrapper.copy_result_tree(self.mfresolv_job_name, self.raxml_wrapper.besttree_fname(self.mfresolv_job_name))
342             
343        if self.raxml_wrapper.besttree_exists(self.mfresolv_job_name):       
344            if not self.cfg.reopt_model:
345                self.raxml_wrapper.copy_best_tree(self.mfresolv_job_name, self.reftree_bfu_fname)
346                self.raxml_wrapper.copy_optmod_params(self.mfresolv_job_name, self.optmod_fname)
347                self.invocation_raxml_optmod = ""
348                job_name = self.mfresolv_job_name
349            else:
350                bfu_fname = self.raxml_wrapper.besttree_fname(self.mfresolv_job_name)
351                job_name = self.optmod_job_name
352
353                # RAxML call to optimize model parameters and write them down to the binary model file
354                self.cfg.log.debug("\nOptimizing model parameters: \n")
355                raxml_params = ["-f", "e", "-s", self.reduced_refalign_fname, "-t", bfu_fname, "--no-seq-check"]
356                if self.cfg.raxml_model.startswith("GTRCAT") and not self.cfg.compress_patterns:
357                    raxml_params +=  ["-H"]
358                if self.cfg.restart and self.raxml_wrapper.result_exists(self.optmod_job_name):
359                    self.invocation_raxml_optmod = self.raxml_wrapper.get_invocation_str(self.optmod_job_name)
360                    self.cfg.log.debug("\nUsing existing optimized tree and parameters found in: %s\n", self.raxml_wrapper.result_fname(self.optmod_job_name))
361                else:
362                    self.invocation_raxml_optmod = self.raxml_wrapper.run(self.optmod_job_name, raxml_params)
363                if self.raxml_wrapper.result_exists(self.optmod_job_name):
364                    self.raxml_wrapper.copy_result_tree(self.optmod_job_name, self.reftree_bfu_fname)
365                    self.raxml_wrapper.copy_optmod_params(self.optmod_job_name, self.optmod_fname)
366                else:
367                    errmsg = "RAxML run failed (model optimization), please examine the log for details: %s" \
368                            % self.raxml_wrapper.make_raxml_fname("output", self.optmod_job_name)
369                    self.cfg.exit_fatal_error(errmsg)
370                   
371            if self.cfg.raxml_model.startswith("GTRCAT"):
372              mod_name = "CAT"
373            else:
374              mod_name = "GAMMA" 
375            self.reftree_loglh = self.raxml_wrapper.get_tree_lh(job_name, mod_name)
376            self.cfg.log.debug("\n%s-based logLH of the reference tree: %f\n" % (mod_name, self.reftree_loglh))
377
378        else:
379            errmsg = "RAxML run failed (mutlifurcation resolution), please examine the log for details: %s" \
380                    % self.raxml_wrapper.make_raxml_fname("output", self.mfresolv_job_name)
381            self.cfg.exit_fatal_error(errmsg)
382           
383    def load_reduced_refalign(self):
384        formats = ["fasta", "phylip_relaxed"]
385        for fmt in formats:
386            try:
387                self.reduced_refalign_seqs = SeqGroup(sequences=self.reduced_refalign_fname, format = fmt)
388                break
389            except:
390                pass
391        if self.reduced_refalign_seqs == None:
392            errmsg = "FATAL ERROR: Invalid input file format in %s! (load_reduced_refalign)" % self.reduced_refalign_fname
393            self.cfg.exit_fatal_error(errmsg)
394   
395    # dummy EPA run to label the branches of the reference tree, which we need to build a mapping to tax ranks   
396    def epa_branch_labeling(self):
397        # create alignment with dummy query seq
398        self.refalign_width = len(self.reduced_refalign_seqs.get_seqbyid(0))
399        self.reduced_refalign_seqs.write(format="fasta", outfile=self.lblalign_fname)
400       
401        with open(self.lblalign_fname, "a") as fout:
402            fout.write(">" + "DUMMY131313" + "\n")       
403            fout.write("A"*self.refalign_width + "\n")       
404       
405        # TODO always load model regardless of the config file settings?
406        epa_result = self.raxml_wrapper.run_epa(self.epalbl_job_name, self.lblalign_fname, self.reftree_bfu_fname, self.optmod_fname, mode="epa_mp")
407        self.reftree_lbl_str = epa_result.get_std_newick_tree()
408        self.raxml_version = epa_result.get_raxml_version()
409        self.invocation_raxml_epalbl = epa_result.get_raxml_invocation()
410
411        if not self.raxml_wrapper.epa_result_exists(self.epalbl_job_name):       
412            errmsg = "RAxML EPA run failed, please examine the log for details: %s" \
413                    % self.raxml_wrapper.make_raxml_fname("output", self.epalbl_job_name)
414            self.cfg.exit_fatal_error(errmsg)
415
416    def epa_post_process(self):
417        lbl_tree = Tree(self.reftree_lbl_str)
418        self.taxtree_helper.set_bf_unrooted_tree(lbl_tree)
419        self.reftree_tax = self.taxtree_helper.get_tax_tree()
420        self.bid_ranks_map = self.taxtree_helper.get_bid_taxonomy_map()
421       
422        if self.cfg.debug:
423            self.reftree_tax.write(outfile=self.reftree_tax_fname, format=3)
424            with open(self.reftree_lbl_fname, "w") as outf:
425                outf.write(self.reftree_lbl_str)
426            with open(self.brmap_fname, "w") as outf:
427                for bid, br_rec in self.bid_ranks_map.iteritems():
428                    outf.write("%s\t%s\t%d\t%f\n" % (bid, br_rec[0], br_rec[1], br_rec[2]))
429
430    def calc_node_heights(self):
431        """Calculate node heights on the reference tree (used to define branch-length cutoff during classification step)
432           Algorithm is as follows:
433           Tip node or node resolved to Species level: height = 1
434           Inner node resolved to Genus or above:      height = min(left_height, right_height) + 1
435         """
436        nh_map = {}
437        dummy_added = False
438        for node in self.reftree_tax.traverse("postorder"):
439            if not node.is_root():
440                if not hasattr(node, "B"):               
441                    # In a rooted tree, there is always one more node/branch than in unrooted one
442                    # That's why one branch will be always not EPA-labelled after the rooting
443                    if not dummy_added: 
444                        node.B = "DDD"
445                        dummy_added = True
446                        species_rank = Taxonomy.EMPTY_RANK
447                    else:
448                        errmsg = "FATAL ERROR: More than one tree branch without EPA label (calc_node_heights)"
449                        self.cfg.exit_fatal_error(errmsg)
450                else:
451                    species_rank = self.bid_ranks_map[node.B][-1]
452                bid = node.B
453                if node.is_leaf() or species_rank != Taxonomy.EMPTY_RANK:
454                    nh_map[bid] = 1
455                else:
456                    lchild = node.children[0]
457                    rchild = node.children[1]
458                    nh_map[bid] = min(nh_map[lchild.B], nh_map[rchild.B]) + 1
459
460        # remove heights for dummy nodes, since there won't be any placements on them
461        if dummy_added:
462            del nh_map["DDD"]
463           
464        self.node_height_map = nh_map
465
466    def __get_all_rank_names(self, root):
467        rnames = set([])
468        for node in root.traverse("postorder"):
469            ranks = node.ranks
470            for rk in ranks:
471                rnames.add(rk)
472        return rnames
473
474    def mono_index(self):
475        """This method will calculate monophyly index by looking at the left and right hand side of the tree"""
476        children = self.reftree_tax.children
477        if len(children) == 1:
478            while len(children) == 1:
479                children = children[0].children
480        if len(children) == 2:
481            left = children[0]
482            right =children[1]
483            lset = self.__get_all_rank_names(left)
484            rset = self.__get_all_rank_names(right)
485            iset = lset & rset
486            return iset
487        else:
488            print("Error: input tree not birfurcating")
489            return set([])
490
491    def build_hmm_profile(self, json_builder):
492        print "Building the HMMER profile...\n"
493
494        # this stupid workaround is needed because RAxML outputs the reduced
495        # alignment in relaxed PHYLIP format, which is not supported by HMMER
496        refalign_fasta = self.cfg.tmp_fname("%NAME%_ref_reduced.fa")
497        self.reduced_refalign_seqs.write(outfile=refalign_fasta)
498
499        hmm = hmmer(self.cfg, refalign_fasta)
500        fprofile = hmm.build_hmm_profile()
501
502        json_builder.set_hmm_profile(fprofile)
503       
504    def write_json(self):
505        jw = RefJsonBuilder()
506
507        jw.set_branch_tax_map(self.bid_ranks_map)
508        jw.set_tree(self.reftree_lbl_str)
509        jw.set_outgroup(self.reftree_outgroup)
510        jw.set_ratehet_model(self.cfg.raxml_model)
511        jw.set_tax_tree(self.reftree_multif)
512        jw.set_pattern_compression(self.cfg.compress_patterns)
513        jw.set_taxcode(self.cfg.taxcode_name)
514       
515        jw.set_merged_ranks_map(self.input_validator.merged_ranks)
516        corr_ranks_reverse = dict((reversed(item) for item in self.input_validator.corr_ranks.items()))
517        jw.set_corr_ranks_map(corr_ranks_reverse)
518        corr_seqid_reverse = dict((reversed(item) for item in self.input_validator.corr_seqid.items()))
519        jw.set_corr_seqid_map(corr_seqid_reverse)
520
521        mdata = { "ref_tree_size": self.reftree_size, 
522                  "ref_alignment_width": self.refalign_width,
523                  "raxml_version": self.raxml_version,
524                  "timestamp": str(datetime.datetime.now()),
525                  "invocation_epac": self.invocation_epac,
526                  "invocation_raxml_multif": self.invocation_raxml_multif,
527                  "invocation_raxml_optmod": self.invocation_raxml_optmod,
528                  "invocation_raxml_epalbl": self.invocation_raxml_epalbl,
529                  "reftree_loglh": self.reftree_loglh
530                }
531        jw.set_metadata(mdata)
532
533        seqs = self.reduced_refalign_seqs.get_entries()   
534        jw.set_sequences(seqs)
535       
536        if not self.cfg.no_hmmer:
537            self.build_hmm_profile(jw)
538
539        orig_tax = self.taxonomy_map
540        jw.set_origin_taxonomy(orig_tax)
541       
542        self.cfg.log.debug("Calculating the speciation rate...\n")
543        tp = tree_param(tree = self.reftree_lbl_str, origin_taxonomy = orig_tax)
544        jw.set_rate(tp.get_speciation_rate_fast())
545        jw.set_nodes_height(self.node_height_map)
546       
547        jw.set_binary_model(self.optmod_fname)
548       
549        self.cfg.log.debug("Writing down the reference file...\n")
550        jw.dump(self.cfg.refjson_fname)
551
552    # top-level function to build a reference tree   
553    def build_ref_tree(self):
554        self.cfg.log.info("=> Loading taxonomy from file: %s ...\n" , self.cfg.taxonomy_fname)
555        self.taxonomy = Taxonomy(prefix=EpacConfig.REF_SEQ_PREFIX, tax_fname=self.cfg.taxonomy_fname)
556        self.cfg.log.info("==> Loading reference alignment from file: %s ...\n" , self.cfg.align_fname)
557        self.load_alignment()
558        self.cfg.log.info("===> Validating taxonomy and alignment ...\n")
559        self.validate_taxonomy()
560        self.cfg.log.info("====> Building a multifurcating tree from taxonomy with %d seqs ...\n" , self.taxonomy.seq_count())
561        self.build_multif_tree()
562        self.cfg.log.info("=====> Building the reference alignment ...\n")
563        self.export_ref_alignment()
564        self.export_ref_taxonomy()
565        self.cfg.log.info("======> Saving the outgroup for later re-rooting ...\n")
566        self.save_rooting()
567        self.cfg.log.info("=======> Resolving multifurcation: choosing the best topology from %d independent RAxML runs ...\n" % self.cfg.rep_num)
568        self.resolve_multif()
569        self.load_reduced_refalign()
570        self.cfg.log.info("========> Calling RAxML-EPA to obtain branch labels ...\n")
571        self.epa_branch_labeling()
572        self.cfg.log.info("=========> Post-processing the EPA tree (re-rooting, taxonomic labeling etc.) ...\n")
573        self.epa_post_process()
574        self.calc_node_heights()
575       
576        self.cfg.log.debug("\n==========> Checking branch labels ...")
577        self.cfg.log.debug("shared rank names before training: %s", repr(self.taxonomy.get_common_ranks()))
578        self.cfg.log.debug("shared rank names after  training: %s\n", repr(self.mono_index()))
579       
580        self.cfg.log.info("==========> Saving the reference JSON file: %s\n" % self.cfg.refjson_fname)
581        self.write_json()
582
583def parse_args():
584    parser = ArgumentParser(description="Build a reference tree for EPA taxonomic placement.",
585    epilog="Example: ./epa_trainer.py -t example/training_tax.txt -s example/training_seq.fa -n myref",
586    formatter_class=RawTextHelpFormatter)
587    parser.add_argument("-t", dest="taxonomy_fname", required=True,
588            help="""Reference taxonomy file.""")
589    parser.add_argument("-s", dest="align_fname", required=True,
590            help="""Reference alignment file. Sequences must be aligned, their IDs must correspond to those
591in taxonomy file.""")
592    parser.add_argument("-r", dest="ref_fname",
593            help="""Reference output file. It will contain reference alignment, phylogenetic tree and other
594information needed for taxonomic placement of query sequences.""")
595    parser.add_argument("-T", dest="num_threads", type=int, default=None,
596            help="""Specify the number of CPUs.  Default: %d""" % multiprocessing.cpu_count())           
597    parser.add_argument("-c", dest="config_fname", default=None,
598            help="""Config file name.""")
599    parser.add_argument("-o", dest="output_dir", default=".",
600            help="""Output directory""")
601    parser.add_argument("-n", dest="output_name", default=None,
602            help="""Run name.""")
603    parser.add_argument("-p", dest="rand_seed", type=int, default=None,
604            help="""Random seed to be used with RAxML. Default: current system time.""")
605    parser.add_argument("-m", dest="mfresolv_method", choices=["thorough", "fast", "ultrafast"],
606            default="thorough", help="""Method of multifurcation resolution:
607            thorough    use stardard constrainted RAxML tree search (default)
608            fast        use RF distance as search convergence criterion (RAxML -D option)
609            ultrafast   optimize model+branch lengths only (RAxML -f e option)""")
610    parser.add_argument("-N", dest="rep_num", type=int, default=1, 
611            help="""Number of RAxML tree searches (with distinct random seeds). Default: 1""")
612    parser.add_argument("-x", dest="taxcode_name", choices=["bac", "bot", "zoo", "vir"], type = str.lower,
613            help="""Taxonomic code: BAC(teriological), BOT(anical), ZOO(logical), VIR(ological)""")
614    parser.add_argument("-R", dest="restart", action="store_true",
615            help="""Resume execution after a premature termination (e.g., due to expired job time limit).
616Run name of the previous (terminated) job must be specified via -n option.""")
617    parser.add_argument("-v", dest="verbose", action="store_true",
618            help="""Print additional info messages to the console.""")
619    parser.add_argument("-debug", dest="debug", action="store_true",
620            help="""Debug mode, intermediate files will not be cleaned up.""")
621    parser.add_argument("-no-hmmer", dest="no_hmmer", action="store_true",
622            help="""Do not build HMMER profile.""")
623    parser.add_argument("-dup-rank-names", dest="dup_rank_names", choices=["ignore", "abort", "skip", "autofix"],
624            default="ignore", help="""Action to be performed if different ranks with same name are found:
625            ignore      do nothing
626            abort       report duplicates and exit
627            skip        skip the corresponding sequences (exlude from reference)
628            autofix     make name unique by concatenating it with the parent rank's name""")
629    parser.add_argument("-wrong-rank-count", dest="wrong_rank_count", choices=["ignore", "abort", "skip", "autofix"],
630            default="ignore", help="""Action to be performed if lineage has less (more) than 7 ranks
631            ignore      do nothing
632            abort       report duplicates and exit
633            skip        skip the corresponding sequences (exlude from reference)
634            autofix     try to guess wich ranks should be added or removed (use with caution!)""")
635    parser.add_argument("-tmpdir", dest="temp_dir", default=None,
636            help="""Directory for temporary files.""")
637   
638    if len(sys.argv) < 4:
639        parser.print_help()
640        sys.exit()
641
642    args = parser.parse_args()
643   
644    return args
645 
646def check_args(args):
647    #check if taxonomy file exists
648    if not os.path.isfile(args.taxonomy_fname):
649        print "ERROR: Taxonomy file not found: %s" % args.taxonomy_fname
650        sys.exit()
651
652    #check if alignment file exists
653    if not os.path.isfile(args.align_fname):
654        print "ERROR: Alignment file not found: %s" % args.align_fname
655        sys.exit()
656       
657    if not args.output_name:
658        args.output_name = args.align_fname
659
660    if not args.ref_fname:
661        args.ref_fname = "%s.refjson" % args.output_name
662   
663    if args.output_dir and not os.path.dirname(args.ref_fname):
664        args.ref_fname = os.path.join(args.output_dir, args.ref_fname)
665
666    #check if reference json file already exists
667    if os.path.isfile(args.ref_fname):
668        print "ERROR: Reference tree file already exists: %s" % args.ref_fname
669        print "Please delete it explicitely if you want to overwrite."
670        sys.exit()
671   
672    #check if reference file can be created
673    try:
674        f = open(args.ref_fname, "w")
675        f.close()
676        os.remove(args.ref_fname)
677    except:
678        print "ERROR: cannot create output file: %s" % args.ref_fname
679        print "Please check if directory %s exists and you have write permissions for it." % os.path.split(os.path.abspath(args.ref_fname))[0]
680        sys.exit()
681       
682    if args.rep_num < 1 or args.rep_num > 1000:
683        print "ERROR: Number of RAxML runs must be between 1 and 1000."
684        sys.exit()
685
686def which(program, custom_path=[]):
687    def is_exe(fpath):
688        return os.path.isfile(fpath) and os.access(fpath, os.X_OK)
689
690    fpath, fname = os.path.split(program)
691    if fpath:
692        if is_exe(program):
693            return program
694    else:
695        path_list = custom_path
696        path_list += os.environ["PATH"].split(os.pathsep)
697        for path in path_list:
698            path = path.strip('"')
699            exe_file = os.path.join(path, program)
700            if is_exe(exe_file):
701                return exe_file
702
703    return None       
704   
705def check_dep(config):           
706    if not config.no_hmmer:
707        if not which("hmmalign", [config.hmmer_home]):
708            print "ERROR: HMMER not found!"
709            print "Please either specify path to HMMER executables in the config file" 
710            print "or call this script with -no-hmmer option to skip building HMMER profile." 
711            config.exit_user_error()
712           
713def run_trainer(config):
714    check_dep(config)
715    builder = RefTreeBuilder(config)
716    builder.invocation_epac = " ".join(sys.argv)
717    builder.build_ref_tree()
718       
719# -------
720# MAIN
721# -------
722if __name__ == "__main__":
723    args = parse_args()
724    check_args(args)
725    config = EpacTrainerConfig(args)
726
727    print ""
728    config.print_version("SATIVA-trainer")
729
730    start_time = time.time()
731
732    run_trainer(config)
733    config.clean_tempdir()
734
735    config.log.info("Reference JSON was saved to: %s", os.path.abspath(config.refjson_fname))
736    config.log.info("Execution log was saved to: %s\n", os.path.abspath(config.log_fname))
737
738    elapsed_time = time.time() - start_time
739    config.log.info("Training completed successfully, elapsed time: %.0f seconds\n", elapsed_time)
Note: See TracBrowser for help on using the repository browser.