1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158
|
//===--- ADContext.cpp - Differentiation Context --------------*- C++ -*---===//
//
// This source file is part of the Swift.org open source project
//
// Copyright (c) 2019 - 2020 Apple Inc. and the Swift project authors
// Licensed under Apache License v2.0 with Runtime Library Exception
//
// See https://swift.org/LICENSE.txt for license information
// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors
//
//===----------------------------------------------------------------------===//
//
// Per-module contextual information for the differentiation transform.
//
//===----------------------------------------------------------------------===//
#define DEBUG_TYPE "differentiation"
#include "swift/SILOptimizer/Differentiation/ADContext.h"
#include "swift/AST/DiagnosticsSIL.h"
#include "swift/AST/SourceFile.h"
#include "swift/SILOptimizer/PassManager/Transforms.h"
using llvm::DenseMap;
using llvm::SmallPtrSet;
using llvm::SmallVector;
namespace swift {
namespace autodiff {
//===----------------------------------------------------------------------===//
// Local helpers
//===----------------------------------------------------------------------===//
/// Given an operator name, such as '+', and a protocol, returns the '+'
/// operator. If the operator does not exist in the protocol, returns null.
static FuncDecl *findOperatorDeclInProtocol(DeclName operatorName,
ProtocolDecl *protocol) {
assert(operatorName.isOperator());
// Find the operator requirement in the given protocol declaration.
auto opLookup = protocol->lookupDirect(operatorName);
for (auto *decl : opLookup) {
if (!decl->isProtocolRequirement())
continue;
auto *fd = dyn_cast<FuncDecl>(decl);
if (!fd || !fd->isStatic() || !fd->isOperator())
continue;
return fd;
}
// Not found.
return nullptr;
}
//===----------------------------------------------------------------------===//
// ADContext methods
//===----------------------------------------------------------------------===//
ADContext::ADContext(SILModuleTransform &transform)
: transform(transform), module(*transform.getModule()),
passManager(*transform.getPassManager()) {}
/// Get the source file for the given `SILFunction`.
static SourceFile &getSourceFile(SILFunction *f) {
if (f->hasLocation())
if (auto *declContext = f->getLocation().getAsDeclContext())
if (auto *parentSourceFile = declContext->getParentSourceFile())
return *parentSourceFile;
for (auto *file : f->getModule().getSwiftModule()->getFiles())
if (auto *sourceFile = dyn_cast<SourceFile>(file))
return *sourceFile;
llvm_unreachable("Could not resolve SourceFile from SILFunction");
}
SynthesizedFileUnit &
ADContext::getOrCreateSynthesizedFile(SILFunction *original) {
auto &SF = getSourceFile(original);
return SF.getOrCreateSynthesizedFile();
}
FuncDecl *ADContext::getPlusDecl() const {
if (!cachedPlusFn) {
cachedPlusFn = findOperatorDeclInProtocol(astCtx.getIdentifier("+"),
additiveArithmeticProtocol);
assert(cachedPlusFn && "AdditiveArithmetic.+ not found");
}
return cachedPlusFn;
}
FuncDecl *ADContext::getPlusEqualDecl() const {
if (!cachedPlusEqualFn) {
cachedPlusEqualFn = findOperatorDeclInProtocol(astCtx.getIdentifier("+="),
additiveArithmeticProtocol);
assert(cachedPlusEqualFn && "AdditiveArithmetic.+= not found");
}
return cachedPlusEqualFn;
}
AccessorDecl *ADContext::getAdditiveArithmeticZeroGetter() const {
if (cachedZeroGetter)
return cachedZeroGetter;
auto zeroDeclLookup = getAdditiveArithmeticProtocol()
->lookupDirect(getASTContext().Id_zero);
auto *zeroDecl = cast<VarDecl>(zeroDeclLookup.front());
assert(zeroDecl->isProtocolRequirement());
cachedZeroGetter = zeroDecl->getOpaqueAccessor(AccessorKind::Get);
return cachedZeroGetter;
}
void ADContext::cleanUp() {
// Delete all references to generated functions.
for (auto fnRef : generatedFunctionReferences) {
if (auto *fnRefInst =
peerThroughFunctionConversions<FunctionRefInst>(fnRef)) {
fnRefInst->replaceAllUsesWithUndef();
fnRefInst->eraseFromParent();
}
}
// Delete all generated functions.
for (auto *generatedFunction : generatedFunctions) {
LLVM_DEBUG(getADDebugStream() << "Deleting generated function "
<< generatedFunction->getName() << '\n');
generatedFunction->dropAllReferences();
transform.notifyWillDeleteFunction(generatedFunction);
module.eraseFunction(generatedFunction);
}
}
DifferentiableFunctionInst *ADContext::createDifferentiableFunction(
SILBuilder &builder, SILLocation loc, IndexSubset *parameterIndices,
IndexSubset *resultIndices, SILValue original,
std::optional<std::pair<SILValue, SILValue>> derivativeFunctions) {
auto *dfi = builder.createDifferentiableFunction(
loc, parameterIndices, resultIndices, original, derivativeFunctions);
processedDifferentiableFunctionInsts.erase(dfi);
return dfi;
}
LinearFunctionInst *ADContext::createLinearFunction(
SILBuilder &builder, SILLocation loc, IndexSubset *parameterIndices,
SILValue original, std::optional<SILValue> transposeFunction) {
auto *lfi = builder.createLinearFunction(loc, parameterIndices, original,
transposeFunction);
processedLinearFunctionInsts.erase(lfi);
return lfi;
}
DifferentiableFunctionExpr *
ADContext::findDifferentialOperator(DifferentiableFunctionInst *inst) {
return inst->getLoc().getAsASTNode<DifferentiableFunctionExpr>();
}
LinearFunctionExpr *
ADContext::findDifferentialOperator(LinearFunctionInst *inst) {
return inst->getLoc().getAsASTNode<LinearFunctionExpr>();
}
} // end namespace autodiff
} // end namespace swift
|