| 1 | #include "muscle.h" |
|---|
| 2 | #include "cluster.h" |
|---|
| 3 | #include "distfunc.h" |
|---|
| 4 | |
|---|
| 5 | static inline float Min(float d1, float d2) |
|---|
| 6 | { |
|---|
| 7 | return d1 < d2 ? d1 : d2; |
|---|
| 8 | } |
|---|
| 9 | |
|---|
| 10 | static inline float Max(float d1, float d2) |
|---|
| 11 | { |
|---|
| 12 | return d1 > d2 ? d1 : d2; |
|---|
| 13 | } |
|---|
| 14 | |
|---|
| 15 | static inline float Mean(float d1, float d2) |
|---|
| 16 | { |
|---|
| 17 | return (float) ((d1 + d2)/2.0); |
|---|
| 18 | } |
|---|
| 19 | |
|---|
| 20 | #if _DEBUG |
|---|
| 21 | void ClusterTree::Validate(unsigned uNodeCount) |
|---|
| 22 | { |
|---|
| 23 | unsigned n; |
|---|
| 24 | ClusterNode *pNode; |
|---|
| 25 | unsigned uDisjointListCount = 0; |
|---|
| 26 | for (pNode = m_ptrDisjoints; pNode; pNode = pNode->GetNextDisjoint()) |
|---|
| 27 | { |
|---|
| 28 | ClusterNode *pPrev = pNode->GetPrevDisjoint(); |
|---|
| 29 | ClusterNode *pNext = pNode->GetNextDisjoint(); |
|---|
| 30 | if (0 != pPrev) |
|---|
| 31 | { |
|---|
| 32 | if (pPrev->GetNextDisjoint() != pNode) |
|---|
| 33 | { |
|---|
| 34 | Log("Prev->This mismatch, prev=\n"); |
|---|
| 35 | pPrev->LogMe(); |
|---|
| 36 | Log("This=\n"); |
|---|
| 37 | pNode->LogMe(); |
|---|
| 38 | Quit("ClusterTree::Validate()"); |
|---|
| 39 | } |
|---|
| 40 | } |
|---|
| 41 | else |
|---|
| 42 | { |
|---|
| 43 | if (pNode != m_ptrDisjoints) |
|---|
| 44 | { |
|---|
| 45 | Log("[%u]->prev = 0 but != m_ptrDisjoints=%d\n", |
|---|
| 46 | pNode->GetIndex(), |
|---|
| 47 | m_ptrDisjoints ? m_ptrDisjoints->GetIndex() : 0xffffffff); |
|---|
| 48 | pNode->LogMe(); |
|---|
| 49 | Quit("ClusterTree::Validate()"); |
|---|
| 50 | } |
|---|
| 51 | } |
|---|
| 52 | if (0 != pNext) |
|---|
| 53 | { |
|---|
| 54 | if (pNext->GetPrevDisjoint() != pNode) |
|---|
| 55 | { |
|---|
| 56 | Log("Next->This mismatch, next=\n"); |
|---|
| 57 | pNext->LogMe(); |
|---|
| 58 | Log("This=\n"); |
|---|
| 59 | pNode->LogMe(); |
|---|
| 60 | Quit("ClusterTree::Validate()"); |
|---|
| 61 | } |
|---|
| 62 | } |
|---|
| 63 | ++uDisjointListCount; |
|---|
| 64 | if (uDisjointListCount > m_uNodeCount) |
|---|
| 65 | Quit("Loop in disjoint list"); |
|---|
| 66 | } |
|---|
| 67 | |
|---|
| 68 | unsigned uParentlessNodeCount = 0; |
|---|
| 69 | for (n = 0; n < uNodeCount; ++n) |
|---|
| 70 | if (0 == m_Nodes[n].GetParent()) |
|---|
| 71 | ++uParentlessNodeCount; |
|---|
| 72 | |
|---|
| 73 | if (uDisjointListCount != uParentlessNodeCount) |
|---|
| 74 | Quit("Disjoints = %u Parentless = %u\n", uDisjointListCount, |
|---|
| 75 | uParentlessNodeCount); |
|---|
| 76 | } |
|---|
| 77 | #else // !_DEBUG |
|---|
| 78 | #define Validate(uNodeCount) // empty |
|---|
| 79 | #endif |
|---|
| 80 | |
|---|
| 81 | void ClusterNode::LogMe() const |
|---|
| 82 | { |
|---|
| 83 | unsigned uClusterSize = GetClusterSize(); |
|---|
| 84 | Log("[%02u] w=%5.3f CW=%5.3f LBW=%5.3f RBW=%5.3f LWT=%5.3f RWT=%5.3f L=%02d R=%02d P=%02d NxDj=%02d PvDj=%02d Sz=%02d {", |
|---|
| 85 | m_uIndex, |
|---|
| 86 | m_dWeight, |
|---|
| 87 | GetClusterWeight(), |
|---|
| 88 | GetLeftBranchWeight(), |
|---|
| 89 | GetRightBranchWeight(), |
|---|
| 90 | GetLeftWeight(), |
|---|
| 91 | GetRightWeight(), |
|---|
| 92 | m_ptrLeft ? m_ptrLeft->GetIndex() : 0xffffffff, |
|---|
| 93 | m_ptrRight ? m_ptrRight->GetIndex() : 0xffffffff, |
|---|
| 94 | m_ptrParent ? m_ptrParent->GetIndex() : 0xffffffff, |
|---|
| 95 | m_ptrNextDisjoint ? m_ptrNextDisjoint->GetIndex() : 0xffffffff, |
|---|
| 96 | m_ptrPrevDisjoint ? m_ptrPrevDisjoint->GetIndex() : 0xffffffff, |
|---|
| 97 | uClusterSize); |
|---|
| 98 | for (unsigned i = 0; i < uClusterSize; ++i) |
|---|
| 99 | Log(" %u", GetClusterLeaf(i)->GetIndex()); |
|---|
| 100 | Log(" }\n"); |
|---|
| 101 | } |
|---|
| 102 | |
|---|
| 103 | // How many leaves in the sub-tree under this node? |
|---|
| 104 | unsigned ClusterNode::GetClusterSize() const |
|---|
| 105 | { |
|---|
| 106 | unsigned uLeafCount = 0; |
|---|
| 107 | |
|---|
| 108 | if (0 == m_ptrLeft && 0 == m_ptrRight) |
|---|
| 109 | return 1; |
|---|
| 110 | |
|---|
| 111 | if (0 != m_ptrLeft) |
|---|
| 112 | uLeafCount += m_ptrLeft->GetClusterSize(); |
|---|
| 113 | if (0 != m_ptrRight) |
|---|
| 114 | uLeafCount += m_ptrRight->GetClusterSize(); |
|---|
| 115 | assert(uLeafCount > 0); |
|---|
| 116 | return uLeafCount; |
|---|
| 117 | } |
|---|
| 118 | |
|---|
| 119 | double ClusterNode::GetClusterWeight() const |
|---|
| 120 | { |
|---|
| 121 | double dWeight = 0.0; |
|---|
| 122 | if (0 != m_ptrLeft) |
|---|
| 123 | dWeight += m_ptrLeft->GetClusterWeight(); |
|---|
| 124 | if (0 != m_ptrRight) |
|---|
| 125 | dWeight += m_ptrRight->GetClusterWeight(); |
|---|
| 126 | return dWeight + GetWeight(); |
|---|
| 127 | } |
|---|
| 128 | |
|---|
| 129 | double ClusterNode::GetLeftBranchWeight() const |
|---|
| 130 | { |
|---|
| 131 | const ClusterNode *ptrLeft = GetLeft(); |
|---|
| 132 | if (0 == ptrLeft) |
|---|
| 133 | return 0.0; |
|---|
| 134 | |
|---|
| 135 | return GetWeight() - ptrLeft->GetWeight(); |
|---|
| 136 | } |
|---|
| 137 | |
|---|
| 138 | double ClusterNode::GetRightBranchWeight() const |
|---|
| 139 | { |
|---|
| 140 | const ClusterNode *ptrRight = GetRight(); |
|---|
| 141 | if (0 == ptrRight) |
|---|
| 142 | return 0.0; |
|---|
| 143 | |
|---|
| 144 | return GetWeight() - ptrRight->GetWeight(); |
|---|
| 145 | } |
|---|
| 146 | |
|---|
| 147 | double ClusterNode::GetRightWeight() const |
|---|
| 148 | { |
|---|
| 149 | const ClusterNode *ptrRight = GetRight(); |
|---|
| 150 | if (0 == ptrRight) |
|---|
| 151 | return 0.0; |
|---|
| 152 | return ptrRight->GetClusterWeight() + GetWeight(); |
|---|
| 153 | } |
|---|
| 154 | |
|---|
| 155 | double ClusterNode::GetLeftWeight() const |
|---|
| 156 | { |
|---|
| 157 | const ClusterNode *ptrLeft = GetLeft(); |
|---|
| 158 | if (0 == ptrLeft) |
|---|
| 159 | return 0.0; |
|---|
| 160 | return ptrLeft->GetClusterWeight() + GetWeight(); |
|---|
| 161 | } |
|---|
| 162 | |
|---|
| 163 | // Return n'th leaf in the sub-tree under this node. |
|---|
| 164 | const ClusterNode *ClusterNode::GetClusterLeaf(unsigned uLeafIndex) const |
|---|
| 165 | { |
|---|
| 166 | if (0 != m_ptrLeft) |
|---|
| 167 | { |
|---|
| 168 | if (0 == m_ptrRight) |
|---|
| 169 | return this; |
|---|
| 170 | |
|---|
| 171 | unsigned uLeftLeafCount = m_ptrLeft->GetClusterSize(); |
|---|
| 172 | |
|---|
| 173 | if (uLeafIndex < uLeftLeafCount) |
|---|
| 174 | return m_ptrLeft->GetClusterLeaf(uLeafIndex); |
|---|
| 175 | |
|---|
| 176 | assert(uLeafIndex >= uLeftLeafCount); |
|---|
| 177 | return m_ptrRight->GetClusterLeaf(uLeafIndex - uLeftLeafCount); |
|---|
| 178 | } |
|---|
| 179 | if (0 == m_ptrRight) |
|---|
| 180 | return this; |
|---|
| 181 | return m_ptrRight->GetClusterLeaf(uLeafIndex); |
|---|
| 182 | } |
|---|
| 183 | |
|---|
| 184 | void ClusterTree::DeleteFromDisjoints(ClusterNode *ptrNode) |
|---|
| 185 | { |
|---|
| 186 | ClusterNode *ptrPrev = ptrNode->GetPrevDisjoint(); |
|---|
| 187 | ClusterNode *ptrNext = ptrNode->GetNextDisjoint(); |
|---|
| 188 | |
|---|
| 189 | if (0 != ptrPrev) |
|---|
| 190 | ptrPrev->SetNextDisjoint(ptrNext); |
|---|
| 191 | else |
|---|
| 192 | m_ptrDisjoints = ptrNext; |
|---|
| 193 | |
|---|
| 194 | if (0 != ptrNext) |
|---|
| 195 | ptrNext->SetPrevDisjoint(ptrPrev); |
|---|
| 196 | |
|---|
| 197 | #if _DEBUG |
|---|
| 198 | // not algorithmically necessary, but improves clarity |
|---|
| 199 | // and supports Validate(). |
|---|
| 200 | ptrNode->SetPrevDisjoint(0); |
|---|
| 201 | ptrNode->SetNextDisjoint(0); |
|---|
| 202 | #endif |
|---|
| 203 | } |
|---|
| 204 | |
|---|
| 205 | void ClusterTree::AddToDisjoints(ClusterNode *ptrNode) |
|---|
| 206 | { |
|---|
| 207 | ptrNode->SetNextDisjoint(m_ptrDisjoints); |
|---|
| 208 | ptrNode->SetPrevDisjoint(0); |
|---|
| 209 | if (0 != m_ptrDisjoints) |
|---|
| 210 | m_ptrDisjoints->SetPrevDisjoint(ptrNode); |
|---|
| 211 | m_ptrDisjoints = ptrNode; |
|---|
| 212 | } |
|---|
| 213 | |
|---|
| 214 | ClusterTree::ClusterTree() |
|---|
| 215 | { |
|---|
| 216 | m_ptrDisjoints = 0; |
|---|
| 217 | m_Nodes = 0; |
|---|
| 218 | m_uNodeCount = 0; |
|---|
| 219 | } |
|---|
| 220 | |
|---|
| 221 | ClusterTree::~ClusterTree() |
|---|
| 222 | { |
|---|
| 223 | delete[] m_Nodes; |
|---|
| 224 | } |
|---|
| 225 | |
|---|
| 226 | void ClusterTree::LogMe() const |
|---|
| 227 | { |
|---|
| 228 | Log("Disjoints=%d\n", m_ptrDisjoints ? m_ptrDisjoints->GetIndex() : 0xffffffff); |
|---|
| 229 | for (unsigned i = 0; i < m_uNodeCount; ++i) |
|---|
| 230 | { |
|---|
| 231 | m_Nodes[i].LogMe(); |
|---|
| 232 | } |
|---|
| 233 | } |
|---|
| 234 | |
|---|
| 235 | ClusterNode *ClusterTree::GetRoot() const |
|---|
| 236 | { |
|---|
| 237 | return &m_Nodes[m_uNodeCount - 1]; |
|---|
| 238 | } |
|---|
| 239 | |
|---|
| 240 | // This is the UPGMA algorithm as described in Durbin et al. p166. |
|---|
| 241 | void ClusterTree::Create(const DistFunc &Dist) |
|---|
| 242 | { |
|---|
| 243 | unsigned i; |
|---|
| 244 | m_uLeafCount = Dist.GetCount(); |
|---|
| 245 | m_uNodeCount = 2*m_uLeafCount - 1; |
|---|
| 246 | |
|---|
| 247 | delete[] m_Nodes; |
|---|
| 248 | m_Nodes = new ClusterNode[m_uNodeCount]; |
|---|
| 249 | |
|---|
| 250 | for (i = 0; i < m_uNodeCount; ++i) |
|---|
| 251 | m_Nodes[i].SetIndex(i); |
|---|
| 252 | |
|---|
| 253 | for (i = 0; i < m_uLeafCount - 1; ++i) |
|---|
| 254 | m_Nodes[i].SetNextDisjoint(&m_Nodes[i+1]); |
|---|
| 255 | |
|---|
| 256 | for (i = 1; i < m_uLeafCount; ++i) |
|---|
| 257 | m_Nodes[i].SetPrevDisjoint(&m_Nodes[i-1]); |
|---|
| 258 | |
|---|
| 259 | m_ptrDisjoints = &m_Nodes[0]; |
|---|
| 260 | |
|---|
| 261 | // Log("Initial state\n"); |
|---|
| 262 | // LogMe(); |
|---|
| 263 | // Log("\n"); |
|---|
| 264 | |
|---|
| 265 | DistFunc ClusterDist; |
|---|
| 266 | ClusterDist.SetCount(m_uNodeCount); |
|---|
| 267 | double dMaxDist = 0.0; |
|---|
| 268 | for (i = 0; i < m_uLeafCount; ++i) |
|---|
| 269 | for (unsigned j = 0; j < m_uLeafCount; ++j) |
|---|
| 270 | { |
|---|
| 271 | float dDist = Dist.GetDist(i, j); |
|---|
| 272 | ClusterDist.SetDist(i, j, dDist); |
|---|
| 273 | } |
|---|
| 274 | |
|---|
| 275 | Validate(m_uLeafCount); |
|---|
| 276 | |
|---|
| 277 | // Iteration. N-1 joins needed to create a binary tree from N leaves. |
|---|
| 278 | for (unsigned uJoinIndex = m_uLeafCount; uJoinIndex < m_uNodeCount; |
|---|
| 279 | ++uJoinIndex) |
|---|
| 280 | { |
|---|
| 281 | // Find closest pair of clusters |
|---|
| 282 | unsigned uIndexClosest1; |
|---|
| 283 | unsigned uIndexClosest2; |
|---|
| 284 | bool bFound = false; |
|---|
| 285 | double dDistClosest = 9e99; |
|---|
| 286 | for (ClusterNode *ptrNode1 = m_ptrDisjoints; ptrNode1; |
|---|
| 287 | ptrNode1 = ptrNode1->GetNextDisjoint()) |
|---|
| 288 | { |
|---|
| 289 | for (ClusterNode *ptrNode2 = ptrNode1->GetNextDisjoint(); ptrNode2; |
|---|
| 290 | ptrNode2 = ptrNode2->GetNextDisjoint()) |
|---|
| 291 | { |
|---|
| 292 | unsigned i1 = ptrNode1->GetIndex(); |
|---|
| 293 | unsigned i2 = ptrNode2->GetIndex(); |
|---|
| 294 | double dDist = ClusterDist.GetDist(i1, i2); |
|---|
| 295 | if (dDist < dDistClosest) |
|---|
| 296 | { |
|---|
| 297 | bFound = true; |
|---|
| 298 | dDistClosest = dDist; |
|---|
| 299 | uIndexClosest1 = i1; |
|---|
| 300 | uIndexClosest2 = i2; |
|---|
| 301 | } |
|---|
| 302 | } |
|---|
| 303 | } |
|---|
| 304 | assert(bFound); |
|---|
| 305 | |
|---|
| 306 | ClusterNode &Join = m_Nodes[uJoinIndex]; |
|---|
| 307 | ClusterNode &Child1 = m_Nodes[uIndexClosest1]; |
|---|
| 308 | ClusterNode &Child2 = m_Nodes[uIndexClosest2]; |
|---|
| 309 | |
|---|
| 310 | Join.SetLeft(&Child1); |
|---|
| 311 | Join.SetRight(&Child2); |
|---|
| 312 | Join.SetWeight(dDistClosest); |
|---|
| 313 | |
|---|
| 314 | Child1.SetParent(&Join); |
|---|
| 315 | Child2.SetParent(&Join); |
|---|
| 316 | |
|---|
| 317 | DeleteFromDisjoints(&Child1); |
|---|
| 318 | DeleteFromDisjoints(&Child2); |
|---|
| 319 | AddToDisjoints(&Join); |
|---|
| 320 | |
|---|
| 321 | // Log("After join %d %d\n", uIndexClosest1, uIndexClosest2); |
|---|
| 322 | // LogMe(); |
|---|
| 323 | |
|---|
| 324 | // Calculate distance of every remaining disjoint cluster to the |
|---|
| 325 | // new cluster created by the join |
|---|
| 326 | for (ClusterNode *ptrNode = m_ptrDisjoints; ptrNode; |
|---|
| 327 | ptrNode = ptrNode->GetNextDisjoint()) |
|---|
| 328 | { |
|---|
| 329 | unsigned uNodeIndex = ptrNode->GetIndex(); |
|---|
| 330 | float dDist1 = ClusterDist.GetDist(uNodeIndex, uIndexClosest1); |
|---|
| 331 | float dDist2 = ClusterDist.GetDist(uNodeIndex, uIndexClosest2); |
|---|
| 332 | float dDist = Min(dDist1, dDist2); |
|---|
| 333 | ClusterDist.SetDist(uJoinIndex, uNodeIndex, dDist); |
|---|
| 334 | } |
|---|
| 335 | Validate(uJoinIndex+1); |
|---|
| 336 | } |
|---|
| 337 | GetRoot()->GetClusterWeight(); |
|---|
| 338 | // LogMe(); |
|---|
| 339 | } |
|---|