File: differentiable_attr.swift

package info (click to toggle)
swiftlang 6.0.3-2
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 2,519,992 kB
  • sloc: cpp: 9,107,863; ansic: 2,040,022; asm: 1,135,751; python: 296,500; objc: 82,456; f90: 60,502; lisp: 34,951; pascal: 19,946; sh: 18,133; perl: 7,482; ml: 4,937; javascript: 4,117; makefile: 3,840; awk: 3,535; xml: 914; fortran: 619; cs: 573; ruby: 573
file content (140 lines) | stat: -rw-r--r-- 4,488 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
// RUN: %empty-directory(%t)
// RUN: %target-swift-frontend %s -emit-module -parse-as-library -o %t
// RUN: llvm-bcanalyzer %t/differentiable_attr.swiftmodule | %FileCheck %s -check-prefix=BCANALYZER
// RUN: %target-sil-opt -enable-sil-verify-all %t/differentiable_attr.swiftmodule -o - | %FileCheck %s

// BCANALYZER-NOT: UnknownCode

import _Differentiation

// CHECK: @differentiable(reverse, wrt: x)
// CHECK-NEXT: func simple(x: Float) -> Float
@differentiable(reverse)
func simple(x: Float) -> Float {
  return x
}

// CHECK: @differentiable(_linear, wrt: x)
// CHECK-NEXT: func simple2(x: Float) -> Float
@differentiable(_linear)
func simple2(x: Float) -> Float {
  return x
}

// CHECK: @differentiable(_linear, wrt: x)
// CHECK-NEXT: func simple4(x: Float) -> Float
@differentiable(_linear, wrt: x)
func simple4(x: Float) -> Float {
  return x
}

func jvpSimple(x: Float) -> (Float, (Float) -> Float) {
  return (x, { v in v })
}

func vjpSimple(x: Float) -> (Float, (Float) -> Float) {
  return (x, { v in v })
}

// CHECK: @differentiable(reverse, wrt: x)
// CHECK-NEXT: func testWrtClause(x: Float, y: Float) -> Float
@differentiable(reverse, wrt: x)
func testWrtClause(x: Float, y: Float) -> Float {
  return x
}

// CHECK: @differentiable(reverse, wrt: x)
// CHECK-NEXT: func testInout(x: inout Float)
@differentiable(reverse)
func testInout(x: inout Float) {
  x = x * 2.0
}

// CHECK: @differentiable(reverse, wrt: x)
// CHECK-NEXT: func testInoutResult(x: inout Float) -> Float
@differentiable(reverse)
func testInoutResult(x: inout Float) -> Float {
  x = x * 2.0
  return x
}

// CHECK: @differentiable(reverse, wrt: (x, y))
// CHECK-NEXT: func testMultipleInout(x: inout Float, y: inout Float)
@differentiable(reverse)
func testMultipleInout(x: inout Float, y: inout Float) {
  x = x * y
  y = x
}

struct InstanceMethod : Differentiable {
  // CHECK: @differentiable(reverse, wrt: (self, y))
  // CHECK-NEXT: func testWrtClause(x: Float, y: Float) -> Float
  @differentiable(reverse, wrt: (self, y))
  func testWrtClause(x: Float, y: Float) -> Float {
    return x
  }

  struct TangentVector: Differentiable, AdditiveArithmetic {
    typealias TangentVector = Self
    static func ==(_: Self, _: Self) -> Bool { fatalError() }
    static var zero: Self { fatalError() }
    static func +(_: Self, _: Self) -> Self { fatalError() }
    static func -(_: Self, _: Self) -> Self { fatalError() }
  }
  mutating func move(by offset: TangentVector) {}
}

// CHECK: @differentiable(reverse, wrt: x where T : Differentiable)
// CHECK-NEXT: func testOnlyWhereClause<T>(x: T) -> T where T : Numeric
@differentiable(reverse where T : Differentiable)
func testOnlyWhereClause<T : Numeric>(x: T) -> T {
  return x
}

// CHECK: @differentiable(reverse, wrt: x where T : Differentiable)
// CHECK-NEXT: func testWhereClause<T>(x: T) -> T where T : Numeric
@differentiable(reverse where T : Differentiable)
func testWhereClause<T : Numeric>(x: T) -> T {
  return x
}

protocol P {}
extension P {
  // CHECK: @differentiable(reverse, wrt: self where Self : Differentiable)
  // CHECK-NEXT: func testWhereClauseMethod() -> Self
  @differentiable(reverse, wrt: self where Self : Differentiable)
  func testWhereClauseMethod() -> Self {
    return self
  }
}
extension P where Self : Differentiable {
  func vjpTestWhereClauseMethod() -> (Self, (Self.TangentVector) -> Self.TangentVector) {
    return (self, { v in v })
  }
}

// CHECK: @differentiable(reverse, wrt: x where T : Differentiable, T == T.TangentVector)
// CHECK-NEXT: func testWhereClauseMethodTypeConstraint<T>(x: T) -> T where T : Numeric
@differentiable(reverse where T : Differentiable, T == T.TangentVector)
func testWhereClauseMethodTypeConstraint<T : Numeric>(x: T) -> T {
  return x
}
func vjpTestWhereClauseMethodTypeConstraint<T>(x: T) -> (T, (T) -> T)
  where T : Numeric, T : Differentiable, T == T.TangentVector
{
  return (x, { v in v })
}

extension P {
  // CHECK: @differentiable(reverse, wrt: self where Self : Differentiable, Self == Self.TangentVector)
  // CHECK-NEXT: func testWhereClauseMethodTypeConstraint() -> Self
  @differentiable(reverse, wrt: self where Self.TangentVector == Self, Self : Differentiable)
  func testWhereClauseMethodTypeConstraint() -> Self {
    return self
  }
}
extension P where Self : Differentiable, Self == Self.TangentVector {
  func vjpTestWhereClauseMethodTypeConstraint() -> (Self, (Self.TangentVector) -> Self.TangentVector) {
    return (self, { v in v })
  }
}