1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90
|
// RUN: %target-run-simple-swift
// REQUIRES: executable_test
import StdlibUnittest
import _Differentiation
var InoutControlFlowTests = TestSuite("InoutControlFlow")
// https://github.com/apple/swift/issues/55999
struct Model: Differentiable {
var first: Float = 3
var second: Float = 1
mutating func outer() {
inner()
}
mutating func inner() {
self.second = self.first
// Dummy no-op if block, required to introduce control flow.
let x = 5
if x < 50 {}
}
}
@differentiable(reverse)
func loss(model: Model) -> Float{
var model = model
model.outer()
return model.second
}
InoutControlFlowTests.test("MutatingBeforeControlFlow") {
var model = Model()
let grad = gradient(at: model, of: loss)
expectEqual(1, grad.first)
expectEqual(0, grad.second)
}
// https://github.com/apple/swift/issues/56444
@differentiable(reverse)
func adjust(model: inout Model, multiplier: Float) {
model.first = model.second * multiplier
// Dummy no-op if block, required to introduce control flow.
let x = 5
if x < 50 {}
}
@differentiable(reverse)
func loss2(model: Model, multiplier: Float) -> Float {
var model = model
adjust(model: &model, multiplier: multiplier)
return model.first
}
InoutControlFlowTests.test("InoutParameterWithControlFlow") {
var model = Model(first: 1, second: 3)
let grad = gradient(at: model, 5.0, of: loss2)
expectEqual(0, grad.0.first)
expectEqual(5, grad.0.second)
}
@differentiable(reverse)
func adjust2(multiplier: Float, model: inout Model) {
model.first = model.second * multiplier
// Dummy no-op if block, required to introduce control flow.
let x = 5
if x < 50 {}
}
@differentiable(reverse)
func loss3(model: Model, multiplier: Float) -> Float {
var model = model
adjust2(multiplier: multiplier, model: &model)
return model.first
}
InoutControlFlowTests.test("LaterInoutParameterWithControlFlow") {
var model = Model(first: 1, second: 3)
let grad = gradient(at: model, 5.0, of: loss3)
expectEqual(0, grad.0.first)
expectEqual(5, grad.0.second)
}
runAllTests()
|