1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178
|
// RUN: %target-swift-emit-silgen -verify -module-name main %s | %FileCheck %s
// RUN: %target-swift-emit-sil -verify -module-name main %s
// NOTE: SILParser crashes for SILGen round-trip
// (https://github.com/apple/swift/issues/54370).
// This file tests:
// - The "derivative generic signature" of `@differentiable` and `@derivative`
// attributes.
// - The generic signature of lowered SIL differentiability witnesses.
// Context:
// - For `@differentiable` attributes: the derivative generic signature is
// resolved from the original declaration's generic signature and additional
// `where` clause requirements.
// - For `@derivative` attributes: the derivative generic signature is the
// attributed declaration's generic signature.
import _Differentiation
//===----------------------------------------------------------------------===//
// Same-type requirements
//===----------------------------------------------------------------------===//
// Test original declaration with a generic signature and derivative generic
// signature where all generic parameters are concrete (i.e. bound to concrete
// types via same-type requirements).
struct AllConcrete<T>: Differentiable {}
extension AllConcrete {
// Original generic signature: `<T>`
// Derivative generic signature: `<T where T == Float>`
// Witness generic signature: `<T where T == Float>`
@_silgen_name("allconcrete_where_gensig_constrained")
@differentiable(reverse where T == Float)
func whereClauseGenericSignatureConstrained() -> AllConcrete {
return self
}
}
extension AllConcrete where T == Float {
@derivative(of: whereClauseGenericSignatureConstrained)
func jvpWhereClauseGenericSignatureConstrained() -> (
value: AllConcrete, differential: (TangentVector) -> TangentVector
) {
(whereClauseGenericSignatureConstrained(), { $0 })
}
}
// CHECK-LABEL: // differentiability witness for allconcrete_where_gensig_constrained
// CHECK-NEXT: sil_differentiability_witness hidden [reverse] [parameters 0] [results 0] <T where T == Float> @allconcrete_where_gensig_constrained : $@convention(method) <T> (AllConcrete<T>) -> AllConcrete<T> {
// CHECK-NEXT: jvp: @allconcrete_where_gensig_constrainedSfRszlTJfSpSr : $@convention(method) (AllConcrete<Float>) -> (AllConcrete<Float>, @owned @callee_guaranteed (AllConcrete<Float>.TangentVector) -> AllConcrete<Float>.TangentVector)
// CHECK-NEXT: }
// If a `@differentiable` or `@derivative` attribute satisfies two conditions:
// 1. The derivative generic signature is equal to the original generic signature.
// 2. The derivative generic signature has *all concrete* generic parameters.
//
// Then the attribute should be lowered to a SIL differentiability witness with
// *no* derivative generic signature.
extension AllConcrete where T == Float {
// Original generic signature: `<T where T == Float>`
// Derivative generic signature: `<T where T == Float>`
// Witness generic signature: none
@_silgen_name("allconcrete_original_gensig")
@differentiable(reverse)
func originalGenericSignature() -> AllConcrete {
return self
}
@derivative(of: originalGenericSignature)
func jvpOriginalGenericSignature() -> (
value: AllConcrete, differential: (TangentVector) -> TangentVector
) {
(originalGenericSignature(), { $0 })
}
// CHECK-LABEL: // differentiability witness for allconcrete_original_gensig
// CHECK-NEXT: sil_differentiability_witness hidden [reverse] [parameters 0] [results 0] @allconcrete_original_gensig : $@convention(method) (AllConcrete<Float>) -> AllConcrete<Float> {
// CHECK-NEXT: jvp: @allconcrete_original_gensigTJfSpSr : $@convention(method) (AllConcrete<Float>) -> (AllConcrete<Float>, @owned @callee_guaranteed (AllConcrete<Float>.TangentVector) -> AllConcrete<Float>.TangentVector)
// CHECK-NEXT: }
// Original generic signature: `<T where T == Float>`
// Derivative generic signature: `<T where T == Float>` (explicit `where` clause)
// Witness generic signature: none
@_silgen_name("allconcrete_where_gensig")
@differentiable(reverse where T == Float)
func whereClauseGenericSignature() -> AllConcrete {
return self
}
@derivative(of: whereClauseGenericSignature)
func jvpWhereClauseGenericSignature() -> (
value: AllConcrete, differential: (TangentVector) -> TangentVector
) {
(whereClauseGenericSignature(), { $0 })
}
// CHECK-LABEL: // differentiability witness for allconcrete_where_gensig
// CHECK-NEXT: sil_differentiability_witness hidden [reverse] [parameters 0] [results 0] @allconcrete_where_gensig : $@convention(method) (AllConcrete<Float>) -> AllConcrete<Float> {
// CHECK-NEXT: jvp: @allconcrete_where_gensigTJfSpSr : $@convention(method) (AllConcrete<Float>) -> (AllConcrete<Float>, @owned @callee_guaranteed (AllConcrete<Float>.TangentVector) -> AllConcrete<Float>.TangentVector)
// CHECK-NEXT: }
}
// Test original declaration with a generic signature and derivative generic
// signature where *not* all generic parameters are concrete.
// types via same-type requirements).
struct NotAllConcrete<T, U>: Differentiable {}
extension NotAllConcrete {
// Original generic signature: `<T, U>`
// Derivative generic signature: `<T, U where T == Float>`
// Witness generic signature: `<T, U where T == Float>` (not all concrete)
@_silgen_name("notallconcrete_where_gensig_constrained")
@differentiable(reverse where T == Float)
func whereClauseGenericSignatureConstrained() -> NotAllConcrete {
return self
}
}
extension NotAllConcrete where T == Float {
@derivative(of: whereClauseGenericSignatureConstrained)
func jvpWhereClauseGenericSignatureConstrained() -> (
value: NotAllConcrete, differential: (TangentVector) -> TangentVector
) {
(whereClauseGenericSignatureConstrained(), { $0 })
}
}
// CHECK-LABEL: // differentiability witness for notallconcrete_where_gensig_constrained
// CHECK-NEXT: sil_differentiability_witness hidden [reverse] [parameters 0] [results 0] <T, U where T == Float> @notallconcrete_where_gensig_constrained : $@convention(method) <T, U> (NotAllConcrete<T, U>) -> NotAllConcrete<T, U> {
// CHECK-NEXT: jvp: @notallconcrete_where_gensig_constrainedSfRszr0_lTJfSpSr : $@convention(method) <τ_0_0, τ_0_1 where τ_0_0 == Float> (NotAllConcrete<Float, τ_0_1>) -> (NotAllConcrete<Float, τ_0_1>, @owned @callee_guaranteed @substituted <τ_0_0, τ_0_1> (τ_0_0) -> τ_0_1 for <NotAllConcrete<Float, τ_0_1>.TangentVector, NotAllConcrete<Float, τ_0_1>.TangentVector>)
// CHECK-NEXT: }
extension NotAllConcrete where T == Float {
// Original generic signature: `<T, U where T == Float>`
// Derivative generic signature: `<T, U where T == Float>`
// Witness generic signature: `<T, U where T == Float>` (not all concrete)
@_silgen_name("notallconcrete_original_gensig")
@differentiable(reverse)
func originalGenericSignature() -> NotAllConcrete {
return self
}
@derivative(of: originalGenericSignature)
func jvpOriginalGenericSignature() -> (
value: NotAllConcrete, differential: (TangentVector) -> TangentVector
) {
(originalGenericSignature(), { $0 })
}
// CHECK-LABEL: // differentiability witness for notallconcrete_original_gensig
// CHECK-NEXT: sil_differentiability_witness hidden [reverse] [parameters 0] [results 0] <T, U where T == Float> @notallconcrete_original_gensig : $@convention(method) <T, U where T == Float> (NotAllConcrete<Float, U>) -> NotAllConcrete<Float, U> {
// CHECK-NEXT: jvp: @notallconcrete_original_gensigSfRszr0_lTJfSpSr : $@convention(method) <τ_0_0, τ_0_1 where τ_0_0 == Float> (NotAllConcrete<Float, τ_0_1>) -> (NotAllConcrete<Float, τ_0_1>, @owned @callee_guaranteed @substituted <τ_0_0, τ_0_1> (τ_0_0) -> τ_0_1 for <NotAllConcrete<Float, τ_0_1>.TangentVector, NotAllConcrete<Float, τ_0_1>.TangentVector>)
// CHECK-NEXT: }
// Original generic signature: `<T, U where T == Float>`
// Derivative generic signature: `<T, U where T == Float>` (explicit `where` clause)
// Witness generic signature: `<T, U where T == Float>` (not all concrete)
@_silgen_name("notallconcrete_where_gensig")
@differentiable(reverse where T == Float)
func whereClauseGenericSignature() -> NotAllConcrete {
return self
}
@derivative(of: whereClauseGenericSignature)
func jvpWhereClauseGenericSignature() -> (
value: NotAllConcrete, differential: (TangentVector) -> TangentVector
) {
(whereClauseGenericSignature(), { $0 })
}
// CHECK-LABEL: // differentiability witness for notallconcrete_where_gensig
// CHECK-NEXT: sil_differentiability_witness hidden [reverse] [parameters 0] [results 0] <T, U where T == Float> @notallconcrete_where_gensig : $@convention(method) <T, U where T == Float> (NotAllConcrete<Float, U>) -> NotAllConcrete<Float, U> {
// CHECK-NEXT: jvp: @notallconcrete_where_gensigSfRszr0_lTJfSpSr : $@convention(method) <τ_0_0, τ_0_1 where τ_0_0 == Float> (NotAllConcrete<Float, τ_0_1>) -> (NotAllConcrete<Float, τ_0_1>, @owned @callee_guaranteed @substituted <τ_0_0, τ_0_1> (τ_0_0) -> τ_0_1 for <NotAllConcrete<Float, τ_0_1>.TangentVector, NotAllConcrete<Float, τ_0_1>.TangentVector>)
// CHECK-NEXT: }
}
|