File: storeborrow.swift

package info (click to toggle)
swiftlang 6.0.3-2
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 2,519,992 kB
  • sloc: cpp: 9,107,863; ansic: 2,040,022; asm: 1,135,751; python: 296,500; objc: 82,456; f90: 60,502; lisp: 34,951; pascal: 19,946; sh: 18,133; perl: 7,482; ml: 4,937; javascript: 4,117; makefile: 3,840; awk: 3,535; xml: 914; fortran: 619; cs: 573; ruby: 573
file content (140 lines) | stat: -rw-r--r-- 4,515 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
// RUN: %target-run-simple-swift
// REQUIRES: executable_test

import StdlibUnittest
import DifferentiationUnittest

var StoreBorrowAdjTest = TestSuite("StoreBorrowAdjTest")

public struct ConstantTimeAccessor<Element>: Differentiable where Element: Differentiable, Element: AdditiveArithmetic {
    public struct TangentVector: Differentiable, AdditiveArithmetic {
        public typealias TangentVector = ConstantTimeAccessor.TangentVector
        public var _base: [Element.TangentVector]
        public var accessed: Element.TangentVector

        public init(_base: [Element.TangentVector], accessed: Element.TangentVector) {
            self._base = _base
            self.accessed = accessed
        }
    }

    @usableFromInline
    var _values: [Element]

    public var accessed: Element

    @inlinable
    @differentiable(reverse)
    public init(_ values: [Element], accessed: Element = .zero) {
        self._values = values
        self.accessed = accessed
    }

    @inlinable
    @differentiable(reverse)
    public var array: [Element] { return _values }

    @noDerivative
    public var count: Int { return _values.count }
}

public extension ConstantTimeAccessor {
    @inlinable
    @derivative(of: init(_:accessed:))
    static func _vjpInit(_ values: [Element],
                         accessed: Element = .zero)
        -> (value: ConstantTimeAccessor, pullback: (TangentVector) -> (Array<Element>.TangentVector, Element.TangentVector)) {
        return (ConstantTimeAccessor(values, accessed: accessed), { v in
            let base: Array<Element>.TangentVector
            if v._base.count < values.count {
                base = Array<Element>
                    .TangentVector(v._base + Array<Element.TangentVector>(repeating: .zero, count: values.count - v._base.count))
            }
            else {
                base = Array<Element>.TangentVector(v._base)
            }

            return (base, v.accessed)
        })
    }

    @inlinable
    @derivative(of: array)
    func vjpArray() -> (value: [Element], pullback: (Array<Element>.TangentVector) -> TangentVector) {
    func pullback(v: Array<Element>.TangentVector) -> TangentVector {
            var base: [Element.TangentVector]
            let localZero = Element.TangentVector.zero
            if v.base.allSatisfy({ $0 == localZero }) {
                base = []
            }
            else {
                base = v.base
            }
            return TangentVector(_base: base, accessed: Element.TangentVector.zero)
    }
        return (_values, pullback)
    }

    mutating func move(by offset: TangentVector) {
        self.accessed.move(by: offset.accessed)
        _values.move(by: Array<Element>.TangentVector(offset._base))
    }
}

public extension ConstantTimeAccessor.TangentVector {
    @inlinable
    static func + (lhs: Self, rhs: Self) -> Self {
        if rhs._base.isEmpty {
            return lhs
        }
        else if lhs._base.isEmpty {
            return rhs
        }
        else {
            var base = zip(lhs._base, rhs._base).map(+)
            if lhs._base.count < rhs._base.count {
                base.append(contentsOf: rhs._base.suffix(from: lhs._base.count))
            }
            else if lhs._base.count > rhs._base.count {
                base.append(contentsOf: lhs._base.suffix(from: rhs._base.count))
            }

            return Self(_base: base, accessed: lhs.accessed + rhs.accessed)
        }
    }

    @inlinable
    static func - (lhs: Self, rhs: Self) -> Self {
        if rhs._base.isEmpty {
            return lhs
        }
        else {
            var base = zip(lhs._base, rhs._base).map(-)
            if lhs._base.count < rhs._base.count {
                base.append(contentsOf: rhs._base.suffix(from: lhs._base.count).map { .zero - $0 })
            }
            else if lhs._base.count > rhs._base.count {
                base.append(contentsOf: lhs._base.suffix(from: rhs._base.count))
            }

            return Self(_base: base, accessed: lhs.accessed - rhs.accessed)
        }
    }

    @inlinable
    static var zero: Self { Self(_base: [], accessed: .zero) }
}

StoreBorrowAdjTest.test("NonZeroGrad") {
  @differentiable(reverse)
  func testInits(input: [Float]) -> Float {
    let internalAccessor = ConstantTimeAccessor(input)
    let internalArray = internalAccessor.array
    return internalArray[1]
  }

  let grad = gradient(at: [42.0, 146.0, 73.0], of: testInits)
  expectEqual(grad[1], 1.0)
}

runAllTests()