File: differentiable_function.sil

package info (click to toggle)
swiftlang 6.0.3-2
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 2,519,992 kB
  • sloc: cpp: 9,107,863; ansic: 2,040,022; asm: 1,135,751; python: 296,500; objc: 82,456; f90: 60,502; lisp: 34,951; pascal: 19,946; sh: 18,133; perl: 7,482; ml: 4,937; javascript: 4,117; makefile: 3,840; awk: 3,535; xml: 914; fortran: 619; cs: 573; ruby: 573
file content (116 lines) | stat: -rw-r--r-- 7,876 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
// RUN: %target-swift-frontend -emit-ir %s | %FileCheck %s

sil_stage raw

import Swift
import Builtin
import _Differentiation

sil @f : $@convention(thin) (Float) -> Float {
bb0(%0 : $Float):
  return undef : $Float
}

sil @f_jvp : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float) {
bb0(%0 : $Float):
  return undef : $(Float, @callee_guaranteed (Float) -> Float)
}

sil @f_vjp : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float) {
bb0(%0 : $Float):
  return undef : $(Float, @callee_guaranteed (Float) -> Float)
}

sil @test_form_diff_func : $@convention(thin) () -> @owned @differentiable(reverse) @callee_guaranteed (Float) -> Float {
bb0:
  %orig = function_ref @f : $@convention(thin) (Float) -> Float
  %origThick = thin_to_thick_function %orig : $@convention(thin) (Float) -> Float to $@callee_guaranteed (Float) -> Float
  %jvp = function_ref @f_jvp : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)
  %jvpThick = thin_to_thick_function %jvp : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float) to $@callee_guaranteed (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)
  %vjp = function_ref @f_vjp : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)
  %vjpThick = thin_to_thick_function %vjp : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float) to $@callee_guaranteed (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)
  %result = differentiable_function [parameters 0] [results 0] %origThick : $@callee_guaranteed (Float) -> Float with_derivative {%jvpThick : $@callee_guaranteed (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float), %vjpThick : $@callee_guaranteed (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)}
  return %result : $@differentiable(reverse) @callee_guaranteed (Float) -> Float
}

// CHECK-LABEL: define{{.*}}test_form_diff_func(ptr noalias nocapture sret(<{ %swift.function, %swift.function, %swift.function }>)
// CHECK-SAME: [[OUT:%.*]])
// CHECK:   [[OUT_ORIG:%.*]] = getelementptr{{.*}}[[OUT]], i32 0, i32 0
// CHECK:   [[OUT_ORIG_FN:%.*]] = getelementptr{{.*}}[[OUT_ORIG]], i32 0, i32 0
// CHECK:   store{{.*}}@f{{.*}}[[OUT_ORIG_FN]]
// CHECK:   [[OUT_ORIG_DATA:%.*]] = getelementptr{{.*}} [[OUT_ORIG]], i32 0, i32 1
// CHECK:   store{{.*}}null{{.*}} [[OUT_ORIG_DATA]]
// CHECK:   [[OUT_JVP:%.*]] = getelementptr{{.*}}[[OUT]], i32 0, i32 1
// CHECK:   [[OUT_JVP_FN:%.*]] = getelementptr{{.*}}[[OUT_JVP]], i32 0, i32 0
// CHECK:   store{{.*}}@f_jvp{{.*}}[[OUT_JVP_FN]]
// CHECK:   [[OUT_JVP_DATA:%.*]] = getelementptr{{.*}}[[OUT_JVP]], i32 0, i32 1
// CHECK:   store{{.*}} null{{.*}}[[OUT_JVP_DATA]]
// CHECK:   [[OUT_VJP:%.*]] = getelementptr{{.*}}[[OUT]], i32 0, i32 2
// CHECK:   [[OUT_VJP_FN:%.*]] = getelementptr{{.*}}[[OUT_VJP]], i32 0, i32 0
// CHECK:   store{{.*}}@f_vjp{{.*}}[[OUT_VJP_FN]]
// CHECK:   [[OUT_VJP_DATA:%.*]] = getelementptr{{.*}}[[OUT_VJP]], i32 0, i32 1
// CHECK:   store{{.*}}null{{.*}}[[OUT_VJP_DATA]]

sil @test_extract_components : $@convention(thin) (@guaranteed @differentiable(reverse) @callee_guaranteed (Float) -> Float) -> (@owned @callee_guaranteed (Float) -> Float, @owned @callee_guaranteed (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float), @owned @callee_guaranteed (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)) {
bb0(%0 : $@differentiable(reverse) @callee_guaranteed (Float) -> Float):
  %orig = differentiable_function_extract [original] %0 : $@differentiable(reverse) @callee_guaranteed (Float) -> Float
  %jvp = differentiable_function_extract [jvp] %0 : $@differentiable(reverse) @callee_guaranteed (Float) -> Float
  %vjp = differentiable_function_extract [vjp] %0 : $@differentiable(reverse) @callee_guaranteed (Float) -> Float
  %result = tuple (%orig : $@callee_guaranteed (Float) -> Float, %jvp : $@callee_guaranteed (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float), %vjp : $@callee_guaranteed (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float))
  return %result : $(@callee_guaranteed (Float) -> Float, @callee_guaranteed (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float), @callee_guaranteed (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float))
}

// CHECK-LABEL: define{{.*}}@test_extract_components(ptr noalias nocapture sret(<{ %swift.function, %swift.function, %swift.function }>)
// CHECK-SAME: [[OUT:%.*]], ptr{{.*}}[[IN:%.*]])
// CHECK:   [[ORIG:%.*]] = getelementptr{{.*}}[[IN]], i32 0, i32 0
// CHECK:   [[ORIG_FN_ADDR:%.*]] = getelementptr{{.*}}[[ORIG]], i32 0, i32 0
// CHECK:   [[ORIG_FN:%.*]] = load{{.*}}[[ORIG_FN_ADDR]]
// CHECK:   [[ORIG_DATA_ADDR:%.*]] = getelementptr{{.*}}[[ORIG]], i32 0, i32 1
// CHECK:   [[ORIG_DATA:%.*]] = load{{.*}}[[ORIG_DATA_ADDR]]
// CHECK:   [[JVP:%.*]] = getelementptr{{.*}}[[IN]], i32 0, i32 1
// CHECK:   [[JVP_FN_ADDR:%.*]] = getelementptr{{.*}}[[JVP]], i32 0, i32 0
// CHECK:   [[JVP_FN:%.*]] = load{{.*}}[[JVP_FN_ADDR]]
// CHECK:   [[JVP_DATA_ADDR:%.*]] = getelementptr{{.*}}[[JVP]], i32 0, i32 1
// CHECK:   [[JVP_DATA:%.*]] = load{{.*}}[[JVP_DATA_ADDR]]
// CHECK:   [[VJP:%.*]] = getelementptr{{.*}}[[IN]], i32 0, i32 2
// CHECK:   [[VJP_FN_ADDR:%.*]] = getelementptr{{.*}}[[VJP]], i32 0, i32 0
// CHECK:   [[VJP_FN:%.*]] = load{{.*}}[[VJP_FN_ADDR]]
// CHECK:   [[VJP_DATA_ADDR:%.*]] = getelementptr{{.*}}[[VJP]], i32 0, i32 1
// CHECK:   [[VJP_DATA:%.*]] = load{{.*}}[[VJP_DATA_ADDR]]
// CHECK:   [[OUT_0:%.*]] = getelementptr{{.*}}[[OUT]], i32 0, i32 0
// CHECK:   [[OUT_0_FN:%.*]] = getelementptr{{.*}}[[OUT_0]], i32 0, i32 0
// CHECK:   store{{.*}}[[ORIG_FN]]{{.*}}[[OUT_0_FN]]
// CHECK:   [[OUT_0_DATA:%.*]] = getelementptr{{.*}}[[OUT_0]], i32 0, i32 1
// CHECK:   store{{.*}}[[ORIG_DATA]]{{.*}}[[OUT_0_DATA]]
// CHECK:   [[OUT_1:%.*]] = getelementptr{{.*}}[[OUT]], i32 0, i32 1
// CHECK:   [[OUT_1_FN:%.*]] = getelementptr{{.*}}[[OUT_1]], i32 0, i32 0
// CHECK:   store{{.*}}[[JVP_FN]]{{.*}}[[OUT_1_FN]]
// CHECK:   [[OUT_1_DATA:%.*]] = getelementptr{{.*}}[[OUT_1]], i32 0, i32 1
// CHECK:   store{{.*}}[[JVP_DATA]]{{.*}}[[OUT_1_DATA]]
// CHECK:   [[OUT_2:%.*]] = getelementptr{{.*}}[[OUT]], i32 0, i32 2
// CHECK:   [[OUT_2_FN:%.*]] = getelementptr{{.*}}[[OUT_2]], i32 0, i32 0
// CHECK:   store{{.*}}[[VJP_FN]]{{.*}}[[OUT_2_FN]]
// CHECK:   [[OUT_2_DATA:%.*]] = getelementptr{{.*}}[[OUT_2]], i32 0, i32 1
// CHECK:   store{{.*}}[[VJP_DATA]]{{.*}}[[OUT_2_DATA]]

sil @test_call_diff_fn : $@convention(thin) (@guaranteed @differentiable(reverse) @callee_guaranteed (Float) -> Float, Float) -> Float {
bb0(%0 : $@differentiable(reverse) @callee_guaranteed (Float) -> Float, %1 : $Float):
  %2 = apply %0(%1) : $@differentiable(reverse) @callee_guaranteed (Float) -> Float
  return %2 : $Float
}

// CHECK-LABEL: define{{.*}}@test_call_diff_fn(ptr
// CHECK-SAME: [[DIFF_FN:%.*]], float [[INPUT_FLOAT:%.*]])
// CHECK:   [[ORIG:%.*]] = getelementptr{{.*}}[[DIFF_FN]], i32 0, i32 0
// CHECK:   [[ORIG_FN_ADDR:%.*]] = getelementptr{{.*}}[[ORIG]], i32 0, i32 0
// CHECK:   [[ORIG_FN:%.*]] = load{{.*}}[[ORIG_FN_ADDR]]
// CHECK:   [[ORIG_DATA_ADDR:%.*]] = getelementptr{{.*}}[[ORIG]], i32 0, i32 1
// CHECK:   [[ORIG_DATA:%.*]] = load{{.*}}[[ORIG_DATA_ADDR]]
// CHECK:   [[RESULT:%.*]] = call swiftcc float [[ORIG_FN]](float [[INPUT_FLOAT]], ptr swiftself [[ORIG_DATA]])
// CHECK:   ret float [[RESULT]]

sil @test_convert_escape_to_noescape : $@convention(thin) (@guaranteed @differentiable(reverse) @callee_guaranteed (Float) -> Float) -> () {
bb0(%0 : $@differentiable(reverse) @callee_guaranteed (Float) -> Float):
  %1 = convert_escape_to_noescape %0 : $@differentiable(reverse) @callee_guaranteed (Float) -> Float to $@noescape @differentiable(reverse) @callee_guaranteed (Float) -> Float
  return undef : $()
}