1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145
|
package ml;
import java.util.Arrays;
import java.util.Random;
import shared.Tools;
public abstract class Function {
/*--------------------------------------------------------------*/
/*---------------- Methods ----------------*/
/*--------------------------------------------------------------*/
public abstract double activate(double x);
public abstract double derivativeX(double x);
public abstract double derivativeFX(double fx);
public abstract double derivativeXFX(double x, double fx);
public abstract int type();
public abstract String name();
/*--------------------------------------------------------------*/
/*---------------- Static Methods ----------------*/
/*--------------------------------------------------------------*/
static final int toType(String b) {
return toType(b, true);
}
static final int toType(String b, boolean assertValid) {
int type;
if(Tools.startsWithLetter(b)) {
type=Tools.findIC(b, TYPES);
if(type<0) {type=Tools.findIC(b, TYPES_LONG);}
}else{
type=Integer.parseInt(b);
throw new RuntimeException("Numbers are not allowed for defining types: "+b);
}
assert(!assertValid || type>=0 && type<TYPES.length) : type;
// System.err.println(b+" -> "+type);
return type;
}
public static synchronized final void normalizeTypeRates() {
assert(TYPE_RATES_CUM==null);
double sum=shared.Vector.sum(TYPE_RATES);
assert(sum>=0) : sum;
if(sum<=0) {
TYPE_RATES_CUM=null;
return;
}
if(Tools.absdif(sum, 1)>0.000001){
final double mult=1.0/sum;
for(int i=0; i<TYPE_RATES.length; i++) {
double r=TYPE_RATES[i];
assert(r>=0) : i+": "+r;
TYPE_RATES[i]=(float)(r*mult);
}
}
TYPE_RATES_CUM=new float[TYPE_RATES.length];
double c=0;
for(int i=0; i<TYPE_RATES.length; i++) {
double r=TYPE_RATES[i];
assert(r>=0) : i+": "+r;
c+=r;
TYPE_RATES_CUM[i]=(float)c;
}
assert(Tools.absdif(c, 1)<0.00001);
TYPE_RATES_CUM[TYPE_RATES_CUM.length-1]=1;
}
/*--------------------------------------------------------------*/
public static final Function getFunction(int type) {
return functions[type];
}
private static final Function[] makeFunctions() {
assert(functions==null);
Function[] array=new Function[TYPES.length];
array[SIG]=Sigmoid.instance;
array[TANH]=Tanh.instance;
array[RSLOG]=RSLog.instance;
array[MSIG]=MSig.instance;
array[SWISH]=Swish.instance;
array[ESIG]=ExtendedSigmoid.instance;
array[EMSIG]=ExtendedMSig.instance;
array[BELL]=Bell.instance;
for(int i=0; i<array.length; i++) {
Function f=array[i];
assert(f!=null) : i+", "+TYPES[i]+", "+f;
assert(f.type()==i) : i+", "+TYPES[i]+", "+f;
assert(f.name().equals(TYPES[i])) : i+", "+f;
}
return array;
}
/*--------------------------------------------------------------*/
static final Function randomFunction(Random randy) {
final int type=randomType(randy, TYPE_RATES_CUM);
return functions[type];
}
static final int randomType(Random randy, float[] cumRate) {
float f=randy.nextFloat();
for(int i=0; i<cumRate.length; i++) {
if(cumRate[i]>=f) {return i;}
}
assert(false) : f+", "+Arrays.toString(cumRate);
return cumRate.length-1;
}
/*--------------------------------------------------------------*/
/*---------------- Fields ----------------*/
/*--------------------------------------------------------------*/
public static final int SIG=0, TANH=1, RSLOG=2, MSIG=3, SWISH=4, ESIG=5, EMSIG=6, BELL=7;
static final String[] TYPES=new String[] {"SIG", "TANH", "RSLOG", "MSIG", "SWISH", "ESIG", "EMSIG", "BELL"};
static final String[] TYPES_LONG=new String[] {"SIGMOID", "HYPERBOLICTANGENT",
"ROTATIONALLYSYMMETRICLOGARITHM", "MIRROREDSIGMOID", "SWISH",
"EXTENDEDSIGMOID", "EXTENDEDMIRROREDSIGMOID", "GAUSSIAN"};
private static final Function[] functions=makeFunctions();
public static final float[] TYPE_RATES=new float[TYPES.length];
//tanh=.4 sig=.6 msig=.02 rslog=.02 swish=0
public static float[] TYPE_RATES_CUM=null;
static {
TYPE_RATES[TANH]=0.4f;
TYPE_RATES[SIG]=0.6f;
TYPE_RATES[MSIG]=0.02f;
TYPE_RATES[RSLOG]=0.02f;
}
}
|