File: derivative_sil.swift

package info (click to toggle)
swiftlang 6.0.3-2
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 2,519,992 kB
  • sloc: cpp: 9,107,863; ansic: 2,040,022; asm: 1,135,751; python: 296,500; objc: 82,456; f90: 60,502; lisp: 34,951; pascal: 19,946; sh: 18,133; perl: 7,482; ml: 4,937; javascript: 4,117; makefile: 3,840; awk: 3,535; xml: 914; fortran: 619; cs: 573; ruby: 573
file content (113 lines) | stat: -rw-r--r-- 7,458 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
// RUN: %target-swift-frontend -emit-sil -enable-experimental-forward-mode-differentiation -verify -Xllvm -sil-print-after=differentiation -o /dev/null 2>&1 %s | %FileCheck %s -check-prefix=CHECK-SIL
// REQUIRES: asserts

// Simple generated derivative code FileCheck tests.

import _Differentiation

extension Float {
  @_silgen_name("add")
  static func add(_ x: Float, _ y: Float) -> Float {
    return x + y
  }

  @derivative(of: add)
  static func addVJP(_ x: Float, _ y: Float) -> (
    value: Float, pullback: (Float) -> (Float, Float)
  ) {
    return (add(x, y), { v in (v, v) })
  }
}

@_silgen_name("foo")
@differentiable(reverse)
func foo(_ x: Float) -> Float {
  let y = Float.add(x, x)
  return y
}

// CHECK-SIL-LABEL: enum _AD__foo_bb0__Pred__src_0_wrt_0 {
// CHECK-SIL-NEXT:  }

// CHECK-SIL-LABEL: enum _AD__fooMethod_bb0__Pred__src_0_wrt_0 {
// CHECK-SIL-NEXT:  }

// CHECK-SIL-LABEL: sil hidden [ossa] @fooTJfSpSr : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float) {
// CHECK-SIL: bb0([[X:%.*]] : $Float):
// CHECK-SIL:   [[ADD_ORIG_REF:%.*]] = function_ref @add : $@convention(method) (Float, Float, @thin Float.Type) -> Float
// CHECK-SIL:   [[ADD_JVP_REF:%.*]] = differentiability_witness_function [jvp] [reverse] [parameters 0 1] [results 0] @add
// CHECK-SIL:   [[ADD_VJP_REF:%.*]] = differentiability_witness_function [vjp] [reverse] [parameters 0 1] [results 0] @add
// CHECK-SIL:   [[ADD_DIFF_FN:%.*]] = differentiable_function [parameters 0 1] [results 0] [[ADD_ORIG_REF]] : $@convention(method) (Float, Float, @thin Float.Type) -> Float with_derivative {[[ADD_JVP_REF]] : $@convention(method) (Float, Float, @thin Float.Type) -> (Float, @owned @callee_guaranteed (Float, Float) -> Float), [[ADD_VJP_REF]] : $@convention(method) (Float, Float, @thin Float.Type) -> (Float, @owned @callee_guaranteed (Float) -> (Float, Float))}
// CHECK-SIL:   [[ADD_JVP_FN:%.*]] = differentiable_function_extract [jvp] [[ADD_DIFF_FN]]
// CHECK-SIL:   [[ADD_RESULT:%.*]] = apply [[ADD_JVP_FN]]([[X]], [[X]], {{.*}})
// CHECK-SIL:   ([[ORIG_RES:%.*]], [[ADD_DF:%.*]]) = destructure_tuple [[ADD_RESULT]]
// CHECK-SIL:   [[DF_STRUCT:%.*]] = tuple ([[ADD_DF]] : $@callee_guaranteed (Float, Float) -> Float)
// CHECK-SIL:   [[DF_REF:%.*]] = function_ref @fooTJdSpSr : $@convention(thin) (Float, @owned (_: @callee_guaranteed (Float, Float) -> Float)) -> Float
// CHECK-SIL:   [[DF_FN:%.*]] = partial_apply [callee_guaranteed] [[DF_REF]]([[DF_STRUCT]])
// CHECK-SIL:   [[VJP_RESULT:%.*]] = tuple ([[ORIG_RES]] : $Float, [[DF_FN]] : $@callee_guaranteed (Float) -> Float)
// CHECK-SIL:   return [[VJP_RESULT]] : $(Float, @callee_guaranteed (Float) -> Float)
// CHECK-SIL: }

// CHECK-SIL-LABEL: sil private [ossa] @fooTJdSpSr : $@convention(thin) (Float, @owned (_: @callee_guaranteed (Float, Float) -> Float)) -> Float {
// CHECK-SIL: bb0([[DX:%.*]] : $Float, [[DF_STRUCT:%.*]] : @owned $(_: @callee_guaranteed (Float, Float) -> Float)):
// CHECK-SIL:   [[ADD_DF:%.*]] = destructure_tuple [[DF_STRUCT]] : $(_: @callee_guaranteed (Float, Float) -> Float)
// CHECK-SIL:   [[DY:%.*]] = apply [[ADD_DF]]([[DX]], [[DX]]) : $@callee_guaranteed (Float, Float) -> Float
// CHECK-SIL:   destroy_value [[ADD_DF]] : $@callee_guaranteed (Float, Float) -> Float
// CHECK-SIL:   return [[DY]] : $Float
// CHECK-SIL: }

// CHECK-SIL-LABEL: sil hidden [ossa] @fooTJrSpSr : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float) {
// CHECK-SIL: bb0([[X:%.*]] : $Float):
// CHECK-SIL:   [[ADD_ORIG_REF:%.*]] = function_ref @add : $@convention(method) (Float, Float, @thin Float.Type) -> Float
// CHECK-SIL:   [[ADD_JVP_REF:%.*]] = differentiability_witness_function [jvp] [reverse] [parameters 0 1] [results 0] @add
// CHECK-SIL:   [[ADD_VJP_REF:%.*]] = differentiability_witness_function [vjp] [reverse] [parameters 0 1] [results 0] @add
// CHECK-SIL:   [[ADD_DIFF_FN:%.*]] = differentiable_function [parameters 0 1] [results 0] [[ADD_ORIG_REF]] : $@convention(method) (Float, Float, @thin Float.Type) -> Float with_derivative {[[ADD_JVP_REF]] : $@convention(method) (Float, Float, @thin Float.Type) -> (Float, @owned @callee_guaranteed (Float, Float) -> Float), [[ADD_VJP_REF]] : $@convention(method) (Float, Float, @thin Float.Type) -> (Float, @owned @callee_guaranteed (Float) -> (Float, Float))}
// CHECK-SIL:   [[ADD_VJP_FN:%.*]] = differentiable_function_extract [vjp] [[ADD_DIFF_FN]]
// CHECK-SIL:   [[ADD_RESULT:%.*]] = apply [[ADD_VJP_FN]]([[X]], [[X]], {{.*}})
// CHECK-SIL:   ([[ORIG_RES:%.*]], [[ADD_PB:%.*]]) = destructure_tuple [[ADD_RESULT]]
// CHECK-SIL:   [[PB_REF:%.*]] = function_ref @fooTJpSpSr : $@convention(thin) (Float, @owned @callee_guaranteed (Float) -> (Float, Float)) -> Float
// CHECK-SIL:   [[PB_FN:%.*]] = partial_apply [callee_guaranteed] [[PB_REF]]([[ADD_PB]])
// CHECK-SIL:   [[VJP_RESULT:%.*]] = tuple ([[ORIG_RES]] : $Float, [[PB_FN]] : $@callee_guaranteed (Float) -> Float)
// CHECK-SIL:   return [[VJP_RESULT]] : $(Float, @callee_guaranteed (Float) -> Float)
// CHECK-SIL: }

// CHECK-SIL-LABEL: sil private [ossa] @fooTJpSpSr : $@convention(thin) (Float, @owned @callee_guaranteed (Float) -> (Float, Float)) -> Float {
// CHECK-SIL: bb0([[DY:%.*]] : $Float, [[ADD_PB:%.*]] : @owned $@callee_guaranteed (Float) -> (Float, Float)):
// CHECK-SIL:   debug_value [[DY]] : $Float, let, name "y"
// CHECK-SIL:   [[ADD_PB_RES:%.*]] = apply [[ADD_PB]]([[DY]]) : $@callee_guaranteed (Float) -> (Float, Float)
// CHECK-SIL:   ([[DX_1:%.*]], [[DX_2:%.*]]) = destructure_tuple [[ADD_PB_RES]] : $(Float, Float)
// CHECK-SIL:   [[TMP_BUF_RES:%.*]] = alloc_stack $Float
// CHECK-SIL:   [[TMP_BUF_LHS:%.*]] = alloc_stack $Float
// CHECK-SIL:   [[TMP_BUF_RHS:%.*]] = alloc_stack $Float
// CHECK-SIL:   store [[DX_1]] to [trivial] [[TMP_BUF_LHS]] : $*Float
// CHECK-SIL:   store [[DX_2]] to [trivial] [[TMP_BUF_RHS]] : $*Float
// CHECK-SIL:   [[PLUS_FN:%.*]] = witness_method $Float, #AdditiveArithmetic."+"
// CHECK-SIL:   apply [[PLUS_FN]]<Float>([[TMP_BUF_RES]], [[TMP_BUF_RHS]], [[TMP_BUF_LHS]], {{.*}})
// CHECK-SIL:   destroy_addr [[TMP_BUF_LHS]] : $*Float
// CHECK-SIL:   destroy_addr [[TMP_BUF_RHS]] : $*Float
// CHECK-SIL:   dealloc_stack [[TMP_BUF_RHS]] : $*Float
// CHECK-SIL:   dealloc_stack [[TMP_BUF_LHS]] : $*Float
// CHECK-SIL:   [[DX:%.*]] = load [trivial] [[TMP_BUF_RES]] : $*Float
// CHECK-SIL:   dealloc_stack [[TMP_BUF_RES]] : $*Float
// CHECK-SIL:   debug_value [[DX]] : $Float, let, name "x", argno 1
// CHECK-SIL:   return [[DX]] : $Float
// CHECK-SIL: }

// https://github.com/apple/swift/issues/56342
// Check the conventions of the generated functions for a method.
struct ExampleStruct {
  @_silgen_name("fooMethod")
  @differentiable(reverse)
  func fooMethod(_ x: Float) -> Float {
    let y = Float.add(x, x)
    return y
  }
}

// CHECK-SIL-LABEL: sil hidden [ossa] @fooMethodTJfSUpSr  : $@convention(method) (Float, ExampleStruct) -> (Float, @owned @callee_guaranteed (Float) -> Float) {

// CHECK-SIL-LABEL: sil private [ossa] @fooMethodTJdSUpSr : $@convention(thin) (Float, @owned (_: @callee_guaranteed (Float, Float) -> Float)) -> Float {

// CHECK-SIL-LABEL: sil hidden [ossa] @fooMethodTJrSUpSr : $@convention(method) (Float, ExampleStruct) -> (Float, @owned @callee_guaranteed (Float) -> Float) {

// CHECK-SIL-LABEL: sil private [ossa] @fooMethodTJpSUpSr : $@convention(thin) (Float, @owned @callee_guaranteed (Float) -> (Float, Float)) -> Float {