File: differentiability_witness_generic_signature.swift

package info (click to toggle)
swiftlang 6.0.3-2
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 2,519,992 kB
  • sloc: cpp: 9,107,863; ansic: 2,040,022; asm: 1,135,751; python: 296,500; objc: 82,456; f90: 60,502; lisp: 34,951; pascal: 19,946; sh: 18,133; perl: 7,482; ml: 4,937; javascript: 4,117; makefile: 3,840; awk: 3,535; xml: 914; fortran: 619; cs: 573; ruby: 573
file content (178 lines) | stat: -rw-r--r-- 9,205 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
// RUN: %target-swift-emit-silgen -verify -module-name main %s | %FileCheck %s
// RUN: %target-swift-emit-sil -verify -module-name main %s

// NOTE: SILParser crashes for SILGen round-trip
// (https://github.com/apple/swift/issues/54370).

// This file tests:
// - The "derivative generic signature" of `@differentiable` and `@derivative`
//   attributes.
// - The generic signature of lowered SIL differentiability witnesses.

// Context:
// - For `@differentiable` attributes: the derivative generic signature is
//   resolved from the original declaration's generic signature and additional
//   `where` clause requirements.
// - For `@derivative` attributes: the derivative generic signature is the
//   attributed declaration's generic signature.

import _Differentiation

//===----------------------------------------------------------------------===//
// Same-type requirements
//===----------------------------------------------------------------------===//

// Test original declaration with a generic signature and derivative generic
// signature where all generic parameters are concrete (i.e. bound to concrete
// types via same-type requirements).

struct AllConcrete<T>: Differentiable {}

extension AllConcrete {
  //   Original generic signature: `<T>`
  // Derivative generic signature: `<T where T == Float>`
  //    Witness generic signature: `<T where T == Float>`
  @_silgen_name("allconcrete_where_gensig_constrained")
  @differentiable(reverse where T == Float)
  func whereClauseGenericSignatureConstrained() -> AllConcrete {
    return self
  }
}
extension AllConcrete where T == Float {
  @derivative(of: whereClauseGenericSignatureConstrained)
  func jvpWhereClauseGenericSignatureConstrained() -> (
    value: AllConcrete, differential: (TangentVector) -> TangentVector
  ) {
    (whereClauseGenericSignatureConstrained(), { $0 })
  }
}

// CHECK-LABEL: // differentiability witness for allconcrete_where_gensig_constrained
// CHECK-NEXT: sil_differentiability_witness hidden [reverse] [parameters 0] [results 0] <T where T == Float> @allconcrete_where_gensig_constrained : $@convention(method) <T> (AllConcrete<T>) -> AllConcrete<T> {
// CHECK-NEXT:   jvp: @allconcrete_where_gensig_constrainedSfRszlTJfSpSr : $@convention(method) (AllConcrete<Float>) -> (AllConcrete<Float>, @owned @callee_guaranteed (AllConcrete<Float>.TangentVector) -> AllConcrete<Float>.TangentVector)
// CHECK-NEXT: }

// If a `@differentiable` or `@derivative` attribute satisfies two conditions:
// 1. The derivative generic signature is equal to the original generic signature.
// 2. The derivative generic signature has *all concrete* generic parameters.
//
// Then the attribute should be lowered to a SIL differentiability witness with
// *no* derivative generic signature.

extension AllConcrete where T == Float {
  //   Original generic signature: `<T where T == Float>`
  // Derivative generic signature: `<T where T == Float>`
  //    Witness generic signature: none
  @_silgen_name("allconcrete_original_gensig")
  @differentiable(reverse)
  func originalGenericSignature() -> AllConcrete {
    return self
  }

  @derivative(of: originalGenericSignature)
  func jvpOriginalGenericSignature() -> (
    value: AllConcrete, differential: (TangentVector) -> TangentVector
  ) {
    (originalGenericSignature(), { $0 })
  }

// CHECK-LABEL: // differentiability witness for allconcrete_original_gensig
// CHECK-NEXT: sil_differentiability_witness hidden [reverse] [parameters 0] [results 0] @allconcrete_original_gensig : $@convention(method) (AllConcrete<Float>) -> AllConcrete<Float> {
// CHECK-NEXT:   jvp: @allconcrete_original_gensigTJfSpSr : $@convention(method) (AllConcrete<Float>) -> (AllConcrete<Float>, @owned @callee_guaranteed (AllConcrete<Float>.TangentVector) -> AllConcrete<Float>.TangentVector)
// CHECK-NEXT: }

  //   Original generic signature: `<T where T == Float>`
  // Derivative generic signature: `<T where T == Float>` (explicit `where` clause)
  //    Witness generic signature: none
  @_silgen_name("allconcrete_where_gensig")
  @differentiable(reverse where T == Float)
  func whereClauseGenericSignature() -> AllConcrete {
    return self
  }

  @derivative(of: whereClauseGenericSignature)
  func jvpWhereClauseGenericSignature() -> (
    value: AllConcrete, differential: (TangentVector) -> TangentVector
  ) {
    (whereClauseGenericSignature(), { $0 })
  }

// CHECK-LABEL: // differentiability witness for allconcrete_where_gensig
// CHECK-NEXT: sil_differentiability_witness hidden [reverse] [parameters 0] [results 0] @allconcrete_where_gensig : $@convention(method) (AllConcrete<Float>) -> AllConcrete<Float> {
// CHECK-NEXT:   jvp: @allconcrete_where_gensigTJfSpSr : $@convention(method) (AllConcrete<Float>) -> (AllConcrete<Float>, @owned @callee_guaranteed (AllConcrete<Float>.TangentVector) -> AllConcrete<Float>.TangentVector)
// CHECK-NEXT: }
}

// Test original declaration with a generic signature and derivative generic
// signature where *not* all generic parameters are concrete.
// types via same-type requirements).

struct NotAllConcrete<T, U>: Differentiable {}

extension NotAllConcrete {
  //   Original generic signature: `<T, U>`
  // Derivative generic signature: `<T, U where T == Float>`
  //    Witness generic signature: `<T, U where T == Float>` (not all concrete)
  @_silgen_name("notallconcrete_where_gensig_constrained")
  @differentiable(reverse where T == Float)
  func whereClauseGenericSignatureConstrained() -> NotAllConcrete {
    return self
  }
}
extension NotAllConcrete where T == Float {
  @derivative(of: whereClauseGenericSignatureConstrained)
  func jvpWhereClauseGenericSignatureConstrained() -> (
    value: NotAllConcrete, differential: (TangentVector) -> TangentVector
  ) {
    (whereClauseGenericSignatureConstrained(), { $0 })
  }
}

// CHECK-LABEL: // differentiability witness for notallconcrete_where_gensig_constrained
// CHECK-NEXT: sil_differentiability_witness hidden [reverse] [parameters 0] [results 0] <T, U where T == Float> @notallconcrete_where_gensig_constrained : $@convention(method) <T, U> (NotAllConcrete<T, U>) -> NotAllConcrete<T, U> {
// CHECK-NEXT:   jvp: @notallconcrete_where_gensig_constrainedSfRszr0_lTJfSpSr : $@convention(method) <τ_0_0, τ_0_1 where τ_0_0 == Float> (NotAllConcrete<Float, τ_0_1>) -> (NotAllConcrete<Float, τ_0_1>, @owned @callee_guaranteed @substituted <τ_0_0, τ_0_1> (τ_0_0) -> τ_0_1 for <NotAllConcrete<Float, τ_0_1>.TangentVector, NotAllConcrete<Float, τ_0_1>.TangentVector>)
// CHECK-NEXT: }

extension NotAllConcrete where T == Float {
  //   Original generic signature: `<T, U where T == Float>`
  // Derivative generic signature: `<T, U where T == Float>`
  //    Witness generic signature: `<T, U where T == Float>` (not all concrete)
  @_silgen_name("notallconcrete_original_gensig")
  @differentiable(reverse)
  func originalGenericSignature() -> NotAllConcrete {
    return self
  }

  @derivative(of: originalGenericSignature)
  func jvpOriginalGenericSignature() -> (
    value: NotAllConcrete, differential: (TangentVector) -> TangentVector
  ) {
    (originalGenericSignature(), { $0 })
  }

// CHECK-LABEL: // differentiability witness for notallconcrete_original_gensig
// CHECK-NEXT: sil_differentiability_witness hidden [reverse] [parameters 0] [results 0] <T, U where T == Float> @notallconcrete_original_gensig : $@convention(method) <T, U where T == Float> (NotAllConcrete<Float, U>) -> NotAllConcrete<Float, U> {
// CHECK-NEXT:   jvp: @notallconcrete_original_gensigSfRszr0_lTJfSpSr : $@convention(method) <τ_0_0, τ_0_1 where τ_0_0 == Float> (NotAllConcrete<Float, τ_0_1>) -> (NotAllConcrete<Float, τ_0_1>, @owned @callee_guaranteed @substituted <τ_0_0, τ_0_1> (τ_0_0) -> τ_0_1 for <NotAllConcrete<Float, τ_0_1>.TangentVector, NotAllConcrete<Float, τ_0_1>.TangentVector>)
// CHECK-NEXT: }

  //   Original generic signature: `<T, U where T == Float>`
  // Derivative generic signature: `<T, U where T == Float>` (explicit `where` clause)
  //    Witness generic signature: `<T, U where T == Float>` (not all concrete)
  @_silgen_name("notallconcrete_where_gensig")
  @differentiable(reverse where T == Float)
  func whereClauseGenericSignature() -> NotAllConcrete {
    return self
  }

  @derivative(of: whereClauseGenericSignature)
  func jvpWhereClauseGenericSignature() -> (
    value: NotAllConcrete, differential: (TangentVector) -> TangentVector
  ) {
    (whereClauseGenericSignature(), { $0 })
  }

// CHECK-LABEL: // differentiability witness for notallconcrete_where_gensig
// CHECK-NEXT: sil_differentiability_witness hidden [reverse] [parameters 0] [results 0] <T, U where T == Float> @notallconcrete_where_gensig : $@convention(method) <T, U where T == Float> (NotAllConcrete<Float, U>) -> NotAllConcrete<Float, U> {
// CHECK-NEXT:   jvp: @notallconcrete_where_gensigSfRszr0_lTJfSpSr : $@convention(method) <τ_0_0, τ_0_1 where τ_0_0 == Float> (NotAllConcrete<Float, τ_0_1>) -> (NotAllConcrete<Float, τ_0_1>, @owned @callee_guaranteed @substituted <τ_0_0, τ_0_1> (τ_0_0) -> τ_0_1 for <NotAllConcrete<Float, τ_0_1>.TangentVector, NotAllConcrete<Float, τ_0_1>.TangentVector>)
// CHECK-NEXT: }
}