File: differentiable_function.swift

package info (click to toggle)
swiftlang 6.0.3-2
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 2,519,992 kB
  • sloc: cpp: 9,107,863; ansic: 2,040,022; asm: 1,135,751; python: 296,500; objc: 82,456; f90: 60,502; lisp: 34,951; pascal: 19,946; sh: 18,133; perl: 7,482; ml: 4,937; javascript: 4,117; makefile: 3,840; awk: 3,535; xml: 914; fortran: 619; cs: 573; ruby: 573
file content (127 lines) | stat: -rw-r--r-- 8,070 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
// RUN: %target-swift-frontend -emit-silgen %s | %FileCheck %s

// Test SILGen for `@differentiable` function typed values.

import _Differentiation

//===----------------------------------------------------------------------===//
// Return `@differentiable` function typed values unmodified.
//===----------------------------------------------------------------------===//

@_silgen_name("differentiable")
func differentiable(_ fn: @escaping @differentiable(reverse) (Float) -> Float)
    -> @differentiable(reverse) (Float) -> Float {
  return fn
}

@_silgen_name("linear")
func linear(_ fn: @escaping @differentiable(_linear) (Float) -> Float)
    -> @differentiable(_linear) (Float) -> Float {
  return fn
}

@_silgen_name("differentiable_noDerivative")
func differentiable_noDerivative(
  _ fn: @escaping @differentiable(reverse) (Float, @noDerivative Float) -> Float
) -> @differentiable(reverse) (Float, @noDerivative Float) -> Float {
  return fn
}

@_silgen_name("linear_noDerivative")
func linear_noDerivative(
  _ fn: @escaping @differentiable(_linear) (Float, @noDerivative Float) -> Float
) -> @differentiable(_linear) (Float, @noDerivative Float) -> Float {
  return fn
}

// CHECK-LABEL: sil hidden [ossa] @differentiable : $@convention(thin) (@guaranteed @differentiable(reverse) @callee_guaranteed (Float) -> Float) -> @owned @differentiable(reverse) @callee_guaranteed (Float) -> Float {
// CHECK: bb0([[FN:%.*]] : @guaranteed $@differentiable(reverse) @callee_guaranteed (Float) -> Float):
// CHECK:   [[COPIED_FN:%.*]] = copy_value [[FN]] : $@differentiable(reverse) @callee_guaranteed (Float) -> Float
// CHECK:   return [[COPIED_FN]] : $@differentiable(reverse) @callee_guaranteed (Float) -> Float
// CHECK: }

// CHECK-LABEL: sil hidden [ossa] @linear : $@convention(thin) (@guaranteed @differentiable(_linear) @callee_guaranteed (Float) -> Float) -> @owned @differentiable(_linear) @callee_guaranteed (Float) -> Float {
// CHECK: bb0([[FN:%.*]] : @guaranteed $@differentiable(_linear) @callee_guaranteed (Float) -> Float):
// CHECK:   [[COPIED_FN:%.*]] = copy_value [[FN]] : $@differentiable(_linear) @callee_guaranteed (Float) -> Float
// CHECK:   return [[COPIED_FN]] : $@differentiable(_linear) @callee_guaranteed (Float) -> Float
// CHECK: }

// CHECK-LABEL: sil hidden [ossa] @differentiable_noDerivative : $@convention(thin) (@guaranteed @differentiable(reverse) @callee_guaranteed (Float, @noDerivative Float) -> Float) -> @owned @differentiable(reverse) @callee_guaranteed (Float, @noDerivative Float) -> Float {
// CHECK: bb0([[FN:%.*]] : @guaranteed $@differentiable(reverse) @callee_guaranteed (Float, @noDerivative Float) -> Float):
// CHECK:   [[COPIED_FN:%.*]] = copy_value [[FN]] : $@differentiable(reverse) @callee_guaranteed (Float, @noDerivative Float) -> Float
// CHECK:   return [[COPIED_FN]] : $@differentiable(reverse) @callee_guaranteed (Float, @noDerivative Float) -> Float
// CHECK: }

// CHECK-LABEL: sil hidden [ossa] @linear_noDerivative : $@convention(thin) (@guaranteed @differentiable(_linear) @callee_guaranteed (Float, @noDerivative Float) -> Float) -> @owned @differentiable(_linear) @callee_guaranteed (Float, @noDerivative Float) -> Float {
// CHECK: bb0([[FN:%.*]] : @guaranteed $@differentiable(_linear) @callee_guaranteed (Float, @noDerivative Float) -> Float):
// CHECK:   [[COPIED_FN:%.*]] = copy_value [[FN]] : $@differentiable(_linear) @callee_guaranteed (Float, @noDerivative Float) -> Float
// CHECK:   return [[COPIED_FN]] : $@differentiable(_linear) @callee_guaranteed (Float, @noDerivative Float) -> Float
// CHECK: }

//===----------------------------------------------------------------------===//
// Closure conversion
//===----------------------------------------------------------------------===//

func thin(x: Float) -> Float { return x }

func myfunction(_ f: @escaping @differentiable(reverse) (Float) -> (Float)) -> (Float) -> Float {
  // @differentiable(reverse) functions should be callable.
  _ = f(.zero)
  return f
}

func myfunction2(_ f: @escaping @differentiable(_linear) (Float) -> (Float)) -> (Float) -> Float {
  // @differentiable(_linear) functions should be callable.
  _ = f(.zero)
  return f
}

var global_f: @differentiable(reverse) (Float) -> Float = {$0}
var global_f_linear: @differentiable(_linear) (Float) -> Float = {$0}

func calls_global_f() {
  _ = global_f(10)
  // TODO(TF-900, TF-902): Uncomment the following line to test loading a linear function from memory and direct calls to a linear function.
  // _ = global_f_linear(10)
}

func apply() {
  _ = myfunction(thin)
  _ = myfunction2(thin)
}

// CHECK-LABEL: @{{.*}}myfunction{{.*}}
// CHECK: bb0([[DIFF:%.*]] : @guaranteed $@differentiable(reverse) @callee_guaranteed (Float) -> Float):
// CHECK:   [[COPIED_DIFF:%.*]] = copy_value [[DIFF]] : $@differentiable(reverse) @callee_guaranteed (Float) -> Float
// CHECK:   [[BORROWED_DIFF:%.*]] = begin_borrow [[COPIED_DIFF]] : $@differentiable(reverse) @callee_guaranteed (Float) -> Float
// CHECK:   apply [[BORROWED_DIFF]]({{%.*}}) : $@differentiable(reverse) @callee_guaranteed (Float) -> Float
// CHECK:   end_borrow [[BORROWED_DIFF]] : $@differentiable(reverse) @callee_guaranteed (Float) -> Float
// CHECK:   destroy_value [[COPIED_DIFF]] : $@differentiable(reverse) @callee_guaranteed (Float) -> Float
// CHECK:   [[COPIED_DIFF:%.*]] = copy_value [[DIFF]] : $@differentiable(reverse) @callee_guaranteed (Float) -> Float
// CHECK:   [[BORROWED_DIFF:%.*]] = begin_borrow [[COPIED_DIFF]] : $@differentiable(reverse) @callee_guaranteed (Float) -> Float
// CHECK:   [[BORROWED_ORIG:%.*]] = differentiable_function_extract [original] [[BORROWED_DIFF]] : $@differentiable(reverse) @callee_guaranteed (Float) -> Float
// CHECK:   [[COPIED_ORIG:%.*]] = copy_value [[BORROWED_ORIG]] : $@callee_guaranteed (Float) -> Float
// CHECK:   return [[COPIED_ORIG]] : $@callee_guaranteed (Float) -> Float

// CHECK-LABEL: @{{.*}}myfunction2{{.*}}
// CHECK: bb0([[LIN:%.*]] : @guaranteed $@differentiable(_linear) @callee_guaranteed (Float) -> Float):
// CHECK:   [[COPIED_LIN:%.*]] = copy_value [[LIN]] : $@differentiable(_linear) @callee_guaranteed (Float) -> Float
// CHECK:   [[BORROWED_LIN:%.*]] = begin_borrow [[COPIED_LIN]] : $@differentiable(_linear) @callee_guaranteed (Float) -> Float
// CHECK:   apply [[BORROWED_LIN]]({{%.*}}) : $@differentiable(_linear) @callee_guaranteed (Float) -> Float
// CHECK:   end_borrow [[BORROWED_LIN]] : $@differentiable(_linear) @callee_guaranteed (Float) -> Float
// CHECK:   [[COPIED_LIN:%.*]] = copy_value [[LIN]] : $@differentiable(_linear) @callee_guaranteed (Float) -> Float
// CHECK:   [[BORROWED_LIN:%.*]] = begin_borrow [[COPIED_LIN]] : $@differentiable(_linear) @callee_guaranteed (Float) -> Float
// CHECK:   [[BORROWED_ORIG:%.*]] = linear_function_extract [original] [[BORROWED_LIN]] : $@differentiable(_linear) @callee_guaranteed (Float) -> Float
// CHECK:   [[COPIED_ORIG:%.*]] = copy_value [[BORROWED_ORIG]] : $@callee_guaranteed (Float) -> Float
// CHECK:   end_borrow [[BORROWED_LIN]] : $@differentiable(_linear) @callee_guaranteed (Float) -> Float
// CHECK:   destroy_value [[COPIED_LIN]] : $@differentiable(_linear) @callee_guaranteed (Float) -> Float
// CHECK:   return [[COPIED_ORIG]] : $@callee_guaranteed (Float) -> Float

// CHECK-LABEL: @{{.*}}apply{{.*}}
// CHECK:       [[ORIG:%.*]] = function_ref @{{.*}}thin{{.*}} : $@convention(thin) (Float) -> Float
// CHECK-NEXT:  [[ORIG_THICK:%.*]] = thin_to_thick_function [[ORIG]] : $@convention(thin) (Float) -> Float to $@callee_guaranteed (Float) -> Float
// CHECK-NEXT:  [[DIFFED:%.*]] = differentiable_function [parameters 0] [results 0] [[ORIG_THICK]] : $@callee_guaranteed (Float) -> Float
// CHECK:       [[ORIG:%.*]] = function_ref @{{.*}}thin{{.*}} : $@convention(thin) (Float) -> Float
// CHECK-NEXT:  [[ORIG_THICK:%.*]] = thin_to_thick_function [[ORIG]] : $@convention(thin) (Float) -> Float to $@callee_guaranteed (Float) -> Float
// CHECK-NEXT:  [[LIN:%.*]] = linear_function [parameters 0] [[ORIG_THICK]] : $@callee_guaranteed (Float) -> Float