File: Function.java

package info (click to toggle)
bbmap 39.20%2Bdfsg-3
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 26,024 kB
  • sloc: java: 312,743; sh: 18,099; python: 5,247; ansic: 2,074; perl: 96; makefile: 39; xml: 38
file content (145 lines) | stat: -rwxr-xr-x 4,362 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
package ml;

import java.util.Arrays;
import java.util.Random;

import shared.Tools;

public abstract class Function {
	
	/*--------------------------------------------------------------*/
	/*----------------           Methods            ----------------*/
	/*--------------------------------------------------------------*/

	public abstract double activate(double x);
	
	public abstract double derivativeX(double x);
	
	public abstract double derivativeFX(double fx);
	
	public abstract double derivativeXFX(double x, double fx);
	
	public abstract int type();
	
	public abstract String name();
	
	/*--------------------------------------------------------------*/
	/*----------------        Static Methods        ----------------*/
	/*--------------------------------------------------------------*/
	
	static final int toType(String b) {
		return toType(b, true);
	}

	static final int toType(String b, boolean assertValid) {
			int type;
			if(Tools.startsWithLetter(b)) {
				type=Tools.findIC(b, TYPES);
				if(type<0) {type=Tools.findIC(b, TYPES_LONG);}
			}else{
				type=Integer.parseInt(b);
				throw new RuntimeException("Numbers are not allowed for defining types: "+b);
			}
			assert(!assertValid || type>=0 && type<TYPES.length) : type;
	//		System.err.println(b+" -> "+type);
			return type;
		}

	public static synchronized final void normalizeTypeRates() {
		assert(TYPE_RATES_CUM==null);
		double sum=shared.Vector.sum(TYPE_RATES);
		assert(sum>=0) : sum;
		
		if(sum<=0) {
			TYPE_RATES_CUM=null;
			return;
		}
		if(Tools.absdif(sum, 1)>0.000001){
			final double mult=1.0/sum;
			for(int i=0; i<TYPE_RATES.length; i++) {
				double r=TYPE_RATES[i];
				assert(r>=0) : i+": "+r;
				TYPE_RATES[i]=(float)(r*mult);
			}
		}
		
		TYPE_RATES_CUM=new float[TYPE_RATES.length];
		double c=0;
		for(int i=0; i<TYPE_RATES.length; i++) {
			double r=TYPE_RATES[i];
			assert(r>=0) : i+": "+r;
			c+=r;
			TYPE_RATES_CUM[i]=(float)c;
		}
		assert(Tools.absdif(c, 1)<0.00001);
		TYPE_RATES_CUM[TYPE_RATES_CUM.length-1]=1;
	}
	
	/*--------------------------------------------------------------*/
	
	public static final Function getFunction(int type) {
		return functions[type];
	}
	
	private static final Function[] makeFunctions() {
		assert(functions==null);
		Function[] array=new Function[TYPES.length];
		array[SIG]=Sigmoid.instance;
		array[TANH]=Tanh.instance;
		array[RSLOG]=RSLog.instance;
		array[MSIG]=MSig.instance;
		array[SWISH]=Swish.instance;
		array[ESIG]=ExtendedSigmoid.instance;
		array[EMSIG]=ExtendedMSig.instance;
		array[BELL]=Bell.instance;
		for(int i=0; i<array.length; i++) {
			Function f=array[i];
			assert(f!=null) : i+", "+TYPES[i]+", "+f;
			assert(f.type()==i) : i+", "+TYPES[i]+", "+f;
			assert(f.name().equals(TYPES[i])) : i+", "+f;
		}
		return array;
	}
	
	/*--------------------------------------------------------------*/
	
	static final Function randomFunction(Random randy) {
		final int type=randomType(randy, TYPE_RATES_CUM);
		return functions[type];
	}
	
	static final int randomType(Random randy, float[] cumRate) {
		float f=randy.nextFloat();
		for(int i=0; i<cumRate.length; i++) {
			if(cumRate[i]>=f) {return i;}
		}
		assert(false) : f+", "+Arrays.toString(cumRate);
		return cumRate.length-1;
	}
	
	/*--------------------------------------------------------------*/
	/*----------------            Fields            ----------------*/
	/*--------------------------------------------------------------*/

	public static final int SIG=0, TANH=1, RSLOG=2, MSIG=3, SWISH=4, ESIG=5, EMSIG=6, BELL=7;
	
	static final String[] TYPES=new String[] {"SIG", "TANH", "RSLOG", "MSIG", "SWISH", "ESIG", "EMSIG", "BELL"};
	
	static final String[] TYPES_LONG=new String[] {"SIGMOID", "HYPERBOLICTANGENT", 
	"ROTATIONALLYSYMMETRICLOGARITHM", "MIRROREDSIGMOID", "SWISH",
	"EXTENDEDSIGMOID", "EXTENDEDMIRROREDSIGMOID", "GAUSSIAN"};
	
	private static final Function[] functions=makeFunctions();
	
	public static final float[] TYPE_RATES=new float[TYPES.length];
	//tanh=.4 sig=.6 msig=.02 rslog=.02 swish=0
	
	public static float[] TYPE_RATES_CUM=null;

	static {
		TYPE_RATES[TANH]=0.4f;
		TYPE_RATES[SIG]=0.6f;
		TYPE_RATES[MSIG]=0.02f;
		TYPE_RATES[RSLOG]=0.02f;
	}
}