| 1 | #! /usr/bin/env python |
|---|
| 2 | import sys |
|---|
| 3 | import os |
|---|
| 4 | import json |
|---|
| 5 | import operator |
|---|
| 6 | import base64 |
|---|
| 7 | from subprocess import call |
|---|
| 8 | from ete2 import Tree, SeqGroup |
|---|
| 9 | from taxonomy_util import TaxCode |
|---|
| 10 | |
|---|
| 11 | class EpaJsonParser: |
|---|
| 12 | """This class parses the RAxML-EPA json output file""" |
|---|
| 13 | def __init__(self, jsonfin): |
|---|
| 14 | self.jdata = json.load(open(jsonfin)) |
|---|
| 15 | |
|---|
| 16 | def get_placement(self): |
|---|
| 17 | return self.jdata["placements"] |
|---|
| 18 | |
|---|
| 19 | def get_tree(self): |
|---|
| 20 | return self.jdata["tree"] |
|---|
| 21 | |
|---|
| 22 | def get_std_newick_tree(self): |
|---|
| 23 | tree = self.jdata["tree"] |
|---|
| 24 | tree = tree.replace("{", "[&&NHX:B=") |
|---|
| 25 | tree = tree.replace("}", "]") |
|---|
| 26 | return tree |
|---|
| 27 | |
|---|
| 28 | def get_raxml_version(self): |
|---|
| 29 | return self.jdata["metadata"]["raxml_version"] |
|---|
| 30 | |
|---|
| 31 | def get_raxml_invocation(self): |
|---|
| 32 | return self.jdata["metadata"]["invocation"] |
|---|
| 33 | |
|---|
| 34 | class RefJsonChecker: |
|---|
| 35 | def __init__(self, jsonfin= None, jdata = None): |
|---|
| 36 | if jsonfin!=None: |
|---|
| 37 | self.jdata = json.load(open(jsonfin)) |
|---|
| 38 | else: |
|---|
| 39 | self.jdata = jdata |
|---|
| 40 | |
|---|
| 41 | def check_field(self, fname, ftype, fvals=None, fopt=False): |
|---|
| 42 | if fname in self.jdata: |
|---|
| 43 | field = self.jdata[fname] |
|---|
| 44 | if isinstance(field, ftype): |
|---|
| 45 | if not fvals or field in fvals: |
|---|
| 46 | return True |
|---|
| 47 | else: |
|---|
| 48 | self.error = "Invalid value of field '%s': %s" % fname, repr(field) |
|---|
| 49 | return False |
|---|
| 50 | else: |
|---|
| 51 | self.error = "Field '%s' has a wrong type: %s (expected: %s)" % fname, type(field).__name__, ftype.__name__ |
|---|
| 52 | return False |
|---|
| 53 | else: |
|---|
| 54 | if fopt: |
|---|
| 55 | return True |
|---|
| 56 | else: |
|---|
| 57 | self.error = "Field not found: %s" % fname |
|---|
| 58 | return False |
|---|
| 59 | |
|---|
| 60 | def validate(self, ver = "1.6"): |
|---|
| 61 | nver = float(ver) |
|---|
| 62 | |
|---|
| 63 | self.error = None |
|---|
| 64 | |
|---|
| 65 | valid = self.check_field("tree", unicode) \ |
|---|
| 66 | and self.check_field("raxmltree", unicode) \ |
|---|
| 67 | and self.check_field("rate", float) \ |
|---|
| 68 | and self.check_field("node_height", dict) \ |
|---|
| 69 | and self.check_field("origin_taxonomy", dict) \ |
|---|
| 70 | and self.check_field("sequences", list) \ |
|---|
| 71 | and self.check_field("binary_model", unicode) \ |
|---|
| 72 | and self.check_field("hmm_profile", list, fopt=True) |
|---|
| 73 | |
|---|
| 74 | # check v1.1 fields, if needed |
|---|
| 75 | if nver >= 1.1: |
|---|
| 76 | valid = valid and \ |
|---|
| 77 | self.check_field("ratehet_model", unicode) # ["GTRGAMMA", "GTRCAT"] |
|---|
| 78 | |
|---|
| 79 | # check v1.2 fields, if needed |
|---|
| 80 | if nver >= 1.2: |
|---|
| 81 | valid = valid and \ |
|---|
| 82 | self.check_field("tax_tree", unicode) |
|---|
| 83 | |
|---|
| 84 | # check v1.3 fields, if needed |
|---|
| 85 | if nver >= 1.3: |
|---|
| 86 | valid = valid and \ |
|---|
| 87 | self.check_field("taxcode", unicode, TaxCode.TAX_CODE_MAP) |
|---|
| 88 | |
|---|
| 89 | # check v1.4 fields, if needed |
|---|
| 90 | if nver >= 1.4: |
|---|
| 91 | valid = valid \ |
|---|
| 92 | and self.check_field("corr_seqid_map", dict) \ |
|---|
| 93 | and self.check_field("corr_ranks_map", dict) |
|---|
| 94 | |
|---|
| 95 | # check v1.5 fields, if needed |
|---|
| 96 | if nver >= 1.5: |
|---|
| 97 | valid = valid \ |
|---|
| 98 | and self.check_field("merged_ranks_map", dict) |
|---|
| 99 | |
|---|
| 100 | # "taxonomy" field has been renamed and its format was changed in v1.6 |
|---|
| 101 | if nver >= 1.6: |
|---|
| 102 | valid = valid \ |
|---|
| 103 | and self.check_field("branch_tax_map", dict) |
|---|
| 104 | else: |
|---|
| 105 | valid = valid \ |
|---|
| 106 | and self.check_field("taxonomy", dict) |
|---|
| 107 | |
|---|
| 108 | return (valid, self.error) |
|---|
| 109 | |
|---|
| 110 | class RefJsonParser: |
|---|
| 111 | """This class parses the EPA Classifier reference json file""" |
|---|
| 112 | def __init__(self, jsonfin): |
|---|
| 113 | self.jdata = json.load(open(jsonfin)) |
|---|
| 114 | self.version = self.jdata["version"] |
|---|
| 115 | self.nversion = float(self.version) |
|---|
| 116 | self.corr_seqid = None |
|---|
| 117 | self.corr_ranks = None |
|---|
| 118 | self.corr_seqid_reverse = None |
|---|
| 119 | |
|---|
| 120 | def validate(self): |
|---|
| 121 | jc = RefJsonChecker(jdata = self.jdata) |
|---|
| 122 | return jc.validate(self.version) |
|---|
| 123 | |
|---|
| 124 | def get_version(self): |
|---|
| 125 | return self.version |
|---|
| 126 | |
|---|
| 127 | def get_rate(self): |
|---|
| 128 | return self.jdata["rate"] |
|---|
| 129 | |
|---|
| 130 | def get_node_height(self): |
|---|
| 131 | return self.jdata["node_height"] |
|---|
| 132 | |
|---|
| 133 | def get_raxml_readable_tree(self, fout_name = None): |
|---|
| 134 | tree_str = self.jdata["raxmltree"] |
|---|
| 135 | #t.unroot() |
|---|
| 136 | if fout_name != None: |
|---|
| 137 | with open(fout_name, "w") as fout: |
|---|
| 138 | fout.write(tree_str) |
|---|
| 139 | else: |
|---|
| 140 | return tree_str |
|---|
| 141 | |
|---|
| 142 | def get_reftree(self, fout_name = None): |
|---|
| 143 | tree_str = self.jdata["tree"] |
|---|
| 144 | if fout_name != None: |
|---|
| 145 | with open(fout_name, "w") as fout: |
|---|
| 146 | fout.write(tree_str) |
|---|
| 147 | else: |
|---|
| 148 | return Tree(tree_str, format=1) |
|---|
| 149 | |
|---|
| 150 | def get_tax_tree(self): |
|---|
| 151 | t = Tree(self.jdata["tax_tree"], format=8) |
|---|
| 152 | return t |
|---|
| 153 | |
|---|
| 154 | def get_outgroup(self): |
|---|
| 155 | t = Tree(self.jdata["outgroup"], format=9) |
|---|
| 156 | return t |
|---|
| 157 | |
|---|
| 158 | def get_branch_tax_map(self): |
|---|
| 159 | if self.nversion >= 1.6: |
|---|
| 160 | return self.jdata["branch_tax_map"] |
|---|
| 161 | else: |
|---|
| 162 | return None |
|---|
| 163 | |
|---|
| 164 | def get_taxonomy(self): |
|---|
| 165 | if self.nversion < 1.6: |
|---|
| 166 | return self.jdata["taxonomy"] |
|---|
| 167 | else: |
|---|
| 168 | return None |
|---|
| 169 | |
|---|
| 170 | def get_origin_taxonomy(self): |
|---|
| 171 | return self.jdata["origin_taxonomy"] |
|---|
| 172 | |
|---|
| 173 | def get_alignment(self, fout): |
|---|
| 174 | entries = self.jdata["sequences"] |
|---|
| 175 | with open(fout, "w") as fo: |
|---|
| 176 | for entr in entries: |
|---|
| 177 | fo.write(">%s\n%s\n" % (entr[0], entr[1])) |
|---|
| 178 | |
|---|
| 179 | return fout |
|---|
| 180 | |
|---|
| 181 | def get_ref_alignment(self): |
|---|
| 182 | entries = self.jdata["sequences"] |
|---|
| 183 | alignment = SeqGroup() |
|---|
| 184 | for entr in entries: |
|---|
| 185 | alignment.set_seq(entr[0], entr[1]) |
|---|
| 186 | return alignment |
|---|
| 187 | |
|---|
| 188 | def get_alignment_list(self): |
|---|
| 189 | return self.jdata["sequences"] |
|---|
| 190 | |
|---|
| 191 | def get_sequences_names(self): |
|---|
| 192 | nameset = set() |
|---|
| 193 | entries = self.jdata["sequences"] |
|---|
| 194 | for entr in entries: |
|---|
| 195 | nameset.add(entr[0]) |
|---|
| 196 | return nameset |
|---|
| 197 | |
|---|
| 198 | def get_alignment_length(self): |
|---|
| 199 | entries = self.jdata["sequences"] |
|---|
| 200 | return len(entries[0][1]) |
|---|
| 201 | |
|---|
| 202 | def get_hmm_profile(self, fout): |
|---|
| 203 | if "hmm_profile" in self.jdata: |
|---|
| 204 | lines = self.jdata["hmm_profile"] |
|---|
| 205 | with open(fout, "w") as fo: |
|---|
| 206 | for line in lines: |
|---|
| 207 | fo.write(line) |
|---|
| 208 | return fout |
|---|
| 209 | else: |
|---|
| 210 | return None |
|---|
| 211 | |
|---|
| 212 | def get_binary_model(self, fout): |
|---|
| 213 | model_str = self.jdata["binary_model"] |
|---|
| 214 | with open(fout, "wb") as fo: |
|---|
| 215 | fo.write(base64.b64decode(model_str)) |
|---|
| 216 | |
|---|
| 217 | def get_ratehet_model(self): |
|---|
| 218 | return self.jdata["ratehet_model"] |
|---|
| 219 | |
|---|
| 220 | def get_pattern_compression(self): |
|---|
| 221 | if "pattern_compression" in self.jdata: |
|---|
| 222 | return self.jdata["pattern_compression"] |
|---|
| 223 | else: |
|---|
| 224 | return False |
|---|
| 225 | |
|---|
| 226 | def get_taxcode(self): |
|---|
| 227 | return self.jdata["taxcode"] |
|---|
| 228 | |
|---|
| 229 | def get_corr_seqid_map(self): |
|---|
| 230 | if "corr_seqid_map" in self.jdata: |
|---|
| 231 | self.corr_seqid = self.jdata["corr_seqid_map"] |
|---|
| 232 | else: |
|---|
| 233 | self.corr_seqid = {} |
|---|
| 234 | return self.corr_seqid |
|---|
| 235 | |
|---|
| 236 | def get_corr_ranks_map(self): |
|---|
| 237 | if "corr_ranks_map" in self.jdata: |
|---|
| 238 | self.corr_ranks = self.jdata["corr_ranks_map"] |
|---|
| 239 | else: |
|---|
| 240 | self.corr_ranks = {} |
|---|
| 241 | return self.corr_ranks |
|---|
| 242 | |
|---|
| 243 | def get_merged_ranks_map(self): |
|---|
| 244 | if "merged_ranks_map" in self.jdata: |
|---|
| 245 | self.merged_ranks = self.jdata["merged_ranks_map"] |
|---|
| 246 | else: |
|---|
| 247 | self.merged_ranks = {} |
|---|
| 248 | return self.merged_ranks |
|---|
| 249 | |
|---|
| 250 | def get_metadata(self): |
|---|
| 251 | return self.jdata["metadata"] |
|---|
| 252 | |
|---|
| 253 | def get_field_string(self, field_name): |
|---|
| 254 | if field_name in self.jdata: |
|---|
| 255 | return json.dumps(self.jdata[field_name], indent=4, separators=(',', ': ')).strip("\"") |
|---|
| 256 | else: |
|---|
| 257 | return None |
|---|
| 258 | |
|---|
| 259 | def get_uncorr_seqid(self, new_seqid): |
|---|
| 260 | if not self.corr_seqid: |
|---|
| 261 | self.get_corr_seqid_map() |
|---|
| 262 | return self.corr_seqid.get(new_seqid, new_seqid) |
|---|
| 263 | |
|---|
| 264 | def get_corr_seqid(self, old_seqid): |
|---|
| 265 | if not self.corr_seqid_reverse: |
|---|
| 266 | if not self.corr_seqid: |
|---|
| 267 | self.get_corr_seqid_map() |
|---|
| 268 | self.corr_seqid_reverse = dict((reversed(item) for item in self.corr_seqid.items())) |
|---|
| 269 | return self.corr_seqid_reverse.get(old_seqid, old_seqid) |
|---|
| 270 | |
|---|
| 271 | def get_uncorr_ranks(self, ranks): |
|---|
| 272 | if not self.corr_ranks: |
|---|
| 273 | self.get_corr_ranks_map() |
|---|
| 274 | uncorr_ranks = list(ranks) |
|---|
| 275 | for i in range(len(ranks)): |
|---|
| 276 | uncorr_ranks[i] = self.corr_ranks.get(ranks[i], ranks[i]) |
|---|
| 277 | return uncorr_ranks |
|---|
| 278 | |
|---|
| 279 | class RefJsonBuilder: |
|---|
| 280 | """This class builds the EPA Classifier reference json file""" |
|---|
| 281 | def __init__(self, old_json=None): |
|---|
| 282 | if old_json: |
|---|
| 283 | self.jdata = old_json.jdata |
|---|
| 284 | else: |
|---|
| 285 | self.jdata = {} |
|---|
| 286 | self.jdata["version"] = "1.6" |
|---|
| 287 | # self.jdata["author"] = "Jiajie Zhang" |
|---|
| 288 | |
|---|
| 289 | def set_branch_tax_map(self, bid_ranks_map): |
|---|
| 290 | self.jdata["branch_tax_map"] = bid_ranks_map |
|---|
| 291 | |
|---|
| 292 | def set_origin_taxonomy(self, orig_tax_map): |
|---|
| 293 | self.jdata["origin_taxonomy"] = orig_tax_map |
|---|
| 294 | |
|---|
| 295 | def set_tax_tree(self, tr): |
|---|
| 296 | self.jdata["tax_tree"] = tr.write(format=8) |
|---|
| 297 | |
|---|
| 298 | def set_tree(self, tr): |
|---|
| 299 | self.jdata["tree"] = tr |
|---|
| 300 | self.jdata["raxmltree"] = Tree(tr, format=1).write(format=5) |
|---|
| 301 | |
|---|
| 302 | def set_outgroup(self, outgr): |
|---|
| 303 | self.jdata["outgroup"] = outgr.write(format=9) |
|---|
| 304 | |
|---|
| 305 | def set_sequences(self, seqs): |
|---|
| 306 | self.jdata["sequences"] = seqs |
|---|
| 307 | |
|---|
| 308 | def set_hmm_profile(self, fprofile): |
|---|
| 309 | with open(fprofile) as fp: |
|---|
| 310 | lines = fp.readlines() |
|---|
| 311 | self.jdata["hmm_profile"] = lines |
|---|
| 312 | |
|---|
| 313 | def set_rate(self, rate): |
|---|
| 314 | self.jdata["rate"] = rate |
|---|
| 315 | |
|---|
| 316 | def set_nodes_height(self, height): |
|---|
| 317 | self.jdata["node_height"] = height |
|---|
| 318 | |
|---|
| 319 | def set_binary_model(self, model_fname): |
|---|
| 320 | with open(model_fname, "rb") as fin: |
|---|
| 321 | model_str = base64.b64encode(fin.read()) |
|---|
| 322 | self.jdata["binary_model"] = model_str |
|---|
| 323 | |
|---|
| 324 | def set_ratehet_model(self, model): |
|---|
| 325 | self.jdata["ratehet_model"] = model |
|---|
| 326 | |
|---|
| 327 | def set_pattern_compression(self, value): |
|---|
| 328 | self.jdata["pattern_compression"] = value |
|---|
| 329 | |
|---|
| 330 | def set_taxcode(self, value): |
|---|
| 331 | self.jdata["taxcode"] = value |
|---|
| 332 | |
|---|
| 333 | def set_corr_seqid_map(self, seqid_map): |
|---|
| 334 | self.jdata["corr_seqid_map"] = seqid_map |
|---|
| 335 | |
|---|
| 336 | def set_corr_ranks_map(self, ranks_map): |
|---|
| 337 | self.jdata["corr_ranks_map"] = ranks_map |
|---|
| 338 | |
|---|
| 339 | def set_merged_ranks_map(self, merged_ranks_map): |
|---|
| 340 | self.jdata["merged_ranks_map"] = merged_ranks_map |
|---|
| 341 | |
|---|
| 342 | def set_metadata(self, metadata): |
|---|
| 343 | self.jdata["metadata"] = metadata |
|---|
| 344 | |
|---|
| 345 | def dump(self, out_fname): |
|---|
| 346 | self.jdata.pop("fields", 0) |
|---|
| 347 | self.jdata["fields"] = self.jdata.keys() |
|---|
| 348 | with open(out_fname, "w") as fo: |
|---|
| 349 | json.dump(self.jdata, fo, indent=4, sort_keys=True) |
|---|
| 350 | |
|---|
| 351 | |
|---|
| 352 | if __name__ == "__main__": |
|---|
| 353 | if len(sys.argv) < 2: |
|---|
| 354 | print("usage: ./json_util.py jsonfile") |
|---|
| 355 | sys.exit() |
|---|
| 356 | jc = json_checker(jsonfin = sys.argv[1]) |
|---|
| 357 | if jc.valid(): |
|---|
| 358 | print("The json file is OK for EPA-classifer") |
|---|
| 359 | else: |
|---|
| 360 | print("!!!Invalid json file!!!") |
|---|
| 361 | |
|---|