File: RGroupScore.cpp

package info (click to toggle)
rdkit 202209.3-1
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 203,880 kB
  • sloc: cpp: 334,239; python: 80,247; ansic: 24,579; java: 7,667; sql: 2,123; yacc: 1,884; javascript: 1,358; lex: 1,260; makefile: 576; xml: 229; fortran: 183; cs: 181; sh: 101
file content (306 lines) | stat: -rw-r--r-- 11,160 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
//
//  Copyright (C) 2017 Novartis Institutes for BioMedical Research
//
//   @@ All Rights Reserved @@
//  This file is part of the RDKit.
//  The contents are covered by the terms of the BSD license
//  which is included in the file license.txt, found at the root
//  of the RDKit source tree.
//

// #define DEBUG

#include "RGroupScore.h"
#include "RDGeneral/Invariant.h"
#include <vector>
#include <map>
#include <algorithm>
namespace RDKit {

// stupid total score
// This has to handle all permutations and doesn't do anything terribly smart
// For r-groups with large symmetries, this can take way too long.
double RGroupScorer::matchScore(
    const std::vector<size_t> &permutation,
    const std::vector<std::vector<RGroupMatch>> &matches,
    const std::set<int> &labels) {
  PRECONDITION(permutation.size() <= matches.size(),
               "permutation.size() should be <= matches.size()");
  double score = 0.;
  const std::string EMPTY_RGROUP = "";
  size_t offset = matches.size() - permutation.size();
#ifdef DEBUG
  std::cerr << "---------------------------------------------------"
            << std::endl;
  std::cerr << "Scoring permutation "
            << " num matches: " << matches.size() << std::endl;

  BOOST_LOG(rdDebugLog) << "Scoring" << std::endl;
  for (size_t m = 0; m < permutation.size(); ++m) {  // for each molecule
    BOOST_LOG(rdDebugLog) << "Molecule " << m << " "
                          << matches[m + offset].at(permutation[m]).toString()
                          << std::endl;
  }
#endif

  // What is the largest rgroup count at any label
  restoreInitialState();
  std::map<int, int> num_rgroups;
  for (size_t m = 0; m < permutation.size(); ++m) {  // for each molecule
    for (auto l : matches[m + offset].at(permutation[m]).rgroups) {
      d_current.N = std::max(d_current.N, ++num_rgroups[l.first]);
    }
  }
  // for each label (r-group)
  for (auto l : labels) {
#ifdef DEBUG
    std::cerr << "Label: " << l << std::endl;
#endif
    auto &labelData = d_current.labelDataMap[l];
    for (size_t m = 0; m < permutation.size(); ++m) {  // for each molecule

      const auto &match = matches[m + offset][permutation[m]];
      auto rg = match.rgroups.find(l);
      if (rg == match.rgroups.end()) {
        continue;
      }
      ++labelData.numRGroups;
      if (rg->second->is_linker) {
        ++labelData.linkerMatchSet[rg->second->attachments];
#ifdef DEBUG
        std::cerr << "  combined: " << MolToSmiles(*rg->second->combinedMol)
                  << std::endl;
        std::cerr << " RGroup: " << rg->second->smiles << " "
                  << rg->second->is_hydrogen << std::endl;
        ;
#endif
      }
#ifdef DEBUG
      std::cerr << l << " rgroup count" << labelData.numRGroups << " num atoms"
                << rg->second->combinedMol->getNumAtoms(false)
                // looks like code has been edited round this define
                // << " score: " << count
                << std::endl;
#endif
      size_t i = 0;
      for (const auto &smiles : rg->second->smilesVect) {
        if (i == labelData.matchSetVect.size()) {
          labelData.matchSetVect.resize(i + 1);
        }
        unsigned int &count = labelData.matchSetVect[i][smiles];
        ++count;
#ifdef DEBUG
        std::cerr << i << " smiles:" << smiles << " " << count << std::endl;
        std::cerr << " Linker Score: "
                  << labelData.linkerMatchSet[rg->second->attachments]
                  << std::endl;
#endif
        ++i;
      }
    }

    double tempScore = 0.;
    for (auto &matchSet : labelData.matchSetVect) {
      // get the counts for each rgroup found and sort in reverse order
      // If we don't have as many rgroups as the largest set add a empty ones
      if (d_current.N - labelData.numRGroups > 0) {
        matchSet[EMPTY_RGROUP] = d_current.N - labelData.numRGroups;
      }
      std::vector<unsigned int> equivalentRGroupCount;

      std::transform(matchSet.begin(), matchSet.end(),
                     std::back_inserter(equivalentRGroupCount),
                     [](const std::pair<std::string, unsigned int> &p) {
                       return p.second;
                     });
      std::sort(equivalentRGroupCount.begin(), equivalentRGroupCount.end(),
                std::greater<unsigned int>());

      // score the sets from the largest to the smallest
      // each smaller set gets penalized (i+1) below
      for (size_t i = 0; i < equivalentRGroupCount.size(); ++i) {
        auto lscore = static_cast<double>(equivalentRGroupCount[i]) /
                      static_cast<double>(((i + 1) * matches.size()));
        tempScore += lscore * lscore;
#ifdef DEBUG
        std::cerr << "    lscore^2 " << i << ": " << lscore * lscore
                  << std::endl;
#endif
      }
      // make sure to rescale groups like [*:1].[*:1]C otherwise this will be
      // double counted
      // WE SHOULD PROBABLY REJECT THESE OUTRIGHT
      tempScore /= static_cast<double>(labelData.matchSetVect.size());
    }

    // overweight linkers with the same attachments points....
    //  because these belong to 2 (or more) rgroups we really want these to stay
    //  the size of the set is the number of labels that are being used
    //  ** this heuristic really should be taken care of above **
    unsigned int maxLinkerMatches = 0;
    for (const auto &it : labelData.linkerMatchSet) {
      if (it.first.size() > 1 || it.second > 1) {
        if (it.first.size() > maxLinkerMatches) {
          maxLinkerMatches = std::max(it.first.size(), it.second);
        }
      }
    }
#ifdef DEBUG
    std::cerr << "Max Linker Matches :" << maxLinkerMatches << std::endl;
#endif
    double increment = 1.0;        // no change in score
    double linkerIncrement = 1.0;  // no change in score
    if (maxLinkerMatches) {
      linkerIncrement = static_cast<double>(maxLinkerMatches) /
                        static_cast<double>(matches.size());
    } else {
      increment = tempScore;
    }
    score += increment * linkerIncrement;
#ifdef DEBUG
    std::cerr << "Increment: " << increment
              << " Linker_Increment: " << linkerIncrement << std::endl;
    std::cerr << "increment*linkerIncrement: " << increment * linkerIncrement
              << std::endl;
    std::cerr << "Score = " << score << std::endl;
#endif
  }  // end for each label

#ifdef DEBUG
  BOOST_LOG(rdDebugLog) << score << std::endl;
#endif

  return score;
}

RGroupScorer::RGroupScorer(const std::vector<std::vector<size_t>> &permutations,
                           double score) {
  d_bestScore = score;
  for (const auto &permutation : permutations) {
    pushTieToStore(permutation);
  }
  if (!d_store.empty()) {
    d_saved = d_store.front();
  }
}

void RGroupScorer::setBestPermutation(const std::vector<size_t> &permutation,
                                      double score) {
  d_bestScore = score;
  d_current.permutation = permutation;
  d_saved = d_current;
}

void RGroupScorer::startProcessing() {
  d_initial = d_saved;
  d_bestScore = -std::numeric_limits<double>::max();
  clearTieStore();
}

void RGroupScorer::breakTies(
    const std::vector<std::vector<RGroupMatch>> &matches,
    const std::set<int> &labels,
    const std::unique_ptr<CartesianProduct> &iterator,
    const std::chrono::steady_clock::time_point &t0, double timeout) {
  size_t maxPermValue = 0;
  d_current = d_saved;
  d_current.numAddedRGroups = labels.size();
  std::vector<int> largestHeavyCounts;
  largestHeavyCounts.reserve(labels.size());
  std::vector<int> orderedLabels;
  orderedLabels.reserve(labels.size());
  std::copy_if(labels.begin(), labels.end(), std::back_inserter(orderedLabels),
               [](const int &i) { return !(i < 0); });
  std::copy_if(labels.rbegin(), labels.rend(), std::back_inserter(orderedLabels),
               [](const int &i) { return (i < 0); });
  // We only care about the sign of the ordered labels,
  // not about their value, so we convert the ordered map
  // into a vector for comparing with the tied permutations
  // If there is a change in sign, then it means a new
  // label was added compared to the cached version,
  // so we need to add a new counter initialized to 0
  auto it = d_current.heavyCountPerLabel.begin();
  for (auto label : orderedLabels) {
    int count = 0;
    if (it != d_current.heavyCountPerLabel.end()) {
      if (!((it->first > 0) ^ (label > 0))) {
        count = it->second;
      }
      ++it;
    }
    largestHeavyCounts.push_back(count);
  }
  std::vector<int> initialHeavyCounts(largestHeavyCounts);
  while (!d_store.empty()) {
    auto &state = d_store.front();
    std::vector<int> heavyCounts(initialHeavyCounts);
    state.computeTieBreakingCriteria(matches, orderedLabels, heavyCounts);
#ifdef DEBUG
    std::cerr << "tiedPermutation ";
    for (const auto &t : state.permutation) {
      std::cerr << t << ",";
    }
    std::cerr << " orderedLabels ";
    for (const auto &l : orderedLabels) {
      std::cerr << l << ",";
    }
    std::cerr << " heavyCounts ";
    for (auto hc : heavyCounts) {
      std::cerr << hc << ",";
    }
    std::cerr << " largestHeavyCounts ";
    for (auto hc : largestHeavyCounts) {
      std::cerr << hc << ",";
    }
    std::cerr << " state.numMatchedUserRGroups " << state.numMatchedUserRGroups
              << " d_current.numMatchedUserRGroups "
              << d_current.numMatchedUserRGroups << ", state.numAddedRGroups "
              << state.numAddedRGroups << ", d_current.numAddedRGroups "
              << d_current.numAddedRGroups << std::endl;
#endif
    size_t permValue =
        iterator ? iterator->value(state.permutation) : maxPermValue;
    if (state.numMatchedUserRGroups > d_current.numMatchedUserRGroups) {
      d_current = state;
      largestHeavyCounts = heavyCounts;
      maxPermValue = permValue;
    } else if (state.numMatchedUserRGroups == d_current.numMatchedUserRGroups) {
      if (state.numAddedRGroups < d_current.numAddedRGroups) {
        d_current = state;
        largestHeavyCounts = heavyCounts;
        maxPermValue = permValue;
      } else if (state.numAddedRGroups == d_current.numAddedRGroups) {
        if (heavyCounts > largestHeavyCounts) {
          d_current = state;
          largestHeavyCounts = heavyCounts;
          maxPermValue = permValue;
        } else if (heavyCounts == largestHeavyCounts) {
          if (permValue > maxPermValue) {
            d_current = state;
            largestHeavyCounts = heavyCounts;
            maxPermValue = permValue;
          }
        }
      }
    }
    checkForTimeout(t0, timeout);
    d_store.pop_front();
  }
  // convert back the heavy count vector into an ordered map
  // to store it in the saved cache
  d_current.heavyCountPerLabel.clear();
  auto count = largestHeavyCounts.begin();
  for (auto label : orderedLabels) {
    d_current.heavyCountPerLabel[label] = *count++;
  }
  d_saved = d_current;
}

void RGroupScorer::pushTieToStore(const std::vector<size_t> &permutation) {
  d_current.permutation = permutation;
  d_store.push_back(d_current);
}

void RGroupScorer::clearTieStore() { d_store.clear(); }

}  // namespace RDKit