source: branches/help/GDE/SATIVA/sativa/epac/json_util.py

Last change on this file was 14544, checked in by akozlov, 9 years ago
File size: 11.2 KB
Line 
1#! /usr/bin/env python
2import sys
3import os
4import json
5import operator
6import base64
7from subprocess import call
8from ete2 import Tree, SeqGroup
9from taxonomy_util import TaxCode
10
11class EpaJsonParser:
12    """This class parses the RAxML-EPA json output file"""
13    def __init__(self, jsonfin):
14        self.jdata = json.load(open(jsonfin))
15   
16    def get_placement(self):
17        return self.jdata["placements"]
18       
19    def get_tree(self):
20        return self.jdata["tree"]
21       
22    def get_std_newick_tree(self):
23        tree = self.jdata["tree"]
24        tree = tree.replace("{", "[&&NHX:B=")
25        tree = tree.replace("}", "]")
26        return tree
27
28    def get_raxml_version(self):
29        return self.jdata["metadata"]["raxml_version"]
30
31    def get_raxml_invocation(self):
32        return self.jdata["metadata"]["invocation"]
33
34class RefJsonChecker:
35    def __init__(self, jsonfin= None, jdata = None):
36        if jsonfin!=None:
37            self.jdata = json.load(open(jsonfin))
38        else:
39            self.jdata = jdata
40   
41    def check_field(self, fname, ftype, fvals=None, fopt=False):
42        if fname in self.jdata:
43            field = self.jdata[fname]
44            if isinstance(field, ftype):
45                if not fvals or field in fvals:
46                    return True
47                else:
48                    self.error = "Invalid value of field '%s': %s" % fname, repr(field)
49                    return False
50            else:
51              self.error = "Field '%s' has a wrong type: %s (expected: %s)" % fname, type(field).__name__, ftype.__name__
52              return False
53        else:
54            if fopt:
55                return True
56            else:
57                self.error = "Field not found: %s" % fname
58                return False
59   
60    def validate(self, ver = "1.6"):
61        nver = float(ver)
62       
63        self.error = None
64
65        valid = self.check_field("tree", unicode) \
66                and self.check_field("raxmltree", unicode) \
67                and self.check_field("rate", float) \
68                and self.check_field("node_height", dict) \
69                and self.check_field("origin_taxonomy", dict) \
70                and self.check_field("sequences", list) \
71                and self.check_field("binary_model", unicode) \
72                and self.check_field("hmm_profile", list, fopt=True) 
73               
74        # check v1.1 fields, if needed
75        if nver >= 1.1:
76            valid = valid and \
77                    self.check_field("ratehet_model", unicode)  # ["GTRGAMMA", "GTRCAT"]
78
79        # check v1.2 fields, if needed
80        if nver >= 1.2:
81            valid = valid and \
82                    self.check_field("tax_tree", unicode)
83
84        # check v1.3 fields, if needed
85        if nver >= 1.3:
86            valid = valid and \
87                    self.check_field("taxcode", unicode, TaxCode.TAX_CODE_MAP)
88       
89        # check v1.4 fields, if needed
90        if nver >= 1.4:
91            valid = valid \
92                    and self.check_field("corr_seqid_map", dict) \
93                    and self.check_field("corr_ranks_map", dict)
94
95        # check v1.5 fields, if needed
96        if nver >= 1.5:
97            valid = valid \
98                    and self.check_field("merged_ranks_map", dict)
99
100        # "taxonomy" field has been renamed and its format was changed in v1.6
101        if nver >= 1.6:
102            valid = valid \
103                    and self.check_field("branch_tax_map", dict)
104        else:
105            valid = valid \
106                    and self.check_field("taxonomy", dict)
107
108        return (valid, self.error)
109
110class RefJsonParser:
111    """This class parses the EPA Classifier reference json file"""
112    def __init__(self, jsonfin):
113        self.jdata = json.load(open(jsonfin))
114        self.version = self.jdata["version"]
115        self.nversion = float(self.version)
116        self.corr_seqid = None
117        self.corr_ranks = None
118        self.corr_seqid_reverse = None
119       
120    def validate(self):
121        jc = RefJsonChecker(jdata = self.jdata)
122        return jc.validate(self.version)
123   
124    def get_version(self):
125        return self.version
126
127    def get_rate(self):
128        return self.jdata["rate"]
129   
130    def get_node_height(self):
131        return self.jdata["node_height"]
132   
133    def get_raxml_readable_tree(self, fout_name = None):
134        tree_str = self.jdata["raxmltree"]
135        #t.unroot()
136        if fout_name != None:
137            with open(fout_name, "w") as fout:
138                fout.write(tree_str)
139        else:
140            return tree_str
141   
142    def get_reftree(self, fout_name = None):
143        tree_str = self.jdata["tree"]
144        if fout_name != None:
145            with open(fout_name, "w") as fout:
146                fout.write(tree_str)
147        else:
148            return Tree(tree_str, format=1)
149   
150    def get_tax_tree(self):
151        t = Tree(self.jdata["tax_tree"], format=8)
152        return t
153
154    def get_outgroup(self):
155        t = Tree(self.jdata["outgroup"], format=9)
156        return t
157
158    def get_branch_tax_map(self):
159        if self.nversion >= 1.6:
160            return self.jdata["branch_tax_map"]
161        else:
162            return None
163
164    def get_taxonomy(self):
165        if self.nversion < 1.6:
166            return self.jdata["taxonomy"]
167        else:
168            return None
169
170    def get_origin_taxonomy(self):
171        return self.jdata["origin_taxonomy"]
172   
173    def get_alignment(self, fout):
174        entries = self.jdata["sequences"]
175        with open(fout, "w") as fo:
176            for entr in entries:
177                fo.write(">%s\n%s\n" % (entr[0], entr[1]))
178
179        return fout
180   
181    def get_ref_alignment(self):
182        entries = self.jdata["sequences"]
183        alignment = SeqGroup()
184        for entr in entries:
185            alignment.set_seq(entr[0], entr[1])
186        return alignment
187   
188    def get_alignment_list(self):
189        return self.jdata["sequences"]
190   
191    def get_sequences_names(self):
192        nameset = set()
193        entries = self.jdata["sequences"]
194        for entr in entries:
195            nameset.add(entr[0])
196        return nameset
197   
198    def get_alignment_length(self):
199        entries = self.jdata["sequences"]
200        return len(entries[0][1])
201   
202    def get_hmm_profile(self, fout):
203        if "hmm_profile" in self.jdata:       
204            lines = self.jdata["hmm_profile"]
205            with open(fout, "w") as fo:
206                for line in lines:
207                    fo.write(line)
208            return fout
209        else:
210            return None
211
212    def get_binary_model(self, fout):
213        model_str = self.jdata["binary_model"]
214        with open(fout, "wb") as fo:
215            fo.write(base64.b64decode(model_str))
216
217    def get_ratehet_model(self):
218        return self.jdata["ratehet_model"]
219
220    def get_pattern_compression(self):
221        if "pattern_compression" in self.jdata:
222            return self.jdata["pattern_compression"]
223        else:
224            return False
225
226    def get_taxcode(self):
227        return self.jdata["taxcode"]
228       
229    def get_corr_seqid_map(self):
230        if "corr_seqid_map" in self.jdata:
231            self.corr_seqid = self.jdata["corr_seqid_map"]
232        else:
233            self.corr_seqid = {}
234        return self.corr_seqid
235
236    def get_corr_ranks_map(self):
237        if "corr_ranks_map" in self.jdata:
238            self.corr_ranks = self.jdata["corr_ranks_map"]
239        else:
240            self.corr_ranks = {}
241        return self.corr_ranks
242
243    def get_merged_ranks_map(self):
244        if "merged_ranks_map" in self.jdata:
245            self.merged_ranks = self.jdata["merged_ranks_map"]
246        else:
247            self.merged_ranks = {}
248        return self.merged_ranks
249
250    def get_metadata(self):
251        return self.jdata["metadata"]
252       
253    def get_field_string(self, field_name):
254        if field_name in self.jdata:
255            return json.dumps(self.jdata[field_name], indent=4, separators=(',', ': ')).strip("\"")
256        else:
257            return None
258           
259    def get_uncorr_seqid(self, new_seqid):
260        if not self.corr_seqid:
261            self.get_corr_seqid_map()
262        return self.corr_seqid.get(new_seqid, new_seqid)
263       
264    def get_corr_seqid(self, old_seqid):
265        if not self.corr_seqid_reverse:
266            if not self.corr_seqid:
267                self.get_corr_seqid_map()
268            self.corr_seqid_reverse = dict((reversed(item) for item in self.corr_seqid.items()))
269        return self.corr_seqid_reverse.get(old_seqid, old_seqid)
270
271    def get_uncorr_ranks(self, ranks):
272        if not self.corr_ranks:
273            self.get_corr_ranks_map()
274        uncorr_ranks = list(ranks)
275        for i in range(len(ranks)):
276            uncorr_ranks[i] = self.corr_ranks.get(ranks[i], ranks[i])
277        return uncorr_ranks       
278               
279class RefJsonBuilder:
280    """This class builds the EPA Classifier reference json file"""
281    def __init__(self, old_json=None):
282        if old_json:
283            self.jdata = old_json.jdata
284        else:
285            self.jdata = {}
286            self.jdata["version"] = "1.6"
287#            self.jdata["author"] = "Jiajie Zhang"
288       
289    def set_branch_tax_map(self, bid_ranks_map):
290        self.jdata["branch_tax_map"] = bid_ranks_map
291
292    def set_origin_taxonomy(self, orig_tax_map):
293        self.jdata["origin_taxonomy"] = orig_tax_map
294
295    def set_tax_tree(self, tr):
296        self.jdata["tax_tree"] = tr.write(format=8)
297
298    def set_tree(self, tr):
299        self.jdata["tree"] = tr
300        self.jdata["raxmltree"] = Tree(tr, format=1).write(format=5)
301       
302    def set_outgroup(self, outgr):
303        self.jdata["outgroup"] = outgr.write(format=9)
304
305    def set_sequences(self, seqs):   
306        self.jdata["sequences"] = seqs
307       
308    def set_hmm_profile(self, fprofile):   
309        with open(fprofile) as fp:
310            lines = fp.readlines()
311        self.jdata["hmm_profile"] = lines
312       
313    def set_rate(self, rate):   
314        self.jdata["rate"] = rate
315       
316    def set_nodes_height(self, height):   
317        self.jdata["node_height"] = height
318
319    def set_binary_model(self, model_fname): 
320        with open(model_fname, "rb") as fin:
321            model_str = base64.b64encode(fin.read())
322        self.jdata["binary_model"] = model_str
323
324    def set_ratehet_model(self, model):
325        self.jdata["ratehet_model"] = model
326
327    def set_pattern_compression(self, value):
328        self.jdata["pattern_compression"] = value
329
330    def set_taxcode(self, value):
331        self.jdata["taxcode"] = value
332
333    def set_corr_seqid_map(self, seqid_map):
334        self.jdata["corr_seqid_map"] = seqid_map
335
336    def set_corr_ranks_map(self, ranks_map):
337        self.jdata["corr_ranks_map"] = ranks_map
338
339    def set_merged_ranks_map(self, merged_ranks_map):
340        self.jdata["merged_ranks_map"] = merged_ranks_map
341
342    def set_metadata(self, metadata):   
343        self.jdata["metadata"] = metadata
344
345    def dump(self, out_fname):
346        self.jdata.pop("fields", 0)
347        self.jdata["fields"] = self.jdata.keys()
348        with open(out_fname, "w") as fo:
349            json.dump(self.jdata, fo, indent=4, sort_keys=True)               
350
351
352if __name__ == "__main__":
353    if len(sys.argv) < 2: 
354        print("usage: ./json_util.py jsonfile")
355        sys.exit()
356    jc = json_checker(jsonfin = sys.argv[1])
357    if jc.valid():
358        print("The json file is OK for EPA-classifer")
359    else:
360        print("!!!Invalid json file!!!")
361   
Note: See TracBrowser for help on using the repository browser.