1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83
|
//===- PolynomialCanonicalization.td - Polynomial patterns -*- tablegen -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef POLYNOMIAL_CANONICALIZATION
#define POLYNOMIAL_CANONICALIZATION
include "mlir/Dialect/Arith/IR/ArithOps.td"
include "mlir/Dialect/Polynomial/IR/Polynomial.td"
include "mlir/IR/EnumAttr.td"
include "mlir/IR/OpBase.td"
include "mlir/IR/PatternBase.td"
defvar DefOverflow = ConstantEnumCase<Arith_IntegerOverflowAttr, "none">;
def Equal : Constraint<CPred<"$0 == $1">>;
// Get a -1 integer attribute of the same type as the polynomial SSA value's
// ring coefficient type.
def getMinusOne
: NativeCodeCall<
"$_builder.getIntegerAttr("
"cast<PolynomialType>($0.getType()).getRing().getCoefficientType(), -1)">;
def SubAsAdd : Pat<
(Polynomial_SubOp $f, $g),
(Polynomial_AddOp $f,
(Polynomial_MulScalarOp $g,
(Arith_ConstantOp (getMinusOne $g))))>;
def INTTAfterNTT : Pat<
(Polynomial_INTTOp (Polynomial_NTTOp $poly, $r1), $r2),
(replaceWithValue $poly),
[(Equal $r1, $r2)]
>;
def NTTAfterINTT : Pat<
(Polynomial_NTTOp (Polynomial_INTTOp $tensor, $r1), $r2),
(replaceWithValue $tensor),
[(Equal $r1, $r2)]
>;
// NTTs are expensive, and addition in coefficient or NTT domain should be
// equivalently expensive, so reducing the number of NTTs is optimal.
// ntt(a) + ntt(b) -> ntt(a + b)
def NTTOfAdd : Pat<
(Arith_AddIOp
(Polynomial_NTTOp $p1, $r1),
(Polynomial_NTTOp $p2, $r2),
$overflow),
(Polynomial_NTTOp (Polynomial_AddOp $p1, $p2), $r1),
[(Equal $r1, $r2)]
>;
// intt(a) + intt(b) -> intt(a + b)
def INTTOfAdd : Pat<
(Polynomial_AddOp
(Polynomial_INTTOp $t1, $r1),
(Polynomial_INTTOp $t2, $r2)),
(Polynomial_INTTOp (Arith_AddIOp $t1, $t2, DefOverflow), $r1),
[(Equal $r1, $r2)]
>;
// repeated for sub
def NTTOfSub : Pat<
(Arith_SubIOp
(Polynomial_NTTOp $p1, $r1),
(Polynomial_NTTOp $p2, $r2),
$overflow),
(Polynomial_NTTOp (Polynomial_SubOp $p1, $p2), $r1),
[(Equal $r1, $r2)]
>;
def INTTOfSub : Pat<
(Polynomial_SubOp
(Polynomial_INTTOp $t1, $r1),
(Polynomial_INTTOp $t2, $r2)),
(Polynomial_INTTOp (Arith_SubIOp $t1, $t2, DefOverflow), $r1),
[(Equal $r1, $r2)]
>;
#endif // POLYNOMIAL_CANONICALIZATION
|