1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179
|
// RUN: %target-swift-frontend -emit-sil -verify %s
import _Differentiation
// Test supported `br`, `cond_br`, and `switch_enum` terminators.
@differentiable(reverse)
func branch(_ x: Float) -> Float {
if x > 0 {
return x
} else if x < 10 {
return x
}
return x
}
enum Enum {
case a(Float)
case b(Float)
}
@differentiable(reverse)
func enum_nonactive1(_ e: Enum, _ x: Float) -> Float {
switch e {
case .a: return x
case .b: return x
}
}
@differentiable(reverse)
func enum_nonactive2(_ e: Enum, _ x: Float) -> Float {
switch e {
case let .a(a): return x + a
case let .b(b): return x + b
}
}
// Test loops.
@differentiable(reverse)
func for_loop(_ x: Float) -> Float {
var result: Float = x
for _ in 0..<3 {
result = result * x
}
return result
}
@differentiable(reverse)
func while_loop(_ x: Float) -> Float {
var result = x
var i = 1
while i < 3 {
result = result * x
i += 1
}
return result
}
@differentiable(reverse)
func nested_loop(_ x: Float) -> Float {
var outer = x
for _ in 1..<3 {
outer = outer * x
var inner = outer
var i = 1
while i < 3 {
inner = inner / x
i += 1
}
outer = inner
}
return outer
}
// TF-433: Test throwing functions.
func rethrowing(_ x: () throws -> Void) rethrows -> Void {}
@differentiable(reverse)
func testTryApply(_ x: Float) -> Float {
rethrowing({})
return x
}
// expected-error @+1 {{function is not differentiable}}
@differentiable(reverse)
// expected-note @+1 {{when differentiating this function definition}}
func withoutDerivative<T : Differentiable, R: Differentiable>(
at x: T, in body: (T) throws -> R
) rethrows -> R {
// expected-note @+1 {{expression is not differentiable}}
try body(x)
}
// Tests active `try_apply`.
// expected-error @+1 {{function is not differentiable}}
@differentiable(reverse)
// expected-note @+1 {{when differentiating this function definition}}
func testNilCoalescing(_ maybeX: Float?) -> Float {
// expected-note @+1 {{expression is not differentiable}}
return maybeX ?? 10
}
// Test unsupported differentiation of active enum values.
// expected-error @+1 {{function is not differentiable}}
@differentiable(reverse)
// expected-note @+1 {{when differentiating this function definition}}
func enum_active(_ x: Float) -> Float {
// expected-note @+1 {{differentiating enum values is not yet supported}}
let e: Enum
if x > 0 {
e = .a(x)
} else {
e = .b(x)
}
switch e {
case let .a(a): return x + a
case let .b(b): return x + b
}
}
enum Tree : Differentiable & AdditiveArithmetic {
case leaf(Float)
case branch(Float, Float)
typealias TangentVector = Self
static var zero: Self { .leaf(0) }
// expected-error @+1 {{function is not differentiable}}
@differentiable(reverse)
// TODO(TF-956): Improve location of active enum non-differentiability errors
// so that they are closer to the source of the non-differentiability.
// expected-note @+2 {{when differentiating this function definition}}
// expected-note @+1 {{differentiating enum values is not yet supported}}
static func +(_ lhs: Self, _ rhs: Self) -> Self {
switch (lhs, rhs) {
case let (.leaf(x), .leaf(y)):
return .leaf(x + y)
case let (.branch(x1, x2), .branch(y1, y2)):
return .branch(x1 + x2, y1 + y2)
default:
fatalError()
}
}
// expected-error @+1 {{function is not differentiable}}
@differentiable(reverse)
// TODO(TF-956): Improve location of active enum non-differentiability errors
// so that they are closer to the source of the non-differentiability.
// expected-note @+2 {{when differentiating this function definition}}
// expected-note @+1 {{differentiating enum values is not yet supported}}
static func -(_ lhs: Self, _ rhs: Self) -> Self {
switch (lhs, rhs) {
case let (.leaf(x), .leaf(y)):
return .leaf(x - y)
case let (.branch(x1, x2), .branch(y1, y2)):
return .branch(x1 - x2, y1 - y2)
default:
fatalError()
}
}
}
// TODO(TF-957): Improve non-differentiability errors for for-in loops
// (`Collection.makeIterator` and `IteratorProtocol.next`).
// expected-error @+1 {{function is not differentiable}}
@differentiable(reverse)
// expected-note @+2 {{when differentiating this function definition}}
// expected-note @+1 {{cannot differentiate through a non-differentiable result; do you want to use 'withoutDerivative(at:)'?}} {{+2:12-12=withoutDerivative(at: }} {{+2:17-17=)}}
func loop_array(_ array: [Float]) -> Float {
var result: Float = 1
for x in array {
result = result * x
}
return result
}
|