1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134
|
# rbf2.py
# tilde
# 2006/08/20
import numpy as N
import random
from scipy.optimize import leastsq
class rbf:
"""Class to define/train/test a radial basis function network
"""
_type = 'rbf'
_outfxns = ('linear','logistic','softmax')
def __init__(self,ni,no,f='linear'):
""" Set up instance of RBF net. N.B. RBF centers and variance are selected at training time
Input:
ni - <int> # of inputs
no - <int> # of outputs
f - <str> output activation fxn
"""
self.ni = ni
self.no = no
self.outfxn = f
def unpack(self):
""" Decompose 1-d vector of weights w into appropriate weight
matrices (self.{w/b}) and reinsert them into net
"""
self.w = N.array(self.wp)[:self.centers.shape[0]*self.no].reshape(self.centers.shape[0],self.no)
self.b = N.array(self.wp)[(self.centers.shape[0]*self.no):].reshape(1,self.no)
def pack(self):
""" Compile weight matrices w,b from net into a
single vector, suitable for optimization routines.
"""
self.wp = N.hstack([self.w.reshape(N.size(self.w)),
self.b.reshape(N.size(self.b))])
def fwd_all(self,X,w=None):
""" Propagate values forward through the net.
Inputs:
inputs - vector of input values
w - packed array of weights
Returns:
array of outputs for all input patterns
"""
if w is not None:
self.wp = w
self.unpack()
# compute hidden unit values
z = N.zeros((len(X),self.centers.shape[0]))
for i in range(len(X)):
z[i] = N.exp((-1.0/(2*self.variance))*(N.sum((X[i]-self.centers)**2,axis=1)))
# compute net outputs
o = N.dot(z,self.w) + N.dot(N.ones((len(z),1)),self.b)
# compute final output activations
if self.outfxn == 'linear':
y = o
elif self.outfxn == 'logistic': # TODO: check for overflow here...
y = 1/(1+N.exp(-o))
elif self.outfxn == 'softmax': # TODO: and here...
tmp = N.exp(o)
y = tmp/(N.sum(temp,1)*N.ones((1,self.no)))
return N.array(y)
def err_fxn(self,w,X,Y):
""" Return vector of squared-errors for the leastsq optimizer
"""
O = self.fwd_all(X,w)
return N.sum(N.array(O-Y)**2,axis=1)
def train(self,X,Y):
""" Train RBF network:
(i) select fixed centers randomly from input data (10%)
(ii) set fixed variance from max dist between centers
(iii) learn output weights using scipy's leastsq optimizer
"""
# set centers & variance
self.centers = N.array(random.sample(X,len(X)/10))
d_max = 0.0
for i in self.centers:
for j in self.centers:
tmp = N.sum(N.sqrt((i-j)**2))
if tmp > d_max:
d_max = tmp
self.variance = d_max/(2.0*len(X))
# train weights
self.nw = self.centers.shape[0]*self.no
self.w = N.random.randn(self.centers.shape[0],self.no)/N.sqrt(self.centers.shape[0]+1)
self.b = N.random.randn(1,self.no)/N.sqrt(self.centers.shape[0]+1)
self.pack()
self.wp = leastsq(self.err_fxn,self.wp,args=(X,Y))[0]
def test_all(self,X,Y):
""" Test network on an array (size>1) of patterns
Input:
x - array of input data
t - array of targets
Returns:
sum-squared-error over all data
"""
return N.sum(self.err_fxn(self.wp,X,Y))
def main():
""" Build/train/test RBF net
"""
from scipy.io import read_array
print "\nCreating RBF net"
net = rbf(12,2)
print "\nLoading training and test sets...",
X_trn = read_array('data/oil-trn.dat',columns=(0,(1,12)),lines=(3,-1))
Y_trn = read_array('data/oil-trn.dat',columns=(12,-1),lines=(3,-1))
X_tst = read_array('data/oil-tst.dat',columns=(0,(1,12)),lines=(3,-1))
Y_tst = read_array('data/oil-tst.dat',columns=(12,-1),lines=(3,-1))
print "done."
#print "\nInitial SSE:\n"
#print "\ttraining set: ",net.test_all(X_trn,Y_trn)
#print "\ttesting set: ",net.test_all(X_tst,Y_tst),"\n"
print "Training...",
net.train(X_trn,Y_trn)
print "done."
print "\nFinal SSE:\n"
print "\ttraining set: ",net.test_all(X_trn,Y_trn)
print "\ttesting set: ",net.test_all(X_tst,Y_tst),"\n"
if __name__ == '__main__':
main()
|