source: branches/items/GDE/PROBCONS/probcons/ProbabilisticModel.h

Last change on this file was 10405, checked in by aboeckma, 11 years ago

added probcons

File size: 39.2 KB
Line 
1/////////////////////////////////////////////////////////////////
2// ProbabilisticModel.h
3//
4// Routines for (1) posterior probability computations
5//              (2) chained anchoring
6//              (3) maximum weight trace alignment
7/////////////////////////////////////////////////////////////////
8
9#ifndef PROBABILISTICMODEL_H
10#define PROBABILISTICMODEL_H
11
12#include <list>
13#include <cmath>
14#include <cstdio>
15#include "SafeVector.h"
16#include "ScoreType.h"
17#include "SparseMatrix.h"
18#include "MultiSequence.h"
19
20using namespace std;
21
22const int NumMatchStates = 1;                                    // note that in this version the number
23                                                                 // of match states is fixed at 1...will
24                                                                 // change in future versions
25const int NumMatrixTypes = NumMatchStates + NumInsertStates * 2;
26
27/////////////////////////////////////////////////////////////////
28// ProbabilisticModel
29//
30// Class for storing the parameters of a probabilistic model and
31// performing different computations based on those parameters.
32// In particular, this class handles the computation of
33// posterior probabilities that may be used in alignment.
34/////////////////////////////////////////////////////////////////
35
36class ProbabilisticModel {
37
38  float initialDistribution[NumMatrixTypes];               // holds the initial probabilities for each state
39  float transProb[NumMatrixTypes][NumMatrixTypes];         // holds all state-to-state transition probabilities
40  float matchProb[256][256];                               // emission probabilities for match states
41  float insProb[256][NumMatrixTypes];                      // emission probabilities for insert states
42
43 public:
44
45  /////////////////////////////////////////////////////////////////
46  // ProbabilisticModel::ProbabilisticModel()
47  //
48  // Constructor.  Builds a new probabilistic model using the
49  // given parameters.
50  /////////////////////////////////////////////////////////////////
51
52  ProbabilisticModel (const VF &initDistribMat, const VF &gapOpen, const VF &gapExtend,
53                      const VVF &emitPairs, const VF &emitSingle){
54
55    // build transition matrix
56    VVF transMat (NumMatrixTypes, VF (NumMatrixTypes, 0.0f));
57    transMat[0][0] = 1;
58    for (int i = 0; i < NumInsertStates; i++){
59      transMat[0][2*i+1] = gapOpen[2*i];
60      transMat[0][2*i+2] = gapOpen[2*i+1];
61      transMat[0][0] -= (gapOpen[2*i] + gapOpen[2*i+1]);
62      assert (transMat[0][0] > 0);
63      transMat[2*i+1][2*i+1] = gapExtend[2*i];
64      transMat[2*i+2][2*i+2] = gapExtend[2*i+1];
65      transMat[2*i+1][2*i+2] = 0;
66      transMat[2*i+2][2*i+1] = 0;
67      transMat[2*i+1][0] = 1 - gapExtend[2*i];
68      transMat[2*i+2][0] = 1 - gapExtend[2*i+1];
69    }
70
71    // create initial and transition probability matrices
72    for (int i = 0; i < NumMatrixTypes; i++){
73      initialDistribution[i] = LOG (initDistribMat[i]);
74      for (int j = 0; j < NumMatrixTypes; j++)
75        transProb[i][j] = LOG (transMat[i][j]);
76    }
77
78    // create insertion and match probability matrices
79    for (int i = 0; i < 256; i++){
80      for (int j = 0; j < NumMatrixTypes; j++)
81        insProb[i][j] = LOG (emitSingle[i]);
82      for (int j = 0; j < 256; j++)
83        matchProb[i][j] = LOG (emitPairs[i][j]);
84    }
85  }
86
87  /////////////////////////////////////////////////////////////////
88  // ProbabilisticModel::ComputeForwardMatrix()
89  //
90  // Computes a set of forward probability matrices for aligning
91  // seq1 and seq2.
92  //
93  // For efficiency reasons, a single-dimensional floating-point
94  // array is used here, with the following indexing scheme:
95  //
96  //    forward[i + NumMatrixTypes * (j * (seq2Length+1) + k)]
97  //    refers to the probability of aligning through j characters
98  //    of the first sequence, k characters of the second sequence,
99  //    and ending in state i.
100  /////////////////////////////////////////////////////////////////
101
102  VF *ComputeForwardMatrix (Sequence *seq1, Sequence *seq2) const {
103
104    assert (seq1);
105    assert (seq2);
106
107    const int seq1Length = seq1->GetLength();
108    const int seq2Length = seq2->GetLength();
109
110    // retrieve the points to the beginning of each sequence
111    SafeVector<char>::iterator iter1 = seq1->GetDataPtr();
112    SafeVector<char>::iterator iter2 = seq2->GetDataPtr();
113
114    // create matrix
115    VF *forwardPtr = new VF (NumMatrixTypes * (seq1Length+1) * (seq2Length+1), LOG_ZERO);
116    assert (forwardPtr);
117    VF &forward = *forwardPtr;
118
119    // initialization condition
120    forward[0 + NumMatrixTypes * (1 * (seq2Length+1) + 1)] = 
121      initialDistribution[0] + matchProb[(unsigned char) iter1[1]][(unsigned char) iter2[1]];
122   
123    for (int k = 0; k < NumInsertStates; k++){
124      forward[2*k+1 + NumMatrixTypes * (1 * (seq2Length+1) + 0)] = 
125        initialDistribution[2*k+1] + insProb[(unsigned char) iter1[1]][k];
126      forward[2*k+2 + NumMatrixTypes * (0 * (seq2Length+1) + 1)] = 
127        initialDistribution[2*k+2] + insProb[(unsigned char) iter2[1]][k]; 
128    }
129   
130    // remember offset for each index combination
131    int ij = 0;
132    int i1j = -seq2Length - 1;
133    int ij1 = -1;
134    int i1j1 = -seq2Length - 2;
135
136    ij *= NumMatrixTypes;
137    i1j *= NumMatrixTypes;
138    ij1 *= NumMatrixTypes;
139    i1j1 *= NumMatrixTypes;
140
141    // compute forward scores
142    for (int i = 0; i <= seq1Length; i++){
143      unsigned char c1 = (i == 0) ? '~' : (unsigned char) iter1[i];
144      for (int j = 0; j <= seq2Length; j++){
145        unsigned char c2 = (j == 0) ? '~' : (unsigned char) iter2[j];
146
147        if (i > 1 || j > 1){
148          if (i > 0 && j > 0){
149            forward[0 + ij] = forward[0 + i1j1] + transProb[0][0];
150            for (int k = 1; k < NumMatrixTypes; k++)
151              LOG_PLUS_EQUALS (forward[0 + ij], forward[k + i1j1] + transProb[k][0]);
152            forward[0 + ij] += matchProb[c1][c2];
153          }
154          if (i > 0){
155            for (int k = 0; k < NumInsertStates; k++)
156              forward[2*k+1 + ij] = insProb[c1][k] +
157                LOG_ADD (forward[0 + i1j] + transProb[0][2*k+1],
158                         forward[2*k+1 + i1j] + transProb[2*k+1][2*k+1]);
159          }
160          if (j > 0){
161            for (int k = 0; k < NumInsertStates; k++)
162              forward[2*k+2 + ij] = insProb[c2][k] +
163                LOG_ADD (forward[0 + ij1] + transProb[0][2*k+2],
164                         forward[2*k+2 + ij1] + transProb[2*k+2][2*k+2]);
165          }
166        }
167
168        ij += NumMatrixTypes;
169        i1j += NumMatrixTypes;
170        ij1 += NumMatrixTypes;
171        i1j1 += NumMatrixTypes;
172      }
173    }
174
175    return forwardPtr;
176  }
177
178  /////////////////////////////////////////////////////////////////
179  // ProbabilisticModel::ComputeBackwardMatrix()
180  //
181  // Computes a set of backward probability matrices for aligning
182  // seq1 and seq2.
183  //
184  // For efficiency reasons, a single-dimensional floating-point
185  // array is used here, with the following indexing scheme:
186  //
187  //    backward[i + NumMatrixTypes * (j * (seq2Length+1) + k)]
188  //    refers to the probability of starting in state i and
189  //    aligning from character j+1 to the end of the first
190  //    sequence and from character k+1 to the end of the second
191  //    sequence.
192  /////////////////////////////////////////////////////////////////
193
194  VF *ComputeBackwardMatrix (Sequence *seq1, Sequence *seq2) const {
195
196    assert (seq1);
197    assert (seq2);
198
199    const int seq1Length = seq1->GetLength();
200    const int seq2Length = seq2->GetLength();
201    SafeVector<char>::iterator iter1 = seq1->GetDataPtr();
202    SafeVector<char>::iterator iter2 = seq2->GetDataPtr();
203
204    // create matrix
205    VF *backwardPtr = new VF (NumMatrixTypes * (seq1Length+1) * (seq2Length+1), LOG_ZERO);
206    assert (backwardPtr);
207    VF &backward = *backwardPtr;
208
209    // initialization condition
210    for (int k = 0; k < NumMatrixTypes; k++)
211      backward[NumMatrixTypes * ((seq1Length+1) * (seq2Length+1) - 1) + k] = initialDistribution[k];
212
213    // remember offset for each index combination
214    int ij = (seq1Length+1) * (seq2Length+1) - 1;
215    int i1j = ij + seq2Length + 1;
216    int ij1 = ij + 1;
217    int i1j1 = ij + seq2Length + 2;
218
219    ij *= NumMatrixTypes;
220    i1j *= NumMatrixTypes;
221    ij1 *= NumMatrixTypes;
222    i1j1 *= NumMatrixTypes;
223
224    // compute backward scores
225    for (int i = seq1Length; i >= 0; i--){
226      unsigned char c1 = (i == seq1Length) ? '~' : (unsigned char) iter1[i+1];
227      for (int j = seq2Length; j >= 0; j--){
228        unsigned char c2 = (j == seq2Length) ? '~' : (unsigned char) iter2[j+1];
229
230        if (i < seq1Length && j < seq2Length){
231          const float ProbXY = backward[0 + i1j1] + matchProb[c1][c2];
232          for (int k = 0; k < NumMatrixTypes; k++)
233            LOG_PLUS_EQUALS (backward[k + ij], ProbXY + transProb[k][0]);
234        }
235        if (i < seq1Length){
236          for (int k = 0; k < NumInsertStates; k++){
237            LOG_PLUS_EQUALS (backward[0 + ij], backward[2*k+1 + i1j] + insProb[c1][k] + transProb[0][2*k+1]);
238            LOG_PLUS_EQUALS (backward[2*k+1 + ij], backward[2*k+1 + i1j] + insProb[c1][k] + transProb[2*k+1][2*k+1]);
239          }
240        }
241        if (j < seq2Length){
242          for (int k = 0; k < NumInsertStates; k++){
243            LOG_PLUS_EQUALS (backward[0 + ij], backward[2*k+2 + ij1] + insProb[c2][k] + transProb[0][2*k+2]);
244            LOG_PLUS_EQUALS (backward[2*k+2 + ij], backward[2*k+2 + ij1] + insProb[c2][k] + transProb[2*k+2][2*k+2]);
245          }
246        }
247
248        ij -= NumMatrixTypes;
249        i1j -= NumMatrixTypes;
250        ij1 -= NumMatrixTypes;
251        i1j1 -= NumMatrixTypes;
252      }
253    }
254
255    return backwardPtr;
256  }
257
258  /////////////////////////////////////////////////////////////////
259  // ProbabilisticModel::ComputeTotalProbability()
260  //
261  // Computes the total probability of an alignment given
262  // the forward and backward matrices.
263  /////////////////////////////////////////////////////////////////
264
265  float ComputeTotalProbability (int seq1Length, int seq2Length,
266                                 const VF &forward, const VF &backward) const {
267
268    // compute total probability
269    float totalForwardProb = LOG_ZERO;
270    float totalBackwardProb = LOG_ZERO;
271    for (int k = 0; k < NumMatrixTypes; k++){
272      LOG_PLUS_EQUALS (totalForwardProb,
273                       forward[k + NumMatrixTypes * ((seq1Length+1) * (seq2Length+1) - 1)] + 
274                       backward[k + NumMatrixTypes * ((seq1Length+1) * (seq2Length+1) - 1)]);
275    }
276
277    totalBackwardProb = 
278      forward[0 + NumMatrixTypes * (1 * (seq2Length+1) + 1)] +
279      backward[0 + NumMatrixTypes * (1 * (seq2Length+1) + 1)];
280
281    for (int k = 0; k < NumInsertStates; k++){
282      LOG_PLUS_EQUALS (totalBackwardProb,
283                       forward[2*k+1 + NumMatrixTypes * (1 * (seq2Length+1) + 0)] +
284                       backward[2*k+1 + NumMatrixTypes * (1 * (seq2Length+1) + 0)]);
285      LOG_PLUS_EQUALS (totalBackwardProb,
286                       forward[2*k+2 + NumMatrixTypes * (0 * (seq2Length+1) + 1)] +
287                       backward[2*k+2 + NumMatrixTypes * (0 * (seq2Length+1) + 1)]);
288    }
289
290    //    cerr << totalForwardProb << " " << totalBackwardProb << endl;
291   
292    return (totalForwardProb + totalBackwardProb) / 2;
293  }
294
295  /////////////////////////////////////////////////////////////////
296  // ProbabilisticModel::ComputePosteriorMatrix()
297  //
298  // Computes the posterior probability matrix based on
299  // the forward and backward matrices.
300  /////////////////////////////////////////////////////////////////
301
302  VF *ComputePosteriorMatrix (Sequence *seq1, Sequence *seq2,
303                              const VF &forward, const VF &backward) const {
304
305    assert (seq1);
306    assert (seq2);
307
308    const int seq1Length = seq1->GetLength();
309    const int seq2Length = seq2->GetLength();
310
311    float totalProb = ComputeTotalProbability (seq1Length, seq2Length,
312                                               forward, backward);
313
314    // compute posterior matrices
315    VF *posteriorPtr = new VF((seq1Length+1) * (seq2Length+1)); assert (posteriorPtr);
316    VF &posterior = *posteriorPtr;
317
318    int ij = 0;
319    VF::iterator ptr = posterior.begin();
320
321    for (int i = 0; i <= seq1Length; i++){
322      for (int j = 0; j <= seq2Length; j++){
323        *(ptr++) = EXP (min (LOG_ONE, forward[ij] + backward[ij] - totalProb));
324        ij += NumMatrixTypes;
325      }
326    }
327
328    posterior[0] = 0;
329
330    return posteriorPtr;
331  }
332
333  /*
334  /////////////////////////////////////////////////////////////////
335  // ProbabilisticModel::ComputeExpectedCounts()
336  //
337  // Computes the expected counts for the various transitions.
338  /////////////////////////////////////////////////////////////////
339
340  VVF *ComputeExpectedCounts () const {
341
342    assert (seq1);
343    assert (seq2);
344
345    const int seq1Length = seq1->GetLength();
346    const int seq2Length = seq2->GetLength();
347    SafeVector<char>::iterator iter1 = seq1->GetDataPtr();
348    SafeVector<char>::iterator iter2 = seq2->GetDataPtr();
349
350    // compute total probability
351    float totalProb = ComputeTotalProbability (seq1Length, seq2Length,
352                                               forward, backward);
353
354    // initialize expected counts
355    VVF *countsPtr = new VVF(NumMatrixTypes + 1, VF(NumMatrixTypes, LOG_ZERO)); assert (countsPtr);
356    VVF &counts = *countsPtr;
357
358    // remember offset for each index combination
359    int ij = 0;
360    int i1j = -seq2Length - 1;
361    int ij1 = -1;
362    int i1j1 = -seq2Length - 2;
363
364    ij *= NumMatrixTypes;
365    i1j *= NumMatrixTypes;
366    ij1 *= NumMatrixTypes;
367    i1j1 *= NumMatrixTypes;
368
369    // compute expected counts
370    for (int i = 0; i <= seq1Length; i++){
371      unsigned char c1 = (i == 0) ? '~' : (unsigned char) iter1[i];
372      for (int j = 0; j <= seq2Length; j++){
373        unsigned char c2 = (j == 0) ? '~' : (unsigned char) iter2[j];
374
375        if (i > 0 && j > 0){
376          for (int k = 0; k < NumMatrixTypes; k++)
377            LOG_PLUS_EQUALS (counts[k][0],
378                             forward[k + i1j1] + transProb[k][0] +
379                             matchProb[c1][c2] + backward[0 + ij]);
380        }
381        if (i > 0){
382          for (int k = 0; k < NumInsertStates; k++){
383            LOG_PLUS_EQUALS (counts[0][2*k+1],
384                             forward[0 + i1j] + transProb[0][2*k+1] +
385                             insProb[c1][k] + backward[2*k+1 + ij]);
386            LOG_PLUS_EQUALS (counts[2*k+1][2*k+1],
387                             forward[2*k+1 + i1j] + transProb[2*k+1][2*k+1] +
388                             insProb[c1][k] + backward[2*k+1 + ij]);
389          }
390        }
391        if (j > 0){
392          for (int k = 0; k < NumInsertStates; k++){
393            LOG_PLUS_EQUALS (counts[0][2*k+2],
394                             forward[0 + ij1] + transProb[0][2*k+2] +
395                             insProb[c2][k] + backward[2*k+2 + ij]);
396            LOG_PLUS_EQUALS (counts[2*k+2][2*k+2],
397                             forward[2*k+2 + ij1] + transProb[2*k+2][2*k+2] +
398                             insProb[c2][k] + backward[2*k+2 + ij]);
399          }
400        }
401
402        ij += NumMatrixTypes;
403        i1j += NumMatrixTypes;
404        ij1 += NumMatrixTypes;
405        i1j1 += NumMatrixTypes;
406      }
407    }
408
409    // scale all expected counts appropriately
410    for (int i = 0; i < NumMatrixTypes; i++)
411      for (int j = 0; j < NumMatrixTypes; j++)
412        counts[i][j] -= totalProb;
413
414  }
415  */
416
417  /////////////////////////////////////////////////////////////////
418  // ProbabilisticModel::ComputeNewParameters()
419  //
420  // Computes a new parameter set based on the expected counts
421  // given.
422  /////////////////////////////////////////////////////////////////
423
424  void ComputeNewParameters (Sequence *seq1, Sequence *seq2,
425                             const VF &forward, const VF &backward,
426                             VF &initDistribMat, VF &gapOpen,
427                             VF &gapExtend, VVF &emitPairs, VF &emitSingle, bool enableTrainEmissions) const {
428   
429    assert (seq1);
430    assert (seq2);
431
432    const int seq1Length = seq1->GetLength();
433    const int seq2Length = seq2->GetLength();
434    SafeVector<char>::iterator iter1 = seq1->GetDataPtr();
435    SafeVector<char>::iterator iter2 = seq2->GetDataPtr();
436
437    // compute total probability
438    float totalProb = ComputeTotalProbability (seq1Length, seq2Length,
439                                               forward, backward);
440   
441    // initialize expected counts
442    VVF transCounts (NumMatrixTypes, VF (NumMatrixTypes, LOG_ZERO));
443    VF initCounts (NumMatrixTypes, LOG_ZERO);
444    VVF pairCounts (256, VF (256, LOG_ZERO));
445    VF singleCounts (256, LOG_ZERO);
446   
447    // remember offset for each index combination
448    int ij = 0;
449    int i1j = -seq2Length - 1;
450    int ij1 = -1;
451    int i1j1 = -seq2Length - 2;
452
453    ij *= NumMatrixTypes;
454    i1j *= NumMatrixTypes;
455    ij1 *= NumMatrixTypes;
456    i1j1 *= NumMatrixTypes;
457
458    // compute initial distribution posteriors
459    initCounts[0] = LOG_ADD (forward[0 + NumMatrixTypes * (1 * (seq2Length+1) + 1)] +
460                             backward[0 + NumMatrixTypes * (1 * (seq2Length+1) + 1)],
461                             forward[0 + NumMatrixTypes * ((seq1Length+1) * (seq2Length+1) - 1)] + 
462                             backward[0 + NumMatrixTypes * ((seq1Length+1) * (seq2Length+1) - 1)]);
463    for (int k = 0; k < NumInsertStates; k++){
464      initCounts[2*k+1] = LOG_ADD (forward[2*k+1 + NumMatrixTypes * (1 * (seq2Length+1) + 0)] +
465                                   backward[2*k+1 + NumMatrixTypes * (1 * (seq2Length+1) + 0)],
466                                   forward[2*k+1 + NumMatrixTypes * ((seq1Length+1) * (seq2Length+1) - 1)] + 
467                                   backward[2*k+1 + NumMatrixTypes * ((seq1Length+1) * (seq2Length+1) - 1)]);
468      initCounts[2*k+2] = LOG_ADD (forward[2*k+2 + NumMatrixTypes * (0 * (seq2Length+1) + 1)] +
469                                   backward[2*k+2 + NumMatrixTypes * (0 * (seq2Length+1) + 1)],
470                                   forward[2*k+2 + NumMatrixTypes * ((seq1Length+1) * (seq2Length+1) - 1)] + 
471                                   backward[2*k+2 + NumMatrixTypes * ((seq1Length+1) * (seq2Length+1) - 1)]);
472    }
473
474    // compute expected counts
475    for (int i = 0; i <= seq1Length; i++){
476      unsigned char c1 = (i == 0) ? '~' : (unsigned char) toupper(iter1[i]);
477      for (int j = 0; j <= seq2Length; j++){
478        unsigned char c2 = (j == 0) ? '~' : (unsigned char) toupper(iter2[j]);
479
480        if (i > 0 && j > 0){
481          if (enableTrainEmissions && i == 1 && j == 1){
482            LOG_PLUS_EQUALS (pairCounts[c1][c2],
483                             initialDistribution[0] + matchProb[c1][c2] + backward[0 + ij]);
484            LOG_PLUS_EQUALS (pairCounts[c2][c1],
485                             initialDistribution[0] + matchProb[c2][c1] + backward[0 + ij]);
486          }
487
488          for (int k = 0; k < NumMatrixTypes; k++){
489            LOG_PLUS_EQUALS (transCounts[k][0],
490                             forward[k + i1j1] + transProb[k][0] +
491                             matchProb[c1][c2] + backward[0 + ij]);
492            if (enableTrainEmissions && i != 1 || j != 1){
493              LOG_PLUS_EQUALS (pairCounts[c1][c2],
494                               forward[k + i1j1] + transProb[k][0] +
495                               matchProb[c1][c2] + backward[0 + ij]);
496              LOG_PLUS_EQUALS (pairCounts[c2][c1],
497                               forward[k + i1j1] + transProb[k][0] +
498                               matchProb[c2][c1] + backward[0 + ij]);
499            }
500          }
501        }
502        if (i > 0){
503          for (int k = 0; k < NumInsertStates; k++){
504            LOG_PLUS_EQUALS (transCounts[0][2*k+1],
505                             forward[0 + i1j] + transProb[0][2*k+1] +
506                             insProb[c1][k] + backward[2*k+1 + ij]);
507            LOG_PLUS_EQUALS (transCounts[2*k+1][2*k+1],
508                             forward[2*k+1 + i1j] + transProb[2*k+1][2*k+1] +
509                             insProb[c1][k] + backward[2*k+1 + ij]);
510            if (enableTrainEmissions){
511              if (i == 1 && j == 0){
512                LOG_PLUS_EQUALS (singleCounts[c1],
513                                 initialDistribution[2*k+1] + insProb[c1][k] + backward[2*k+1 + ij]);
514              }
515              else {
516                LOG_PLUS_EQUALS (singleCounts[c1],
517                                 forward[0 + i1j] + transProb[0][2*k+1] +
518                                 insProb[c1][k] + backward[2*k+1 + ij]);
519                LOG_PLUS_EQUALS (singleCounts[c1],
520                                 forward[2*k+1 + i1j] + transProb[2*k+1][2*k+1] +
521                                 insProb[c1][k] + backward[2*k+1 + ij]);
522              }
523            }
524          }
525        }
526        if (j > 0){
527          for (int k = 0; k < NumInsertStates; k++){
528            LOG_PLUS_EQUALS (transCounts[0][2*k+2],
529                             forward[0 + ij1] + transProb[0][2*k+2] +
530                             insProb[c2][k] + backward[2*k+2 + ij]);
531            LOG_PLUS_EQUALS (transCounts[2*k+2][2*k+2],
532                             forward[2*k+2 + ij1] + transProb[2*k+2][2*k+2] +
533                             insProb[c2][k] + backward[2*k+2 + ij]);
534            if (enableTrainEmissions){
535              if (i == 0 && j == 1){
536                LOG_PLUS_EQUALS (singleCounts[c2],
537                                 initialDistribution[2*k+2] + insProb[c2][k] + backward[2*k+2 + ij]);
538              }
539              else {
540                LOG_PLUS_EQUALS (singleCounts[c2],
541                                 forward[0 + ij1] + transProb[0][2*k+2] +
542                                 insProb[c2][k] + backward[2*k+2 + ij]);
543                LOG_PLUS_EQUALS (singleCounts[c2],
544                                 forward[2*k+2 + ij1] + transProb[2*k+2][2*k+2] +
545                                 insProb[c2][k] + backward[2*k+2 + ij]);
546              }
547            }
548          }
549        }
550     
551        ij += NumMatrixTypes;
552        i1j += NumMatrixTypes;
553        ij1 += NumMatrixTypes;
554        i1j1 += NumMatrixTypes;
555      }
556    }
557
558    // scale all expected counts appropriately
559    for (int i = 0; i < NumMatrixTypes; i++){
560      initCounts[i] -= totalProb;
561      for (int j = 0; j < NumMatrixTypes; j++)
562        transCounts[i][j] -= totalProb;
563    }
564    if (enableTrainEmissions){
565      for (int i = 0; i < 256; i++){
566        for (int j = 0; j < 256; j++)
567          pairCounts[i][j] -= totalProb;
568        singleCounts[i] -= totalProb;
569      }
570    }
571
572    // compute new initial distribution
573    float totalInitDistribCounts = 0;
574    for (int i = 0; i < NumMatrixTypes; i++)
575      totalInitDistribCounts += exp (initCounts[i]); // should be 2
576    initDistribMat[0] = min (1.0f, max (0.0f, (float) exp (initCounts[0]) / totalInitDistribCounts));
577    for (int k = 0; k < NumInsertStates; k++){
578      float val = (exp (initCounts[2*k+1]) + exp (initCounts[2*k+2])) / 2;
579      initDistribMat[2*k+1] = initDistribMat[2*k+2] = min (1.0f, max (0.0f, val / totalInitDistribCounts));
580    }
581
582    // compute total counts for match state
583    float inMatchStateCounts = 0;
584    for (int i = 0; i < NumMatrixTypes; i++)
585      inMatchStateCounts += exp (transCounts[0][i]);
586    for (int i = 0; i < NumInsertStates; i++){
587
588      // compute total counts for gap state
589      float inGapStateCounts =
590        exp (transCounts[2*i+1][0]) +
591        exp (transCounts[2*i+1][2*i+1]) +
592        exp (transCounts[2*i+2][0]) +
593        exp (transCounts[2*i+2][2*i+2]);
594
595      gapOpen[2*i] = gapOpen[2*i+1] =
596        (exp (transCounts[0][2*i+1]) +
597         exp (transCounts[0][2*i+2])) /
598        (2 * inMatchStateCounts);
599
600      gapExtend[2*i] = gapExtend[2*i+1] =
601        (exp (transCounts[2*i+1][2*i+1]) +
602         exp (transCounts[2*i+2][2*i+2])) /
603        inGapStateCounts;
604    }
605
606    if (enableTrainEmissions){
607      float totalPairCounts = 0;
608      float totalSingleCounts = 0;
609      for (int i = 0; i < 256; i++){
610        for (int j = 0; j <= i; j++)
611          totalPairCounts += exp (pairCounts[j][i]);
612        totalSingleCounts += exp (singleCounts[i]);
613      }
614     
615      for (int i = 0; i < 256; i++) if (!islower ((char) i)){
616        int li = (int)((unsigned char) tolower ((char) i));
617        for (int j = 0; j <= i; j++) if (!islower ((char) j)){
618          int lj = (int)((unsigned char) tolower ((char) j));
619          emitPairs[i][j] = emitPairs[i][lj] = emitPairs[li][j] = emitPairs[li][lj] = 
620            emitPairs[j][i] = emitPairs[j][li] = emitPairs[lj][i] = emitPairs[lj][li] = exp(pairCounts[j][i]) / totalPairCounts;
621        }
622        emitSingle[i] = emitSingle[li] = exp(singleCounts[i]) / totalSingleCounts;
623      }
624    }
625  }
626   
627  /////////////////////////////////////////////////////////////////
628  // ProbabilisticModel::ComputeAlignment()
629  //
630  // Computes an alignment based on the given posterior matrix.
631  // This is done by finding the maximum summing path (or
632  // maximum weight trace) through the posterior matrix.  The
633  // final alignment is returned as a pair consisting of:
634  //    (1) a string (e.g., XXXBBXXXBBBBBBYYYYBBB) where X's and
635  //        denote insertions in one of the two sequences and
636  //        B's denote that both sequences are present (i.e.
637  //        matches).
638  //    (2) a float indicating the sum achieved
639  /////////////////////////////////////////////////////////////////
640
641  pair<SafeVector<char> *, float> ComputeAlignment (int seq1Length, int seq2Length,
642                                                    const VF &posterior) const {
643
644    float *twoRows = new float[(seq2Length+1)*2]; assert (twoRows);
645    float *oldRow = twoRows;
646    float *newRow = twoRows + seq2Length + 1;
647
648    char *tracebackMatrix = new char[(seq1Length+1)*(seq2Length+1)]; assert (tracebackMatrix);
649    char *tracebackPtr = tracebackMatrix;
650
651    VF::const_iterator posteriorPtr = posterior.begin() + seq2Length + 1;
652
653    // initialization
654    for (int i = 0; i <= seq2Length; i++){
655      oldRow[i] = 0;
656      *(tracebackPtr++) = 'L';
657    }
658
659    // fill in matrix
660    for (int i = 1; i <= seq1Length; i++){
661
662      // initialize left column
663      newRow[0] = 0;
664      posteriorPtr++;
665      *(tracebackPtr++) = 'U';
666
667      // fill in rest of row
668      for (int j = 1; j <= seq2Length; j++){
669        ChooseBestOfThree (*(posteriorPtr++) + oldRow[j-1], newRow[j-1], oldRow[j],
670                           'D', 'L', 'U', &newRow[j], tracebackPtr++);
671      }
672
673      // swap rows
674      float *temp = oldRow;
675      oldRow = newRow;
676      newRow = temp;
677    }
678
679    // store best score
680    float total = oldRow[seq2Length];
681    delete [] twoRows;
682
683    // compute traceback
684    SafeVector<char> *alignment = new SafeVector<char>; assert (alignment);
685    int r = seq1Length, c = seq2Length;
686    while (r != 0 || c != 0){
687      char ch = tracebackMatrix[r*(seq2Length+1) + c];
688      switch (ch){
689      case 'L': c--; alignment->push_back ('Y'); break;
690      case 'U': r--; alignment->push_back ('X'); break;
691      case 'D': c--; r--; alignment->push_back ('B'); break;
692      default: assert (false);
693      }
694    }
695
696    delete [] tracebackMatrix;
697
698    reverse (alignment->begin(), alignment->end());
699
700    return make_pair(alignment, total);
701  }
702
703  /////////////////////////////////////////////////////////////////
704  // ProbabilisticModel::ComputeAlignmentWithGapPenalties()
705  //
706  // Similar to ComputeAlignment() except with gap penalties.
707  /////////////////////////////////////////////////////////////////
708
709  pair<SafeVector<char> *, float> ComputeAlignmentWithGapPenalties (MultiSequence *align1,
710                                                                    MultiSequence *align2,
711                                                                    const VF &posterior, int numSeqs1,
712                                                                    int numSeqs2,
713                                                                    float gapOpenPenalty,
714                                                                    float gapContinuePenalty) const {
715    int seq1Length = align1->GetSequence(0)->GetLength();
716    int seq2Length = align2->GetSequence(0)->GetLength();
717    SafeVector<SafeVector<char>::iterator > dataPtrs1 (align1->GetNumSequences());
718    SafeVector<SafeVector<char>::iterator > dataPtrs2 (align2->GetNumSequences());
719
720    // grab character data
721    for (int i = 0; i < align1->GetNumSequences(); i++)
722      dataPtrs1[i] = align1->GetSequence(i)->GetDataPtr();
723    for (int i = 0; i < align2->GetNumSequences(); i++)
724      dataPtrs2[i] = align2->GetSequence(i)->GetDataPtr();
725
726    // the number of active sequences at any given column is defined to be the
727    // number of non-gap characters in that column; the number of gap opens at
728    // any given column is defined to be the number of gap characters in that
729    // column where the previous character in the respective sequence was not
730    // a gap
731    SafeVector<int> numActive1 (seq1Length+1), numGapOpens1 (seq1Length+1);
732    SafeVector<int> numActive2 (seq2Length+1), numGapOpens2 (seq2Length+1);
733
734    // compute number of active sequences and gap opens for each group
735    for (int i = 0; i < align1->GetNumSequences(); i++){
736      SafeVector<char>::iterator dataPtr = align1->GetSequence(i)->GetDataPtr();
737      numActive1[0] = numGapOpens1[0] = 0;
738      for (int j = 1; j <= seq1Length; j++){
739        if (dataPtr[j] != '-'){
740          numActive1[j]++;
741          numGapOpens1[j] += (j != 1 && dataPtr[j-1] != '-');
742        }
743      }
744    }
745    for (int i = 0; i < align2->GetNumSequences(); i++){
746      SafeVector<char>::iterator dataPtr = align2->GetSequence(i)->GetDataPtr();
747      numActive2[0] = numGapOpens2[0] = 0;
748      for (int j = 1; j <= seq2Length; j++){
749        if (dataPtr[j] != '-'){
750          numActive2[j]++;
751          numGapOpens2[j] += (j != 1 && dataPtr[j-1] != '-');
752        }
753      }
754    }
755
756    VVF openingPenalty1 (numSeqs1+1, VF (numSeqs2+1));
757    VF continuingPenalty1 (numSeqs1+1);
758    VVF openingPenalty2 (numSeqs1+1, VF (numSeqs2+1));
759    VF continuingPenalty2 (numSeqs2+1);
760
761    // precompute penalties
762    for (int i = 0; i <= numSeqs1; i++)
763      for (int j = 0; j <= numSeqs2; j++)
764        openingPenalty1[i][j] = i * (gapOpenPenalty * j + gapContinuePenalty * (numSeqs2 - j));
765    for (int i = 0; i <= numSeqs1; i++)
766      continuingPenalty1[i] = i * gapContinuePenalty * numSeqs2;
767    for (int i = 0; i <= numSeqs2; i++)
768      for (int j = 0; j <= numSeqs1; j++)
769        openingPenalty2[i][j] = i * (gapOpenPenalty * j + gapContinuePenalty * (numSeqs1 - j));
770    for (int i = 0; i <= numSeqs2; i++)
771      continuingPenalty2[i] = i * gapContinuePenalty * numSeqs1;
772
773    float *twoRows = new float[6*(seq2Length+1)]; assert (twoRows);
774    float *oldRowMatch = twoRows;
775    float *newRowMatch = twoRows + (seq2Length+1);
776    float *oldRowInsertX = twoRows + 2*(seq2Length+1);
777    float *newRowInsertX = twoRows + 3*(seq2Length+1);
778    float *oldRowInsertY = twoRows + 4*(seq2Length+1);
779    float *newRowInsertY = twoRows + 5*(seq2Length+1);
780
781    char *tracebackMatrix = new char[3*(seq1Length+1)*(seq2Length+1)]; assert (tracebackMatrix);
782    char *tracebackPtr = tracebackMatrix;
783
784    VF::const_iterator posteriorPtr = posterior.begin() + seq2Length + 1;
785
786    // initialization
787    for (int i = 0; i <= seq2Length; i++){
788      oldRowMatch[i] = oldRowInsertX[i] = (i == 0) ? 0 : LOG_ZERO;
789      oldRowInsertY[i] = (i == 0) ? 0 : oldRowInsertY[i-1] + continuingPenalty2[numActive2[i]];
790      *(tracebackPtr) = *(tracebackPtr+1) = *(tracebackPtr+2) = 'Y';
791      tracebackPtr += 3;
792    }
793
794    // fill in matrix
795    for (int i = 1; i <= seq1Length; i++){
796
797      // initialize left column
798      newRowMatch[0] = newRowInsertY[0] = LOG_ZERO;
799      newRowInsertX[0] = oldRowInsertX[0] + continuingPenalty1[numActive1[i]];
800      posteriorPtr++;
801      *(tracebackPtr) = *(tracebackPtr+1) = *(tracebackPtr+2) = 'X';
802      tracebackPtr += 3;
803
804      // fill in rest of row
805      for (int j = 1; j <= seq2Length; j++){
806
807        // going to MATCH state
808        ChooseBestOfThree (oldRowMatch[j-1],
809                           oldRowInsertX[j-1],
810                           oldRowInsertY[j-1],
811                           'M', 'X', 'Y', &newRowMatch[j], tracebackPtr++);
812        newRowMatch[j] += *(posteriorPtr++);
813
814        // going to INSERT X state
815        ChooseBestOfThree (oldRowMatch[j] + openingPenalty1[numActive1[i]][numGapOpens2[j]],
816                           oldRowInsertX[j] + continuingPenalty1[numActive1[i]],
817                           oldRowInsertY[j] + openingPenalty1[numActive1[i]][numGapOpens2[j]],
818                           'M', 'X', 'Y', &newRowInsertX[j], tracebackPtr++);
819
820        // going to INSERT Y state
821        ChooseBestOfThree (newRowMatch[j-1] + openingPenalty2[numActive2[j]][numGapOpens1[i]],
822                           newRowInsertX[j-1] + openingPenalty2[numActive2[j]][numGapOpens1[i]],
823                           newRowInsertY[j-1] + continuingPenalty2[numActive2[j]],
824                           'M', 'X', 'Y', &newRowInsertY[j], tracebackPtr++);
825      }
826
827      // swap rows
828      float *temp;
829      temp = oldRowMatch; oldRowMatch = newRowMatch; newRowMatch = temp;
830      temp = oldRowInsertX; oldRowInsertX = newRowInsertX; newRowInsertX = temp;
831      temp = oldRowInsertY; oldRowInsertY = newRowInsertY; newRowInsertY = temp;
832    }
833
834    // store best score
835    float total;
836    char matrix;
837    ChooseBestOfThree (oldRowMatch[seq2Length], oldRowInsertX[seq2Length], oldRowInsertY[seq2Length],
838                       'M', 'X', 'Y', &total, &matrix);
839
840    delete [] twoRows;
841
842    // compute traceback
843    SafeVector<char> *alignment = new SafeVector<char>; assert (alignment);
844    int r = seq1Length, c = seq2Length;
845    while (r != 0 || c != 0){
846
847      int offset = (matrix == 'M') ? 0 : (matrix == 'X') ? 1 : 2;
848      char ch = tracebackMatrix[(r*(seq2Length+1) + c) * 3 + offset];
849      switch (matrix){
850      case 'Y': c--; alignment->push_back ('Y'); break;
851      case 'X': r--; alignment->push_back ('X'); break;
852      case 'M': c--; r--; alignment->push_back ('B'); break;
853      default: assert (false);
854      }
855      matrix = ch;
856    }
857
858    delete [] tracebackMatrix;
859
860    reverse (alignment->begin(), alignment->end());
861
862    return make_pair(alignment, 1.0f);
863  }
864
865  /////////////////////////////////////////////////////////////////
866  // ProbabilisticModel::ComputeViterbiAlignment()
867  //
868  // Computes the highest probability pairwise alignment using the
869  // probabilistic model.  The final alignment is returned as a
870  //  pair consisting of:
871  //    (1) a string (e.g., XXXBBXXXBBBBBBYYYYBBB) where X's and
872  //        denote insertions in one of the two sequences and
873  //        B's denote that both sequences are present (i.e.
874  //        matches).
875  //    (2) a float containing the log probability of the best
876  //        alignment (not used)
877  /////////////////////////////////////////////////////////////////
878
879  pair<SafeVector<char> *, float> ComputeViterbiAlignment (Sequence *seq1, Sequence *seq2) const {
880   
881    assert (seq1);
882    assert (seq2);
883   
884    const int seq1Length = seq1->GetLength();
885    const int seq2Length = seq2->GetLength();
886   
887    // retrieve the points to the beginning of each sequence
888    SafeVector<char>::iterator iter1 = seq1->GetDataPtr();
889    SafeVector<char>::iterator iter2 = seq2->GetDataPtr();
890   
891    // create viterbi matrix
892    VF *viterbiPtr = new VF (NumMatrixTypes * (seq1Length+1) * (seq2Length+1), LOG_ZERO);
893    assert (viterbiPtr);
894    VF &viterbi = *viterbiPtr;
895
896    // create traceback matrix
897    VI *tracebackPtr = new VI (NumMatrixTypes * (seq1Length+1) * (seq2Length+1), -1);
898    assert (tracebackPtr);
899    VI &traceback = *tracebackPtr;
900
901    // initialization condition
902    for (int k = 0; k < NumMatrixTypes; k++)
903      viterbi[k] = initialDistribution[k];
904
905    // remember offset for each index combination
906    int ij = 0;
907    int i1j = -seq2Length - 1;
908    int ij1 = -1;
909    int i1j1 = -seq2Length - 2;
910
911    ij *= NumMatrixTypes;
912    i1j *= NumMatrixTypes;
913    ij1 *= NumMatrixTypes;
914    i1j1 *= NumMatrixTypes;
915
916    // compute viterbi scores
917    for (int i = 0; i <= seq1Length; i++){
918      unsigned char c1 = (i == 0) ? '~' : (unsigned char) iter1[i];
919      for (int j = 0; j <= seq2Length; j++){
920        unsigned char c2 = (j == 0) ? '~' : (unsigned char) iter2[j];
921
922        if (i > 0 && j > 0){
923          for (int k = 0; k < NumMatrixTypes; k++){
924            float newVal = viterbi[k + i1j1] + transProb[k][0] + matchProb[c1][c2];
925            if (viterbi[0 + ij] < newVal){
926              viterbi[0 + ij] = newVal;
927              traceback[0 + ij] = k;
928            }
929          }
930        }
931        if (i > 0){
932          for (int k = 0; k < NumInsertStates; k++){
933            float valFromMatch = insProb[c1][k] + viterbi[0 + i1j] + transProb[0][2*k+1];
934            float valFromIns = insProb[c1][k] + viterbi[2*k+1 + i1j] + transProb[2*k+1][2*k+1];
935            if (valFromMatch >= valFromIns){
936              viterbi[2*k+1 + ij] = valFromMatch;
937              traceback[2*k+1 + ij] = 0;
938            }
939            else {
940              viterbi[2*k+1 + ij] = valFromIns;
941              traceback[2*k+1 + ij] = 2*k+1;
942            }
943          }
944        }
945        if (j > 0){
946          for (int k = 0; k < NumInsertStates; k++){
947            float valFromMatch = insProb[c2][k] + viterbi[0 + ij1] + transProb[0][2*k+2];
948            float valFromIns = insProb[c2][k] + viterbi[2*k+2 + ij1] + transProb[2*k+2][2*k+2];
949            if (valFromMatch >= valFromIns){
950              viterbi[2*k+2 + ij] = valFromMatch;
951              traceback[2*k+2 + ij] = 0;
952            }
953            else {
954              viterbi[2*k+2 + ij] = valFromIns;
955              traceback[2*k+2 + ij] = 2*k+2;
956            }
957          }
958        }
959
960        ij += NumMatrixTypes;
961        i1j += NumMatrixTypes;
962        ij1 += NumMatrixTypes;
963        i1j1 += NumMatrixTypes;
964      }
965    }
966
967    // figure out best terminating cell
968    float bestProb = LOG_ZERO;
969    int state = -1;
970    for (int k = 0; k < NumMatrixTypes; k++){
971      float thisProb = viterbi[k + NumMatrixTypes * ((seq1Length+1)*(seq2Length+1) - 1)] + initialDistribution[k];
972      if (bestProb < thisProb){
973        bestProb = thisProb;
974        state = k;
975      }
976    }
977    assert (state != -1);
978
979    delete viterbiPtr;
980
981    // compute traceback
982    SafeVector<char> *alignment = new SafeVector<char>; assert (alignment);
983    int r = seq1Length, c = seq2Length;
984    while (r != 0 || c != 0){
985      int newState = traceback[state + NumMatrixTypes * (r * (seq2Length+1) + c)];
986     
987      if (state == 0){ c--; r--; alignment->push_back ('B'); }
988      else if (state % 2 == 1){ r--; alignment->push_back ('X'); }
989      else { c--; alignment->push_back ('Y'); }
990     
991      state = newState;
992    }
993
994    delete tracebackPtr;
995
996    reverse (alignment->begin(), alignment->end());
997   
998    return make_pair(alignment, bestProb);
999  }
1000
1001  /////////////////////////////////////////////////////////////////
1002  // ProbabilisticModel::BuildPosterior()
1003  //
1004  // Builds a posterior probability matrix needed to align a pair
1005  // of alignments.  Mathematically, the returned matrix M is
1006  // defined as follows:
1007  //    M[i,j] =     sum          sum      f(s,t,i,j)
1008  //             s in align1  t in align2
1009  // where
1010  //                  [  P(s[i'] <--> t[j'])
1011  //                  [       if s[i'] is a letter in the ith column of align1 and
1012  //                  [          t[j'] it a letter in the jth column of align2
1013  //    f(s,t,i,j) =  [
1014  //                  [  0    otherwise
1015  //
1016  /////////////////////////////////////////////////////////////////
1017
1018  VF *BuildPosterior (MultiSequence *align1, MultiSequence *align2,
1019                      const SafeVector<SafeVector<SparseMatrix *> > &sparseMatrices,
1020                      float cutoff = 0.0f) const {
1021    const int seq1Length = align1->GetSequence(0)->GetLength();
1022    const int seq2Length = align2->GetSequence(0)->GetLength();
1023
1024    VF *posteriorPtr = new VF((seq1Length+1) * (seq2Length+1), 0); assert (posteriorPtr);
1025    VF &posterior = *posteriorPtr;
1026    VF::iterator postPtr = posterior.begin();
1027
1028    // for each s in align1
1029    for (int i = 0; i < align1->GetNumSequences(); i++){
1030      int first = align1->GetSequence(i)->GetLabel();
1031      SafeVector<int> *mapping1 = align1->GetSequence(i)->GetMapping();
1032
1033      // for each t in align2
1034      for (int j = 0; j < align2->GetNumSequences(); j++){
1035        int second = align2->GetSequence(j)->GetLabel();
1036        SafeVector<int> *mapping2 = align2->GetSequence(j)->GetMapping();
1037
1038        if (first < second){
1039
1040          // get the associated sparse matrix
1041          SparseMatrix *matrix = sparseMatrices[first][second];
1042         
1043          for (int ii = 1; ii <= matrix->GetSeq1Length(); ii++){
1044            SafeVector<PIF>::iterator row = matrix->GetRowPtr(ii);
1045            int base = (*mapping1)[ii] * (seq2Length+1);
1046            int rowSize = matrix->GetRowSize(ii);
1047           
1048            // add in all relevant values
1049            for (int jj = 0; jj < rowSize; jj++)
1050              posterior[base + (*mapping2)[row[jj].first]] += row[jj].second;
1051           
1052            // subtract cutoff
1053            for (int jj = 0; jj < matrix->GetSeq2Length(); jj++)
1054              posterior[base + (*mapping2)[jj]] -= cutoff;
1055          }
1056
1057        } else {
1058
1059          // get the associated sparse matrix
1060          SparseMatrix *matrix = sparseMatrices[second][first];
1061         
1062          for (int jj = 1; jj <= matrix->GetSeq1Length(); jj++){
1063            SafeVector<PIF>::iterator row = matrix->GetRowPtr(jj);
1064            int base = (*mapping2)[jj];
1065            int rowSize = matrix->GetRowSize(jj);
1066           
1067            // add in all relevant values
1068            for (int ii = 0; ii < rowSize; ii++)
1069              posterior[base + (*mapping1)[row[ii].first] * (seq2Length + 1)] += row[ii].second;
1070           
1071            // subtract cutoff
1072            for (int ii = 0; ii < matrix->GetSeq2Length(); ii++)
1073              posterior[base + (*mapping1)[ii] * (seq2Length + 1)] -= cutoff;
1074          }
1075
1076        }
1077       
1078
1079        delete mapping2;
1080      }
1081
1082      delete mapping1;
1083    }
1084
1085    return posteriorPtr;
1086  }
1087};
1088
1089#endif
Note: See TracBrowser for help on using the repository browser.