File: differential_operators.swift.gyb

package info (click to toggle)
swiftlang 6.0.3-2
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 2,519,992 kB
  • sloc: cpp: 9,107,863; ansic: 2,040,022; asm: 1,135,751; python: 296,500; objc: 82,456; f90: 60,502; lisp: 34,951; pascal: 19,946; sh: 18,133; perl: 7,482; ml: 4,937; javascript: 4,117; makefile: 3,840; awk: 3,535; xml: 914; fortran: 619; cs: 573; ruby: 573
file content (64 lines) | stat: -rw-r--r-- 2,227 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
// RUN: %empty-directory(%t)
// RUN: %gyb %s -o %t/differential_operators.swift
// RUN: %target-build-swift %t/differential_operators.swift -o %t/differential_operators
// RUN: %target-codesign %t/differential_operators
// RUN: %target-run %t/differential_operators
// REQUIRES: executable_test

import _Differentiation

import StdlibUnittest

var DifferentialOperatorTestSuite = TestSuite("DifferentialOperator")

% for arity in range(1, 3 + 1):

% params = ', '.join(['_ x%d: Float' % i for i in range(arity)])
% pb_return_type = '(' + ', '.join(['Float' for _ in range(arity)]) + ')'
func exampleDiffFunc_${arity}(${params}) -> Float {
    fatalError()
}
@derivative(of: exampleDiffFunc_${arity})
func exampleVJP_${arity}(${params}) -> (value: Float, pullback: (Float) -> ${pb_return_type}) {
  (
    ${' + '.join(['x%d * x%d' % (i, i) for i in range(arity)])},
    { (${', '.join(['2 * x%d * $0' % i for i in range(arity)])}) }
  )
}

% argValues = [i * 10 for i in range(1, arity + 1)]
% args = ', '.join([str(v) for v in argValues])
% expectedValue = sum([v * v for v in argValues])
% expectedGradientValues = [2 * v for v in argValues]
% expectedGradients = '(' + ', '.join([str(g) for g in expectedGradientValues]) + ')'

DifferentialOperatorTestSuite.test("valueWithPullback_${arity}") {
  let (value, pb) = valueWithPullback(at: ${args}, of: exampleDiffFunc_${arity})
  expectEqual(${expectedValue}, value)
  expectEqual(${expectedGradients}, pb(1))
}

DifferentialOperatorTestSuite.test("pullback_${arity}") {
  let pb = pullback(at: ${args}, of: exampleDiffFunc_${arity})
  expectEqual(${expectedGradients}, pb(1))
}

DifferentialOperatorTestSuite.test("gradient_${arity}") {
  let grad = gradient(at: ${args}, of: exampleDiffFunc_${arity})
  expectEqual(${expectedGradients}, grad)
}

DifferentialOperatorTestSuite.test("valueWithGradient_${arity}") {
  let (value, grad) = valueWithGradient(at: ${args}, of: exampleDiffFunc_${arity})
  expectEqual(${expectedValue}, value)
  expectEqual(${expectedGradients}, grad)
}

DifferentialOperatorTestSuite.test("gradient_curried_${arity}") {
  let gradF = gradient(of: exampleDiffFunc_${arity})
  expectEqual(${expectedGradients}, gradF(${args}))
}

% end

runAllTests()