1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145
|
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
#
# Copyright (c) 2018-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import numpy as np
import argparse
from utils import *
import sys
parser = argparse.ArgumentParser(description='RCSLS for supervised word alignment')
parser.add_argument("--src_emb", type=str, default='', help="Load source embeddings")
parser.add_argument("--tgt_emb", type=str, default='', help="Load target embeddings")
parser.add_argument('--center', action='store_true', help='whether to center embeddings or not')
parser.add_argument("--dico_train", type=str, default='', help="train dictionary")
parser.add_argument("--dico_test", type=str, default='', help="validation dictionary")
parser.add_argument("--output", type=str, default='', help="where to save aligned embeddings")
parser.add_argument("--knn", type=int, default=10, help="number of nearest neighbors in RCSL/CSLS")
parser.add_argument("--maxneg", type=int, default=200000, help="Maximum number of negatives for the Extended RCSLS")
parser.add_argument("--maxsup", type=int, default=-1, help="Maximum number of training examples")
parser.add_argument("--maxload", type=int, default=200000, help="Maximum number of loaded vectors")
parser.add_argument("--model", type=str, default="none", help="Set of constraints: spectral or none")
parser.add_argument("--reg", type=float, default=0.0 , help='regularization parameters')
parser.add_argument("--lr", type=float, default=1.0, help='learning rate')
parser.add_argument("--niter", type=int, default=10, help='number of iterations')
parser.add_argument('--sgd', action='store_true', help='use sgd')
parser.add_argument("--batchsize", type=int, default=10000, help="batch size for sgd")
params = parser.parse_args()
###### SPECIFIC FUNCTIONS ######
# functions specific to RCSLS
# the rest of the functions are in utils.py
def getknn(sc, x, y, k=10):
sidx = np.argpartition(sc, -k, axis=1)[:, -k:]
ytopk = y[sidx.flatten(), :]
ytopk = ytopk.reshape(sidx.shape[0], sidx.shape[1], y.shape[1])
f = np.sum(sc[np.arange(sc.shape[0])[:, None], sidx])
df = np.dot(ytopk.sum(1).T, x)
return f / k, df / k
def rcsls(X_src, Y_tgt, Z_src, Z_tgt, R, knn=10):
X_trans = np.dot(X_src, R.T)
f = 2 * np.sum(X_trans * Y_tgt)
df = 2 * np.dot(Y_tgt.T, X_src)
fk0, dfk0 = getknn(np.dot(X_trans, Z_tgt.T), X_src, Z_tgt, knn)
fk1, dfk1 = getknn(np.dot(np.dot(Z_src, R.T), Y_tgt.T).T, Y_tgt, Z_src, knn)
f = f - fk0 -fk1
df = df - dfk0 - dfk1.T
return -f / X_src.shape[0], -df / X_src.shape[0]
def proj_spectral(R):
U, s, V = np.linalg.svd(R)
s[s > 1] = 1
s[s < 0] = 0
return np.dot(U, np.dot(np.diag(s), V))
###### MAIN ######
# load word embeddings
words_tgt, x_tgt = load_vectors(params.tgt_emb, maxload=params.maxload, center=params.center)
words_src, x_src = load_vectors(params.src_emb, maxload=params.maxload, center=params.center)
# load validation bilingual lexicon
src2tgt, lexicon_size = load_lexicon(params.dico_test, words_src, words_tgt)
# word --> vector indices
idx_src = idx(words_src)
idx_tgt = idx(words_tgt)
# load train bilingual lexicon
pairs = load_pairs(params.dico_train, idx_src, idx_tgt)
if params.maxsup > 0 and params.maxsup < len(pairs):
pairs = pairs[:params.maxsup]
# selecting training vector pairs
X_src, Y_tgt = select_vectors_from_pairs(x_src, x_tgt, pairs)
# adding negatives for RCSLS
Z_src = x_src[:params.maxneg, :]
Z_tgt = x_tgt[:params.maxneg, :]
# initialization:
R = procrustes(X_src, Y_tgt)
nnacc = compute_nn_accuracy(np.dot(x_src, R.T), x_tgt, src2tgt, lexicon_size=lexicon_size)
print("[init -- Procrustes] NN: %.4f"%(nnacc))
sys.stdout.flush()
# optimization
fold, Rold = 0, []
niter, lr = params.niter, params.lr
for it in range(0, niter + 1):
if lr < 1e-4:
break
if params.sgd:
indices = np.random.choice(X_src.shape[0], size=params.batchsize, replace=False)
f, df = rcsls(X_src[indices, :], Y_tgt[indices, :], Z_src, Z_tgt, R, params.knn)
else:
f, df = rcsls(X_src, Y_tgt, Z_src, Z_tgt, R, params.knn)
if params.reg > 0:
R *= (1 - lr * params.reg)
R -= lr * df
if params.model == "spectral":
R = proj_spectral(R)
print("[it=%d] f = %.4f" % (it, f))
sys.stdout.flush()
if f > fold and it > 0 and not params.sgd:
lr /= 2
f, R = fold, Rold
fold, Rold = f, R
if (it > 0 and it % 10 == 0) or it == niter:
nnacc = compute_nn_accuracy(np.dot(x_src, R.T), x_tgt, src2tgt, lexicon_size=lexicon_size)
print("[it=%d] NN = %.4f - Coverage = %.4f" % (it, nnacc, len(src2tgt) / lexicon_size))
nnacc = compute_nn_accuracy(np.dot(x_src, R.T), x_tgt, src2tgt, lexicon_size=lexicon_size)
print("[final] NN = %.4f - Coverage = %.4f" % (nnacc, len(src2tgt) / lexicon_size))
if params.output != "":
print("Saving all aligned vectors at %s" % params.output)
words_full, x_full = load_vectors(params.src_emb, maxload=-1, center=params.center, verbose=False)
x = np.dot(x_full, R.T)
x /= np.linalg.norm(x, axis=1)[:, np.newaxis] + 1e-8
save_vectors(params.output, x, words_full)
save_matrix(params.output + "-mat", R)
|