1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99
|
//===--- Differentiable.swift ---------------------------------*- swift -*-===//
//
// This source file is part of the Swift Numerics open source project
//
// Copyright (c) 2019 - 2020 Apple Inc. and the Swift Numerics project authors
// Licensed under Apache License v2.0 with Runtime Library Exception
//
// See https://swift.org/LICENSE.txt for license information
//
//===----------------------------------------------------------------------===//
#if swift(>=5.3) && canImport(_Differentiation)
import _Differentiation
extension Complex: Differentiable
where RealType: Differentiable, RealType.TangentVector == RealType {
public typealias TangentVector = Self
@inlinable
public var zeroTangentVectorInitializer: () -> Self {
{ Complex.zero }
}
}
extension Complex
where RealType: Differentiable, RealType.TangentVector == RealType {
@derivative(of: init(_:_:))
@usableFromInline
static func _derivativeInit(
_ real: RealType,
_ imaginary: RealType
) -> (value: Complex, pullback: (Complex) -> (RealType, RealType)) {
(value: .init(real, imaginary), pullback: { v in
(v.real, v.imaginary)
})
}
@derivative(of: real)
@usableFromInline
func _derivativeReal() -> (value: RealType, pullback: (RealType) -> Complex) {
(value: real, pullback: { v in
Complex(v, .zero)
})
}
@derivative(of: imaginary)
@usableFromInline
func _derivativeImaginary() -> (
value: RealType,
pullback: (RealType) -> Complex
) {
(value: real, pullback: { v in
Complex(.zero, v)
})
}
@derivative(of: +)
@usableFromInline
static func _derivativeAdd(lhs: Complex, rhs: Complex)
-> (value: Complex, pullback: (Complex) -> (Complex, Complex))
{
(lhs + rhs, { v in (v, v) })
}
@derivative(of: -)
@usableFromInline
static func _derivativeSubtract(lhs: Complex, rhs: Complex)
-> (value: Complex, pullback: (Complex) -> (Complex, Complex))
{
(lhs - rhs, { v in (v, -v) })
}
@derivative(of: *)
@usableFromInline
static func _derivativeMultiply(lhs: Complex, rhs: Complex)
-> (value: Complex, pullback: (Complex) -> (Complex, Complex))
{
(lhs * rhs, { v in (rhs * v, lhs * v) })
}
@derivative(of: /)
@usableFromInline
static func _derivativeDivide(lhs: Complex, rhs: Complex)
-> (value: Complex, pullback: (Complex) -> (Complex, Complex))
{
(lhs / rhs, { v in (v / rhs, -lhs / (rhs * rhs) * v) })
}
@derivative(of: conjugate)
@usableFromInline
func _derivativeConjugate() -> (
value: Complex,
pullback: (Complex) -> Complex
) {
(conjugate, { v in v.conjugate })
}
}
#endif
|