1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100
|
#pragma once
#include <torch/csrc/jit/codegen/fuser/cuda/resource_strings.h>
#include <torch/csrc/jit/tensorexpr/cuda_codegen.h>
namespace torch {
namespace jit {
namespace tensorexpr {
// Walk the Statment looking for Half size loads/stores.
class CudaHalfChecker : public IRMutator {
public:
bool hasHalf() {
return hasHalf_;
}
const Expr* mutate(const Load* v) override {
const Expr* child = IRMutator::mutate(v);
if (child->dtype().scalar_type() != ScalarType::Half) {
return child;
}
hasHalf_ = true;
// TODO discards lanes.
return new Cast(kFloat, child);
}
Stmt* mutate(const Store* v) override {
const Expr* new_val = v->value()->accept_mutator(this);
if (v->value()->dtype().scalar_type() == ScalarType::Half) {
// TODO discards lanes.
new_val = new Cast(kHalf, new_val);
inserted_half_casts_.insert(new_val);
hasHalf_ = true;
}
return new Store(v->buf(), v->indices(), new_val, v->mask());
}
const Expr* mutate(const HalfImm* v) override {
hasHalf_ = true;
return new Cast(kFloat, v);
}
const Expr* mutate(const Cast* v) override {
const Expr* child = v->src_value()->accept_mutator(this);
// just don't allow half casts we didn't insert.
if (v->dtype().scalar_type() == ScalarType::Half) {
if (inserted_half_casts_.count(v) < 1) {
// TODO: discards lanes.
return new Cast(kFloat, child);
}
}
if (child == v->src_value()) {
return v;
}
return new Cast(v->dtype(), child);
}
private:
bool hasHalf_{false};
std::unordered_set<const Expr*> inserted_half_casts_;
};
class CudaHalfScalarRewriter : public IRMutator {
Stmt* mutate(const Let* v) override {
if (v->dtype().scalar_type() == ScalarType::Half) {
// TODO: discards lanes.
const Var* load_new_var = new Var(v->var()->name_hint(), kFloat);
const Expr* new_value =
new Cast(kFloat, v->value()->accept_mutator(this));
var_map[v->var()] = load_new_var;
return new Let(load_new_var, new_value);
}
return IRMutator::mutate(v);
}
const Expr* mutate(const Var* v) override {
auto it = var_map.find(v);
if (it != var_map.end()) {
return it->second;
}
return v;
}
private:
std::unordered_map<const Var*, const Var*> var_map;
};
} // namespace tensorexpr
} // namespace jit
} // namespace torch
|