1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140
|
// RUN: %target-run-simple-swift
// REQUIRES: executable_test
import StdlibUnittest
import DifferentiationUnittest
var StoreBorrowAdjTest = TestSuite("StoreBorrowAdjTest")
public struct ConstantTimeAccessor<Element>: Differentiable where Element: Differentiable, Element: AdditiveArithmetic {
public struct TangentVector: Differentiable, AdditiveArithmetic {
public typealias TangentVector = ConstantTimeAccessor.TangentVector
public var _base: [Element.TangentVector]
public var accessed: Element.TangentVector
public init(_base: [Element.TangentVector], accessed: Element.TangentVector) {
self._base = _base
self.accessed = accessed
}
}
@usableFromInline
var _values: [Element]
public var accessed: Element
@inlinable
@differentiable(reverse)
public init(_ values: [Element], accessed: Element = .zero) {
self._values = values
self.accessed = accessed
}
@inlinable
@differentiable(reverse)
public var array: [Element] { return _values }
@noDerivative
public var count: Int { return _values.count }
}
public extension ConstantTimeAccessor {
@inlinable
@derivative(of: init(_:accessed:))
static func _vjpInit(_ values: [Element],
accessed: Element = .zero)
-> (value: ConstantTimeAccessor, pullback: (TangentVector) -> (Array<Element>.TangentVector, Element.TangentVector)) {
return (ConstantTimeAccessor(values, accessed: accessed), { v in
let base: Array<Element>.TangentVector
if v._base.count < values.count {
base = Array<Element>
.TangentVector(v._base + Array<Element.TangentVector>(repeating: .zero, count: values.count - v._base.count))
}
else {
base = Array<Element>.TangentVector(v._base)
}
return (base, v.accessed)
})
}
@inlinable
@derivative(of: array)
func vjpArray() -> (value: [Element], pullback: (Array<Element>.TangentVector) -> TangentVector) {
func pullback(v: Array<Element>.TangentVector) -> TangentVector {
var base: [Element.TangentVector]
let localZero = Element.TangentVector.zero
if v.base.allSatisfy({ $0 == localZero }) {
base = []
}
else {
base = v.base
}
return TangentVector(_base: base, accessed: Element.TangentVector.zero)
}
return (_values, pullback)
}
mutating func move(by offset: TangentVector) {
self.accessed.move(by: offset.accessed)
_values.move(by: Array<Element>.TangentVector(offset._base))
}
}
public extension ConstantTimeAccessor.TangentVector {
@inlinable
static func + (lhs: Self, rhs: Self) -> Self {
if rhs._base.isEmpty {
return lhs
}
else if lhs._base.isEmpty {
return rhs
}
else {
var base = zip(lhs._base, rhs._base).map(+)
if lhs._base.count < rhs._base.count {
base.append(contentsOf: rhs._base.suffix(from: lhs._base.count))
}
else if lhs._base.count > rhs._base.count {
base.append(contentsOf: lhs._base.suffix(from: rhs._base.count))
}
return Self(_base: base, accessed: lhs.accessed + rhs.accessed)
}
}
@inlinable
static func - (lhs: Self, rhs: Self) -> Self {
if rhs._base.isEmpty {
return lhs
}
else {
var base = zip(lhs._base, rhs._base).map(-)
if lhs._base.count < rhs._base.count {
base.append(contentsOf: rhs._base.suffix(from: lhs._base.count).map { .zero - $0 })
}
else if lhs._base.count > rhs._base.count {
base.append(contentsOf: lhs._base.suffix(from: rhs._base.count))
}
return Self(_base: base, accessed: lhs.accessed - rhs.accessed)
}
}
@inlinable
static var zero: Self { Self(_base: [], accessed: .zero) }
}
StoreBorrowAdjTest.test("NonZeroGrad") {
@differentiable(reverse)
func testInits(input: [Float]) -> Float {
let internalAccessor = ConstantTimeAccessor(input)
let internalArray = internalAccessor.array
return internalArray[1]
}
let grad = gradient(at: [42.0, 146.0, 73.0], of: testInits)
expectEqual(grad[1], 1.0)
}
runAllTests()
|