File: scalar.rs

package info (click to toggle)
rustc 1.88.0%2Bdfsg1-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 934,128 kB
  • sloc: xml: 158,127; python: 36,062; javascript: 19,855; sh: 19,700; cpp: 18,947; ansic: 12,993; asm: 4,792; makefile: 690; lisp: 29; perl: 29; ruby: 19; sql: 11
file content (33 lines) | stat: -rw-r--r-- 947 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
//@ compile-flags: -Zautodiff=Enable -C opt-level=3  -Clto=fat
//@ no-prefer-dynamic
//@ needs-enzyme
#![feature(autodiff)]

use std::autodiff::autodiff;

#[autodiff(d_square, Reverse, Duplicated, Active)]
#[no_mangle]
fn square(x: &f64) -> f64 {
    x * x
}

// CHECK:define internal fastcc double @diffesquare(double %x.0.val, ptr nocapture nonnull align 8 %"x'"
// CHECK-NEXT:invertstart:
// CHECK-NEXT:  %_0 = fmul double %x.0.val, %x.0.val
// CHECK-NEXT:  %0 = fadd fast double %x.0.val, %x.0.val
// CHECK-NEXT:  %1 = load double, ptr %"x'", align 8
// CHECK-NEXT:  %2 = fadd fast double %1, %0
// CHECK-NEXT:  store double %2, ptr %"x'", align 8
// CHECK-NEXT:  ret double %_0
// CHECK-NEXT:}

fn main() {
    let x = std::hint::black_box(3.0);
    let output = square(&x);
    assert_eq!(9.0, output);

    let mut df_dx = 0.0;
    let output_ = d_square(&x, &mut df_dx, 1.0);
    assert_eq!(output, output_);
    assert_eq!(6.0, df_dx);
}