1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186
|
// Pullback generation tests written in SIL for features
// that may not be directly supported by the Swift frontend
// RUN: %target-sil-opt --differentiation -emit-sorted-sil %s 2>&1 | %FileCheck %s
//===----------------------------------------------------------------------===//
// Pullback generation - `struct_extract`
// - Input to pullback has non-owned ownership semantics which requires copying
// this value to stack before lifetime-ending uses.
//===----------------------------------------------------------------------===//
sil_stage raw
import Builtin
import Swift
import SwiftShims
import _Differentiation
struct X {
@_hasStorage var a: Float { get set }
@_hasStorage var b: String { get set }
init(a: Float, b: String)
}
extension X : Differentiable, Equatable, AdditiveArithmetic {
public typealias TangentVector = X
mutating func move(by offset: X)
public static var zero: X { get }
public static func + (lhs: X, rhs: X) -> X
public static func - (lhs: X, rhs: X) -> X
@_implements(Equatable, ==(_:_:)) static func __derived_struct_equals(_ a: X, _ b: X) -> Bool
}
struct Y {
@_hasStorage var a: X { get set }
@_hasStorage var b: String { get set }
init(a: X, b: String)
}
extension Y : Differentiable, Equatable, AdditiveArithmetic {
public typealias TangentVector = Y
mutating func move(by offset: Y)
public static var zero: Y { get }
public static func + (lhs: Y, rhs: Y) -> Y
public static func - (lhs: Y, rhs: Y) -> Y
@_implements(Equatable, ==(_:_:)) static func __derived_struct_equals(_ a: Y, _ b: Y) -> Bool
}
sil_differentiability_witness hidden [reverse] [parameters 0] [results 0] @$function_with_struct_extract_1 : $@convention(thin) (@guaranteed Y) -> @owned X {
}
sil hidden [ossa] @$function_with_struct_extract_1 : $@convention(thin) (@guaranteed Y) -> @owned X {
bb0(%0 : @guaranteed $Y):
%1 = struct_extract %0 : $Y, #Y.a
%2 = copy_value %1 : $X
return %2 : $X
}
// CHECK-LABEL: sil private [ossa] @$function_with_struct_extract_1TJpSpSr : $@convention(thin) (@guaranteed X) -> @owned Y {
// CHECK: bb0(%0 : @guaranteed $X):
// CHECK: %1 = alloc_stack $Y
// CHECK: %2 = witness_method $Y, #AdditiveArithmetic.zero!getter : <Self where Self : AdditiveArithmetic> (Self.Type) -> () -> Self : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@thick τ_0_0.Type) -> @out τ_0_0
// CHECK: %3 = metatype $@thick Y.Type
// CHECK: %4 = apply %2<Y>(%1, %3) : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@thick τ_0_0.Type) -> @out τ_0_0
// CHECK: %5 = struct_element_addr %1 : $*Y, #Y.a
// Since input parameter $0 has non-owned ownership semantics, it
// needs to be copied before a lifetime-ending use.
// CHECK: %6 = copy_value %0 : $X
// CHECK: %7 = alloc_stack $X
// CHECK: store %6 to [init] %7 : $*X
// CHECK: %9 = witness_method $X, #AdditiveArithmetic."+=" : <Self where Self : AdditiveArithmetic> (Self.Type) -> (inout Self, Self) -> () : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@inout τ_0_0, @in_guaranteed τ_0_0, @thick τ_0_0.Type) -> ()
// CHECK: %10 = metatype $@thick X.Type
// CHECK: %11 = apply %9<X>(%5, %7, %10) : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@inout τ_0_0, @in_guaranteed τ_0_0, @thick τ_0_0.Type) -> ()
// CHECK: destroy_addr %7 : $*X
// CHECK: dealloc_stack %7 : $*X
// CHECK: %14 = load [take] %1 : $*Y
// CHECK: dealloc_stack %1 : $*Y
// CHECK: %16 = copy_value %14 : $Y
// CHECK: destroy_value %14 : $Y
// CHECK: return %16 : $Y
// CHECK: } // end sil function '$function_with_struct_extract_1TJpSpSr'
//===----------------------------------------------------------------------===//
// Pullback generation - `tuple_extract`
// - Tuples as differentiable input arguments are not supported yet, so creating
// a basic test in SIL instead.
//===----------------------------------------------------------------------===//
sil_differentiability_witness hidden [reverse] [parameters 0] [results 0] @function_with_tuple_extract_1: $@convention(thin) ((Float, Float)) -> Float {
}
sil hidden [ossa] @function_with_tuple_extract_1: $@convention(thin) ((Float, Float)) -> Float {
bb0(%0 : $(Float, Float)):
%1 = tuple_extract %0 : $(Float, Float), 0
return %1 : $Float
}
// CHECK-LABEL: sil private [ossa] @function_with_tuple_extract_1TJpSpSr : $@convention(thin) (Float) -> (Float, Float) {
// CHECK: bb0(%0 : $Float):
// CHECK: %1 = alloc_stack $(Float, Float)
// CHECK: %2 = tuple_element_addr %1 : $*(Float, Float), 0
// CHECK: %3 = witness_method $Float, #AdditiveArithmetic.zero!getter : <Self where Self : AdditiveArithmetic> (Self.Type) -> () -> Self : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@thick τ_0_0.Type) -> @out τ_0_0
// CHECK: %4 = metatype $@thick Float.Type
// CHECK: %5 = apply %3<Float>(%2, %4) : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@thick τ_0_0.Type) -> @out τ_0_0
// CHECK: %6 = tuple_element_addr %1 : $*(Float, Float), 1
// CHECK: %7 = witness_method $Float, #AdditiveArithmetic.zero!getter : <Self where Self : AdditiveArithmetic> (Self.Type) -> () -> Self : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@thick τ_0_0.Type) -> @out τ_0_0
// CHECK: %8 = metatype $@thick Float.Type
// CHECK: %9 = apply %7<Float>(%6, %8) : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@thick τ_0_0.Type) -> @out τ_0_0
// CHECK: %10 = tuple_element_addr %1 : $*(Float, Float), 0
// CHECK: %11 = alloc_stack $Float
// CHECK: store %0 to [trivial] %11 : $*Float
// CHECK: %13 = witness_method $Float, #AdditiveArithmetic."+=" : <Self where Self : AdditiveArithmetic> (Self.Type) -> (inout Self, Self) -> () : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@inout τ_0_0, @in_guaranteed τ_0_0, @thick τ_0_0.Type) -> ()
// CHECK: %14 = metatype $@thick Float.Type
// CHECK: %15 = apply %13<Float>(%10, %11, %14) : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@inout τ_0_0, @in_guaranteed τ_0_0, @thick τ_0_0.Type) -> ()
// CHECK: destroy_addr %11 : $*Float
// CHECK: dealloc_stack %11 : $*Float
// CHECK: %18 = load [trivial] %1 : $*(Float, Float)
// CHECK: dealloc_stack %1 : $*(Float, Float)
// CHECK: return %18 : $(Float, Float)
// CHECK: } // end sil function 'function_with_tuple_extract_1TJpSpSr'
//===----------------------------------------------------------------------===//
// Pullback generation - Inner values of concrete adjoints must be copied
// during direct materialization.
// - If the input to pullback BB has non-owned ownership semantics we cannot
// perform a lifetime-ending operation on it.
// - If the input to the pullback BB is an owned, non-trivial value we must
// copy it or there will be a double consume when all owned parameters are
// destroyed at the end of the basic block.
//===----------------------------------------------------------------------===//
sil_differentiability_witness hidden [reverse] [parameters 0] [results 0] @function_with_tuple_extract_2: $@convention(thin) (@guaranteed (X, X)) -> @owned X {
}
sil hidden [ossa] @function_with_tuple_extract_2: $@convention(thin) (@guaranteed (X, X)) -> @owned X {
bb0(%0 : @guaranteed $(X, X)):
%1 = tuple_extract %0 : $(X, X), 0
%2 = copy_value %1: $X
return %2 : $X
}
// CHECK-LABEL: sil private [ossa] @function_with_tuple_extract_2TJpSpSr : $@convention(thin) (@guaranteed X) -> @owned (X, X) {
// CHECK: bb0(%0 : @guaranteed $X):
// CHECK: %1 = alloc_stack $(X, X)
// CHECK: %2 = tuple_element_addr %1 : $*(X, X), 0
// CHECK: %3 = witness_method $X, #AdditiveArithmetic.zero!getter : <Self where Self : AdditiveArithmetic> (Self.Type) -> () -> Self : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@thick τ_0_0.Type) -> @out τ_0_0
// CHECK: %4 = metatype $@thick X.Type
// CHECK: %5 = apply %3<X>(%2, %4) : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@thick τ_0_0.Type) -> @out τ_0_0
// CHECK: %6 = tuple_element_addr %1 : $*(X, X), 1
// CHECK: %7 = witness_method $X, #AdditiveArithmetic.zero!getter : <Self where Self : AdditiveArithmetic> (Self.Type) -> () -> Self : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@thick τ_0_0.Type) -> @out τ_0_0
// CHECK: %8 = metatype $@thick X.Type
// CHECK: %9 = apply %7<X>(%6, %8) : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@thick τ_0_0.Type) -> @out τ_0_0
// CHECK: %10 = tuple_element_addr %1 : $*(X, X), 0
// CHECK: %11 = copy_value %0 : $X
// CHECK: %12 = alloc_stack $X
// CHECK: store %11 to [init] %12 : $*X
// CHECK: %14 = witness_method $X, #AdditiveArithmetic."+=" : <Self where Self : AdditiveArithmetic> (Self.Type) -> (inout Self, Self) -> () : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@inout τ_0_0, @in_guaranteed τ_0_0, @thick τ_0_0.Type) -> ()
// CHECK: %15 = metatype $@thick X.Type
// CHECK: %16 = apply %14<X>(%10, %12, %15) : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@inout τ_0_0, @in_guaranteed τ_0_0, @thick τ_0_0.Type) -> ()
// CHECK: destroy_addr %12 : $*X
// CHECK: dealloc_stack %12 : $*X
// CHECK: %19 = load [take] %1 : $*(X, X)
// CHECK: dealloc_stack %1 : $*(X, X)
// CHECK: %21 = copy_value %19 : $(X, X)
// CHECK: destroy_value %19 : $(X, X)
// CHECK: return %21 : $(X, X)
// CHECK: } // end sil function 'function_with_tuple_extract_2TJpSpSr'
//===----------------------------------------------------------------------===//
// Pullback generation - `tuple_extract`
// - Adjoint of extracted element can be `AddElement`
// - Just need to make sure that we are able to generate a pullback
//===----------------------------------------------------------------------===//
sil_differentiability_witness hidden [reverse] [parameters 0] [results 0] @function_with_tuple_extract_3: $@convention(thin) (((Float, Float), Float)) -> Float {
}
sil hidden [ossa] @function_with_tuple_extract_3: $@convention(thin) (((Float, Float), Float)) -> Float {
bb0(%0 : $((Float, Float), Float)):
%1 = tuple_extract %0 : $((Float, Float), Float), 0
%2 = tuple_extract %1 : $(Float, Float), 0
return %2 : $Float
}
// CHECK-LABEL: sil private [ossa] @function_with_tuple_extract_3TJpSpSr : $@convention(thin) (Float) -> ((Float, Float), Float) {
|