File: identical_fnc.rs

package info (click to toggle)
rustc 1.88.0%2Bdfsg1-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 934,128 kB
  • sloc: xml: 158,127; python: 36,062; javascript: 19,855; sh: 19,700; cpp: 18,947; ansic: 12,993; asm: 4,792; makefile: 690; lisp: 29; perl: 29; ruby: 19; sql: 11
file content (45 lines) | stat: -rw-r--r-- 1,903 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
//@ compile-flags: -Zautodiff=Enable -C opt-level=3  -Clto=fat
//@ no-prefer-dynamic
//@ needs-enzyme
//
// Each autodiff invocation creates a new placeholder function, which we will replace on llvm-ir
// level. If a user tries to differentiate two identical functions within the same compilation unit,
// then LLVM might merge them in release mode before AD. In that case we can't rewrite one of the
// merged placeholder function anymore, and compilation would fail. We prevent this by disabling
// LLVM's merge_function pass before AD. Here we implicetely test that our solution keeps working.
// We also explicetly test that we keep running merge_function after AD, by checking for two
// identical function calls in the LLVM-IR, while having two different calls in the Rust code.
#![feature(autodiff)]

use std::autodiff::autodiff;

#[autodiff(d_square, Reverse, Duplicated, Active)]
fn square(x: &f64) -> f64 {
    x * x
}

#[autodiff(d_square2, Reverse, Duplicated, Active)]
fn square2(x: &f64) -> f64 {
    x * x
}

// CHECK:; identical_fnc::main
// CHECK-NEXT:; Function Attrs:
// CHECK-NEXT:define internal void @_ZN13identical_fnc4main17hf4dbc69c8d2f9130E()
// CHECK-NEXT:start:
// CHECK-NOT:br
// CHECK-NOT:ret
// CHECK:; call identical_fnc::d_square
// CHECK-NEXT:  call fastcc void @_ZN13identical_fnc8d_square17h4c364207a2f8e06dE(double %x.val, ptr noalias noundef nonnull align 8 dereferenceable(8) %dx1)
// CHECK-NEXT:; call identical_fnc::d_square
// CHECK-NEXT:  call fastcc void @_ZN13identical_fnc8d_square17h4c364207a2f8e06dE(double %x.val, ptr noalias noundef nonnull align 8 dereferenceable(8) %dx2)

fn main() {
    let x = std::hint::black_box(3.0);
    let mut dx1 = std::hint::black_box(1.0);
    let mut dx2 = std::hint::black_box(1.0);
    let _ = d_square(&x, &mut dx1, 1.0);
    let _ = d_square2(&x, &mut dx2, 1.0);
    assert_eq!(dx1, 6.0);
    assert_eq!(dx2, 6.0);
}