File: CrossValidate.py

package info (click to toggle)
rdkit 201203-3
  • links: PTS, VCS
  • area: main
  • in suites: wheezy
  • size: 37,840 kB
  • sloc: cpp: 93,902; python: 51,897; java: 5,192; ansic: 3,497; xml: 2,499; sql: 1,641; yacc: 1,518; lex: 1,076; makefile: 325; fortran: 183; sh: 153; cs: 51
file content (82 lines) | stat: -rw-r--r-- 2,956 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
# $Id: CrossValidate.py 997 2009-02-25 06:12:43Z glandrum $
#
#  Copyright (C) 2004-2005 Rational Discovery LLC.
#   All Rights Reserved
#
""" handles doing cross validation with naive bayes models
and evaluation of individual models

"""
from rdkit.ML.NaiveBayes.ClassificationModel import NaiveBayesClassifier
from rdkit.ML.Data import SplitData
try:
  from rdkit.ML.FeatureSelect import CMIM
except ImportError:
  CMIM=None

def makeNBClassificationModel(trainExamples, attrs, nPossibleValues, nQuantBounds,
                              mEstimateVal=-1.0,
                              useSigs=False,
                              ensemble=None,useCMIM=0,
                              **kwargs) :
  if CMIM is not None and useCMIM > 0 and useSigs and not ensemble:
    ensemble = CMIM.SelectFeatures(trainExamples,useCMIM,bvCol=1)
  if ensemble:
    attrs = ensemble
  model = NaiveBayesClassifier(attrs, nPossibleValues, nQuantBounds, 
                               mEstimateVal=mEstimateVal,useSigs=useSigs)


  model.SetTrainingExamples(trainExamples)
  model.trainModel()
  return model
    
def CrossValidate(NBmodel, testExamples, appendExamples=0) :
    
  nTest = len(testExamples)
  assert nTest,'no test examples: %s'%str(testExamples)
  badExamples = []
  nBad = 0
  preds = NBmodel.ClassifyExamples(testExamples, appendExamples)
  assert len(preds) == nTest

  for i in range(nTest):
    testEg = testExamples[i]
    trueRes = testEg[-1]
    res = preds[i]

    if (trueRes != res) :
      badExamples.append(testEg)
      nBad += 1
  return float(nBad)/nTest, badExamples

def CrossValidationDriver(examples, attrs, nPossibleValues, nQuantBounds,
                          mEstimateVal=0.0,
                          holdOutFrac=0.3, modelBuilder=makeNBClassificationModel,
                          silent=0, calcTotalError=0, **kwargs) :
  nTot = len(examples)
  if not kwargs.get('replacementSelection',0):
    testIndices,trainIndices = SplitData.SplitIndices(nTot,holdOutFrac,
                                                      silent=1,legacy=1,
                                                      replacement=0)
  else :
    testIndices,trainIndices = SplitData.SplitIndices(nTot,holdOutFrac,
                                                      silent=1,legacy=0,
                                                      replacement=1)

  trainExamples = [examples[x] for x in trainIndices]
  testExamples = [examples[x] for x in testIndices]

  NBmodel = modelBuilder(trainExamples, attrs, nPossibleValues, nQuantBounds,
                         mEstimateVal,**kwargs)

  if not calcTotalError:                # 
    xValError, badExamples = CrossValidate(NBmodel, testExamples,appendExamples=1)
  else:
    xValError,badExamples = CrossValidate(NBmodel, examples,appendExamples=0)

  if not silent:
    print 'Validation error was %%%4.2f'%(100*xValError)
  NBmodel._trainIndices = trainIndices
  return NBmodel, xValError