1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306
|
//
// Copyright (C) 2017 Novartis Institutes for BioMedical Research
//
// @@ All Rights Reserved @@
// This file is part of the RDKit.
// The contents are covered by the terms of the BSD license
// which is included in the file license.txt, found at the root
// of the RDKit source tree.
//
// #define DEBUG
#include "RGroupScore.h"
#include "RDGeneral/Invariant.h"
#include <vector>
#include <map>
#include <algorithm>
namespace RDKit {
// stupid total score
// This has to handle all permutations and doesn't do anything terribly smart
// For r-groups with large symmetries, this can take way too long.
double RGroupScorer::matchScore(
const std::vector<size_t> &permutation,
const std::vector<std::vector<RGroupMatch>> &matches,
const std::set<int> &labels) {
PRECONDITION(permutation.size() <= matches.size(),
"permutation.size() should be <= matches.size()");
double score = 0.;
const std::string EMPTY_RGROUP = "";
size_t offset = matches.size() - permutation.size();
#ifdef DEBUG
std::cerr << "---------------------------------------------------"
<< std::endl;
std::cerr << "Scoring permutation "
<< " num matches: " << matches.size() << std::endl;
BOOST_LOG(rdDebugLog) << "Scoring" << std::endl;
for (size_t m = 0; m < permutation.size(); ++m) { // for each molecule
BOOST_LOG(rdDebugLog) << "Molecule " << m << " "
<< matches[m + offset].at(permutation[m]).toString()
<< std::endl;
}
#endif
// What is the largest rgroup count at any label
restoreInitialState();
std::map<int, int> num_rgroups;
for (size_t m = 0; m < permutation.size(); ++m) { // for each molecule
for (auto l : matches[m + offset].at(permutation[m]).rgroups) {
d_current.N = std::max(d_current.N, ++num_rgroups[l.first]);
}
}
// for each label (r-group)
for (auto l : labels) {
#ifdef DEBUG
std::cerr << "Label: " << l << std::endl;
#endif
auto &labelData = d_current.labelDataMap[l];
for (size_t m = 0; m < permutation.size(); ++m) { // for each molecule
const auto &match = matches[m + offset][permutation[m]];
auto rg = match.rgroups.find(l);
if (rg == match.rgroups.end()) {
continue;
}
++labelData.numRGroups;
if (rg->second->is_linker) {
++labelData.linkerMatchSet[rg->second->attachments];
#ifdef DEBUG
std::cerr << " combined: " << MolToSmiles(*rg->second->combinedMol)
<< std::endl;
std::cerr << " RGroup: " << rg->second->smiles << " "
<< rg->second->is_hydrogen << std::endl;
;
#endif
}
#ifdef DEBUG
std::cerr << l << " rgroup count" << labelData.numRGroups << " num atoms"
<< rg->second->combinedMol->getNumAtoms(false)
// looks like code has been edited round this define
// << " score: " << count
<< std::endl;
#endif
size_t i = 0;
for (const auto &smiles : rg->second->smilesVect) {
if (i == labelData.matchSetVect.size()) {
labelData.matchSetVect.resize(i + 1);
}
unsigned int &count = labelData.matchSetVect[i][smiles];
++count;
#ifdef DEBUG
std::cerr << i << " smiles:" << smiles << " " << count << std::endl;
std::cerr << " Linker Score: "
<< labelData.linkerMatchSet[rg->second->attachments]
<< std::endl;
#endif
++i;
}
}
double tempScore = 0.;
for (auto &matchSet : labelData.matchSetVect) {
// get the counts for each rgroup found and sort in reverse order
// If we don't have as many rgroups as the largest set add a empty ones
if (d_current.N - labelData.numRGroups > 0) {
matchSet[EMPTY_RGROUP] = d_current.N - labelData.numRGroups;
}
std::vector<unsigned int> equivalentRGroupCount;
std::transform(matchSet.begin(), matchSet.end(),
std::back_inserter(equivalentRGroupCount),
[](const std::pair<std::string, unsigned int> &p) {
return p.second;
});
std::sort(equivalentRGroupCount.begin(), equivalentRGroupCount.end(),
std::greater<unsigned int>());
// score the sets from the largest to the smallest
// each smaller set gets penalized (i+1) below
for (size_t i = 0; i < equivalentRGroupCount.size(); ++i) {
auto lscore = static_cast<double>(equivalentRGroupCount[i]) /
static_cast<double>(((i + 1) * matches.size()));
tempScore += lscore * lscore;
#ifdef DEBUG
std::cerr << " lscore^2 " << i << ": " << lscore * lscore
<< std::endl;
#endif
}
// make sure to rescale groups like [*:1].[*:1]C otherwise this will be
// double counted
// WE SHOULD PROBABLY REJECT THESE OUTRIGHT
tempScore /= static_cast<double>(labelData.matchSetVect.size());
}
// overweight linkers with the same attachments points....
// because these belong to 2 (or more) rgroups we really want these to stay
// the size of the set is the number of labels that are being used
// ** this heuristic really should be taken care of above **
unsigned int maxLinkerMatches = 0;
for (const auto &it : labelData.linkerMatchSet) {
if (it.first.size() > 1 || it.second > 1) {
if (it.first.size() > maxLinkerMatches) {
maxLinkerMatches = std::max(it.first.size(), it.second);
}
}
}
#ifdef DEBUG
std::cerr << "Max Linker Matches :" << maxLinkerMatches << std::endl;
#endif
double increment = 1.0; // no change in score
double linkerIncrement = 1.0; // no change in score
if (maxLinkerMatches) {
linkerIncrement = static_cast<double>(maxLinkerMatches) /
static_cast<double>(matches.size());
} else {
increment = tempScore;
}
score += increment * linkerIncrement;
#ifdef DEBUG
std::cerr << "Increment: " << increment
<< " Linker_Increment: " << linkerIncrement << std::endl;
std::cerr << "increment*linkerIncrement: " << increment * linkerIncrement
<< std::endl;
std::cerr << "Score = " << score << std::endl;
#endif
} // end for each label
#ifdef DEBUG
BOOST_LOG(rdDebugLog) << score << std::endl;
#endif
return score;
}
RGroupScorer::RGroupScorer(const std::vector<std::vector<size_t>> &permutations,
double score) {
d_bestScore = score;
for (const auto &permutation : permutations) {
pushTieToStore(permutation);
}
if (!d_store.empty()) {
d_saved = d_store.front();
}
}
void RGroupScorer::setBestPermutation(const std::vector<size_t> &permutation,
double score) {
d_bestScore = score;
d_current.permutation = permutation;
d_saved = d_current;
}
void RGroupScorer::startProcessing() {
d_initial = d_saved;
d_bestScore = -std::numeric_limits<double>::max();
clearTieStore();
}
void RGroupScorer::breakTies(
const std::vector<std::vector<RGroupMatch>> &matches,
const std::set<int> &labels,
const std::unique_ptr<CartesianProduct> &iterator,
const std::chrono::steady_clock::time_point &t0, double timeout) {
size_t maxPermValue = 0;
d_current = d_saved;
d_current.numAddedRGroups = labels.size();
std::vector<int> largestHeavyCounts;
largestHeavyCounts.reserve(labels.size());
std::vector<int> orderedLabels;
orderedLabels.reserve(labels.size());
std::copy_if(labels.begin(), labels.end(), std::back_inserter(orderedLabels),
[](const int &i) { return !(i < 0); });
std::copy_if(labels.rbegin(), labels.rend(), std::back_inserter(orderedLabels),
[](const int &i) { return (i < 0); });
// We only care about the sign of the ordered labels,
// not about their value, so we convert the ordered map
// into a vector for comparing with the tied permutations
// If there is a change in sign, then it means a new
// label was added compared to the cached version,
// so we need to add a new counter initialized to 0
auto it = d_current.heavyCountPerLabel.begin();
for (auto label : orderedLabels) {
int count = 0;
if (it != d_current.heavyCountPerLabel.end()) {
if (!((it->first > 0) ^ (label > 0))) {
count = it->second;
}
++it;
}
largestHeavyCounts.push_back(count);
}
std::vector<int> initialHeavyCounts(largestHeavyCounts);
while (!d_store.empty()) {
auto &state = d_store.front();
std::vector<int> heavyCounts(initialHeavyCounts);
state.computeTieBreakingCriteria(matches, orderedLabels, heavyCounts);
#ifdef DEBUG
std::cerr << "tiedPermutation ";
for (const auto &t : state.permutation) {
std::cerr << t << ",";
}
std::cerr << " orderedLabels ";
for (const auto &l : orderedLabels) {
std::cerr << l << ",";
}
std::cerr << " heavyCounts ";
for (auto hc : heavyCounts) {
std::cerr << hc << ",";
}
std::cerr << " largestHeavyCounts ";
for (auto hc : largestHeavyCounts) {
std::cerr << hc << ",";
}
std::cerr << " state.numMatchedUserRGroups " << state.numMatchedUserRGroups
<< " d_current.numMatchedUserRGroups "
<< d_current.numMatchedUserRGroups << ", state.numAddedRGroups "
<< state.numAddedRGroups << ", d_current.numAddedRGroups "
<< d_current.numAddedRGroups << std::endl;
#endif
size_t permValue =
iterator ? iterator->value(state.permutation) : maxPermValue;
if (state.numMatchedUserRGroups > d_current.numMatchedUserRGroups) {
d_current = state;
largestHeavyCounts = heavyCounts;
maxPermValue = permValue;
} else if (state.numMatchedUserRGroups == d_current.numMatchedUserRGroups) {
if (state.numAddedRGroups < d_current.numAddedRGroups) {
d_current = state;
largestHeavyCounts = heavyCounts;
maxPermValue = permValue;
} else if (state.numAddedRGroups == d_current.numAddedRGroups) {
if (heavyCounts > largestHeavyCounts) {
d_current = state;
largestHeavyCounts = heavyCounts;
maxPermValue = permValue;
} else if (heavyCounts == largestHeavyCounts) {
if (permValue > maxPermValue) {
d_current = state;
largestHeavyCounts = heavyCounts;
maxPermValue = permValue;
}
}
}
}
checkForTimeout(t0, timeout);
d_store.pop_front();
}
// convert back the heavy count vector into an ordered map
// to store it in the saved cache
d_current.heavyCountPerLabel.clear();
auto count = largestHeavyCounts.begin();
for (auto label : orderedLabels) {
d_current.heavyCountPerLabel[label] = *count++;
}
d_saved = d_current;
}
void RGroupScorer::pushTieToStore(const std::vector<size_t> &permutation) {
d_current.permutation = permutation;
d_store.push_back(d_current);
}
void RGroupScorer::clearTieStore() { d_store.clear(); }
} // namespace RDKit
|