1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139
|
package ml;
import java.util.ArrayList;
import structures.ByteBuilder;
public class Matrix {
/*--------------------------------------------------------------*/
/*---------------- Methods ----------------*/
/*--------------------------------------------------------------*/
void initializeRange() {
detectRange();
if(convertTo01) {
convertToZeroOne(outputMidpoint);
}
if(setTargetOutputRangeMin || setTargetOutputRangeMax) {
adjustRange();
}
}
/*--------------------------------------------------------------*/
void detectRange() {
outputMin=Float.MAX_VALUE;
outputMax=-Float.MAX_VALUE;
double sum=0;
long count=0;
for(float[] line : outputs){
for(float f : line){
outputMin=Math.min(f, outputMin);
outputMax=Math.max(f, outputMax);
sum+=f;
count++;
}
}
assert(outputMin<outputMax) : outputMin+", "+outputMax;
outputMean=(float)(sum/count);
outputRange=outputMax-outputMin;
outputMidpoint=outputMin+outputRange*0.5f;
}
void convertToZeroOne(final float cutoff) {
double sum=0;
long count=0;
for(float[] line : outputs){
for(int j=0; j<line.length; j++){
final float f=line[j]<cutoff ? 0 : 1;
line[j]=f;
sum+=f;
}
}
outputMin=0;
outputMax=1;
outputMean=(float)(sum/count);
outputRange=outputMax-outputMin;
outputMidpoint=outputMin+outputRange*0.5f;
}
void adjustRange() {
assert(setTargetOutputRangeMin || setTargetOutputRangeMax) : "Must set minoutput or maxoutput";
assert(outputMin<outputMax) : outputMin+", "+outputMax;
if(!setTargetOutputRangeMin){targetOutputRangeMin=outputMin;}
if(!setTargetOutputRangeMax){targetOutputRangeMax=outputMax;}
if(targetOutputRangeMin==outputMin && targetOutputRangeMax==outputMax) {return;}//Nothing to do
final float range2=targetOutputRangeMax-targetOutputRangeMin;
assert(range2!=outputRange);
final float mult=range2/outputRange;
double sum=0;
long count=0;
for(float[] line : outputs){
for(int i=0; i<line.length; i++){
float f=((line[i]-outputMin)*mult)+targetOutputRangeMin;
line[i]=f;
sum+=f;
count++;
}
}
outputMin=targetOutputRangeMin;
outputMax=targetOutputRangeMax;
outputMean=(float)(sum/count);
outputRange=outputMax-outputMin;
outputMidpoint=outputMin+outputRange*0.5f;
}
/*--------------------------------------------------------------*/
public String toString(){
ByteBuilder bb=new ByteBuilder();
bb.append(columns.toString()).nl();
bb.append("lines="+inputs.length).nl();
bb.append("inputs="+numInputs).nl();
bb.append("outputs="+numOutputs).nl();
bb.append("mean="+outputMean).nl();
bb.append("midpoint="+outputMidpoint).nl();
bb.append("range="+outputRange).nl();
bb.append("inputs="+numInputs).nl();
return bb.toString();
}
int numInputs() {return numInputs;}
int numOutputs() {return numOutputs;}
public float outputMidpoint() {return outputMidpoint;}
/*--------------------------------------------------------------*/
/*---------------- Fields ----------------*/
/*--------------------------------------------------------------*/
ArrayList<String> columns;
int[] dims;
int numInputs;
int numOutputs;
int numPositive=0, numNegative=0;
int validLines=0;
int invalidLines=0;
private float outputMin;
private float outputMax;
private float outputMean;
private float outputMidpoint;
private float outputRange;
float[][][] data;
float inputs[][], outputs[][];
/*--------------------------------------------------------------*/
/*---------------- Static Fields ----------------*/
static boolean convertTo01=false;
static float targetOutputRangeMin=0;
static float targetOutputRangeMax=0;
// static float outputRangeMidpoint=0;
static boolean setTargetOutputRangeMin=false;
static boolean setTargetOutputRangeMax=false;
}
|