File: issue-58353.swift

package info (click to toggle)
swiftlang 6.0.3-2
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 2,519,992 kB
  • sloc: cpp: 9,107,863; ansic: 2,040,022; asm: 1,135,751; python: 296,500; objc: 82,456; f90: 60,502; lisp: 34,951; pascal: 19,946; sh: 18,133; perl: 7,482; ml: 4,937; javascript: 4,117; makefile: 3,840; awk: 3,535; xml: 914; fortran: 619; cs: 573; ruby: 573
file content (127 lines) | stat: -rw-r--r-- 4,135 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
// RUN: %target-run-simple-swift
// REQUIRES: executable_test

// https://github.com/apple/swift/issues/58353

import _Differentiation
import StdlibUnittest

var PullbackTests = TestSuite("Pullback")

extension Dictionary: Differentiable where Value: Differentiable {
    public typealias TangentVector = [Key: Value.TangentVector]
    public mutating func move(by direction: TangentVector) {
        for (componentKey, componentDirection) in direction {
            func fatalMissingComponent() -> Value {
                fatalError("missing component \(componentKey) in moved Dictionary")
            }
            self[componentKey, default: fatalMissingComponent()].move(by: componentDirection)
        }
    }
    
    public var zeroTangentVectorInitializer: () -> TangentVector {
        let listOfKeys = self.keys // capturing only what's needed, not the entire self, in order to not waste memory
        func initializer() -> Self.TangentVector {
            return listOfKeys.reduce(into: [Key: Value.TangentVector]()) { $0[$1] = Value.TangentVector.zero }
        }
        return initializer
    }
}

extension Dictionary: AdditiveArithmetic where Value: AdditiveArithmetic {
    public static func + (_ lhs: Self, _ rhs: Self) -> Self {
        return lhs.merging(rhs, uniquingKeysWith: +)
    }

    public static func - (_ lhs: Self, _ rhs: Self) -> Self {
        return lhs.merging(rhs.mapValues { .zero - $0 }, uniquingKeysWith: +)
    }

    public static var zero: Self { [:] }
}

extension Dictionary where Value: Differentiable {
    // get
    @usableFromInline
    @derivative(of: subscript(_:))
    func vjpSubscriptGet(key: Key) -> (value: Value?, pullback: (Optional<Value>.TangentVector) -> Dictionary<Key, Value>.TangentVector) {
        // When adding two dictionaries, nil values are equivalent to zeroes, so there is no need to manually zero-out
        // every key's value. Instead, it is faster to create a dictionary with the single non-zero entry.
        return (self[key], { v in
            if let value = v.value {
                return [key: value]
            }
            else {
                return .zero
            }
        })
    }
 }

public extension Dictionary where Value: Differentiable {
    @differentiable(reverse)
    mutating func set(_ key: Key, to newValue: Value) {
        self[key] = newValue
    }

    @derivative(of: set)
    mutating func vjpUpdated(_ key: Key, to newValue: Value) -> (value: Void, pullback: (inout TangentVector) -> (Value.TangentVector)) {
        self.set(key, to: newValue)
        
        let forwardCount = self.count
        let forwardKeys = self.keys // may be heavy to capture all of these, not sure how to do without them though
        
        return ((), { v in
            // manual zero tangent initialization
            if v.count < forwardCount {
                v = Self.TangentVector()
                forwardKeys.forEach { v[$0] = .zero }
            }
            
            if let dElement = v[key] {
                v[key] = .zero
                return dElement
            }
            else { // should this fail?
                v[key] = .zero
                return .zero
            }
        })
    }
}


PullbackTests.test("ConcreteType") {
  func getD(from newValues: [String: Double], at key: String) -> Double? {
    if newValues.keys.contains(key) {
      return newValues[key]
    }
    return nil
  }
  
  @differentiable(reverse)
  func testFunctionD(newValues: [String: Double]) -> Double {
    return getD(from: newValues, at: "s1")!
  }

  expectEqual(pullback(at: ["s1": 1.0], of: testFunctionD)(2), ["s1" : 2.0])
}

PullbackTests.test("GenericType") {
  func getG<DataType>(from newValues: [String: DataType], at key: String) -> DataType?
  where DataType: Differentiable {
    if newValues.keys.contains(key) {
      return newValues[key]
    }
    return nil
  }

  @differentiable(reverse)
  func testFunctionG(newValues: [String: Double]) -> Double {
    return getG(from: newValues, at: "s1")!
  }

  expectEqual(pullback(at: ["s1": 1.0], of: testFunctionG)(2), ["s1" : 2.0])
}

runAllTests()