1 | #include "muscle.h" |
---|
2 | #include "cluster.h" |
---|
3 | #include "distfunc.h" |
---|
4 | |
---|
5 | static inline float Min(float d1, float d2) |
---|
6 | { |
---|
7 | return d1 < d2 ? d1 : d2; |
---|
8 | } |
---|
9 | |
---|
10 | static inline float Max(float d1, float d2) |
---|
11 | { |
---|
12 | return d1 > d2 ? d1 : d2; |
---|
13 | } |
---|
14 | |
---|
15 | static inline float Mean(float d1, float d2) |
---|
16 | { |
---|
17 | return (float) ((d1 + d2)/2.0); |
---|
18 | } |
---|
19 | |
---|
20 | #if _DEBUG |
---|
21 | void ClusterTree::Validate(unsigned uNodeCount) |
---|
22 | { |
---|
23 | unsigned n; |
---|
24 | ClusterNode *pNode; |
---|
25 | unsigned uDisjointListCount = 0; |
---|
26 | for (pNode = m_ptrDisjoints; pNode; pNode = pNode->GetNextDisjoint()) |
---|
27 | { |
---|
28 | ClusterNode *pPrev = pNode->GetPrevDisjoint(); |
---|
29 | ClusterNode *pNext = pNode->GetNextDisjoint(); |
---|
30 | if (0 != pPrev) |
---|
31 | { |
---|
32 | if (pPrev->GetNextDisjoint() != pNode) |
---|
33 | { |
---|
34 | Log("Prev->This mismatch, prev=\n"); |
---|
35 | pPrev->LogMe(); |
---|
36 | Log("This=\n"); |
---|
37 | pNode->LogMe(); |
---|
38 | Quit("ClusterTree::Validate()"); |
---|
39 | } |
---|
40 | } |
---|
41 | else |
---|
42 | { |
---|
43 | if (pNode != m_ptrDisjoints) |
---|
44 | { |
---|
45 | Log("[%u]->prev = 0 but != m_ptrDisjoints=%d\n", |
---|
46 | pNode->GetIndex(), |
---|
47 | m_ptrDisjoints ? m_ptrDisjoints->GetIndex() : 0xffffffff); |
---|
48 | pNode->LogMe(); |
---|
49 | Quit("ClusterTree::Validate()"); |
---|
50 | } |
---|
51 | } |
---|
52 | if (0 != pNext) |
---|
53 | { |
---|
54 | if (pNext->GetPrevDisjoint() != pNode) |
---|
55 | { |
---|
56 | Log("Next->This mismatch, next=\n"); |
---|
57 | pNext->LogMe(); |
---|
58 | Log("This=\n"); |
---|
59 | pNode->LogMe(); |
---|
60 | Quit("ClusterTree::Validate()"); |
---|
61 | } |
---|
62 | } |
---|
63 | ++uDisjointListCount; |
---|
64 | if (uDisjointListCount > m_uNodeCount) |
---|
65 | Quit("Loop in disjoint list"); |
---|
66 | } |
---|
67 | |
---|
68 | unsigned uParentlessNodeCount = 0; |
---|
69 | for (n = 0; n < uNodeCount; ++n) |
---|
70 | if (0 == m_Nodes[n].GetParent()) |
---|
71 | ++uParentlessNodeCount; |
---|
72 | |
---|
73 | if (uDisjointListCount != uParentlessNodeCount) |
---|
74 | Quit("Disjoints = %u Parentless = %u\n", uDisjointListCount, |
---|
75 | uParentlessNodeCount); |
---|
76 | } |
---|
77 | #else // !_DEBUG |
---|
78 | #define Validate(uNodeCount) // empty |
---|
79 | #endif |
---|
80 | |
---|
81 | void ClusterNode::LogMe() const |
---|
82 | { |
---|
83 | unsigned uClusterSize = GetClusterSize(); |
---|
84 | Log("[%02u] w=%5.3f CW=%5.3f LBW=%5.3f RBW=%5.3f LWT=%5.3f RWT=%5.3f L=%02d R=%02d P=%02d NxDj=%02d PvDj=%02d Sz=%02d {", |
---|
85 | m_uIndex, |
---|
86 | m_dWeight, |
---|
87 | GetClusterWeight(), |
---|
88 | GetLeftBranchWeight(), |
---|
89 | GetRightBranchWeight(), |
---|
90 | GetLeftWeight(), |
---|
91 | GetRightWeight(), |
---|
92 | m_ptrLeft ? m_ptrLeft->GetIndex() : 0xffffffff, |
---|
93 | m_ptrRight ? m_ptrRight->GetIndex() : 0xffffffff, |
---|
94 | m_ptrParent ? m_ptrParent->GetIndex() : 0xffffffff, |
---|
95 | m_ptrNextDisjoint ? m_ptrNextDisjoint->GetIndex() : 0xffffffff, |
---|
96 | m_ptrPrevDisjoint ? m_ptrPrevDisjoint->GetIndex() : 0xffffffff, |
---|
97 | uClusterSize); |
---|
98 | for (unsigned i = 0; i < uClusterSize; ++i) |
---|
99 | Log(" %u", GetClusterLeaf(i)->GetIndex()); |
---|
100 | Log(" }\n"); |
---|
101 | } |
---|
102 | |
---|
103 | // How many leaves in the sub-tree under this node? |
---|
104 | unsigned ClusterNode::GetClusterSize() const |
---|
105 | { |
---|
106 | unsigned uLeafCount = 0; |
---|
107 | |
---|
108 | if (0 == m_ptrLeft && 0 == m_ptrRight) |
---|
109 | return 1; |
---|
110 | |
---|
111 | if (0 != m_ptrLeft) |
---|
112 | uLeafCount += m_ptrLeft->GetClusterSize(); |
---|
113 | if (0 != m_ptrRight) |
---|
114 | uLeafCount += m_ptrRight->GetClusterSize(); |
---|
115 | assert(uLeafCount > 0); |
---|
116 | return uLeafCount; |
---|
117 | } |
---|
118 | |
---|
119 | double ClusterNode::GetClusterWeight() const |
---|
120 | { |
---|
121 | double dWeight = 0.0; |
---|
122 | if (0 != m_ptrLeft) |
---|
123 | dWeight += m_ptrLeft->GetClusterWeight(); |
---|
124 | if (0 != m_ptrRight) |
---|
125 | dWeight += m_ptrRight->GetClusterWeight(); |
---|
126 | return dWeight + GetWeight(); |
---|
127 | } |
---|
128 | |
---|
129 | double ClusterNode::GetLeftBranchWeight() const |
---|
130 | { |
---|
131 | const ClusterNode *ptrLeft = GetLeft(); |
---|
132 | if (0 == ptrLeft) |
---|
133 | return 0.0; |
---|
134 | |
---|
135 | return GetWeight() - ptrLeft->GetWeight(); |
---|
136 | } |
---|
137 | |
---|
138 | double ClusterNode::GetRightBranchWeight() const |
---|
139 | { |
---|
140 | const ClusterNode *ptrRight = GetRight(); |
---|
141 | if (0 == ptrRight) |
---|
142 | return 0.0; |
---|
143 | |
---|
144 | return GetWeight() - ptrRight->GetWeight(); |
---|
145 | } |
---|
146 | |
---|
147 | double ClusterNode::GetRightWeight() const |
---|
148 | { |
---|
149 | const ClusterNode *ptrRight = GetRight(); |
---|
150 | if (0 == ptrRight) |
---|
151 | return 0.0; |
---|
152 | return ptrRight->GetClusterWeight() + GetWeight(); |
---|
153 | } |
---|
154 | |
---|
155 | double ClusterNode::GetLeftWeight() const |
---|
156 | { |
---|
157 | const ClusterNode *ptrLeft = GetLeft(); |
---|
158 | if (0 == ptrLeft) |
---|
159 | return 0.0; |
---|
160 | return ptrLeft->GetClusterWeight() + GetWeight(); |
---|
161 | } |
---|
162 | |
---|
163 | // Return n'th leaf in the sub-tree under this node. |
---|
164 | const ClusterNode *ClusterNode::GetClusterLeaf(unsigned uLeafIndex) const |
---|
165 | { |
---|
166 | if (0 != m_ptrLeft) |
---|
167 | { |
---|
168 | if (0 == m_ptrRight) |
---|
169 | return this; |
---|
170 | |
---|
171 | unsigned uLeftLeafCount = m_ptrLeft->GetClusterSize(); |
---|
172 | |
---|
173 | if (uLeafIndex < uLeftLeafCount) |
---|
174 | return m_ptrLeft->GetClusterLeaf(uLeafIndex); |
---|
175 | |
---|
176 | assert(uLeafIndex >= uLeftLeafCount); |
---|
177 | return m_ptrRight->GetClusterLeaf(uLeafIndex - uLeftLeafCount); |
---|
178 | } |
---|
179 | if (0 == m_ptrRight) |
---|
180 | return this; |
---|
181 | return m_ptrRight->GetClusterLeaf(uLeafIndex); |
---|
182 | } |
---|
183 | |
---|
184 | void ClusterTree::DeleteFromDisjoints(ClusterNode *ptrNode) |
---|
185 | { |
---|
186 | ClusterNode *ptrPrev = ptrNode->GetPrevDisjoint(); |
---|
187 | ClusterNode *ptrNext = ptrNode->GetNextDisjoint(); |
---|
188 | |
---|
189 | if (0 != ptrPrev) |
---|
190 | ptrPrev->SetNextDisjoint(ptrNext); |
---|
191 | else |
---|
192 | m_ptrDisjoints = ptrNext; |
---|
193 | |
---|
194 | if (0 != ptrNext) |
---|
195 | ptrNext->SetPrevDisjoint(ptrPrev); |
---|
196 | |
---|
197 | #if _DEBUG |
---|
198 | // not algorithmically necessary, but improves clarity |
---|
199 | // and supports Validate(). |
---|
200 | ptrNode->SetPrevDisjoint(0); |
---|
201 | ptrNode->SetNextDisjoint(0); |
---|
202 | #endif |
---|
203 | } |
---|
204 | |
---|
205 | void ClusterTree::AddToDisjoints(ClusterNode *ptrNode) |
---|
206 | { |
---|
207 | ptrNode->SetNextDisjoint(m_ptrDisjoints); |
---|
208 | ptrNode->SetPrevDisjoint(0); |
---|
209 | if (0 != m_ptrDisjoints) |
---|
210 | m_ptrDisjoints->SetPrevDisjoint(ptrNode); |
---|
211 | m_ptrDisjoints = ptrNode; |
---|
212 | } |
---|
213 | |
---|
214 | ClusterTree::ClusterTree() |
---|
215 | { |
---|
216 | m_ptrDisjoints = 0; |
---|
217 | m_Nodes = 0; |
---|
218 | m_uNodeCount = 0; |
---|
219 | } |
---|
220 | |
---|
221 | ClusterTree::~ClusterTree() |
---|
222 | { |
---|
223 | delete[] m_Nodes; |
---|
224 | } |
---|
225 | |
---|
226 | void ClusterTree::LogMe() const |
---|
227 | { |
---|
228 | Log("Disjoints=%d\n", m_ptrDisjoints ? m_ptrDisjoints->GetIndex() : 0xffffffff); |
---|
229 | for (unsigned i = 0; i < m_uNodeCount; ++i) |
---|
230 | { |
---|
231 | m_Nodes[i].LogMe(); |
---|
232 | } |
---|
233 | } |
---|
234 | |
---|
235 | ClusterNode *ClusterTree::GetRoot() const |
---|
236 | { |
---|
237 | return &m_Nodes[m_uNodeCount - 1]; |
---|
238 | } |
---|
239 | |
---|
240 | // This is the UPGMA algorithm as described in Durbin et al. p166. |
---|
241 | void ClusterTree::Create(const DistFunc &Dist) |
---|
242 | { |
---|
243 | unsigned i; |
---|
244 | m_uLeafCount = Dist.GetCount(); |
---|
245 | m_uNodeCount = 2*m_uLeafCount - 1; |
---|
246 | |
---|
247 | delete[] m_Nodes; |
---|
248 | m_Nodes = new ClusterNode[m_uNodeCount]; |
---|
249 | |
---|
250 | for (i = 0; i < m_uNodeCount; ++i) |
---|
251 | m_Nodes[i].SetIndex(i); |
---|
252 | |
---|
253 | for (i = 0; i < m_uLeafCount - 1; ++i) |
---|
254 | m_Nodes[i].SetNextDisjoint(&m_Nodes[i+1]); |
---|
255 | |
---|
256 | for (i = 1; i < m_uLeafCount; ++i) |
---|
257 | m_Nodes[i].SetPrevDisjoint(&m_Nodes[i-1]); |
---|
258 | |
---|
259 | m_ptrDisjoints = &m_Nodes[0]; |
---|
260 | |
---|
261 | // Log("Initial state\n"); |
---|
262 | // LogMe(); |
---|
263 | // Log("\n"); |
---|
264 | |
---|
265 | DistFunc ClusterDist; |
---|
266 | ClusterDist.SetCount(m_uNodeCount); |
---|
267 | double dMaxDist = 0.0; |
---|
268 | for (i = 0; i < m_uLeafCount; ++i) |
---|
269 | for (unsigned j = 0; j < m_uLeafCount; ++j) |
---|
270 | { |
---|
271 | float dDist = Dist.GetDist(i, j); |
---|
272 | ClusterDist.SetDist(i, j, dDist); |
---|
273 | } |
---|
274 | |
---|
275 | Validate(m_uLeafCount); |
---|
276 | |
---|
277 | // Iteration. N-1 joins needed to create a binary tree from N leaves. |
---|
278 | for (unsigned uJoinIndex = m_uLeafCount; uJoinIndex < m_uNodeCount; |
---|
279 | ++uJoinIndex) |
---|
280 | { |
---|
281 | // Find closest pair of clusters |
---|
282 | unsigned uIndexClosest1; |
---|
283 | unsigned uIndexClosest2; |
---|
284 | bool bFound = false; |
---|
285 | double dDistClosest = 9e99; |
---|
286 | for (ClusterNode *ptrNode1 = m_ptrDisjoints; ptrNode1; |
---|
287 | ptrNode1 = ptrNode1->GetNextDisjoint()) |
---|
288 | { |
---|
289 | for (ClusterNode *ptrNode2 = ptrNode1->GetNextDisjoint(); ptrNode2; |
---|
290 | ptrNode2 = ptrNode2->GetNextDisjoint()) |
---|
291 | { |
---|
292 | unsigned i1 = ptrNode1->GetIndex(); |
---|
293 | unsigned i2 = ptrNode2->GetIndex(); |
---|
294 | double dDist = ClusterDist.GetDist(i1, i2); |
---|
295 | if (dDist < dDistClosest) |
---|
296 | { |
---|
297 | bFound = true; |
---|
298 | dDistClosest = dDist; |
---|
299 | uIndexClosest1 = i1; |
---|
300 | uIndexClosest2 = i2; |
---|
301 | } |
---|
302 | } |
---|
303 | } |
---|
304 | assert(bFound); |
---|
305 | |
---|
306 | ClusterNode &Join = m_Nodes[uJoinIndex]; |
---|
307 | ClusterNode &Child1 = m_Nodes[uIndexClosest1]; |
---|
308 | ClusterNode &Child2 = m_Nodes[uIndexClosest2]; |
---|
309 | |
---|
310 | Join.SetLeft(&Child1); |
---|
311 | Join.SetRight(&Child2); |
---|
312 | Join.SetWeight(dDistClosest); |
---|
313 | |
---|
314 | Child1.SetParent(&Join); |
---|
315 | Child2.SetParent(&Join); |
---|
316 | |
---|
317 | DeleteFromDisjoints(&Child1); |
---|
318 | DeleteFromDisjoints(&Child2); |
---|
319 | AddToDisjoints(&Join); |
---|
320 | |
---|
321 | // Log("After join %d %d\n", uIndexClosest1, uIndexClosest2); |
---|
322 | // LogMe(); |
---|
323 | |
---|
324 | // Calculate distance of every remaining disjoint cluster to the |
---|
325 | // new cluster created by the join |
---|
326 | for (ClusterNode *ptrNode = m_ptrDisjoints; ptrNode; |
---|
327 | ptrNode = ptrNode->GetNextDisjoint()) |
---|
328 | { |
---|
329 | unsigned uNodeIndex = ptrNode->GetIndex(); |
---|
330 | float dDist1 = ClusterDist.GetDist(uNodeIndex, uIndexClosest1); |
---|
331 | float dDist2 = ClusterDist.GetDist(uNodeIndex, uIndexClosest2); |
---|
332 | float dDist = Min(dDist1, dDist2); |
---|
333 | ClusterDist.SetDist(uJoinIndex, uNodeIndex, dDist); |
---|
334 | } |
---|
335 | Validate(uJoinIndex+1); |
---|
336 | } |
---|
337 | GetRoot()->GetClusterWeight(); |
---|
338 | // LogMe(); |
---|
339 | } |
---|