File: differentiation_control_flow_diagnostics.swift

package info (click to toggle)
swiftlang 6.0.3-2
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 2,519,992 kB
  • sloc: cpp: 9,107,863; ansic: 2,040,022; asm: 1,135,751; python: 296,500; objc: 82,456; f90: 60,502; lisp: 34,951; pascal: 19,946; sh: 18,133; perl: 7,482; ml: 4,937; javascript: 4,117; makefile: 3,840; awk: 3,535; xml: 914; fortran: 619; cs: 573; ruby: 573
file content (179 lines) | stat: -rw-r--r-- 4,661 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
// RUN: %target-swift-frontend -emit-sil -verify %s

import _Differentiation

// Test supported `br`, `cond_br`, and `switch_enum` terminators.

@differentiable(reverse)
func branch(_ x: Float) -> Float {
  if x > 0 {
    return x
  } else if x < 10 {
    return x
  }
  return x
}

enum Enum {
  case a(Float)
  case b(Float)
}

@differentiable(reverse)
func enum_nonactive1(_ e: Enum, _ x: Float) -> Float {
  switch e {
    case .a: return x
    case .b: return x
  }
}

@differentiable(reverse)
func enum_nonactive2(_ e: Enum, _ x: Float) -> Float {
  switch e {
    case let .a(a): return x + a
    case let .b(b): return x + b
  }
}

// Test loops.

@differentiable(reverse)
func for_loop(_ x: Float) -> Float {
  var result: Float = x
  for _ in 0..<3 {
    result = result * x
  }
  return result
}

@differentiable(reverse)
func while_loop(_ x: Float) -> Float {
  var result = x
  var i = 1
  while i < 3 {
    result = result * x
    i += 1
  }
  return result
}

@differentiable(reverse)
func nested_loop(_ x: Float) -> Float {
  var outer = x
  for _ in 1..<3 {
    outer = outer * x

    var inner = outer
    var i = 1
    while i < 3 {
      inner = inner / x
      i += 1
    }
    outer = inner
  }
  return outer
}

// TF-433: Test throwing functions.

func rethrowing(_ x: () throws -> Void) rethrows -> Void {}

@differentiable(reverse)
func testTryApply(_ x: Float) -> Float {
  rethrowing({})
  return x
}

// expected-error @+1 {{function is not differentiable}}
@differentiable(reverse)
// expected-note @+1 {{when differentiating this function definition}}
func withoutDerivative<T : Differentiable, R: Differentiable>(
  at x: T, in body: (T) throws -> R
) rethrows -> R {
  // expected-note @+1 {{expression is not differentiable}}
  try body(x)
}

// Tests active `try_apply`.
// expected-error @+1 {{function is not differentiable}}
@differentiable(reverse)
// expected-note @+1 {{when differentiating this function definition}}
func testNilCoalescing(_ maybeX: Float?) -> Float {
  // expected-note @+1 {{expression is not differentiable}}
  return maybeX ?? 10
}

// Test unsupported differentiation of active enum values.

// expected-error @+1 {{function is not differentiable}}
@differentiable(reverse)
// expected-note @+1 {{when differentiating this function definition}}
func enum_active(_ x: Float) -> Float {
  // expected-note @+1 {{differentiating enum values is not yet supported}}
  let e: Enum
  if x > 0 {
    e = .a(x)
  } else {
    e = .b(x)
  }
  switch e {
    case let .a(a): return x + a
    case let .b(b): return x + b
  }
}

enum Tree : Differentiable & AdditiveArithmetic {
  case leaf(Float)
  case branch(Float, Float)

  typealias TangentVector = Self
  static var zero: Self { .leaf(0) }

  // expected-error @+1 {{function is not differentiable}}
  @differentiable(reverse)
  // TODO(TF-956): Improve location of active enum non-differentiability errors
  // so that they are closer to the source of the non-differentiability.
  // expected-note @+2 {{when differentiating this function definition}}
  // expected-note @+1 {{differentiating enum values is not yet supported}}
  static func +(_ lhs: Self, _ rhs: Self) -> Self {
    switch (lhs, rhs) {
    case let (.leaf(x), .leaf(y)):
      return .leaf(x + y)
    case let (.branch(x1, x2), .branch(y1, y2)):
      return .branch(x1 + x2, y1 + y2)
    default:
      fatalError()
    }
  }

  // expected-error @+1 {{function is not differentiable}}
  @differentiable(reverse)
  // TODO(TF-956): Improve location of active enum non-differentiability errors
  // so that they are closer to the source of the non-differentiability.
  // expected-note @+2 {{when differentiating this function definition}}
  // expected-note @+1 {{differentiating enum values is not yet supported}}
  static func -(_ lhs: Self, _ rhs: Self) -> Self {
    switch (lhs, rhs) {
    case let (.leaf(x), .leaf(y)):
      return .leaf(x - y)
    case let (.branch(x1, x2), .branch(y1, y2)):
      return .branch(x1 - x2, y1 - y2)
    default:
      fatalError()
    }
  }
}

// TODO(TF-957): Improve non-differentiability errors for for-in loops
// (`Collection.makeIterator` and `IteratorProtocol.next`).
// expected-error @+1 {{function is not differentiable}}
@differentiable(reverse)
// expected-note @+2 {{when differentiating this function definition}}
// expected-note @+1 {{cannot differentiate through a non-differentiable result; do you want to use 'withoutDerivative(at:)'?}} {{+2:12-12=withoutDerivative(at: }} {{+2:17-17=)}}
func loop_array(_ array: [Float]) -> Float {
  var result: Float = 1
  for x in array {
    result = result * x
  }
  return result
}