File: pullback_generation.sil

package info (click to toggle)
swiftlang 6.0.3-2
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 2,519,992 kB
  • sloc: cpp: 9,107,863; ansic: 2,040,022; asm: 1,135,751; python: 296,500; objc: 82,456; f90: 60,502; lisp: 34,951; pascal: 19,946; sh: 18,133; perl: 7,482; ml: 4,937; javascript: 4,117; makefile: 3,840; awk: 3,535; xml: 914; fortran: 619; cs: 573; ruby: 573
file content (186 lines) | stat: -rw-r--r-- 11,828 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
// Pullback generation tests written in SIL for features 
// that may not be directly supported by the Swift frontend

// RUN: %target-sil-opt --differentiation -emit-sorted-sil %s 2>&1 | %FileCheck %s

//===----------------------------------------------------------------------===//
// Pullback generation - `struct_extract`
// - Input to pullback has non-owned ownership semantics which requires copying
// this value to stack before lifetime-ending uses. 
//===----------------------------------------------------------------------===//

sil_stage raw

import Builtin
import Swift
import SwiftShims

import _Differentiation

struct X {
  @_hasStorage var a: Float { get set }
  @_hasStorage var b: String { get set }
  init(a: Float, b: String)
}

extension X : Differentiable, Equatable, AdditiveArithmetic {
  public typealias TangentVector = X
  mutating func move(by offset: X)
  public static var zero: X { get }
  public static func + (lhs: X, rhs: X) -> X
  public static func - (lhs: X, rhs: X) -> X
  @_implements(Equatable, ==(_:_:)) static func __derived_struct_equals(_ a: X, _ b: X) -> Bool
}

struct Y {
  @_hasStorage var a: X { get set }
  @_hasStorage var b: String { get set }
  init(a: X, b: String)
}

extension Y : Differentiable, Equatable, AdditiveArithmetic {
  public typealias TangentVector = Y
  mutating func move(by offset: Y)
  public static var zero: Y { get }
  public static func + (lhs: Y, rhs: Y) -> Y
  public static func - (lhs: Y, rhs: Y) -> Y
  @_implements(Equatable, ==(_:_:)) static func __derived_struct_equals(_ a: Y, _ b: Y) -> Bool
}

sil_differentiability_witness hidden [reverse] [parameters 0] [results 0] @$function_with_struct_extract_1 : $@convention(thin) (@guaranteed Y) -> @owned X {
}

sil hidden [ossa] @$function_with_struct_extract_1 : $@convention(thin) (@guaranteed Y) -> @owned X {
bb0(%0 : @guaranteed $Y):
  %1 = struct_extract %0 : $Y, #Y.a               
  %2 = copy_value %1 : $X                         
  return %2 : $X                                  
}

// CHECK-LABEL: sil private [ossa] @$function_with_struct_extract_1TJpSpSr : $@convention(thin) (@guaranteed X) -> @owned Y {
// CHECK: bb0(%0 : @guaranteed $X):
// CHECK:   %1 = alloc_stack $Y                             
// CHECK:   %2 = witness_method $Y, #AdditiveArithmetic.zero!getter : <Self where Self : AdditiveArithmetic> (Self.Type) -> () -> Self : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@thick τ_0_0.Type) -> @out τ_0_0 
// CHECK:   %3 = metatype $@thick Y.Type                    
// CHECK:   %4 = apply %2<Y>(%1, %3) : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@thick τ_0_0.Type) -> @out τ_0_0
// CHECK:   %5 = struct_element_addr %1 : $*Y, #Y.a         

// Since input parameter $0 has non-owned ownership semantics, it 
// needs to be copied before a lifetime-ending use.
// CHECK:   %6 = copy_value %0 : $X                         

// CHECK:   %7 = alloc_stack $X                             
// CHECK:   store %6 to [init] %7 : $*X                     
// CHECK:   %9 = witness_method $X, #AdditiveArithmetic."+=" : <Self where Self : AdditiveArithmetic> (Self.Type) -> (inout Self, Self) -> () : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@inout τ_0_0, @in_guaranteed τ_0_0, @thick τ_0_0.Type) -> () 
// CHECK:   %10 = metatype $@thick X.Type                   
// CHECK:   %11 = apply %9<X>(%5, %7, %10) : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@inout τ_0_0, @in_guaranteed τ_0_0, @thick τ_0_0.Type) -> ()
// CHECK:   destroy_addr %7 : $*X                           
// CHECK:   dealloc_stack %7 : $*X                          
// CHECK:   %14 = load [take] %1 : $*Y                      
// CHECK:   dealloc_stack %1 : $*Y                          
// CHECK:   %16 = copy_value %14 : $Y                       
// CHECK:   destroy_value %14 : $Y                          
// CHECK:   return %16 : $Y                                 
// CHECK: } // end sil function '$function_with_struct_extract_1TJpSpSr'

//===----------------------------------------------------------------------===//
// Pullback generation - `tuple_extract`
// - Tuples as differentiable input arguments are not supported yet, so creating
// a basic test in SIL instead.
//===----------------------------------------------------------------------===//

sil_differentiability_witness hidden [reverse] [parameters 0] [results 0] @function_with_tuple_extract_1: $@convention(thin) ((Float, Float)) -> Float {
}

sil hidden [ossa] @function_with_tuple_extract_1: $@convention(thin) ((Float, Float)) -> Float {
bb0(%0 : $(Float, Float)):
  %1 = tuple_extract %0 : $(Float, Float), 0
  return %1 : $Float
}


// CHECK-LABEL: sil private [ossa] @function_with_tuple_extract_1TJpSpSr : $@convention(thin) (Float) -> (Float, Float) {
// CHECK: bb0(%0 : $Float):
// CHECK:   %1 = alloc_stack $(Float, Float)                
// CHECK:   %2 = tuple_element_addr %1 : $*(Float, Float), 0 
// CHECK:   %3 = witness_method $Float, #AdditiveArithmetic.zero!getter : <Self where Self : AdditiveArithmetic> (Self.Type) -> () -> Self : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@thick τ_0_0.Type) -> @out τ_0_0 
// CHECK:   %4 = metatype $@thick Float.Type                
// CHECK:   %5 = apply %3<Float>(%2, %4) : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@thick τ_0_0.Type) -> @out τ_0_0
// CHECK:   %6 = tuple_element_addr %1 : $*(Float, Float), 1 
// CHECK:   %7 = witness_method $Float, #AdditiveArithmetic.zero!getter : <Self where Self : AdditiveArithmetic> (Self.Type) -> () -> Self : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@thick τ_0_0.Type) -> @out τ_0_0 
// CHECK:   %8 = metatype $@thick Float.Type                
// CHECK:   %9 = apply %7<Float>(%6, %8) : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@thick τ_0_0.Type) -> @out τ_0_0
// CHECK:   %10 = tuple_element_addr %1 : $*(Float, Float), 0 
// CHECK:   %11 = alloc_stack $Float                        
// CHECK:   store %0 to [trivial] %11 : $*Float             
// CHECK:   %13 = witness_method $Float, #AdditiveArithmetic."+=" : <Self where Self : AdditiveArithmetic> (Self.Type) -> (inout Self, Self) -> () : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@inout τ_0_0, @in_guaranteed τ_0_0, @thick τ_0_0.Type) -> () 
// CHECK:   %14 = metatype $@thick Float.Type               
// CHECK:   %15 = apply %13<Float>(%10, %11, %14) : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@inout τ_0_0, @in_guaranteed τ_0_0, @thick τ_0_0.Type) -> ()
// CHECK:   destroy_addr %11 : $*Float                      
// CHECK:   dealloc_stack %11 : $*Float                     
// CHECK:   %18 = load [trivial] %1 : $*(Float, Float)      
// CHECK:   dealloc_stack %1 : $*(Float, Float)             
// CHECK:   return %18 : $(Float, Float)                    
// CHECK: } // end sil function 'function_with_tuple_extract_1TJpSpSr'

//===----------------------------------------------------------------------===//
// Pullback generation - Inner values of concrete adjoints must be copied 
// during direct materialization. 
// - If the input to pullback BB has non-owned ownership semantics we cannot 
// perform a lifetime-ending operation on it.
// - If the input to the pullback BB is an owned, non-trivial value we must 
// copy it or there will be a double consume when all owned parameters are 
// destroyed at the end of the basic block.
//===----------------------------------------------------------------------===//
sil_differentiability_witness hidden [reverse] [parameters 0] [results 0] @function_with_tuple_extract_2: $@convention(thin) (@guaranteed (X, X)) -> @owned X {
}

sil hidden [ossa] @function_with_tuple_extract_2: $@convention(thin) (@guaranteed (X, X)) -> @owned X {
bb0(%0 : @guaranteed $(X, X)):
  %1 = tuple_extract %0 : $(X, X), 0
  %2 = copy_value %1: $X
  return %2 : $X
}

// CHECK-LABEL: sil private [ossa] @function_with_tuple_extract_2TJpSpSr : $@convention(thin) (@guaranteed X) -> @owned (X, X) {
// CHECK: bb0(%0 : @guaranteed $X):
// CHECK:   %1 = alloc_stack $(X, X)                        
// CHECK:   %2 = tuple_element_addr %1 : $*(X, X), 0        
// CHECK:   %3 = witness_method $X, #AdditiveArithmetic.zero!getter : <Self where Self : AdditiveArithmetic> (Self.Type) -> () -> Self : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@thick τ_0_0.Type) -> @out τ_0_0 
// CHECK:   %4 = metatype $@thick X.Type                    
// CHECK:   %5 = apply %3<X>(%2, %4) : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@thick τ_0_0.Type) -> @out τ_0_0
// CHECK:   %6 = tuple_element_addr %1 : $*(X, X), 1        
// CHECK:   %7 = witness_method $X, #AdditiveArithmetic.zero!getter : <Self where Self : AdditiveArithmetic> (Self.Type) -> () -> Self : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@thick τ_0_0.Type) -> @out τ_0_0 
// CHECK:   %8 = metatype $@thick X.Type                    
// CHECK:   %9 = apply %7<X>(%6, %8) : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@thick τ_0_0.Type) -> @out τ_0_0
// CHECK:   %10 = tuple_element_addr %1 : $*(X, X), 0       
// CHECK:   %11 = copy_value %0 : $X                        
// CHECK:   %12 = alloc_stack $X                            
// CHECK:   store %11 to [init] %12 : $*X                   
// CHECK:   %14 = witness_method $X, #AdditiveArithmetic."+=" : <Self where Self : AdditiveArithmetic> (Self.Type) -> (inout Self, Self) -> () : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@inout τ_0_0, @in_guaranteed τ_0_0, @thick τ_0_0.Type) -> () 
// CHECK:   %15 = metatype $@thick X.Type                   
// CHECK:   %16 = apply %14<X>(%10, %12, %15) : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@inout τ_0_0, @in_guaranteed τ_0_0, @thick τ_0_0.Type) -> ()
// CHECK:   destroy_addr %12 : $*X                          
// CHECK:   dealloc_stack %12 : $*X                         
// CHECK:   %19 = load [take] %1 : $*(X, X)                 
// CHECK:   dealloc_stack %1 : $*(X, X)                     
// CHECK:   %21 = copy_value %19 : $(X, X)                  
// CHECK:   destroy_value %19 : $(X, X)                     
// CHECK:   return %21 : $(X, X)                            
// CHECK: } // end sil function 'function_with_tuple_extract_2TJpSpSr'

//===----------------------------------------------------------------------===//
// Pullback generation - `tuple_extract`
// - Adjoint of extracted element can be `AddElement`
// - Just need to make sure that we are able to generate a pullback
//===----------------------------------------------------------------------===//
sil_differentiability_witness hidden [reverse] [parameters 0] [results 0] @function_with_tuple_extract_3: $@convention(thin) (((Float, Float), Float)) -> Float {
}

sil hidden [ossa] @function_with_tuple_extract_3: $@convention(thin) (((Float, Float), Float)) -> Float {
bb0(%0 : $((Float, Float), Float)):
  %1 = tuple_extract %0 : $((Float, Float), Float), 0
  %2 = tuple_extract %1 : $(Float, Float), 0
  return %2 : $Float
}
// CHECK-LABEL: sil private [ossa] @function_with_tuple_extract_3TJpSpSr : $@convention(thin) (Float) -> ((Float, Float), Float) {