File: custom_derivatives.swift

package info (click to toggle)
swiftlang 6.0.3-2
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 2,519,992 kB
  • sloc: cpp: 9,107,863; ansic: 2,040,022; asm: 1,135,751; python: 296,500; objc: 82,456; f90: 60,502; lisp: 34,951; pascal: 19,946; sh: 18,133; perl: 7,482; ml: 4,937; javascript: 4,117; makefile: 3,840; awk: 3,535; xml: 914; fortran: 619; cs: 573; ruby: 573
file content (60 lines) | stat: -rw-r--r-- 1,461 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
// RUN: %target-run-simple-swift
// REQUIRES: executable_test

import StdlibUnittest
#if canImport(Darwin)
  import Darwin.C
#elseif canImport(Glibc)
  import Glibc
#elseif canImport(Android)
  import Android
#elseif os(Windows)
  import CRT
#else
#error("Unsupported platform")
#endif
import DifferentiationUnittest

var CustomDerivativesTests = TestSuite("CustomDerivatives")

// Specify non-differentiable functions.
// These will be wrapped in `differentiableFunction` and tested.

func unary(_ x: Tracked<Float>) -> Tracked<Float> {
  var x = x
  x *= 2
  return x
}

func binary(_ x: Tracked<Float>, _ y: Tracked<Float>) -> Tracked<Float> {
  var x = x
  x *= y
  return x
}

CustomDerivativesTests.testWithLeakChecking("SumOfGradPieces") {
  var grad: Tracked<Float> = 0
  func addToGrad(_ x: inout Tracked<Float>) { grad += x }
  _ = gradient(at: 4) { (x: Tracked<Float>) in
    x.withDerivative(addToGrad)
      * x.withDerivative(addToGrad)
        * x.withDerivative(addToGrad)
  }
  expectEqual(48, grad)
}

CustomDerivativesTests.testWithLeakChecking("ModifyGradientOfSum") {
  expectEqual(30, gradient(at: 4) { (x: Tracked<Float>) in
    x.withDerivative { $0 *= 10 } + x.withDerivative { $0 *= 20 }
  })
}

CustomDerivativesTests.testWithLeakChecking("WithoutDerivative") {
  expectEqual(0, gradient(at: Tracked<Float>(4)) { x in
    withoutDerivative(at: x) { x in
      Tracked<Float>(sinf(x.value) + cosf(x.value))
    }
  })
}

runAllTests()