1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170
|
package com.wcohen.ss;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import com.wcohen.ss.api.StringDistance;
import com.wcohen.ss.api.StringWrapper;
import com.wcohen.ss.api.Token;
import com.wcohen.ss.api.Tokenizer;
/**
* TFIDF-based distance metric, extended to use "soft" token-matching. Specifically, tokens are considered a partial
* match if they get a good score using an inner string comparator.
*
* <p>
* On the WHIRL datasets, thresholding JaroWinkler at 0.9 or 0.95 seems to be about right.
*/
public class SoftTFIDF extends TFIDF {
// distance to use to compare tokens
private StringDistance tokenDistance;
// threshold beyond which tokens are considered a match
private double tokenMatchThreshold;
// default token distance
private static final StringDistance DEFAULT_TOKEN_DISTANCE = new JaroWinkler();
public SoftTFIDF(Tokenizer tokenizer, StringDistance tokenDistance, double tokenMatchThreshold) {
super(tokenizer);
this.tokenDistance = tokenDistance;
this.tokenMatchThreshold = tokenMatchThreshold;
}
public SoftTFIDF(StringDistance tokenDistance, double tokenMatchThreshold) {
super();
this.tokenDistance = tokenDistance;
this.tokenMatchThreshold = tokenMatchThreshold;
}
public SoftTFIDF(StringDistance tokenDistance) {
this(tokenDistance, 0.9);
}
public SoftTFIDF() {
this(new JaroWinkler(), 0.9);
}
public void setTokenMatchThreshold(double d) {
tokenMatchThreshold = d;
}
public void setTokenMatchThreshold(Double d) {
tokenMatchThreshold = d.doubleValue();
}
public double getTokenMatchThreshold() {
return tokenMatchThreshold;
}
private class Similarity implements Comparable<Similarity> {
int r1;
int r2;
double sim;
public Similarity(int r1, int r2, double sim) {
this.r1 = r1;
this.r2 = r2;
this.sim = sim;
}
public int compareTo(Similarity o) {
if (sim > o.sim)
return 1;
else if (sim < o.sim)
return -1;
return 0;
}
}
public double score(StringWrapper s, StringWrapper t) {
checkTrainingHasHappened(s, t);
UnitVector sBag = asUnitVector(s);
UnitVector tBag = asUnitVector(t);
List<Similarity> similarities = new ArrayList<Similarity>(sBag.size());
double sim = 0.0;
int i = 0;
for (Iterator<Token> ti = sBag.tokenIterator(); ti.hasNext(); i++) {
Token tok = ti.next();
int j = 0;
for (Iterator<Token> tj = tBag.tokenIterator(); tj.hasNext(); j++) {
Token tokJ = tj.next();
double distItoJ = tokenDistance.score(tok.getValue(), tokJ.getValue());
if (distItoJ >= tokenMatchThreshold) {
similarities.add(new Similarity(i, j, distItoJ * sBag.getWeight(tok) * tBag.getWeight(tokJ)));
}
}
}
/*
* This could be O(sBag.size() * tBag.size() * log (sBag.size() * tBag.size())) in the worst case but usually
* the threshold will make it much better and likely less than O(sBag.size() * tBag.size()) which is the current
* complexity.
*/
Collections.sort(similarities, Collections.reverseOrder());
boolean[] sUsed = new boolean[sBag.size()];
boolean[] tUsed = new boolean[tBag.size()];
// enforce that each word is only used for one similarity, to make sure normalization works
for (int k = 0; k < similarities.size(); k++) {
Similarity similarity = similarities.get(k);
if (sUsed[similarity.r1] || tUsed[similarity.r2])
continue;
sim += similarity.sim;
sUsed[similarity.r1] = true;
tUsed[similarity.r2] = true;
}
return sim;
}
/**
* Explain how the distance was computed. In the output, the tokens in S and T are listed, and the common tokens are
* marked with an asterisk.
*/
public String explainScore(StringWrapper s, StringWrapper t) {
BagOfTokens sBag = (BagOfTokens) s;
BagOfTokens tBag = (BagOfTokens) t;
StringBuilder buf = new StringBuilder("");
PrintfFormat fmt = new PrintfFormat("%.3f");
buf.append("Common tokens: ");
for (Iterator<Token> i = sBag.tokenIterator(); i.hasNext();) {
Token tok = i.next();
if (tBag.contains(tok)) {
buf.append(" " + tok.getValue() + ": ");
buf.append(fmt.sprintf(sBag.getWeight(tok)));
buf.append("*");
buf.append(fmt.sprintf(tBag.getWeight(tok)));
} else {
// find best matching token
double matchScore = tokenMatchThreshold;
Token matchTok = null;
for (Iterator<Token> j = tBag.tokenIterator(); j.hasNext();) {
Token tokJ = j.next();
double distItoJ = tokenDistance.score(tok.getValue(), tokJ.getValue());
if (distItoJ >= matchScore) {
matchTok = tokJ;
matchScore = distItoJ;
}
}
if (matchTok != null) {
buf.append(" '" + tok.getValue() + "'~='" + matchTok.getValue() + "': ");
buf.append(fmt.sprintf(sBag.getWeight(tok)));
buf.append("*");
buf.append(fmt.sprintf(tBag.getWeight(matchTok)));
buf.append("*");
buf.append(fmt.sprintf(matchScore));
}
}
}
buf.append("\nscore = " + score(s, t));
return buf.toString();
}
public String toString() {
return "[SoftTFIDF thresh=" + tokenMatchThreshold + ";" + tokenDistance + "]";
}
}
|