source: tags/arb_5.1/GDE/MOLPHY/njtree.c

Last change on this file was 1885, checked in by westram, 21 years ago

missing parameter type

  • Property svn:eol-style set to native
  • Property svn:keywords set to Author Date Id Revision
File size: 11.9 KB
Line 
1/*
2 * njtree.c   Adachi, J.   1995.06.09
3 * Copyright (C) 1992-1995 J. Adachi & M. Hasegawa. All rights reserved.
4 */
5
6#include "protml.h"
7#include "prot_tml.h"
8
9#define ENJ 0
10
11Tree *
12new_njtree(maxspc, maxibrnch, numptrn, seqconint)
13int maxspc, maxibrnch, numptrn;
14imatrix seqconint;
15{
16        int n, i;
17        Tree *tr;
18        Node *dp, *up;
19
20        tr = (Tree *) malloc(sizeof(Tree));
21        if (tr == NULL) maerror("tr in new_njtree().");
22        tr->ebrnchp = (Node **) malloc((unsigned)maxspc * sizeof(Node *));
23        if (tr->ebrnchp == NULL) maerror("ebrnchp in new_njtree().");
24        tr->ibrnchp = (Node **) malloc((unsigned)maxibrnch * sizeof(Node *));
25        if (tr->ibrnchp == NULL) maerror("ibrnchp in new_njtree().");
26        tr->bturn = new_ivector(maxspc);
27        for (n = 0; n < maxspc; n++) {
28                tr->bturn[n] = n;
29                dp = (Node *) malloc(sizeof(Node));
30                if (dp == NULL) maerror("dp in new_njtree().");
31                up = (Node *) malloc(sizeof(Node));
32                if (up == NULL) maerror("up in new_njtree().");
33                dp->isop = NULL;
34                up->isop = NULL;
35                dp->kinp = up;
36                up->kinp = dp;
37                dp->descen = TRUE;
38                up->descen = FALSE;
39                dp->num = n;
40                up->num = n;
41                dp->length = 0.0;
42                up->length = 0.0;
43                dp->lklhdl = 0.0;
44                up->lklhdl = 0.0;
45                dp->paths = new_ivector(maxspc);
46                up->paths = dp->paths;
47                for (i = 0; i < maxspc; i++) dp->paths[i] = 0;
48                dp->paths[n] = 1;
49                dp->eprob = seqconint[n];
50                up->eprob = NULL;
51                dp->iprob = NULL;
52                up->iprob = new_dmatrix(numptrn, Tpmradix);
53                tr->ebrnchp[n] = dp;
54        }
55        for (n = 0; n < maxspc - 1; n++) {
56                tr->ebrnchp[n]->kinp->isop = tr->ebrnchp[n + 1]->kinp;
57        }
58        tr->ebrnchp[maxspc - 1]->kinp->isop = tr->ebrnchp[0]->kinp;
59        for (n = 0; n < maxibrnch; n++) {
60                dp = (Node *) malloc(sizeof(Node));
61                if (dp == NULL) maerror("dp in new_njtree().");
62                up = (Node *) malloc(sizeof(Node));
63                if (up == NULL) maerror("up in new_njtree().");
64                dp->isop = NULL;
65                up->isop = NULL;
66                dp->kinp = up;
67                up->kinp = dp;
68                dp->descen = TRUE;
69                up->descen = FALSE;
70                dp->num = n + maxspc;
71                up->num = n + maxspc;
72                dp->length = 0.0;
73                up->length = 0.0;
74                dp->lklhdl = 0.0;
75                up->lklhdl = 0.0;
76                dp->paths = new_ivector(maxspc);
77                up->paths = dp->paths;
78                for (i = 0; i < maxspc; i++) dp->paths[i] = 0;
79                dp->eprob = NULL;
80                up->eprob = NULL;
81                dp->iprob = new_dmatrix(numptrn, Tpmradix);
82                up->iprob = new_dmatrix(numptrn, Tpmradix);
83                tr->ibrnchp[n] = dp;
84        }
85        tr->rootp = tr->ebrnchp[maxspc - 1]->kinp;
86
87        return tr;
88} /*_ new_njtree */
89
90
91void
92free_njtree(tr, maxspc, maxibrnch)
93Tree *tr;
94int maxspc, maxibrnch;
95{
96        int n;
97        Node *dp, *up;
98
99        for (n = 0; n < maxspc; n++) {
100                dp = tr->ebrnchp[n];
101                up = dp->kinp;
102                free_ivector(dp->paths);
103                free_dmatrix(up->iprob);
104                free(up);
105                free(dp);
106        }
107        for (n = 0; n < maxibrnch; n++) {
108                dp = tr->ibrnchp[n];
109                up = dp->kinp;
110                free_ivector(dp->paths);
111                free_dmatrix(up->iprob);
112                free_dmatrix(dp->iprob);
113                free(up);
114                free(dp);
115        }
116        free(tr->bturn);
117        free(tr->ibrnchp);
118        free(tr->ebrnchp);
119        free(tr);
120} /*_ free_njtree */
121
122
123double
124emledis(dis, ip, kp)
125double dis;
126Node *ip, *kp;
127{
128        int i, j, k, it, numloop;
129        double sumlk, sumd1, sumd2, lkld1, lkld2, prod, vari;
130        double arc, arcold, arcdiff, arcpre;
131        dmattpmty tprob, tdif1, tdif2;
132        dmatrix oprob;
133        ivector dseqi;
134        dvector opb;
135#ifdef NUC
136        dvector tpb, td1, td2;
137        double pn0, pn1, pn2, pn3;
138#endif /* NUC */
139
140        oprob = ip->iprob;
141        dseqi = kp->eprob;
142        arc = arcpre = dis;
143
144        numloop = 30;
145        for (it = 0; it < numloop; it++) {
146                tdiffmtrx(arc, tprob, tdif1, tdif2);
147#ifdef NUC
148                tpb = *tprob; td1 = *tdif1; td2 = *tdif2;
149#endif  /* NUC */
150                lkld1 = lkld2 = 0.0;
151                for (k = 0; k < Numptrn; k++) {
152                        if ((j = dseqi[k]) >= 0) {
153                                opb = oprob[k];
154#ifndef NUC
155                                sumlk = sumd1 = sumd2 = 0.0;
156                                for (i = 0; i < Tpmradix; i++) {
157                                        prod = Freqtpm[i] * opb[i];
158                                        sumlk += prod * tprob[i][j];
159                                        sumd1 += prod * tdif1[i][j];
160                                        sumd2 += prod * tdif2[i][j];
161                                }
162#else                   /* NUC */
163                                pn0 = Freqtpm[0] * opb[0];
164                                pn1 = Freqtpm[1] * opb[1];
165                                pn2 = Freqtpm[2] * opb[2];
166                                pn3 = Freqtpm[3] * opb[3];
167                                sumlk = pn0 * tpb[] + pn1 * tpb[j+ 4]
168                                          + pn2 * tpb[j+8] + pn3 * tpb[j+12];
169                                sumd1 = pn0 * td1[] + pn1 * td1[j+ 4]
170                                          + pn2 * td1[j+8] + pn3 * td1[j+12];
171                                sumd2 = pn0 * td2[] + pn1 * td2[j+ 4]
172                                          + pn2 * td2[j+8] + pn3 * td2[j+12];
173#endif                  /* NUC */
174                                sumd1 /= sumlk;
175                                lkld1 += sumd1 * Weight[k];
176                                lkld2 += (sumd2 / sumlk - sumd1 * sumd1) * Weight[k];
177                        }
178                }
179                vari = 1.0 / fabs(lkld2);
180                arcold = arc;
181                arcdiff = - (lkld1 / lkld2);
182                arc += arcdiff;
183                if (arc > Mlimit && arcpre < 10.0) arc = Llimit;
184                if (arc < LOWERLIMIT) arc = LOWERLIMIT;
185                if (arc > Ulimit) arc = Ulimit;
186                if (lkld2 > 0.0) {
187                        arc = Llimit;
188                        if (Debug || Debug_optn)
189                        fprintf(stderr,"mli: second derivative is positive! %8.3f\n",lkld2);
190                        break;
191                }
192                /*      printf("mle %3d %3d %8.3f %8.3f %12.5f %12.5f %10.1f\n",
193                        kp->num+1, it+1, arc, arcold, arcdiff, lkld1, lkld2); */
194                if (fabs(arcold - arc) < DEPSILON) break;
195        }
196        /*      if (Debug) */
197        if (fabs(arcdiff) > DEPSILON)
198                printf("mle%4d%3d%8.3f%8.3f%8.3f%10.5f%9.5f%9.3f\n",
199                kp->num+1, it+1, arc, arcpre, sqrt(vari), arcdiff, lkld1, lkld2);
200        return arc;
201} /* emledis */
202
203
204double
205imledis(dis, ip, kp)
206double dis;
207Node *ip, *kp;
208{
209        int i, j, k, it, numloop;
210        double sumlk, sumd1, sumd2, lkld1, lkld2, prod1, prod2, vari;
211        double arc, arcold, arcdiff, arcpre, slk, sd1, sd2;
212        dmattpmty tprob, tdif1, tdif2;
213        dmatrix oprob, cprob;
214        dvector tpb, td1, td2, opb, cpb;
215#ifdef NUC
216        double cpb0, cpb1, cpb2, cpb3;
217#endif /* NUC */
218
219        oprob = ip->iprob;
220        cprob = kp->iprob;
221        arc = arcpre = dis;
222/*      if (Debug) { prprob(oprob); prprob(cprob); } */
223        numloop = 30;
224        for (it = 0; it < numloop; it++) {
225                tdiffmtrx(arc, tprob, tdif1, tdif2);
226                lkld1 = lkld2 = 0.0;
227                for (k = 0; k < Numptrn; k++) {
228                        sumlk = sumd1 = sumd2 = 0.0;
229                        opb = oprob[k];
230                        cpb = cprob[k];
231#ifdef NUC
232                        cpb0 = cpb[0]; cpb1 = cpb[1]; cpb2 = cpb[2]; cpb3 = cpb[3];
233#endif          /* NUC */
234                        for (i = 0; i < Tpmradix; i++) {
235                                tpb = tprob[i];
236                                td1 = tdif1[i];
237                                td2 = tdif2[i];
238                                prod1 = Freqtpm[i] * opb[i];
239#ifndef NUC
240                                slk = sd1 = sd2 = 0.0;
241                                for (j = 0; j < Tpmradix; j++) {
242                                        prod2 = cpb[j];
243                                        slk += prod2 * tpb[j];
244                                        sd1 += prod2 * td1[j];
245                                        sd2 += prod2 * td2[j];
246                                }
247                                sumlk += prod1 * slk;
248                                sumd1 += prod1 * sd1;
249                                sumd2 += prod1 * sd2;
250#else                   /* NUC */
251                                slk = cpb0*tpb[0] + cpb1*tpb[1] + cpb2*tpb[2] + cpb3*tpb[3];
252                                sd1 = cpb0*td1[0] + cpb1*td1[1] + cpb2*td1[2] + cpb3*td1[3];
253                                sd2 = cpb0*td2[0] + cpb1*td2[1] + cpb2*td2[2] + cpb3*td2[3];
254                                sumlk += prod1 * slk;
255                                sumd1 += prod1 * sd1;
256                                sumd2 += prod1 * sd2;
257#endif                  /* NUC */
258                        }
259                        sumd1 /= sumlk;
260                        lkld1 += sumd1 * Weight[k];
261                        lkld2 += (sumd2 / sumlk - sumd1 * sumd1) * Weight[k];
262                }
263                vari = 1.0 / fabs(lkld2);
264                arcold = arc;
265                arcdiff = - (lkld1 / lkld2);
266                arc += arcdiff;
267                if (arc > Mlimit && arcpre < 10.0) arc = Llimit;
268                if (arc < Llimit) arc = Llimit;
269                if (arc > Ulimit) arc = Ulimit;
270                if (lkld2 > 0.0) {
271                        arc = Llimit;
272                        if (Debug || Debug_optn)
273                        fprintf(stderr,"mli: second derivative is positive! %8.3f\n",lkld2);
274                        break;
275                }
276                /*      printf("mli %3d %3d %8.3f %8.3f %12.5f %12.5f %10.1f\n",
277                        kp->num+1, it+1, arc, arcold, arcdiff, lkld1, lkld2); */
278                if (fabs(arcold - arc) < DEPSILON) break;
279        }
280        /*      if (Debug) */
281        if (fabs(arcdiff) > DEPSILON)
282                printf("mli%4d%3d%8.3f%8.3f%8.3f%10.5f%9.5f%9.3f\n",
283                kp->num+1, it+1, arc, arcpre, sqrt(vari), arcdiff, lkld1, lkld2);
284        return arc;
285} /* imledis */
286
287
288#if ENJ
289void
290redmat(dmat, dij, psotu, otu, restsp, ii, jj, ns)
291dmatrix dmat;
292double dij;
293Node **psotu;
294ivector otu;
295int restsp, ii, jj, ns;
296{
297        int k, kk;
298        double dis, predis;
299        Node *ip, *jp, *kp;
300
301        ip = psotu[ii];
302        jp = psotu[jj];
303        if (ip->kinp->isop == NULL) /* external */
304                partelkl(ip);
305        else /* internal */
306                partilkl(ip);
307        if (jp->kinp->isop == NULL) /* external */
308                partelkl(jp);
309        else /* internal */
310                partilkl(jp);
311        prodpart(ip);
312
313        for (k = 0; k < restsp; k++) {
314                kk = otu[k];
315                if (kk != ii && kk != jj) {
316                        predis = (dmat[ii][kk] + dmat[jj][kk] - dij) * 0.5;
317                        kp = psotu[kk];
318                        if (kp->kinp->isop == NULL) { /* external */
319                                if (predis < LOWERLIMIT) predis = LOWERLIMIT;
320                                dis = emledis(predis, ip, kp->kinp);
321                        } else { /* internal */
322                                if (predis < Llimit) predis = Llimit;
323                                dis = imledis(predis, ip, kp->kinp->isop);
324                        }
325                        /*
326                        printf("%3d%3d%3d",ip->kinp->num+1,jp->kinp->num+1,kp->kinp->num+1);
327                        printf(" %3d%3d%9.4f%9.4f\n", ii+1, kk+1, predis, dis);
328                        */
329                        dmat[ii][kk] = dmat[kk][ii] = dis;
330                }
331                dmat[jj][kk] = dmat[kk][jj] = 0.0;
332        }
333} /* redmat */
334#endif /* ENJ */
335
336
337
338void
339enjtree(tr, distan, ns, flag)
340Tree *tr;
341dmatrix distan;
342int ns;
343boolean flag;
344{
345        int i, j, ii, jj, kk, otui, otuj, nsp2, cinode, restsp;
346        double dij, bix, bjx, bkx, sij, smax, dnsp2, dij2;
347        ivector otu;
348        dvector r;
349        dmatrix dmat;
350        Node **psotu, *cp, *ip, *jp, *kp;
351
352        dmat = new_dmatrix(ns, ns);
353        for (i = 0; i < ns; i++) {
354                        for (j = 0; j < ns; j++) dmat[i][j] = distan[i][j];
355        }
356        cinode = ns;
357        nsp2 = ns - 2;
358        dnsp2 = 1.0 / nsp2;
359        r = new_dvector(ns);
360        otu = new_ivector(ns);
361        psotu = (Node **)new_npvector(ns);
362        for (i = 0; i < ns; i++) {
363                otu[i] = i;
364                psotu[i] = tr->ebrnchp[i]->kinp;
365        }
366
367        for (restsp = ns; restsp > 3; restsp--) {
368
369                for (i = 0; i < restsp; i++) {
370                        ii = otu[i];
371                        for (j = 0, sij = 0.0; j < restsp; j++) sij += dmat[ii][otu[j]];
372                        r[ii] = sij;
373                }
374                for (i = 0, smax = - DBL_MAX; i < restsp-1; i++) {
375                        ii = otu[i];
376                        for (j = i + 1; j < restsp; j++) {
377                                jj = otu[j];
378                                sij = ( r[ii] + r[jj] ) * dnsp2 - dmat[ii][jj]; /* max */
379                                /* printf("%3d%3d %9.3f %9.3f %9.3f\n",
380                                        ii+1,jj+1,sij,r[ii],r[jj]); */
381                                if (!flag) sij = - sij;
382                                if (sij > smax) {
383                                        smax = sij; otui = i; otuj = j;
384                                }
385                        }
386                }
387
388                ii = otu[otui];
389                jj = otu[otuj];
390                dij = dmat[ii][jj];
391                dij2 = dij * 0.5;
392                bix = (dij + r[ii]/nsp2 - r[jj]/nsp2) * 0.5;
393                bjx = dij - bix;
394                cp = tr->ibrnchp[cinode - ns];
395                ip = psotu[ii];
396                jp = psotu[jj];
397                cp->isop = ip;
398                ip->isop = jp;
399                jp->isop = cp;
400                ip->length += bix;
401                jp->length += bjx;
402                if (ip->kinp->isop == NULL) {
403                        if (ip->length < LOWERLIMIT) ip->length = LOWERLIMIT;
404                } else {
405                        if (ip->length < Llimit) ip->length = Llimit;
406                }
407                if (jp->kinp->isop == NULL) {
408                        if (jp->length < LOWERLIMIT) jp->length = LOWERLIMIT;
409                } else {
410                        if (jp->length < Llimit) jp->length = Llimit;
411                }
412                ip->kinp->length = ip->length;
413                jp->kinp->length = jp->length;
414                cp = cp->kinp;
415
416#if             ENJ
417                cp->length = 0.0;
418                redmat(dmat, dij, psotu, otu, restsp, ii, jj, ns);
419#else   /* ENJ */
420                cp->length = - dij2;
421                for (j = 0; j < restsp; j++) {
422                        kk = otu[j];
423                        if (kk != ii && kk != jj) {
424                                dij = (dmat[ii][kk] + dmat[jj][kk]) * 0.5;
425                                dmat[ii][kk] = dmat[kk][ii] = dij;
426                        }
427                        dmat[jj][kk] = dmat[kk][jj] = 0.0;
428                }
429#endif  /* ENJ */
430
431                psotu[ii] = cp;
432                psotu[jj] = NULL;
433                Numibrnch++;
434                Numbrnch = ++cinode;
435                dnsp2 = 1.0 / --nsp2;
436                if (Debug_optn) {
437                        for (putchar('\n'), j = 0; j < restsp; j++) printf("%6d",otu[j]+1);
438                        for (putchar('\n'), i = 0; i < restsp; i++, putchar('\n')) {
439                                for (j = 0, ii = otu[i]; j < restsp; j++) {
440                                        printf("%6.0f", dmat[ii][otu[j]]*100);
441                                }
442                        }
443                }
444                for (j = otuj; j < restsp - 1; j++) otu[j] = otu[j + 1];
445
446        } /* for restsp */
447
448        ii = otu[0];
449        jj = otu[1];
450        kk = otu[2];
451        bix = (dmat[ii][jj] + dmat[ii][kk] - dmat[jj][kk]) * 0.5;
452        bjx = dmat[ii][jj] - bix;
453        bkx = dmat[ii][kk] - bix;
454        ip = psotu[ii];
455        jp = psotu[jj];
456        kp = psotu[kk];
457        ip->isop = jp;
458        jp->isop = kp;
459        kp->isop = ip;
460        ip->length += bix;
461        jp->length += bjx;
462        kp->length += bkx;
463        if (ip->kinp->isop == NULL) {
464                if (ip->length < LOWERLIMIT) ip->length = LOWERLIMIT;
465        } else {
466                if (ip->length < Llimit) ip->length = Llimit;
467        }
468        if (jp->kinp->isop == NULL) {
469                if (jp->length < LOWERLIMIT) jp->length = LOWERLIMIT;
470        } else {
471                if (jp->length < Llimit) jp->length = Llimit;
472        }
473        if (kp->kinp->isop == NULL) {
474                if (kp->length < LOWERLIMIT) kp->length = LOWERLIMIT;
475        } else {
476                if (kp->length < Llimit) kp->length = Llimit;
477        }
478        ip->kinp->length = ip->length;
479        jp->kinp->length = jp->length;
480        kp->kinp->length = kp->length;
481
482        reroot(tr, tr->ebrnchp[ns-1]->kinp); /* !? */
483
484        free_dvector(r);
485        free_ivector(otu);
486        free_npvector(psotu);
487        free_dmatrix(dmat);
488} /*_ enjtree */
Note: See TracBrowser for help on using the repository browser.