1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251
|
// RUN: %target-swift-frontend -emit-silgen %s | %target-sil-opt | %FileCheck %s
// Test SIL differentiability witness SIL generation.
import _Differentiation
// Dummy `Differentiable`-conforming type.
public struct DummyTangentVector: Differentiable & AdditiveArithmetic {
public static var zero: Self { Self() }
public static func + (_: Self, _: Self) -> Self { Self() }
public static func - (_: Self, _: Self) -> Self { Self() }
public typealias TangentVector = Self
}
// Test public non-generic function.
// SIL differentiability witness:
// - Has public linkage (implicit).
// - Has no `where` clause.
public func foo(_ x: Float) -> Float { x }
@derivative(of: foo)
public func foo_jvp(_ x: Float) -> (value: Float, differential: (Float) -> Float) {
(x, { $0 })
}
@derivative(of: foo)
public func foo_vjp(_ x: Float) -> (value: Float, pullback: (Float) -> Float) {
(x, { $0 })
}
// CHECK-LABEL: // differentiability witness for foo(_:)
// CHECK-NEXT: sil_differentiability_witness [serialized] [reverse] [parameters 0] [results 0] @$s29sil_differentiability_witness3fooyS2fF : $@convention(thin) (Float) -> Float {
// CHECK-NEXT: jvp: @$s29sil_differentiability_witness3fooyS2fFTJfSpSr : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)
// CHECK-NEXT: vjp: @$s29sil_differentiability_witness3fooyS2fFTJrSpSr : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)
// CHECK-NEXT: }
// Test internal non-generic function.
// SIL differentiability witness:
// - Has hidden linkage.
// - Has no `where` clause.
// - Has only VJP.
func bar<T>(_ x: Float, _ y: T) -> Float { x }
@derivative(of: bar)
func bar_jvp<T>(_ x: Float, _ y: T) -> (value: Float, differential: (Float) -> Float) {
(x, { $0 })
}
// CHECK-LABEL: // differentiability witness for bar<A>(_:_:)
// CHECK-NEXT: sil_differentiability_witness hidden [reverse] [parameters 0] [results 0] <τ_0_0> @$s29sil_differentiability_witness3baryS2f_xtlF : $@convention(thin) <T> (Float, @in_guaranteed T) -> Float {
// CHECK-NEXT: jvp: @$s29sil_differentiability_witness3baryS2f_xtlFlTJfSUpSr : $@convention(thin) <τ_0_0> (Float, @in_guaranteed τ_0_0) -> (Float, @owned @callee_guaranteed (Float) -> Float)
// CHECK-NEXT: }
// Test internal generic function.
// SIL differentiability witness:
// - Has hidden linkage.
// - Has `where` clause.
func generic<T>(_ x: T, _ y: Float) -> T { x }
@derivative(of: generic)
func generic_jvp<T: Differentiable>(_ x: T, _ y: Float) -> (
value: T, differential: (T.TangentVector, Float) -> T.TangentVector
) {
(x, { dx, dy in dx })
}
@derivative(of: generic)
func generic_vjp<T: Differentiable>(_ x: T, _ y: Float) -> (
value: T, pullback: (T.TangentVector) -> (T.TangentVector, Float)
) {
(x, { ($0, .zero) })
}
// CHECK-LABEL: // differentiability witness for generic<A>(_:_:)
// CHECK-NEXT: sil_differentiability_witness hidden [reverse] [parameters 0 1] [results 0] <τ_0_0 where τ_0_0 : Differentiable> @$s29sil_differentiability_witness7genericyxx_SftlF : $@convention(thin) <T> (@in_guaranteed T, Float) -> @out T {
// CHECK-NEXT: jvp: @$s29sil_differentiability_witness7genericyxx_SftlF16_Differentiation14DifferentiableRzlTJfSSpSr : $@convention(thin) <τ_0_0 where τ_0_0 : Differentiable> (@in_guaranteed τ_0_0, Float) -> (@out τ_0_0, @owned @callee_guaranteed @substituted <τ_0_0, τ_0_1> (@in_guaranteed τ_0_0, Float) -> @out τ_0_1 for <τ_0_0.TangentVector, τ_0_0.TangentVector>)
// CHECK-NEXT: vjp: @$s29sil_differentiability_witness7genericyxx_SftlF16_Differentiation14DifferentiableRzlTJrSSpSr : $@convention(thin) <τ_0_0 where τ_0_0 : Differentiable> (@in_guaranteed τ_0_0, Float) -> (@out τ_0_0, @owned @callee_guaranteed @substituted <τ_0_0, τ_0_1> (@in_guaranteed τ_0_0) -> (@out τ_0_1, Float) for <τ_0_0.TangentVector, τ_0_0.TangentVector>)
// CHECK-NEXT: }
public struct Foo: Differentiable {
public typealias TangentVector = DummyTangentVector
public mutating func move(by _: TangentVector) {}
@differentiable(reverse)
public var x: Float
// CHECK-LABEL: // differentiability witness for Foo.x.getter
// CHECK-NEXT: sil_differentiability_witness [serialized] [reverse] [parameters 0] [results 0] @$s29sil_differentiability_witness3FooV1xSfvg : $@convention(method) (Foo) -> Float {
// CHECK-NEXT: }
@differentiable(reverse)
public init(_ x: Float) {
self.x = x
}
// CHECK-LABEL: // differentiability witness for Foo.init(_:)
// CHECK-NEXT: sil_differentiability_witness [serialized] [reverse] [parameters 0] [results 0] @$s29sil_differentiability_witness3FooVyACSfcfC : $@convention(method) (Float, @thin Foo.Type) -> Foo {
// CHECK-NEXT: }
@differentiable(reverse)
public func method() -> Float {
x
}
// CHECK-LABEL: // differentiability witness for Foo.method()
// CHECK-NEXT: sil_differentiability_witness [serialized] [reverse] [parameters 0] [results 0] @$s29sil_differentiability_witness3FooV6methodSfyF : $@convention(method) (Foo) -> Float {
// CHECK-NEXT: }
@differentiable(reverse)
public var computedProperty: Float {
x
}
// CHECK-LABEL: // differentiability witness for Foo.computedProperty.getter
// CHECK-NEXT: sil_differentiability_witness [serialized] [reverse] [parameters 0] [results 0] @$s29sil_differentiability_witness3FooV16computedPropertySfvg : $@convention(method) (Foo) -> Float {
// CHECK-NEXT: }
@differentiable(reverse)
public subscript() -> Float {
x
}
// CHECK-LABEL: // differentiability witness for Foo.subscript.getter
// CHECK-NEXT: sil_differentiability_witness [serialized] [reverse] [parameters 0] [results 0] @$s29sil_differentiability_witness3FooVSfycig : $@convention(method) (Foo) -> Float {
// CHECK-NEXT: }
}
// Test function that is differentiable wrt subset of its parameters:
// - wrt x: explicit @differentiable(reverse) attribute, with no custom derivative specified
// - wrt y: explicit @differentiable(reverse) attribute, with custom derivative specified
// - wrt x, y: custom derivative specified, with no explicit @differentiable(reverse) attribute
// Has a tuple argument to verify that indices are correctly lowered to SIL.
@differentiable(reverse, wrt: x)
public func wrt_subset(_ tup: (Int, Int), _ x: Float, _ y: Float) -> Float {
return 0
}
@derivative(of: wrt_subset, wrt: y)
public func wrt_subset_jvp_wrt_y(_ tup: (Int, Int), _ x: Float, _ y: Float) -> (value: Float, differential: (Float) -> Float) {
return (0, { $0 })
}
@derivative(of: wrt_subset, wrt: y)
public func wrt_subset_vjp_wrt_y(_ tup: (Int, Int), _ x: Float, _ y: Float) -> (value: Float, pullback: (Float) -> Float) {
return (0, { $0 })
}
@derivative(of: wrt_subset)
public func wrt_subset_jvp_wrt_x_y(_ tup: (Int, Int), _ x: Float, _ y: Float) -> (value: Float, differential: (Float, Float) -> Float) {
return (0, { $0 + $1 })
}
@derivative(of: wrt_subset)
public func wrt_subset_vjp_wrt_x_y(_ tup: (Int, Int), _ x: Float, _ y: Float) -> (value: Float, pullback: (Float) -> (Float, Float)) {
return (0, { ($0, $0) })
}
// CHECK-LABEL: // differentiability witness for wrt_subset(_:_:_:)
// CHECK-NEXT: sil_differentiability_witness [serialized] [reverse] [parameters 2] [results 0] @$s29sil_differentiability_witness10wrt_subsetySfSi_Sit_S2ftF : $@convention(thin) (Int, Int, Float, Float) -> Float {
// CHECK-NEXT: }
// CHECK-LABEL: // differentiability witness for wrt_subset(_:_:_:)
// CHECK-NEXT: sil_differentiability_witness [serialized] [reverse] [parameters 3] [results 0] @$s29sil_differentiability_witness10wrt_subsetySfSi_Sit_S2ftF : $@convention(thin) (Int, Int, Float, Float) -> Float {
// CHECK-NEXT: jvp:
// CHECK-NEXT: vjp:
// CHECK-NEXT: }
// CHECK-LABEL: // differentiability witness for wrt_subset(_:_:_:)
// CHECK-NEXT: sil_differentiability_witness [serialized] [reverse] [parameters 2 3] [results 0] @$s29sil_differentiability_witness10wrt_subsetySfSi_Sit_S2ftF : $@convention(thin) (Int, Int, Float, Float) -> Float {
// CHECK-NEXT: jvp:
// CHECK-NEXT: vjp:
// CHECK-NEXT: }
// Test original function with `@differentiable` and `@derivative` attributes.
protocol P1: Differentiable {}
extension P1 {
@differentiable(reverse) // derivative generic signature: none
func foo() -> Float { 1 }
}
extension P1 {
@derivative(of: foo) // derivative generic signature: `<P1 where Self: P1>`
func vjpFoo() -> (value: Float, pullback: (Float) -> (TangentVector)) {
fatalError()
}
}
// CHECK-LABEL: // differentiability witness for P1.foo()
// CHECK-NEXT: sil_differentiability_witness hidden [reverse] [parameters 0] [results 0] <τ_0_0 where τ_0_0 : P1> @$s29sil_differentiability_witness2P1PAAE3fooSfyF : $@convention(method) <Self where Self : P1> (@in_guaranteed Self) -> Float {
// CHECK-NEXT: vjp: @$s29sil_differentiability_witness2P1PAAE3fooSfyFAaBRzlTJrSpSr : $@convention(method) <τ_0_0 where τ_0_0 : P1> (@in_guaranteed τ_0_0) -> (Float, @owned @callee_guaranteed @substituted <τ_0_0> (Float) -> @out τ_0_0 for <τ_0_0.TangentVector>)
// CHECK-NEXT: }
// Test custom derivatives of functions with generic signatures and `@differentiable` attributes.
@differentiable(reverse)
@_silgen_name("genericWithDiffAttr")
public func genericWithDiffAttr<T: Differentiable>(_ x: T) -> T { fatalError() }
@derivative(of: genericWithDiffAttr)
public func vjpGenericWithDiffAttr<T: Differentiable>(_ x: T)
-> (value: T, pullback: (T.TangentVector) -> T.TangentVector)
{
fatalError()
}
// CHECK-LABEL: // differentiability witness for genericWithDiffAttr
// CHECK-NEXT: sil_differentiability_witness [serialized] [reverse] [parameters 0] [results 0] <τ_0_0 where τ_0_0 : Differentiable> @genericWithDiffAttr : $@convention(thin) <T where T : Differentiable> (@in_guaranteed T) -> @out T {
// CHECK-NEXT: vjp
// CHECK-NEXT: }
// CHECK-NOT: // differentiability witness for genericWithDiffAttr
@differentiable(reverse where T: Differentiable)
@_silgen_name("genericWithConstrainedDifferentiable")
public func genericWithConstrainedDifferentiable<T>(_ x: T) -> T { fatalError() }
@derivative(of: genericWithConstrainedDifferentiable)
public func vjpGenericWithConstrainedDifferentiable<T: Differentiable>(_ x: T)
-> (value: T, pullback: (T.TangentVector) -> T.TangentVector)
{
fatalError()
}
// CHECK-LABEL: // differentiability witness for genericWithConstrainedDifferentiable
// CHECK-NEXT: sil_differentiability_witness [serialized] [reverse] [parameters 0] [results 0] <τ_0_0 where τ_0_0 : Differentiable> @genericWithConstrainedDifferentiable : $@convention(thin) <T> (@in_guaranteed T) -> @out T {
// CHECK-NEXT: vjp
// CHECK-NEXT: }
// CHECK-NOT: // differentiability witness for genericWithConstrainedDifferentiable
public extension Differentiable {
@differentiable(reverse)
@_silgen_name("protocolExtensionWithDiffAttr")
func protocolExtensionWithDiffAttr() -> Self { self }
@derivative(of: protocolExtensionWithDiffAttr)
func protocolExtensionWithDiffAttr() -> (value: Self, pullback: (TangentVector) -> TangentVector) {
fatalError("unimplemented")
}
}
// CHECK-LABEL: // differentiability witness for protocolExtensionWithDiffAttr
// CHECK-NEXT: sil_differentiability_witness [serialized] [reverse] [parameters 0] [results 0] <τ_0_0 where τ_0_0 : Differentiable> @protocolExtensionWithDiffAttr : $@convention(method) <Self where Self : Differentiable> (@in_guaranteed Self) -> @out Self {
// CHECK-NEXT: vjp
// CHECK-NEXT: }
// CHECK-NOT: // differentiability witness for protocolExtensionWithDiffAttr
|