source: branches/ali/GDE/SATIVA/sativa/epac/config.py

Last change on this file was 14661, checked in by akozlov, 9 years ago
File size: 14.8 KB
Line 
1#!/usr/bin/env python
2
3import os
4import sys
5import glob
6import shutil
7import datetime
8import time
9import logging
10import random
11import multiprocessing
12import ConfigParser
13
14from epac.version import SATIVA_BUILD,SATIVA_RELEASE_DATE,SATIVA_RAXML_VER
15
16class DefaultedConfigParser(ConfigParser.SafeConfigParser):
17    def get_param(self, section, option, ctype=str, default=None):
18        if default is None:
19            ret = self.get(section, option)
20        else:
21            confdict = self.__dict__.get('_sections')
22            sec = confdict.get(section)
23            if sec:
24                if ctype == bool:
25                    ret = self.getboolean(section, option)
26                else:
27                    ret = sec.get(option, default)
28            else:
29                ret = default
30        return ctype(ret)
31
32class EpacConfig:
33    # this prefix will be added to every sequence name in reference to prevent
34    # name clashes with query sequences, which are coded with numbers
35    REF_SEQ_PREFIX = "r_";
36    QUERY_SEQ_PREFIX = "q_";
37
38    CAT_LOWER_THRES   = 100
39    CAT_GAMMA_THRES   = 500
40    GAMMA_UPPER_THRES = 10000
41    EPA_HEUR_THRES    = 1000
42   
43    SATIVA_INFO = \
44    """%s %s, released on %s. Last version: https://github.com/amkozlov/sativa
45By A.Kozlov and J.Zhang, the Exelixis Lab. Based on RAxML %s by A.Stamatakis.\n"""\
46    % ("%s", SATIVA_BUILD, SATIVA_RELEASE_DATE, SATIVA_RAXML_VER)
47   
48   
49    @staticmethod
50    def strip_prefix(seq_name, prefix):
51        if seq_name.startswith(prefix):
52            plen = len(prefix)
53            return seq_name[plen:]
54        else:
55            return seq_name
56       
57    @staticmethod
58    def strip_ref_prefix(seq_name):
59        return EpacConfig.strip_prefix(seq_name, EpacConfig.REF_SEQ_PREFIX)
60       
61    @staticmethod
62    def strip_query_prefix(seq_name):
63        return EpacConfig.strip_prefix(seq_name, EpacConfig.QUERY_SEQ_PREFIX)
64
65    def __init__(self, args=None): 
66        self.basepath = os.path.dirname(os.path.abspath(__file__))
67        self.epac_home = os.path.abspath(os.path.join(self.basepath, os.pardir)) + "/"
68
69        self.set_defaults()
70
71        if not args:
72            return
73       
74        self.verbose = args.verbose
75        self.debug = args.debug
76        self.restart = args.restart
77        self.refjson_fname = args.ref_fname
78               
79        self.rand_seed = args.rand_seed
80       
81        timestamp = "%d" % (time.time()*1000) #datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
82        if args.output_name:
83            self.name = args.output_name
84        else:
85            self.name = str(random.randint(1, 99999))
86
87        self.output_dir = args.output_dir
88
89        if args.temp_dir:
90            self.temp_dir = args.temp_dir
91        else:
92            self.temp_dir = os.path.join(self.epac_home, "tmp")
93
94        if self.restart:
95            tmpdirs = glob.glob(os.path.join(self.temp_dir, self.name + "_*"))
96            if len(tmpdirs) > 0:
97                tmpdirs.sort(key=os.path.getmtime, reverse=True)
98                self.temp_dir = tmpdirs[0]
99            else:
100                self.exit_user_error("ERROR: Cannot resume execution: temp directory for the previous run not found in %s" % self.temp_dir)
101        else:
102            temp_name = self.name + "_" + timestamp
103            self.temp_dir = os.path.join(self.temp_dir, temp_name)
104           
105        self.raxml_outdir = self.temp_dir
106        self.raxml_outdir_abs = os.path.abspath(self.raxml_outdir)
107
108        self.init_logger()
109       
110        if not args.config_fname:
111            args.config_fname = os.path.join(self.epac_home, "sativa.cfg")
112       
113        self.config_path = os.path.dirname(os.path.abspath(args.config_fname))
114        self.read_from_file(args.config_fname)
115
116        # command line setting has preference over config file and default
117        if args.num_threads:
118            self.num_threads = args.num_threads       
119        self.check_raxml()   
120       
121        if not self.restart:
122            os.mkdir(self.temp_dir)
123
124    def set_defaults(self):
125        self.muscle_home = self.epac_home + "/epac/bin" + "/"
126        self.hmmer_home = self.epac_home + "/epac/bin" + "/"
127        self.raxml_home = self.epac_home + "/epac/bin" + "/"
128        self.raxml_exec = "raxmlHPC-PTHREADS-SSE3"
129        self.raxml_model = "AUTO"
130        self.raxml_remote_host = ""
131        self.raxml_remote_call = False       
132        self.run_on_cluster = False
133        self.cluster_epac_home = self.epac_home
134        self.cluster_qsub_script = ""
135        self.epa_load_optmod = True
136        self.epa_use_heuristic = "AUTO"
137        self.epa_heur_rate = 0.01
138        self.min_confidence = 0.2
139        self.num_threads = multiprocessing.cpu_count()
140        self.compress_patterns = False
141        self.use_bfgs = False
142        self.save_memory = False
143        self.taxa_ident_thres = 0.6
144        self.debug = False
145        self.restart = False
146        self.verbose = False
147        self.log = logging.getLogger('epac')
148       
149    def init_logger(self):
150        self.log_fname = self.out_fname("%NAME%.log")
151        if self.verbose or self.debug:
152           log_lvl = logging.DEBUG
153        else:
154           log_lvl = logging.INFO
155
156        # configure logger object
157        self.log.setLevel(logging.DEBUG)
158        formatter = logging.Formatter('%(message)s')
159
160        # add console handler
161        ch = logging.StreamHandler()
162        ch.setLevel(log_lvl)
163        ch.setFormatter(formatter)       
164        self.log.addHandler(ch)
165
166        # add file handler
167        if self.restart:
168            logf_mode = "a"
169        else:
170            logf_mode = "w"
171        fh = logging.FileHandler(self.log_fname, mode=logf_mode)
172        fh.setLevel(log_lvl)
173        fh.setFormatter(formatter)       
174        self.log.addHandler(fh)
175
176    def resolve_auto_settings(self, tree_size):
177        if self.raxml_model == "AUTO":
178            if tree_size > EpacConfig.CAT_GAMMA_THRES:
179                self.raxml_model = "GTRCAT"
180            else:
181                self.raxml_model = "GTRGAMMA"
182        elif self.raxml_model == "GTRCAT" and tree_size < EpacConfig.CAT_LOWER_THRES:
183            print "WARNING: You're using GTRCAT model on a very small dataset (%d taxa), which might lead to unreliable results!" % tree_size
184            print "Please consider switching to GTRGAMMA model.\n"
185        elif self.raxml_model == "GTRGAMMA" and tree_size > EpacConfig.GAMMA_UPPER_THRES:
186            print "WARNING: You're using GTRGAMMA model on a very large dataset (%d taxa), which might lead to numerical issues!" % tree_size
187            print "In case of problems, please consider switching to GTRCAT model.\n"
188
189        if self.epa_use_heuristic == "AUTO": 
190            if tree_size > EpacConfig.EPA_HEUR_THRES:
191                self.epa_use_heuristic = "TRUE"
192                self.epa_heur_rate = 0.5 * float(EpacConfig.EPA_HEUR_THRES) / tree_size
193            else:
194                self.epa_use_heuristic = "FALSE"
195
196    def resolve_relative_path(self, rpath):
197        if rpath.startswith("/"):
198            return rpath
199        else:
200            return os.path.join(self.config_path, rpath)
201       
202    def check_raxml(self):
203        self.raxml_exec_full = self.raxml_home + self.raxml_exec
204        if self.raxml_remote_host in ["", "localhost"]:
205            self.raxml_remote_call = False
206            # if raxml_home is empty, raxml binary must be on PATH; otherwise check if file exists
207            if self.raxml_home: 
208                if not os.path.isdir(self.raxml_home):
209                    self.exit_user_error("RAxML home directory not found: %s" % self.raxml_home)
210                elif not os.path.isfile(self.raxml_exec_full):
211                    self.exit_user_error("RAxML executable not found: %s" % self.raxml_exec_full)
212        else:
213            self.raxml_remote_call = True
214        self.raxml_cmd = [self.raxml_exec_full, "-w", self.raxml_outdir_abs]
215        if self.num_threads > 1:
216            self.raxml_cmd += ["-T", str(self.num_threads)]
217       
218    def read_from_file(self, config_fname):
219        if not os.path.exists(config_fname):
220            self.exit_user_error("ERROR: Config file not found: %s" % config_fname)
221
222        parser = DefaultedConfigParser() #ConfigParser.SafeConfigParser()
223        parser.read(config_fname)
224       
225        self.raxml_home = parser.get_param("raxml", "raxml_home", str, self.raxml_home)
226        if self.raxml_home:
227            self.raxml_home = self.resolve_relative_path(self.raxml_home + "/")
228        self.raxml_exec = parser.get_param("raxml", "raxml_exec", str, self.raxml_exec)
229        self.raxml_remote_host = parser.get_param("raxml", "raxml_remote_host", str, self.raxml_remote_host)
230
231        self.raxml_model = parser.get_param("raxml", "raxml_model", str, self.raxml_model).upper()
232        self.num_threads = parser.get_param("raxml", "raxml_threads", int, self.num_threads)
233
234        self.epa_use_heuristic = parser.get_param("raxml", "epa_use_heuristic", str, self.epa_use_heuristic).upper()
235        self.epa_heur_rate = parser.get_param("raxml", "epa_heur_rate", float, self.epa_heur_rate)
236        self.epa_load_optmod = parser.get_param("raxml", "epa_load_optmod", bool, self.epa_load_optmod)
237
238        self.hmmer_home = self.resolve_relative_path(parser.get_param("hmmer", "hmmer_home", str, self.hmmer_home))
239        self.muscle_home = self.resolve_relative_path(parser.get_param("muscle", "muscle_home", str, self.muscle_home))
240       
241        self.run_on_cluster = parser.get_param("cluster", "run_on_cluster", bool, self.run_on_cluster)
242        self.cluster_epac_home = parser.get_param("cluster", "cluster_epac_home", str, self.cluster_epac_home) + "/"
243        self.cluster_qsub_script = parser.get_param("cluster", "cluster_qsub_script", str, self.cluster_qsub_script)
244
245        self.min_confidence = parser.get_param("assignment", "min_confidence", float, self.min_confidence)
246
247        return parser
248
249    def subst_name(self, in_str):
250        """Replace %NAME% macros with an actual EPAC run name. Used to
251        generate unique run-specific identifiers (filenames, RAxML job names etc)"""
252        return in_str.replace("%NAME%", self.name)
253   
254    def tmp_fname(self, fname):
255        return os.path.join(self.temp_dir, self.subst_name(fname))
256
257    def out_fname(self, fname):
258        return os.path.join(self.output_dir, self.subst_name(fname))
259       
260    def clean_tempdir(self):
261        if not self.debug and os.path.isdir(self.temp_dir):
262            shutil.rmtree(self.temp_dir)
263   
264    def exit_fatal_error(self, msg=None):
265        if msg:
266            self.log.error(msg)
267        sys.exit(13)   
268
269    def exit_user_error(self, msg=None):
270        if msg:
271            self.log.error(msg)
272        self.clean_tempdir()
273        sys.exit(14)
274
275    def print_version(self, progname):
276        self.log.info(EpacConfig.SATIVA_INFO % progname)   
277        if self.restart:
278            self.log.info("Resuming %s execution using files from previous run found in: %s\n" % (progname, self.temp_dir))
279   
280class EpacTrainerConfig(EpacConfig):
281   
282    def __init__(self, args=None):
283        EpacConfig.__init__(self, args)
284        if args:
285            self.taxonomy_fname = args.taxonomy_fname
286            self.align_fname = args.align_fname
287            self.no_hmmer = args.no_hmmer
288            self.dup_rank_names  = args.dup_rank_names
289            self.wrong_rank_count  = args.wrong_rank_count
290            self.mfresolv_method = args.mfresolv_method
291            self.taxcode_name = args.taxcode_name
292            self.rep_num = args.rep_num
293       
294    def set_defaults(self):
295        EpacConfig.set_defaults(self)
296        self.no_hmmer = False
297        # whether model parameters should be re-optimized from scratch on the best topology using "-f e"
298        self.reopt_model = False
299        self.dup_rank_names = "ignore"
300        self.wrong_rank_count = "ignore"
301        self.taxassign_method = "1"
302        self.mfresolv_method = "thorough"
303        self.taxcode_name = "bac"
304        self.rep_num = 1
305        # default settings below imply no taxonomy filtering,
306        # i.e. all sequences from taxonomy file will be included into reference tree
307        self.reftree_min_rank = 0
308        self.reftree_max_seqs_per_leaf = 1e6
309        self.reftree_clades_to_include=[]
310        self.reftree_clades_to_ignore=[]
311       
312
313    def read_from_file(self, config_fname):
314        parser = EpacConfig.read_from_file(self, config_fname)
315       
316        self.reftree_min_rank = parser.get_param("reftree", "min_rank", int, self.reftree_min_rank)
317        self.reftree_max_seqs_per_leaf = parser.get_param("reftree", "max_seqs_per_leaf", int, self.reftree_max_seqs_per_leaf)
318        clades_str = parser.get_param("reftree", "clades_to_include", str, "")
319        self.reftree_clades_to_include = self.parse_clades(clades_str)
320        clades_str = parser.get_param("reftree", "clades_to_ignore", str, "")
321        self.reftree_clades_to_ignore = self.parse_clades(clades_str)
322       
323    def parse_clades(self, clades_str):
324        clade_list = []
325        try:       
326            if clades_str:
327                clades = clades_str.split(",")
328                for clade in clades:
329                    toks = clade.split("|")
330                    clade_list += [(int(toks[0]), toks[1])]
331        except:
332            print "Invalid format in config parameter: clades_to_include"
333            sys.exit()
334
335        return clade_list
336       
337class EpacClassifierConfig(EpacConfig):
338
339    def __init__(self, args=None):
340        EpacConfig.__init__(self, args)
341
342        if args:
343            self.taxassign_method = args.taxassign_method
344            self.min_lhw = args.min_lhw
345            self.brlen_pv = args.brlen_pv
346        else:
347            self.taxassign_method = "1"
348            self.min_lhw = 0.
349            self.brlen_pv = 0.
350
351class SativaConfig(EpacTrainerConfig):
352   
353    def __init__(self, args):
354        args.no_hmmer = True
355        args.dup_rank_names = "ignore"
356        args.wrong_rank_count = "ignore"
357        args.taxassign_method = "1"
358
359        EpacTrainerConfig.__init__(self, args)
360
361        self.taxassign_method = args.taxassign_method
362        self.min_lhw = args.min_lhw
363        self.brlen_pv = args.brlen_pv
364        self.ranktest = args.ranktest
365        self.conf_cutoff = args.conf_cutoff
366        self.jplace_fname = args.jplace_fname
367       
368        self.output_interim_files = True
369        self.compress_patterns = True
370        self.use_bfgs = True
371        self.save_memory = False
372
373        if self.refjson_fname:
374            self.load_refjson = True
375        else:
376            self.load_refjson = False
377            if self.output_interim_files:
378                self.refjson_fname = self.out_fname("%NAME%.refjson")
379            else:
380                self.refjson_fname = self.tmp_fname("%NAME%.refjson")
381               
382        if self.restart and os.path.isfile(self.refjson_fname):
383            self.load_refjson = True
384       
385    def set_defaults(self):
386        EpacTrainerConfig.set_defaults(self)
387
388    def read_from_file(self, config_fname):
389        parser = EpacTrainerConfig.read_from_file(self, config_fname)
Note: See TracBrowser for help on using the repository browser.