1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70
|
// RUN: %target-run-simple-swift
// REQUIRES: executable_test
// Would fail due to unavailability of swift_autoDiffCreateLinearMapContext.
import StdlibUnittest
import DifferentiationUnittest
var AddressOnlyTangentVectorTests = TestSuite("AddressOnlyTangentVector")
// TF-1149: Test loadable class type with an address-only `TangentVector` type.
AddressOnlyTangentVectorTests.test("LoadableClassAddressOnlyTangentVector") {
final class LoadableClass<T: Differentiable>: Differentiable {
@differentiable(reverse)
var stored: T
@differentiable(reverse)
init(_ stored: T) {
self.stored = stored
}
@differentiable(reverse)
func method(_ x: T) -> T {
stored
}
}
@differentiable(reverse)
func projection<T: Differentiable>(_ s: LoadableClass<T>) -> T {
var x = s.stored
return x
}
expectEqual(.init(stored: 1), gradient(at: LoadableClass<Float>(10), of: projection))
@differentiable(reverse)
func tuple<T: Differentiable>(_ s: LoadableClass<T>) -> T {
var tuple = (s, (s, s))
return tuple.1.0.stored
}
expectEqual(.init(stored: 1), gradient(at: LoadableClass<Float>(10), of: tuple))
@differentiable(reverse)
func conditional<T: Differentiable>(_ s: LoadableClass<T>) -> T {
var tuple = (s, (s, s))
// TODO: cannot use literal `false` because it crashes
if 1 == 0 {}
return tuple.1.0.stored
}
expectEqual(.init(stored: 1), gradient(at: LoadableClass<Float>(10), of: conditional))
@differentiable(reverse)
func loop<T: Differentiable>(_ array: [LoadableClass<T>]) -> T {
var result: [LoadableClass<T>] = []
for i in withoutDerivative(at: array.indices) {
result.append(array[i])
}
return result[0].stored
}
expectEqual([.init(stored: 1)], gradient(at: [LoadableClass<Float>(10)], of: loop))
@differentiable(reverse)
func arrayLiteral<T: Differentiable>(_ s: LoadableClass<T>) -> T {
var result: [[LoadableClass<T>]] = [[s, s]]
return result[0][1].stored
}
expectEqual(.init(stored: 1), gradient(at: LoadableClass<Float>(10), of: arrayLiteral))
}
runAllTests()
|