File: licm_context.swift

package info (click to toggle)
swiftlang 6.0.3-2
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 2,519,992 kB
  • sloc: cpp: 9,107,863; ansic: 2,040,022; asm: 1,135,751; python: 296,500; objc: 82,456; f90: 60,502; lisp: 34,951; pascal: 19,946; sh: 18,133; perl: 7,482; ml: 4,937; javascript: 4,117; makefile: 3,840; awk: 3,535; xml: 914; fortran: 619; cs: 573; ruby: 573
file content (90 lines) | stat: -rw-r--r-- 2,463 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
// RUN: %target-swift-frontend -emit-sil -O %s | %FileCheck %s

// Ensure that autoDiffCreateLinearMapContext call is not LICM'ed
import _Differentiation;

public struct R: Differentiable {
    @noDerivative public var z: Int
}

public struct Z: Differentiable {
    public var r: [R] = []
}

public struct B: Differentiable {
    public var h = [Float]();
    public var e = Z()
}

public extension Array {
    @differentiable(reverse where Element: Differentiable)
    mutating func update(at x: Int, with n: Element) {
        self[x] = n
    }
}

public extension Array where Element: Differentiable {
    @derivative(of: update(at:with:))
    mutating func v(at x: Int, with nv: Element) ->
      (value: Void,
       pullback: (inout TangentVector) -> (Element.TangentVector)) {
        update(at: x, with: nv);
        let f = count;
        return ((),
                { v in
                    if v.base.count < f {
                        v.base = [Element.TangentVector](repeating: .zero, count: f)
                    };
                    let d = v[x];
                    v.base[x] = .zero;
                    return d}
        )
    }
}

extension B {
    @differentiable(reverse)
    mutating func a() {
        for idx in 0 ..< withoutDerivative(at: self.e.r).count {
            let z = self.e.r[idx].z;
            let c = self.h[z];
            self.h.update(at: z, with: c + 2.4)
        }
    }
}

public func b(y: B) -> (value: B,
                        pullback: (B.TangentVector) -> (B.TangentVector)) {
    let s = valueWithPullback(at: y, of: s);
    return (value: s.value, pullback: s.pullback)
}

@differentiable(reverse)
public func s(y: B) -> B {
    @differentiable(reverse)
    func q(_ u: B) -> B {
        var l = u;
        for _ in 0 ..< 1 {
            l.a()
        };
        return l
    };
    let w = m(q);
    return w(y)
}

// CHECK-LABEL: sil private @$s12licm_context1s1yAA1BVAE_tF1qL_yA2EFTJrSpSr :
// CHECK: autoDiffCreateLinearMapContext
// CHECK: autoDiffCreateLinearMapContext
// CHECK-LABEL: end sil function '$s12licm_context1s1yAA1BVAE_tF1qL_yA2EFTJrSpSr'

func o<T, R>(_ x: T, _ f: @differentiable(reverse) (T) -> R) -> R {
    f(x)
}

func m<T, R>(_ f: @escaping @differentiable(reverse) (T) -> R) -> @differentiable(reverse) (T) -> R {
    { x in o(x, f) }
}

let m = b(y: B());
let grad = m.pullback(B.TangentVector(h: Array<Float>.TangentVector(), e: Z.TangentVector(r: Array<R>.TangentVector())))