source: trunk/GDE/MUSCLE/src/cluster.cpp

Last change on this file was 10390, checked in by aboeckma, 11 years ago

added muscle sourcles amd makefile

File size: 8.3 KB
Line 
1#include "muscle.h"
2#include "cluster.h"
3#include "distfunc.h"
4
5static inline float Min(float d1, float d2)
6        {
7        return d1 < d2 ? d1 : d2;
8        }
9
10static inline float Max(float d1, float d2)
11        {
12        return d1 > d2 ? d1 : d2;
13        }
14
15static inline float Mean(float d1, float d2)
16        {
17        return (float) ((d1 + d2)/2.0);
18        }
19
20#if     _DEBUG
21void ClusterTree::Validate(unsigned uNodeCount)
22        {
23        unsigned n;
24        ClusterNode *pNode;
25        unsigned uDisjointListCount = 0;
26        for (pNode = m_ptrDisjoints; pNode; pNode = pNode->GetNextDisjoint())
27                {
28                ClusterNode *pPrev = pNode->GetPrevDisjoint();
29                ClusterNode *pNext = pNode->GetNextDisjoint();
30                if (0 != pPrev)
31                        {
32                        if (pPrev->GetNextDisjoint() != pNode)
33                                {
34                                Log("Prev->This mismatch, prev=\n");
35                                pPrev->LogMe();
36                                Log("This=\n");
37                                pNode->LogMe();
38                                Quit("ClusterTree::Validate()");
39                                }
40                        }
41                else
42                        {
43                        if (pNode != m_ptrDisjoints)
44                                {
45                                Log("[%u]->prev = 0 but != m_ptrDisjoints=%d\n",
46                                  pNode->GetIndex(),
47                                  m_ptrDisjoints ? m_ptrDisjoints->GetIndex() : 0xffffffff);
48                                pNode->LogMe();
49                                Quit("ClusterTree::Validate()");
50                                }
51                        }
52                if (0 != pNext)
53                        {
54                        if (pNext->GetPrevDisjoint() != pNode)
55                                {
56                                Log("Next->This mismatch, next=\n");
57                                pNext->LogMe();
58                                Log("This=\n");
59                                pNode->LogMe();
60                                Quit("ClusterTree::Validate()");
61                                }
62                        }
63                ++uDisjointListCount;
64                if (uDisjointListCount > m_uNodeCount)
65                        Quit("Loop in disjoint list");
66                }
67
68        unsigned uParentlessNodeCount = 0;
69        for (n = 0; n < uNodeCount; ++n)
70                if (0 == m_Nodes[n].GetParent())
71                        ++uParentlessNodeCount;
72       
73        if (uDisjointListCount != uParentlessNodeCount)
74                Quit("Disjoints = %u Parentless = %u\n", uDisjointListCount,
75                  uParentlessNodeCount);
76        }
77#else   // !_DEBUG
78#define Validate(uNodeCount)    // empty
79#endif
80
81void ClusterNode::LogMe() const
82        {
83        unsigned uClusterSize = GetClusterSize();
84        Log("[%02u] w=%5.3f  CW=%5.3f  LBW=%5.3f  RBW=%5.3f  LWT=%5.3f  RWT=%5.3f  L=%02d  R=%02d  P=%02d  NxDj=%02d  PvDj=%02d  Sz=%02d  {",
85                m_uIndex,
86                m_dWeight,
87                GetClusterWeight(),
88                GetLeftBranchWeight(),
89                GetRightBranchWeight(),
90                GetLeftWeight(),
91                GetRightWeight(),
92                m_ptrLeft ? m_ptrLeft->GetIndex() : 0xffffffff,
93                m_ptrRight ? m_ptrRight->GetIndex() : 0xffffffff,
94                m_ptrParent ? m_ptrParent->GetIndex() : 0xffffffff,
95                m_ptrNextDisjoint ? m_ptrNextDisjoint->GetIndex() : 0xffffffff,
96                m_ptrPrevDisjoint ? m_ptrPrevDisjoint->GetIndex() : 0xffffffff,
97                uClusterSize);
98        for (unsigned i = 0; i < uClusterSize; ++i)
99                Log(" %u", GetClusterLeaf(i)->GetIndex());
100        Log(" }\n");
101        }
102
103// How many leaves in the sub-tree under this node?
104unsigned ClusterNode::GetClusterSize() const
105        {
106        unsigned uLeafCount = 0;
107
108        if (0 == m_ptrLeft && 0 == m_ptrRight)
109                return 1;
110
111        if (0 != m_ptrLeft)
112                uLeafCount += m_ptrLeft->GetClusterSize();
113        if (0 != m_ptrRight)
114                uLeafCount += m_ptrRight->GetClusterSize();
115        assert(uLeafCount > 0);
116        return uLeafCount;
117        }
118
119double ClusterNode::GetClusterWeight() const
120        {
121        double dWeight = 0.0;
122        if (0 != m_ptrLeft)
123                dWeight += m_ptrLeft->GetClusterWeight();
124        if (0 != m_ptrRight)
125                dWeight += m_ptrRight->GetClusterWeight();
126        return dWeight + GetWeight();
127        }
128
129double ClusterNode::GetLeftBranchWeight() const
130        {
131        const ClusterNode *ptrLeft = GetLeft();
132        if (0 == ptrLeft)
133                return 0.0;
134
135        return GetWeight() - ptrLeft->GetWeight();
136        }
137
138double ClusterNode::GetRightBranchWeight() const
139        {
140        const ClusterNode *ptrRight = GetRight();
141        if (0 == ptrRight)
142                return 0.0;
143
144        return GetWeight() - ptrRight->GetWeight();
145        }
146
147double ClusterNode::GetRightWeight() const
148        {
149        const ClusterNode *ptrRight = GetRight();
150        if (0 == ptrRight)
151                return 0.0;
152        return ptrRight->GetClusterWeight() + GetWeight();
153        }
154
155double ClusterNode::GetLeftWeight() const
156        {
157        const ClusterNode *ptrLeft = GetLeft();
158        if (0 == ptrLeft)
159                return 0.0;
160        return ptrLeft->GetClusterWeight() + GetWeight();
161        }
162
163// Return n'th leaf in the sub-tree under this node.
164const ClusterNode *ClusterNode::GetClusterLeaf(unsigned uLeafIndex) const
165        {
166        if (0 != m_ptrLeft)
167                {
168                if (0 == m_ptrRight)
169                        return this;
170
171                unsigned uLeftLeafCount = m_ptrLeft->GetClusterSize();
172
173                if (uLeafIndex < uLeftLeafCount)
174                        return m_ptrLeft->GetClusterLeaf(uLeafIndex);
175
176                assert(uLeafIndex >= uLeftLeafCount);
177                return m_ptrRight->GetClusterLeaf(uLeafIndex - uLeftLeafCount);
178                }
179        if (0 == m_ptrRight)
180                return this;
181        return m_ptrRight->GetClusterLeaf(uLeafIndex);
182        }
183
184void ClusterTree::DeleteFromDisjoints(ClusterNode *ptrNode)
185        {
186        ClusterNode *ptrPrev = ptrNode->GetPrevDisjoint();
187        ClusterNode *ptrNext = ptrNode->GetNextDisjoint();
188
189        if (0 != ptrPrev)
190                ptrPrev->SetNextDisjoint(ptrNext);
191        else
192                m_ptrDisjoints = ptrNext;
193
194        if (0 != ptrNext)
195                ptrNext->SetPrevDisjoint(ptrPrev);
196
197#if     _DEBUG
198// not algorithmically necessary, but improves clarity
199// and supports Validate().
200        ptrNode->SetPrevDisjoint(0);
201        ptrNode->SetNextDisjoint(0);
202#endif
203        }
204
205void ClusterTree::AddToDisjoints(ClusterNode *ptrNode)
206        {
207        ptrNode->SetNextDisjoint(m_ptrDisjoints);
208        ptrNode->SetPrevDisjoint(0);
209        if (0 != m_ptrDisjoints)
210                m_ptrDisjoints->SetPrevDisjoint(ptrNode);
211        m_ptrDisjoints = ptrNode;
212        }
213
214ClusterTree::ClusterTree()
215        {
216        m_ptrDisjoints = 0;
217        m_Nodes = 0;
218        m_uNodeCount = 0;
219        }
220
221ClusterTree::~ClusterTree()
222        {
223        delete[] m_Nodes;
224        }
225
226void ClusterTree::LogMe() const
227        {
228        Log("Disjoints=%d\n", m_ptrDisjoints ? m_ptrDisjoints->GetIndex() : 0xffffffff);
229        for (unsigned i = 0; i < m_uNodeCount; ++i)
230                {
231                m_Nodes[i].LogMe();
232                }
233        }
234
235ClusterNode *ClusterTree::GetRoot() const
236        {
237        return &m_Nodes[m_uNodeCount - 1];
238        }
239
240// This is the UPGMA algorithm as described in Durbin et al. p166.
241void ClusterTree::Create(const DistFunc &Dist)
242        {
243        unsigned i;
244        m_uLeafCount = Dist.GetCount();
245        m_uNodeCount = 2*m_uLeafCount - 1;
246
247        delete[] m_Nodes;
248        m_Nodes = new ClusterNode[m_uNodeCount];
249
250        for (i = 0; i < m_uNodeCount; ++i)
251                m_Nodes[i].SetIndex(i);
252
253        for (i = 0; i < m_uLeafCount - 1; ++i)
254                m_Nodes[i].SetNextDisjoint(&m_Nodes[i+1]);
255
256        for (i = 1; i < m_uLeafCount; ++i)
257                m_Nodes[i].SetPrevDisjoint(&m_Nodes[i-1]);
258       
259        m_ptrDisjoints = &m_Nodes[0];
260
261//      Log("Initial state\n");
262//      LogMe();
263//      Log("\n");
264
265        DistFunc ClusterDist;
266        ClusterDist.SetCount(m_uNodeCount);
267        double dMaxDist = 0.0;
268        for (i = 0; i < m_uLeafCount; ++i)
269                for (unsigned j = 0; j < m_uLeafCount; ++j)
270                        {
271                        float dDist = Dist.GetDist(i, j);
272                        ClusterDist.SetDist(i, j, dDist);
273                        }
274
275        Validate(m_uLeafCount);
276
277// Iteration. N-1 joins needed to create a binary tree from N leaves.
278        for (unsigned uJoinIndex = m_uLeafCount; uJoinIndex < m_uNodeCount;
279          ++uJoinIndex)
280                {
281        // Find closest pair of clusters
282                unsigned uIndexClosest1;
283                unsigned uIndexClosest2;
284                bool bFound = false;
285                double dDistClosest = 9e99;
286                for (ClusterNode *ptrNode1 = m_ptrDisjoints; ptrNode1;
287                  ptrNode1 = ptrNode1->GetNextDisjoint())
288                        {
289                        for (ClusterNode *ptrNode2 = ptrNode1->GetNextDisjoint(); ptrNode2;
290                          ptrNode2 = ptrNode2->GetNextDisjoint())
291                                {
292                                unsigned i1 = ptrNode1->GetIndex();
293                                unsigned i2 = ptrNode2->GetIndex();
294                                double dDist = ClusterDist.GetDist(i1, i2);
295                                if (dDist < dDistClosest)
296                                        {
297                                        bFound = true;
298                                        dDistClosest = dDist;
299                                        uIndexClosest1 = i1;
300                                        uIndexClosest2 = i2;
301                                        }
302                                }
303                        }
304                assert(bFound);
305
306                ClusterNode &Join = m_Nodes[uJoinIndex];
307                ClusterNode &Child1 = m_Nodes[uIndexClosest1];
308                ClusterNode &Child2 = m_Nodes[uIndexClosest2];
309
310                Join.SetLeft(&Child1);
311                Join.SetRight(&Child2);
312                Join.SetWeight(dDistClosest);
313
314                Child1.SetParent(&Join);
315                Child2.SetParent(&Join);
316
317                DeleteFromDisjoints(&Child1);
318                DeleteFromDisjoints(&Child2);
319                AddToDisjoints(&Join);
320
321//              Log("After join %d %d\n", uIndexClosest1, uIndexClosest2);
322//              LogMe();
323
324        // Calculate distance of every remaining disjoint cluster to the
325        // new cluster created by the join
326                for (ClusterNode *ptrNode = m_ptrDisjoints; ptrNode;
327                  ptrNode = ptrNode->GetNextDisjoint())
328                        {
329                        unsigned uNodeIndex = ptrNode->GetIndex();
330                        float dDist1 = ClusterDist.GetDist(uNodeIndex, uIndexClosest1);
331                        float dDist2 = ClusterDist.GetDist(uNodeIndex, uIndexClosest2);
332                        float dDist = Min(dDist1, dDist2);
333                        ClusterDist.SetDist(uJoinIndex, uNodeIndex, dDist);
334                        }
335                Validate(uJoinIndex+1);
336                }
337        GetRoot()->GetClusterWeight();
338//      LogMe();
339        }
Note: See TracBrowser for help on using the repository browser.