File: FeatMapUtils.py

package info (click to toggle)
rdkit 202009.4-1
  • links: PTS, VCS
  • area: main
  • in suites: bullseye
  • size: 129,624 kB
  • sloc: cpp: 288,030; python: 75,571; java: 6,999; ansic: 5,481; sql: 1,968; yacc: 1,842; lex: 1,254; makefile: 572; javascript: 461; xml: 229; fortran: 183; sh: 134; cs: 93
file content (238 lines) | stat: -rw-r--r-- 7,261 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
# $Id$
#
# Copyright (C) 2006 Greg Landrum
#
#   @@ All Rights Reserved @@
#  This file is part of the RDKit.
#  The contents are covered by the terms of the BSD license
#  which is included in the file license.txt, found at the root
#  of the RDKit source tree.
#
import copy
from rdkit.Chem.FeatMaps import FeatMaps


class MergeMethod(object):
  # Put the new point at the weighted average position of the two fused points
  WeightedAverage = 0
  # Put the new point at the un-weighted average position of the two fused points
  Average = 1
  # Put the new point at the position of the larger (by weight) of the two points
  UseLarger = 2

  @classmethod
  def valid(cls, mergeMethod):
    """ Check that mergeMethod is valid """
    if mergeMethod not in (cls.WeightedAverage, cls.Average, cls.UseLarger):
      raise ValueError('unrecognized mergeMethod')


class MergeMetric(object):
  # Do not merge points
  NoMerge = 0
  # merge two points if they come within a threshold distance
  Distance = 1
  # merge two points if their percent overlap exceeds a threshold
  Overlap = 2

  @classmethod
  def valid(cls, mergeMetric):
    """ Check that mergeMetric is valid """
    if mergeMetric not in (cls.NoMerge, cls.Distance, cls.Overlap):
      raise ValueError('unrecognized mergeMetric')


class DirMergeMode(object):
  # Do not merge directions (i.e. keep all direction vectors)
  NoMerge = 0
  # Sum direction vectors
  Sum = 1

  @classmethod
  def valid(cls, dirMergeMode):
    """ Check that dirMergeMode is valid """
    if dirMergeMode not in (cls.NoMerge, cls.Sum):
      raise ValueError('unrecognized dirMergeMode')


def __copyAll(res, fm1, fm2):
  """ no user-serviceable parts inside """
  for feat in fm1.GetFeatures():
    res.AddFeatPoint(copy.deepcopy(feat))
  for feat in fm2.GetFeatures():
    res.AddFeatPoint(copy.deepcopy(feat))


def GetFeatFeatDistMatrix(fm, mergeMetric, mergeTol, dirMergeMode, compatFunc):
  """

    NOTE that mergeTol is a max value for merging when using distance-based
    merging and a min value when using score-based merging.

  """
  MergeMetric.valid(mergeMetric)

  dists = [[1e8] * fm.GetNumFeatures() for _ in range(fm.GetNumFeatures())]
  if mergeMetric == MergeMetric.NoMerge:
    return dists
  elif mergeMetric == MergeMetric.Distance:
    mergeTol2 = mergeTol * mergeTol
    for i in range(fm.GetNumFeatures()):
      ptI = fm.GetFeature(i)
      for j in range(i + 1, fm.GetNumFeatures()):
        ptJ = fm.GetFeature(j)
        if compatFunc(ptI, ptJ):
          dist2 = ptI.GetDist2(ptJ)
          if dist2 < mergeTol2:
            dists[i][j] = dist2
            dists[j][i] = dist2
  elif mergeMetric == MergeMetric.Overlap:
    for i in range(fm.GetNumFeatures()):
      ptI = fm.GetFeature(i)
      for j in range(i + 1, fm.GetNumFeatures()):
        ptJ = fm.GetFeature(j)
        if compatFunc(ptI, ptJ):
          score = fm.GetFeatFeatScore(ptI, ptJ, typeMatch=False)
          score *= -1 * ptJ.weight
          if score < mergeTol:
            dists[i][j] = score
            dists[j][i] = score
  return dists


def familiesMatch(f1, f2):
  return f1.GetFamily() == f2.GetFamily()


def feq(v1, v2, tol=1e-4):
  return abs(v1 - v2) < tol


def MergeFeatPoints(fm, mergeMetric=MergeMetric.NoMerge, mergeTol=1.5,
                    dirMergeMode=DirMergeMode.NoMerge, mergeMethod=MergeMethod.WeightedAverage,
                    compatFunc=familiesMatch):
  """

    NOTE that mergeTol is a max value for merging when using distance-based
    merging and a min value when using score-based merging.

    returns whether or not any points were actually merged

  """
  MergeMetric.valid(mergeMetric)
  MergeMethod.valid(mergeMethod)
  DirMergeMode.valid(dirMergeMode)

  res = False
  if mergeMetric == MergeMetric.NoMerge:
    return res
  dists = GetFeatFeatDistMatrix(fm, mergeMetric, mergeTol, dirMergeMode, compatFunc)
  distOrders = [None] * len(dists)
  for i in range(len(dists)):
    distV = dists[i]
    distOrders[i] = []
    for j, dist in enumerate(distV):
      if dist < mergeTol:
        distOrders[i].append((dist, j))
    distOrders[i].sort()

  # print('distOrders:')
  # print(distOrders)

  # we now know the "distances" and have rank-ordered list of
  # each point's neighbors. Work with that.

  # progressively merge nearest neighbors until there
  # are no more points left to merge
  featsInPlay = list(range(fm.GetNumFeatures()))
  featsToRemove = []
  # print '--------------------------------'
  while featsInPlay:
    # find two features who are mutual nearest neighbors:
    fipCopy = featsInPlay[:]
    for fi in fipCopy:
      # print('>>>',fi,fipCopy,featsInPlay)
      # print('\t',distOrders[fi])
      mergeThem = False
      if not distOrders[fi]:
        featsInPlay.remove(fi)
        continue
      dist, nbr = distOrders[fi][0]
      if nbr not in featsInPlay:
        continue
      if distOrders[nbr][0][1] == fi:
        # print 'direct:',fi,nbr
        mergeThem = True
      else:
        # it may be that there are several points at about the same distance,
        # check for that now
        if (feq(distOrders[nbr][0][0], dist)):
          for distJ, nbrJ in distOrders[nbr][1:]:
            if feq(dist, distJ):
              if nbrJ == fi:
                # print 'indirect: ',fi,nbr
                mergeThem = True
                break
            else:
              break
      # print '    bottom:',mergeThem
      if mergeThem:
        break
    if mergeThem:
      res = True
      featI = fm.GetFeature(fi)
      nbrFeat = fm.GetFeature(nbr)

      if mergeMethod == MergeMethod.WeightedAverage:
        newPos = featI.GetPos() * featI.weight + nbrFeat.GetPos() * nbrFeat.weight
        newPos /= (featI.weight + nbrFeat.weight)
        newWeight = (featI.weight + nbrFeat.weight) / 2
      elif mergeMethod == MergeMethod.Average:
        newPos = featI.GetPos() + nbrFeat.GetPos()
        newPos /= 2
        newWeight = (featI.weight + nbrFeat.weight) / 2
      elif mergeMethod == MergeMethod.UseLarger:
        if featI.weight > nbrFeat.weight:
          newPos = featI.GetPos()
          newWeight = featI.weight
        else:
          newPos = nbrFeat.GetPos()
          newWeight = nbrFeat.weight

      featI.SetPos(newPos)
      featI.weight = newWeight

      # nbr and fi are no longer valid targets:
      # print 'nbr done:',nbr,featsToRemove,featsInPlay
      featsToRemove.append(nbr)
      featsInPlay.remove(fi)
      featsInPlay.remove(nbr)
      for nbrList in distOrders:
        try:
          nbrList.remove(fi)
        except ValueError:
          pass
        try:
          nbrList.remove(nbr)
        except ValueError:
          pass
    else:
      # print ">>>> Nothing found, abort"
      break
  featsToRemove.sort()
  for i, fIdx in enumerate(featsToRemove):
    fm.DropFeature(fIdx - i)
  return res


def CombineFeatMaps(fm1, fm2, mergeMetric=MergeMetric.NoMerge, mergeTol=1.5,
                    dirMergeMode=DirMergeMode.NoMerge):
  """
     the parameters will be taken from fm1
  """
  res = FeatMaps.FeatMap(params=fm1.params)

  __copyAll(res, fm1, fm2)
  if mergeMetric != MergeMetric.NoMerge:
    MergeFeatPoints(res, mergeMetric=mergeMetric, mergeTol=mergeTol)
  return res