1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140
|
// RUN: %empty-directory(%t)
// RUN: %target-swift-frontend %s -emit-module -parse-as-library -o %t
// RUN: llvm-bcanalyzer %t/differentiable_attr.swiftmodule | %FileCheck %s -check-prefix=BCANALYZER
// RUN: %target-sil-opt -enable-sil-verify-all %t/differentiable_attr.swiftmodule -o - | %FileCheck %s
// BCANALYZER-NOT: UnknownCode
import _Differentiation
// CHECK: @differentiable(reverse, wrt: x)
// CHECK-NEXT: func simple(x: Float) -> Float
@differentiable(reverse)
func simple(x: Float) -> Float {
return x
}
// CHECK: @differentiable(_linear, wrt: x)
// CHECK-NEXT: func simple2(x: Float) -> Float
@differentiable(_linear)
func simple2(x: Float) -> Float {
return x
}
// CHECK: @differentiable(_linear, wrt: x)
// CHECK-NEXT: func simple4(x: Float) -> Float
@differentiable(_linear, wrt: x)
func simple4(x: Float) -> Float {
return x
}
func jvpSimple(x: Float) -> (Float, (Float) -> Float) {
return (x, { v in v })
}
func vjpSimple(x: Float) -> (Float, (Float) -> Float) {
return (x, { v in v })
}
// CHECK: @differentiable(reverse, wrt: x)
// CHECK-NEXT: func testWrtClause(x: Float, y: Float) -> Float
@differentiable(reverse, wrt: x)
func testWrtClause(x: Float, y: Float) -> Float {
return x
}
// CHECK: @differentiable(reverse, wrt: x)
// CHECK-NEXT: func testInout(x: inout Float)
@differentiable(reverse)
func testInout(x: inout Float) {
x = x * 2.0
}
// CHECK: @differentiable(reverse, wrt: x)
// CHECK-NEXT: func testInoutResult(x: inout Float) -> Float
@differentiable(reverse)
func testInoutResult(x: inout Float) -> Float {
x = x * 2.0
return x
}
// CHECK: @differentiable(reverse, wrt: (x, y))
// CHECK-NEXT: func testMultipleInout(x: inout Float, y: inout Float)
@differentiable(reverse)
func testMultipleInout(x: inout Float, y: inout Float) {
x = x * y
y = x
}
struct InstanceMethod : Differentiable {
// CHECK: @differentiable(reverse, wrt: (self, y))
// CHECK-NEXT: func testWrtClause(x: Float, y: Float) -> Float
@differentiable(reverse, wrt: (self, y))
func testWrtClause(x: Float, y: Float) -> Float {
return x
}
struct TangentVector: Differentiable, AdditiveArithmetic {
typealias TangentVector = Self
static func ==(_: Self, _: Self) -> Bool { fatalError() }
static var zero: Self { fatalError() }
static func +(_: Self, _: Self) -> Self { fatalError() }
static func -(_: Self, _: Self) -> Self { fatalError() }
}
mutating func move(by offset: TangentVector) {}
}
// CHECK: @differentiable(reverse, wrt: x where T : Differentiable)
// CHECK-NEXT: func testOnlyWhereClause<T>(x: T) -> T where T : Numeric
@differentiable(reverse where T : Differentiable)
func testOnlyWhereClause<T : Numeric>(x: T) -> T {
return x
}
// CHECK: @differentiable(reverse, wrt: x where T : Differentiable)
// CHECK-NEXT: func testWhereClause<T>(x: T) -> T where T : Numeric
@differentiable(reverse where T : Differentiable)
func testWhereClause<T : Numeric>(x: T) -> T {
return x
}
protocol P {}
extension P {
// CHECK: @differentiable(reverse, wrt: self where Self : Differentiable)
// CHECK-NEXT: func testWhereClauseMethod() -> Self
@differentiable(reverse, wrt: self where Self : Differentiable)
func testWhereClauseMethod() -> Self {
return self
}
}
extension P where Self : Differentiable {
func vjpTestWhereClauseMethod() -> (Self, (Self.TangentVector) -> Self.TangentVector) {
return (self, { v in v })
}
}
// CHECK: @differentiable(reverse, wrt: x where T : Differentiable, T == T.TangentVector)
// CHECK-NEXT: func testWhereClauseMethodTypeConstraint<T>(x: T) -> T where T : Numeric
@differentiable(reverse where T : Differentiable, T == T.TangentVector)
func testWhereClauseMethodTypeConstraint<T : Numeric>(x: T) -> T {
return x
}
func vjpTestWhereClauseMethodTypeConstraint<T>(x: T) -> (T, (T) -> T)
where T : Numeric, T : Differentiable, T == T.TangentVector
{
return (x, { v in v })
}
extension P {
// CHECK: @differentiable(reverse, wrt: self where Self : Differentiable, Self == Self.TangentVector)
// CHECK-NEXT: func testWhereClauseMethodTypeConstraint() -> Self
@differentiable(reverse, wrt: self where Self.TangentVector == Self, Self : Differentiable)
func testWhereClauseMethodTypeConstraint() -> Self {
return self
}
}
extension P where Self : Differentiable, Self == Self.TangentVector {
func vjpTestWhereClauseMethodTypeConstraint() -> (Self, (Self.TangentVector) -> Self.TangentVector) {
return (self, { v in v })
}
}
|