source: trunk/PROBE/PT_family.cxx

Last change on this file was 19061, checked in by westram, 2 years ago
  • fix loop vectorization checks failing with gcc 10.x
  • Property svn:eol-style set to native
  • Property svn:keywords set to Author Date Id Revision
File size: 15.2 KB
Line 
1// =============================================================== //
2//                                                                 //
3//   File      : PT_family.cxx                                     //
4//   Purpose   :                                                   //
5//                                                                 //
6//   Institute of Microbiology (Technical University Munich)       //
7//   http://www.arb-home.de/                                       //
8//                                                                 //
9// =============================================================== //
10
11#include "PT_rangeCheck.h"
12#include "pt_prototypes.h"
13
14#include <struct_man.h>
15#include <PT_server_prototypes.h>
16#include "PT_global_defs.h"
17#include "PT_complement.h"
18
19#include <arbdbt.h>
20
21#include <algorithm>
22#include <vector>
23#include <map>
24
25// overloaded functions to avoid problems with type-punning:
26inline void aisc_link(dll_public *dll, PT_family_list *family)   { aisc_link(reinterpret_cast<dllpublic_ext*>(dll), reinterpret_cast<dllheader_ext*>(family)); }
27
28struct TraversalHitLimit {
29    int id;    // "unique" for each traversal
30    int limit; // max hits allowed to each target seq for this traversal
31
32    TraversalHitLimit(int id_, int limit_)
33        : id(id_), limit(limit_)
34    { pt_assert(limit>0); }
35};
36
37class HitCounter {
38    int    trav_id;    // current traversal id
39    int    trav_hits;  // how many hits have been counted during traversal 'trav_id' ?
40    int    count;      // Counter for matches
41    double rel_count;  // match_count / (seqlen - oligolen + 1). seqlen depends on RelativeScoreScaling
42
43public:
44    HitCounter() : trav_id(-1), trav_hits(0), count(0), rel_count(0.0) {}
45
46    void inc(const TraversalHitLimit& traversal) {
47        if (traversal.id != trav_id) { // first hit during this traversal
48            trav_id  = traversal.id;
49            trav_hits = 1; // reset
50            count++;
51        }
52        else {
53            if (trav_hits<traversal.limit) {
54                trav_hits++;
55                count++;
56            }
57        }
58    }
59    void calc_rel_match(int max_poss_matches) {
60        rel_count = max_poss_matches>0 ? double(count)/max_poss_matches : 0;
61    }
62
63    int cmp_abs(const HitCounter& other) const { return count - other.count; }
64    int cmp_rel(const HitCounter& other) const { return double_cmp(rel_count, other.rel_count); }
65
66    int get_match_count() const { return count; }
67    const double& get_rel_match_count() const { return rel_count; }
68};
69
70class FamilyStat : virtual Noncopyable {
71    size_t                size;
72    HitCounter           *famstat;
73    TraversalHitLimit     trav_info;
74    RelativeScoreScaling  scaling;
75
76public:
77    FamilyStat(size_t size_, RelativeScoreScaling scaling_)
78        : size(size_),
79          famstat(new HitCounter[size]),
80          trav_info(-1, 1),
81          scaling(scaling_)
82    {}
83    ~FamilyStat() { delete [] famstat; }
84
85    void calc_rel_matches(int oligo_len, int sequence_length)  {
86        for (size_t i = 0; i < size; i++) {
87            int full_length = 0;
88            switch (scaling) {
89                case RSS_SOURCE:   full_length = sequence_length; break;
90                case RSS_TARGET:   full_length = psg.data[i].get_size(); break;
91                case RSS_BOTH_MIN: full_length = std::min(psg.data[i].get_size(), sequence_length); break;
92                case RSS_BOTH_MAX: full_length = std::max(psg.data[i].get_size(), sequence_length); break;
93            }
94            int max_poss_matches = full_length - oligo_len + 1; // @@@ wrong if target range is used!
95
96            famstat[i].calc_rel_match(max_poss_matches);
97        }
98    }
99
100    const HitCounter& hits(size_t idx) const { pt_assert(idx<size); return famstat[idx]; }
101
102    void limit_hits_for_next_traversal(int hit_limit) {
103        trav_info.id++;
104        trav_info.limit = hit_limit;
105    }
106    void count_match(size_t idx) { famstat[idx].inc(trav_info); }
107
108    int cmp_abs(int a, int b) const { int cmp = famstat[a].cmp_abs(famstat[b]); return cmp ? cmp : a-b; }
109    int cmp_rel(int a, int b) const { int cmp = famstat[a].cmp_rel(famstat[b]); return cmp ? cmp : a-b; }
110};
111
112class PT_Traversal {
113    static Range range;
114
115    const char *oligo;
116    int         height;
117    int         needed_positions;
118    int         accept_mismatches;
119
120    FamilyStat& fam_stat;
121
122    void count_match(const AbsLoc& match) const {
123        if (range.contains(match)) {
124            fam_stat.count_match(match.get_name());
125        }
126    }
127
128    bool at_end() const { return *oligo == PT_QU; }
129
130    bool too_many_mismatches() const { return accept_mismatches<0; }
131
132    bool did_match() const { return needed_positions <= 0 && !too_many_mismatches(); }
133    bool need_match() const { return needed_positions > 0 && !too_many_mismatches(); }
134    bool match_possible() const { return need_match() && !at_end(); }
135
136    void match_one_char(char c) {
137        pt_assert(match_possible()); // avoid unneeded calls
138
139        if (*oligo++ != c) accept_mismatches--;
140        needed_positions--;
141        height++;
142    }
143
144    void match_rest_and_mark(const ReadableDataLoc& loc) {
145        do match_one_char(loc[height]); while (match_possible());
146        if (did_match()) count_match(loc);
147    }
148
149    void mark_all(POS_TREE2 *pt) const;
150    inline void mark_chain_or_leaf(POS_TREE2 *pt) const;
151
152public:
153
154    static void restrictMatchesToRegion(int start, int end, int oligo_len) {
155        range  = Range(start, end, oligo_len);
156    }
157    static void unrestrictMatchesToRegion() {
158        range  = Range(-1, -1, -1);
159    }
160
161    PT_Traversal(const char *oligo_, int needed_positions_, int accept_mismatches_, FamilyStat& fam_stat_)
162        : oligo(oligo_),
163          height(0),
164          needed_positions(needed_positions_),
165          accept_mismatches(accept_mismatches_),
166          fam_stat(fam_stat_)
167    { }
168
169    void mark_matching(POS_TREE2 *pt) const;
170
171    int operator()(const DataLoc& loc) const {
172        //! Increment match_count for matched postree-tips
173        if (did_match()) count_match(loc);
174        else if (match_possible()) {
175            PT_Traversal(*this).match_rest_and_mark(ReadableDataLoc(loc)); // @@@ EXPENSIVE_CONVERSION
176        }
177        return 0;
178    }
179    int operator()(const AbsLoc& loc) const {
180        //! Increment match_count for matched postree-tips
181        if (did_match()) count_match(loc);
182        else if (match_possible()) {
183            PT_Traversal(*this).match_rest_and_mark(ReadableDataLoc(DataLoc(loc))); // @@@ VERY EXPENSIVE_CONVERSION (2)
184        }
185        return 0;
186    }
187};
188
189Range PT_Traversal::range(-1, -1, -1);
190
191inline void PT_Traversal::mark_chain_or_leaf(POS_TREE2 *pt) const {
192    pt_assert(pt);
193    switch (pt->get_type()) {
194        case PT2_LEAF:
195            (*this)(DataLoc(pt));
196            break;
197
198        case PT2_CHAIN: {
199            ChainIteratorStage2 entry(pt);
200            while (entry) {
201                (*this)(entry.at());
202                ++entry;
203            }
204            break;
205        }
206        case PT2_NODE:
207            pt_assert(0); // not called with chain or leaf
208            break;
209    }
210}
211
212void PT_Traversal::mark_matching(POS_TREE2 *pt) const {
213    //! Traverse pos(sub)tree through matching branches and increment 'match_count'
214    pt_assert(pt);
215    pt_assert(!too_many_mismatches());
216    pt_assert(!did_match());
217
218    if (pt->is_node()) {
219        for (int base = PT_N; base < PT_BASES; base++) {
220            POS_TREE2 *pt_son = PT_read_son(pt, (PT_base)base);
221            if (pt_son && !at_end()) {
222                PT_Traversal sub(*this);
223                sub.match_one_char(base);
224                if (!sub.too_many_mismatches()) {
225                    if (sub.did_match()) sub.mark_all(pt_son);
226                    else sub.mark_matching(pt_son);
227                }
228            }
229        }
230    }
231    else {
232        mark_chain_or_leaf(pt);
233    }
234}
235
236void PT_Traversal::mark_all(POS_TREE2 *pt) const {
237    pt_assert(pt);
238    pt_assert(!too_many_mismatches());
239    pt_assert(did_match());
240
241    if (pt->is_node()) {
242        for (int base = PT_N; base < PT_BASES; base++) {
243            POS_TREE2 *pt_son = PT_read_son(pt, (PT_base)base);
244            if (pt_son) mark_all(pt_son);
245        }
246    }
247    else {
248        mark_chain_or_leaf(pt);
249    }
250}
251
252struct oligo_cmp_abs {
253    const FamilyStat& fam_stat;
254    oligo_cmp_abs(const FamilyStat& fam_stat_) : fam_stat(fam_stat_) {}
255    bool operator()(int a, int b) { return fam_stat.cmp_abs(a, b) > 0; } // biggest scores 1st
256};
257
258struct oligo_cmp_rel {
259    const FamilyStat& fam_stat;
260    oligo_cmp_rel(const FamilyStat& fam_stat_) : fam_stat(fam_stat_) {}
261    bool operator()(int a, int b) { return fam_stat.cmp_rel(a, b) > 0; } // biggest scores 1st
262};
263
264static int make_PT_family_list(PT_family *ffinder, const FamilyStat& famStat) {
265    //!  Make sorted list of family members
266
267    // destroy old list
268    while (ffinder->fl) destroy_PT_family_list(ffinder->fl);
269
270    // Sort the data
271    std::vector<int> sorted;
272    sorted.resize(psg.data_count);
273
274    size_t matching_results = psg.data_count;
275    if (ffinder->min_score == 0) { // collect all hits
276        for (int i = 0; i < psg.data_count; i++) sorted[i] = i; // LOOP_VECTORIZED[!<5.0,!>8.0<10]
277    }
278    else {
279        int j = 0;
280        if (ffinder->sort_type == 0) { // filter by absolut score
281            double min_score = ffinder->min_score;
282            for (int i = 0; i < psg.data_count; i++) {
283                const HitCounter& ps = famStat.hits(i);
284                if (ps.get_match_count() >= min_score) {
285                    sorted[j++] = i;
286                }
287            }
288        }
289        else { // filter by relative score
290            double min_score_rel = double(ffinder->min_score)/100.0;
291            for (int i = 0; i < psg.data_count; i++) {
292                const HitCounter& ps = famStat.hits(i);
293                if (ps.get_rel_match_count()>min_score_rel) {
294                    sorted[j++] = i;
295                }
296            }
297        }
298        matching_results = j;
299    }
300
301    bool sort_all = ffinder->sort_max == 0 || ffinder->sort_max >= int(matching_results);
302
303    if (ffinder->sort_type == 0) { // sort by absolut score
304        if (sort_all) {
305            std::sort(sorted.begin(), sorted.end(), oligo_cmp_abs(famStat));
306        }
307        else {
308            std::partial_sort(sorted.begin(), sorted.begin() + ffinder->sort_max, sorted.begin() + matching_results, oligo_cmp_abs(famStat));
309        }
310    }
311    else { // sort by relative score
312        if (sort_all) {
313            std::sort(sorted.begin(), sorted.begin() + psg.data_count, oligo_cmp_rel(famStat));
314        }
315        else {
316            std::partial_sort(sorted.begin(), sorted.begin() + ffinder->sort_max, sorted.begin() + matching_results, oligo_cmp_rel(famStat));
317        }
318    }
319
320    // build new list
321    int real_hits = 0;
322
323    int end = (sort_all) ? matching_results : ffinder->sort_max;
324    for (int i = 0; i < end; i++) {
325        probe_input_data& pid = psg.data[sorted[i]];
326        const HitCounter& ps  = famStat.hits(sorted[i]);
327
328        if (ps.get_match_count() != 0) {
329            PT_family_list *fl = create_PT_family_list();
330
331            fl->name        = ARB_strdup(pid.get_shortname());
332            fl->matches     = ps.get_match_count();
333            fl->rel_matches = ps.get_rel_match_count();
334
335            aisc_link(&ffinder->pfl, fl);
336            real_hits++;
337        }
338    }
339
340    ffinder->list_size = real_hits;
341
342    return 0;
343}
344
345inline bool contains_ambiguities(char *oligo, int oligo_len) {
346    //! Check the oligo for ambiguities
347    for (int i = 0; i < oligo_len; i++) {
348        if (!is_std_base(oligo[i])) {
349            return true;
350        }
351    }
352    return false;
353}
354
355class oligo_comparator {
356    int oligo_len;
357public:
358    oligo_comparator(int len) : oligo_len(len) {}
359    bool operator()(const char *p1, const char *p2) const {
360        bool isless = false;
361        for (int o = 0; o<oligo_len; ++o) {
362            if (p1[o] != p2[o]) {
363                isless = p1[o]<p2[o];
364                break;
365            }
366        }
367        return isless;
368    }
369};
370
371typedef std::map<const char *, int, oligo_comparator> OligoMap;
372typedef OligoMap::const_iterator                      OligoIter;
373
374class OligoRegistry {
375    OligoMap oligos;
376public:
377    OligoRegistry(int oligo_len)
378        : oligos(oligo_comparator(oligo_len))
379    {}
380    void add(const char *seq) {
381        OligoMap::iterator found = oligos.find(seq);
382        if (found == oligos.end()) oligos[seq] = 1;
383        else found->second++;
384    }
385    OligoIter begin() { return oligos.begin(); }
386    OligoIter end() { return oligos.end(); }
387};
388
389int find_family(PT_family *ffinder, bytestring *species) {
390    //! make sorted list of family members of species
391
392    int oligo_len = ffinder->pr_len;
393
394    if (oligo_len<1) {
395        freedup(ffinder->ff_error, "minimum oligo length is 1");
396    }
397    else {
398        int mismatch_nr = ffinder->mis_nr;
399        int complement  = ffinder->complement; // any combination of: 1 = forward, 2 = reverse, 4 = reverse-complement, 8 = complement
400
401        char *sequence     = species->data; // sequence data passed by caller
402        int   sequence_len = probe_compress_sequence(sequence, species->size-1);
403
404        bool use_all_oligos = ffinder->only_A_probes == 0;
405
406        PT_Traversal::restrictMatchesToRegion(ffinder->range_start, ffinder->range_end, oligo_len);
407
408        FamilyStat famStat(psg.data_count, RelativeScoreScaling(ffinder->rel_scoring));
409
410        char *seq[4];
411        int   seq_count = 0;
412
413        // Note: loop-logic depends on order of ../AWTC/awtc_next_neighbours.hxx@FF_complement_dep
414        for (int cmode = 1; cmode <= 8; cmode *= 2) {
415            switch (cmode) {
416                case FF_FORWARD:
417                    break;
418                case FF_REVERSE:
419                case FF_COMPLEMENT:
420                    reverse_probe(sequence, sequence_len); // build reverse sequence
421                    break;
422                case FF_REVERSE_COMPLEMENT:
423                    complement_probe(sequence, sequence_len); // build complement sequence
424                    break;
425            }
426
427            if ((complement&cmode) != 0) {
428                char *s = ARB_alloc<char>(sequence_len+1);
429
430                memcpy(s, sequence, sequence_len);
431                s[sequence_len] = 0;
432
433                seq[seq_count++] = s;
434            }
435        }
436
437        OligoRegistry occurring_oligos(oligo_len);
438
439        for (int s = 0; s<seq_count; s++) {
440            char *last_oligo = seq[s]+sequence_len-oligo_len;
441            for (char *oligo = seq[s]; oligo < last_oligo; ++oligo) {
442                if (use_all_oligos || oligo[0] == PT_A) {
443                    if (!contains_ambiguities(oligo, oligo_len)) {
444                        occurring_oligos.add(oligo);
445                    }
446                }
447            }
448        }
449
450        for (OligoIter o = occurring_oligos.begin(); o != occurring_oligos.end(); ++o)  {
451            const char *oligo       = o->first;
452            int         occur_count = o->second;
453
454            famStat.limit_hits_for_next_traversal(occur_count);
455            PT_Traversal(oligo, oligo_len, mismatch_nr, famStat).mark_matching(psg.TREE_ROOT2());
456        }
457
458        famStat.calc_rel_matches(ffinder->pr_len, sequence_len);
459        make_PT_family_list(ffinder, famStat);
460
461        for (int s = 0; s<seq_count; s++) {
462            free(seq[s]);
463        }
464
465        PT_Traversal::unrestrictMatchesToRegion();
466    }
467    free(species->data);
468    return 0;
469}
Note: See TracBrowser for help on using the repository browser.