1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159
|
package com.wcohen.ss;
import java.util.*;
import com.wcohen.ss.tokens.*;
import com.wcohen.ss.api.*;
/**
* TFIDF-based distance metric.
*/
public class TFIDF extends AbstractStatisticalTokenDistance
{
private UnitVector lastVector = null;
public TFIDF(Tokenizer tokenizer) { super(tokenizer); }
public TFIDF() { super(); }
public double score(StringWrapper s,StringWrapper t) {
checkTrainingHasHappened(s,t);
UnitVector sBag = asUnitVector(s);
UnitVector tBag = asUnitVector(t);
double sim = 0.0;
for (Iterator i = sBag.tokenIterator(); i.hasNext(); ) {
Token tok = (Token)i.next();
if (tBag.contains(tok)) {
sim += sBag.getWeight(tok) * tBag.getWeight(tok);
}
}
return sim;
}
protected UnitVector asUnitVector(StringWrapper w) {
if (w instanceof UnitVector) return (UnitVector)w;
else if (w instanceof BagOfTokens) return new UnitVector((BagOfTokens)w);
else return new UnitVector(w.unwrap(),tokenizer.tokenize(w.unwrap()));
}
/** Preprocess a string by finding tokens and giving them TFIDF weights */
public StringWrapper prepare(String s) {
lastVector = new UnitVector(s, tokenizer.tokenize(s));
return lastVector;
}
//
// some special methods added mostly for SoftTFIDFDictionary
//
/** Access the tokens of the last prepare()-ed string. */
public Token[] getTokens() { return lastVector.getTokens(); }
/** Access the weight of a token in the vector created for the last prepare()-ed string. */
public double getWeight(Token token) { return lastVector.getWeight(token); }
/** Get the document frequency of the token. */
public int getDocumentFrequency(Token token)
{
Integer df = (Integer)documentFrequency.get(token);
if (df == null) return 0;
else return df.intValue();
}
/** Set the document frequency of the token to some value.
* Setting the collectionSize and also setting the document
* frequency of every token is an alternative to explicit
* training.
*/
public void setDocumentFrequency(Token token, int df)
{
documentFrequency.put(token,new Integer(df));
}
public void setTokenCount(int tc) {
this.totalTokenCount = tc;
}
/* Return the size of the collection that this TFIDF measure was
* trained on to some value. */
public int getCollectionSize()
{
return collectionSize;
}
/** Setting the collectionSize and alsoSet the size of the collection that this TFIDF measure was
* trained on to some value.
* setting the document frequency of every token is an alternative
* to explicit training.
*/
public void setCollectionSize(int n)
{
collectionSize=n;
}
public int getVocabularySize() { return documentFrequency.size(); }
/** Marker class extending BagOfTokens */
protected class UnitVector extends BagOfTokens
{
public UnitVector(String s,Token[] tokens) {
super(s,tokens);
termFreq2TFIDF();
}
public UnitVector(BagOfTokens bag) {
this(bag.unwrap(), bag.getTokens());
termFreq2TFIDF();
}
/** convert term frequency weights to unit-length TFIDF weights */
private void termFreq2TFIDF() {
double normalizer = 0.0;
for (Iterator i=tokenIterator(); i.hasNext(); ) {
Token tok = (Token)i.next();
if (collectionSize>0) {
Integer dfInteger = (Integer)documentFrequency.get(tok);
// set previously unknown words to df==1, which gives them a high value
double df = dfInteger==null ? 1.0 : dfInteger.intValue();
double w = Math.log( getWeight(tok) + 1) * Math.log( collectionSize/df );
setWeight( tok, w );
normalizer += w*w;
} else {
setWeight( tok, 1.0 );
normalizer += 1.0;
}
}
normalizer = Math.sqrt(normalizer);
for (Iterator i=tokenIterator(); i.hasNext(); ) {
Token tok = (Token)i.next();
setWeight( tok, getWeight(tok)/normalizer );
}
}
}
/** Explain how the distance was computed.
* In the output, the tokens in S and T are listed, and the
* common tokens are marked with an asterisk.
*/
public String explainScore(StringWrapper s, StringWrapper t)
{
BagOfTokens sBag = (BagOfTokens)s;
BagOfTokens tBag = (BagOfTokens)t;
StringBuffer buf = new StringBuffer("");
PrintfFormat fmt = new PrintfFormat("%.3f");
buf.append("Common tokens: ");
for (Iterator i = sBag.tokenIterator(); i.hasNext(); ) {
Token tok = (Token)i.next();
if (tBag.contains(tok)) {
buf.append(" "+tok.getValue()+": ");
buf.append(fmt.sprintf(sBag.getWeight(tok)));
buf.append("*");
buf.append(fmt.sprintf(tBag.getWeight(tok)));
}
}
buf.append("\nscore = "+score(s,t));
return buf.toString();
}
public String toString() { return "[TFIDF]"; }
static public void main(String[] argv) {
doMain(new TFIDF(), argv);
}
}
|