File: Sample.java

package info (click to toggle)
bbmap 39.20%2Bdfsg-3
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 26,024 kB
  • sloc: java: 312,743; sh: 18,099; python: 5,247; ansic: 2,074; perl: 96; makefile: 39; xml: 38
file content (114 lines) | stat: -rwxr-xr-x 3,047 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
package ml;

import java.util.Arrays;

import structures.ByteBuilder;

public class Sample implements Comparable<Sample> {
	
	public Sample(float[] in_, float[] out_, int id_) {
		in=in_;
		goal=out_;
		result=new float[goal.length];
		id=id_;
		positive=(goal[0]>=0.5f);
	}
	
	@Override
	public int compareTo(Sample o) {
		final float a=pivot, b=o.pivot;
		return a>b ? -1 : b>a ? 1 : id-o.id;
	}
	
	public boolean checkPivot() {
		return pivot==calcPivot();
	}
	
	synchronized void setPivot() {
		pivot=calcPivot();
	}
	
	synchronized float calcPivot() {
		final float v=result[0];
		final boolean positiveError=v>goal[0];
		final boolean excess=(positiveError == positive);
		final float mult=(excess ? excessPivotMult*0.5f : 0.5f);
		return (errorMagnitude+weightedErrorMagnitude)*mult-epoch*EPOCH_MULT;
//		return (errorMagnitude+weightedErrorMagnitude)*0.5f-epoch*EPOCH_MULT;
	}
	
	public String toString() {
//		String s="S%d\t%s\t%s\tep=%d\tg=%4f\tr=%4f\tem=%6f\tev=%.6f\tpv=%.6f";
		String s="S%d\t%s\t%s\tep=%d\tg=%4f\tr=%4f\tem=%6f\tpv=%.6f";
		
		
		boolean gol=(goal[0]>=0.5f);
		boolean pred=(result[0]>=0.5f);
		String type=(gol && pred) ? "TP" : (!gol && !pred) ? "TN" : (!gol && pred) ? "FP" : (gol && !pred) ? "FN" : "??";
		String sign=(positive ? "+" : "-");

//		s=String.format(s, id, sign, type, epoch, goal[0], result[0], errorMagnitude, errorValue, calcPivot());
		s=String.format(s, id, sign, type, epoch, goal[0], result[0], errorMagnitude, calcPivot());
		return s+"\t"+Arrays.toString(in);
	}
	
	public ByteBuilder toBytes() {
		return toBytes(new ByteBuilder());
	}
	
	public ByteBuilder toBytes(ByteBuilder bb) {
		for(float f : in) {bb.append(f, 6).tab();}
		for(float f : goal) {bb.append(f, 6).tab();}
		bb.trimLast(1);
		bb.nl();
		return bb;
	}
	
//	synchronized boolean positive() {
//		return goal[0]>=0.5f;
//	}
	
	public void calcError(float weightMult){
		double error=0;
		for(int i=0; i<result.length; i++){
			float r=result[i];
			float g=goal[i];
			float e=calcError(g, r);
			error+=e;
		}
		errorMagnitude=(float)error;
		weightedErrorMagnitude=Cell.toWeightedError(error, result[0], goal[0], weightMult);
	}
	
	public synchronized int epoch() {return epoch;}
	public synchronized int lastTID() {return lastTID;}
	public synchronized void setEpoch(long x) {
		epoch=(int)x;
	}
	
	public synchronized void setLastTID(int x) {
		lastTID=x;
	}
	
	public static final float calcError(float goal, float pred) {
		float e=goal-pred;
		return 0.5f*e*e;
	}

	final boolean positive;
	float errorMagnitude=1;
	float weightedErrorMagnitude=1;
//	float errorValue=1;//Unused, commented for efficiency
	private int epoch=-1;
	private int lastTID=-1;
	float pivot=0;
	
	final float[] in;
	final float[] goal;
	final float[] result;//Can't be volatile
	final int id;

	//0.2f is good for binary classifiers
	public static float excessPivotMult=0.2f;
	public static final float EPOCH_MULT=1/256f;
}