1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143
|
#include "deep_wide_pt.h"
#include <torch/csrc/jit/serialization/import_source.h>
#include <torch/script.h>
namespace {
// No ReplaceNaN (this removes the constant in the model)
const std::string deep_wide_pt = R"JIT(
class DeepAndWide(Module):
__parameters__ = ["_mu", "_sigma", "_fc_w", "_fc_b", ]
__buffers__ = []
_mu : Tensor
_sigma : Tensor
_fc_w : Tensor
_fc_b : Tensor
training : bool
def forward(self: __torch__.DeepAndWide,
ad_emb_packed: Tensor,
user_emb: Tensor,
wide: Tensor) -> Tuple[Tensor]:
_0 = self._fc_b
_1 = self._fc_w
_2 = self._sigma
wide_offset = torch.add(wide, self._mu, alpha=1)
wide_normalized = torch.mul(wide_offset, _2)
wide_preproc = torch.clamp(wide_normalized, 0., 10.)
user_emb_t = torch.transpose(user_emb, 1, 2)
dp_unflatten = torch.bmm(ad_emb_packed, user_emb_t)
dp = torch.flatten(dp_unflatten, 1, -1)
input = torch.cat([dp, wide_preproc], 1)
fc1 = torch.addmm(_0, input, torch.t(_1), beta=1, alpha=1)
return (torch.sigmoid(fc1),)
)JIT";
const std::string trivial_model_1 = R"JIT(
def forward(self, a, b, c):
s = torch.tensor([[3, 3], [3, 3]])
return a + b * c + s
)JIT";
const std::string leaky_relu_model_const = R"JIT(
def forward(self, input):
x = torch.leaky_relu(input, 0.1)
x = torch.leaky_relu(x, 0.1)
x = torch.leaky_relu(x, 0.1)
x = torch.leaky_relu(x, 0.1)
return torch.leaky_relu(x, 0.1)
)JIT";
const std::string leaky_relu_model = R"JIT(
def forward(self, input, neg_slope):
x = torch.leaky_relu(input, neg_slope)
x = torch.leaky_relu(x, neg_slope)
x = torch.leaky_relu(x, neg_slope)
x = torch.leaky_relu(x, neg_slope)
return torch.leaky_relu(x, neg_slope)
)JIT";
void import_libs(
std::shared_ptr<at::CompilationUnit> cu,
const std::string& class_name,
const std::shared_ptr<torch::jit::Source>& src,
const std::vector<at::IValue>& tensor_table) {
torch::jit::SourceImporter si(
cu,
&tensor_table,
[&](const std::string& /* unused */)
-> std::shared_ptr<torch::jit::Source> { return src; },
/*version=*/2);
si.loadType(c10::QualifiedName(class_name));
}
} // namespace
torch::jit::Module getDeepAndWideSciptModel(int num_features) {
auto cu = std::make_shared<at::CompilationUnit>();
std::vector<at::IValue> constantTable;
import_libs(
cu,
"__torch__.DeepAndWide",
std::make_shared<torch::jit::Source>(deep_wide_pt),
constantTable);
c10::QualifiedName base("__torch__");
auto clstype = cu->get_class(c10::QualifiedName(base, "DeepAndWide"));
torch::jit::Module mod(cu, clstype);
mod.register_parameter("_mu", torch::randn({1, num_features}), false);
mod.register_parameter("_sigma", torch::randn({1, num_features}), false);
mod.register_parameter("_fc_w", torch::randn({1, num_features + 1}), false);
mod.register_parameter("_fc_b", torch::randn({1}), false);
// mod.dump(true, true, true);
return mod;
}
torch::jit::Module getTrivialScriptModel() {
torch::jit::Module module("m");
module.define(trivial_model_1);
return module;
}
torch::jit::Module getLeakyReLUScriptModel() {
torch::jit::Module module("leaky_relu");
module.define(leaky_relu_model);
return module;
}
torch::jit::Module getLeakyReLUConstScriptModel() {
torch::jit::Module module("leaky_relu_const");
module.define(leaky_relu_model_const);
return module;
}
const std::string long_model = R"JIT(
def forward(self, a, b, c):
d = torch.relu(a * b)
e = torch.relu(a * c)
f = torch.relu(e * d)
g = torch.relu(f * f)
h = torch.relu(g * c)
return h
)JIT";
torch::jit::Module getLongScriptModel() {
torch::jit::Module module("m");
module.define(long_model);
return module;
}
const std::string signed_log1p_model = R"JIT(
def forward(self, a):
b = torch.abs(a)
c = torch.log1p(b)
d = torch.sign(a)
e = d * c
return e
)JIT";
torch::jit::Module getSignedLog1pModel() {
torch::jit::Module module("signed_log1p");
module.define(signed_log1p_model);
return module;
}
|