1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202
|
//===----- DifferentiationMangler.cpp --------- differentiation -*- C++ -*-===//
//
// This source file is part of the Swift.org open source project
//
// Copyright (c) 2014 - 2017 Apple Inc. and the Swift project authors
// Licensed under Apache License v2.0 with Runtime Library Exception
//
// See https://swift.org/LICENSE.txt for license information
// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors
//
//===----------------------------------------------------------------------===//
#include "swift/SILOptimizer/Utils/DifferentiationMangler.h"
#include "swift/AST/AutoDiff.h"
#include "swift/AST/GenericEnvironment.h"
#include "swift/AST/GenericSignature.h"
#include "swift/AST/SubstitutionMap.h"
#include "swift/Demangling/ManglingMacros.h"
#include "swift/SIL/SILGlobalVariable.h"
using namespace swift;
using namespace Mangle;
/// Mangles the generic signature and get its mangling tree. This is necessary
/// because the derivative generic signature's requirements may contain names
/// which repeat the contents of the original function name. To follow Swift's
/// mangling scheme, these repetitions must be mangled as substitutions.
/// Therefore, we build mangling trees in `DifferentiationMangler` and let the
/// remangler take care of substitutions.
static NodePointer mangleGenericSignatureAsNode(GenericSignature sig,
Demangler &demangler) {
if (!sig)
return nullptr;
ASTMangler sigMangler;
auto mangledGenSig = sigMangler.mangleGenericSignature(sig);
auto demangledGenSig = demangler.demangleSymbol(mangledGenSig);
assert(demangledGenSig->getKind() == Node::Kind::Global);
assert(demangledGenSig->getNumChildren() == 1);
auto result = demangledGenSig->getFirstChild();
assert(result->getKind() == Node::Kind::DependentGenericSignature);
return result;
}
static NodePointer mangleAutoDiffFunctionAsNode(
StringRef originalName, Demangle::AutoDiffFunctionKind kind,
const AutoDiffConfig &config, Demangler &demangler) {
assert(isMangledName(originalName));
auto demangledOrig = demangler.demangleSymbol(originalName);
assert(demangledOrig && "Should only be called when the original "
"function has a mangled name");
assert(demangledOrig->getKind() == Node::Kind::Global);
auto derivativeGenericSignatureNode = mangleGenericSignatureAsNode(
config.derivativeGenericSignature, demangler);
auto *adFunc = demangler.createNode(Node::Kind::AutoDiffFunction);
for (auto *child : *demangledOrig)
adFunc->addChild(child, demangler);
if (derivativeGenericSignatureNode)
adFunc->addChild(derivativeGenericSignatureNode, demangler);
adFunc->addChild(
demangler.createNode(
Node::Kind::AutoDiffFunctionKind, (Node::IndexType)kind),
demangler);
adFunc->addChild(
demangler.createNode(
Node::Kind::IndexSubset, config.parameterIndices->getString()),
demangler);
adFunc->addChild(
demangler.createNode(
Node::Kind::IndexSubset, config.resultIndices->getString()),
demangler);
auto root = demangler.createNode(Node::Kind::Global);
root->addChild(adFunc, demangler);
return root;
}
std::string DifferentiationMangler::mangleAutoDiffFunction(
StringRef originalName, Demangle::AutoDiffFunctionKind kind,
const AutoDiffConfig &config) {
// If the original function is mangled, mangle the tree.
if (isMangledName(originalName)) {
Demangler demangler;
auto node = mangleAutoDiffFunctionAsNode(
originalName, kind, config, demangler);
auto mangling = Demangle::mangleNode(node);
assert(mangling.isSuccess());
return mangling.result();
}
// Otherwise, treat the original function symbol as a black box and just
// mangle the other parts.
beginManglingWithoutPrefix();
appendOperator(originalName);
appendAutoDiffFunctionParts("TJ", kind, config);
return finalize();
}
// Returns the mangled name for a derivative function of the given kind.
std::string DifferentiationMangler::mangleDerivativeFunction(
StringRef originalName, AutoDiffDerivativeFunctionKind kind,
const AutoDiffConfig &config) {
return mangleAutoDiffFunction(
originalName, getAutoDiffFunctionKind(kind), config);
}
// Returns the mangled name for a derivative function of the given kind.
std::string DifferentiationMangler::mangleLinearMap(
StringRef originalName, AutoDiffLinearMapKind kind,
const AutoDiffConfig &config) {
return mangleAutoDiffFunction(
originalName, getAutoDiffFunctionKind(kind), config);
}
static NodePointer mangleDerivativeFunctionSubsetParametersThunkAsNode(
StringRef originalName, Type toType, Demangle::AutoDiffFunctionKind kind,
IndexSubset *fromParamIndices, IndexSubset *fromResultIndices,
IndexSubset *toParamIndices, Demangler &demangler) {
assert(isMangledName(originalName));
auto demangledOrig = demangler.demangleSymbol(originalName);
assert(demangledOrig && "Should only be called when the original "
"function has a mangled name");
assert(demangledOrig->getKind() == Node::Kind::Global);
auto *thunk = demangler.createNode(Node::Kind::AutoDiffSubsetParametersThunk);
for (auto *child : *demangledOrig)
thunk->addChild(child, demangler);
NodePointer toTypeNode = nullptr;
{
ASTMangler typeMangler;
toTypeNode = demangler.demangleType(
typeMangler.mangleTypeWithoutPrefix(toType));
assert(toTypeNode && "Cannot demangle the to-type as node");
}
thunk->addChild(toTypeNode, demangler);
thunk->addChild(
demangler.createNode(
Node::Kind::AutoDiffFunctionKind, (Node::IndexType)kind),
demangler);
thunk->addChild(
demangler.createNode(
Node::Kind::IndexSubset, fromParamIndices->getString()),
demangler);
thunk->addChild(
demangler.createNode(
Node::Kind::IndexSubset, fromResultIndices->getString()),
demangler);
thunk->addChild(
demangler.createNode(
Node::Kind::IndexSubset, toParamIndices->getString()),
demangler);
auto root = demangler.createNode(Node::Kind::Global);
root->addChild(thunk, demangler);
return root;
}
std::string
DifferentiationMangler::mangleDerivativeFunctionSubsetParametersThunk(
StringRef originalName, CanType toType,
AutoDiffDerivativeFunctionKind linearMapKind,
IndexSubset *fromParamIndices, IndexSubset *fromResultIndices,
IndexSubset *toParamIndices) {
beginMangling();
auto kind = getAutoDiffFunctionKind(linearMapKind);
// If the original function is mangled, mangle the tree.
if (isMangledName(originalName)) {
Demangler demangler;
auto node = mangleDerivativeFunctionSubsetParametersThunkAsNode(
originalName, toType, kind, fromParamIndices, fromResultIndices,
toParamIndices, demangler);
auto mangling = Demangle::mangleNode(node);
assert(mangling.isSuccess());
return mangling.result();
}
// Otherwise, treat the original function symbol as a black box and just
// mangle the other parts.
beginManglingWithoutPrefix();
appendOperator(originalName);
appendType(toType, nullptr);
auto kindCode = (char)kind;
appendOperator("TJS", StringRef(&kindCode, 1));
appendIndexSubset(fromParamIndices);
appendOperator("p");
appendIndexSubset(fromResultIndices);
appendOperator("r");
appendIndexSubset(toParamIndices);
appendOperator("P");
return finalize();
}
std::string DifferentiationMangler::mangleLinearMapSubsetParametersThunk(
CanType fromType, AutoDiffLinearMapKind linearMapKind,
IndexSubset *fromParamIndices, IndexSubset *fromResultIndices,
IndexSubset *toParamIndices) {
beginMangling();
appendType(fromType, nullptr);
auto functionKindCode = (char)getAutoDiffFunctionKind(linearMapKind);
appendOperator("TJS", StringRef(&functionKindCode, 1));
appendIndexSubset(fromParamIndices);
appendOperator("p");
appendIndexSubset(fromResultIndices);
appendOperator("r");
appendIndexSubset(toParamIndices);
appendOperator("P");
return finalize();
}
|