source: tags/arb-6.0/GDE/MAFFT/mafft-7.055-with-extensions/extensions/mxscarna_src/AlifoldMEA.cpp

Last change on this file was 10371, checked in by aboeckma, 11 years ago

updated mafft version. Added extensions (no svn ignore, yet)

File size: 4.4 KB
Line 
1#include "AlifoldMEA.h"
2
3namespace MXSCARNA{
4
5const int AlifoldMEA::TURN = 3;
6
7void
8AlifoldMEA::
9Run()
10{
11    makeProfileBPPMatrix(alignment);
12    Initialization();
13    DP();
14    TraceBack();
15}
16
17void
18AlifoldMEA::
19makeProfileBPPMatrix(const MultiSequence *Sequences)
20{
21    int length = Sequences->GetSequence(0)->GetLength();
22
23    Trimat<float> *consBppMat = new Trimat<float>(length + 1);
24    fill(consBppMat->begin(), consBppMat->end(), 0);
25
26    for(int i = 1; i <= length; i++) 
27        for (int j = i; j <= length; j++) 
28            bppMat.ref(i, j) = 0;
29
30
31    int number = Sequences->GetNumSequences();
32    for(int seqNum = 0; seqNum < number; seqNum++) {
33        SafeVector<int> *tmpMap = Sequences->GetSequence(seqNum)->GetMappingNumber();
34        int label = Sequences->GetSequence(seqNum)->GetLabel();
35        BPPMatrix *tmpBppMatrix = BPPMatrices[label];
36       
37        for(int i = 1; i <= length ; i++) {
38            int originI = tmpMap->at(i);
39            for(int j = i; j <= length; j++) {
40                int originJ = tmpMap->at(j);
41                if(originI != 0 && originJ != 0) {
42                    float tmpProb = tmpBppMatrix->GetProb(originI, originJ);
43                    bppMat.ref(i, j) += tmpProb;
44                }
45            }
46        }
47    }
48
49        /* compute the mean of base pairing probability  */
50    for(int i = 1; i <= length; i++) {
51        for(int j = i; j <= length; j++) {
52            bppMat.ref(i,j) = bppMat.ref(i,j)/(float)number;
53        }
54    }
55
56    for (int i = 1; i <= length; i++) {
57        float sum = 0;
58        for (int j = i; j <= length; j++) {
59            sum += bppMat.ref(i,j);
60        }
61        Qi[i] = 1 - sum;
62    }
63
64    for (int i = 1; i <= length; i++) {
65        float sum = 0;
66        for (int j = i; j >= 1; j--) {
67            sum += bppMat.ref(j, i);
68        }
69        Qj[i] = 1 - sum;
70    }
71}
72
73void
74AlifoldMEA::
75Initialization()
76{
77    int length = alignment->GetSequence(0)->GetLength();
78
79    for (int i = 1; i <= length; i++) {
80        for (int j = i; j <= length; j++) {
81            M.ref(i,j) = 0;
82            traceI.ref(i,j) = 0;
83            traceJ.ref(i,j) = 0;
84        }
85    }
86
87    for (int i = 1; i <= length; i++) {
88        M.ref(i,i)   = Qi[i]; 
89        traceI.ref(i,i) = 0;
90        traceJ.ref(i,i) = 0;
91    }
92
93    for (int i = 1; i <= length - 1; i++) {
94        M.ref(i, i+1) =  Qi[i+1];
95        traceI.ref(i,i + 1) = 0;
96        traceJ.ref(i,i + 1) = 0;
97    }
98
99    for (int i = 0; i <= length; i++) {
100        ssCons[i] = '.';
101    }
102}
103
104void
105AlifoldMEA::
106DP()
107{
108    float g    = BasePairConst; // see scarna.hpp
109    int length = alignment->GetSequence(0)->GetLength();
110   
111    for (int i = length - 1; i >= 1; i--) {
112        for (int j = i + TURN + 1; j <= length; j++) {
113            float qi       = Qi[i];
114            float qj       = Qj[j];
115            float p        = bppMat.ref(i,j);
116
117           
118            float maxScore = qi + M.ref(i+1, j);
119            int tmpI = i+1;
120            int tmpJ = j;
121           
122            float tmpScore = qj + M.ref(i, j-1);
123            if (tmpScore > maxScore) {
124                maxScore = tmpScore;
125                tmpI     = i;
126                tmpJ     = j - 1;
127            }
128           
129            tmpScore = g*2*p + M.ref(i+1, j-1);
130            if (tmpScore > maxScore) {
131                maxScore = tmpScore;
132                tmpI     = i + 1;
133                tmpJ     = j - 1;
134            }
135           
136            for (int k = i + 1; k < j - 1; k++) {
137                tmpScore = M.ref(i,k) + M.ref(k+1,j);
138                if (tmpScore > maxScore) {
139                    maxScore = tmpScore;
140                    tmpI = i;
141                    tmpJ = j;
142                }
143            }
144            M.ref(i,j)       = maxScore;
145            traceI.ref(i, j) = tmpI;
146            traceJ.ref(i, j) = tmpJ;
147        }
148    }
149}
150
151void
152AlifoldMEA::
153TraceBack()
154{
155
156    int length = alignment->GetSequence(0)->GetLength();
157    SafeVector<int> stackI((length + 1)*(length+1));
158    SafeVector<int> stackJ((length + 1)*(length+1));
159    int pt = 0;
160
161    stackI[pt] = traceI.ref(1, length);
162    stackJ[pt] = traceJ.ref(1, length);
163    ++pt;
164   
165    while(pt != 0) {
166        --pt;
167        int tmpI = stackI[pt];
168        int tmpJ = stackJ[pt];
169        int nextI = traceI.ref(tmpI, tmpJ);
170        int nextJ = traceJ.ref(tmpI, tmpJ);
171
172        if (tmpI < tmpJ) {
173            if (tmpI + 1  == nextI && tmpJ == nextJ) {
174                stackI[pt] = nextI;
175                stackJ[pt] = nextJ;
176                ++pt;
177            }
178            else if (tmpI == nextI && tmpJ - 1 == nextJ) {
179                stackI[pt] = nextI;
180                stackJ[pt] = nextJ;
181                ++pt;
182            }
183            else if (tmpI + 1 == nextI && tmpJ - 1== nextJ) {
184                stackI[pt] = nextI;
185                stackJ[pt] = nextJ;
186                ++pt;
187                ssCons[tmpI] = '(';
188                ssCons[tmpJ] = ')';
189            }
190            else if (tmpI == nextI && tmpJ == nextJ) {
191                float maxScore = IMPOSSIBLE;
192                int maxK = 0;
193
194                for (int k = tmpI + 1; k < tmpJ - 1; k++) {
195                    float tmpScore = M.ref(tmpI,k) + M.ref(k+1,tmpJ);
196                    if (tmpScore > maxScore) {
197                        maxScore = tmpScore;
198                        maxK = k;
199                    }
200                }
201                stackI[pt] = traceI.ref(tmpI, maxK);
202                stackJ[pt] = traceJ.ref(tmpI, maxK);
203                ++pt;
204                stackI[pt] = traceI.ref(maxK+1, tmpJ);
205                stackJ[pt] = traceJ.ref(maxK+1, tmpJ);
206                ++pt;
207            }
208        }
209    }
210}
211}
Note: See TracBrowser for help on using the repository browser.