File: mul_cskip.cpp

package info (click to toggle)
cppad 2026.00.00.0-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 11,584 kB
  • sloc: cpp: 112,960; sh: 6,146; ansic: 179; python: 71; sed: 12; makefile: 10
file content (64 lines) | stat: -rw-r--r-- 1,737 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
// SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-or-later
// SPDX-FileCopyrightText: Bradley M. Bell <bradbell@seanet.com>
// SPDX-FileContributor: 2003-22 Bradley M. Bell
// ----------------------------------------------------------------------------
# include <iostream>
# include <cppad/cppad.hpp>

// Test multiple level conditional skip where value of comparison is
// uncertain during forward mode base Base value can be a variable.
bool mul_cskip(void)
{  bool ok = true;
   using namespace CppAD;
   using CppAD::vector;

   typedef AD<double>  a1type;
   typedef AD<a1type>  a2type;

   size_t n = 2;
   size_t m = 1;
   vector<double> x(n), y(m);
   x[0] = 0.0;
   x[1] = 1.0;

   // start recording a2type operations
   vector<a2type> a2x(n), a2y(m);
   for (size_t j = 0; j < n; j++)
      a2x[j] = a2type( a1type(x[j]) );
   Independent(a2x);

   // a1f(x) = x_0 * x_1 if x[0] == 1
   //         0.0       otherwise
   a2type a2zero = a2type(0.0);
   a2type a2one  = a2type(1.0);
   a2type a2p    = a2x[0] * a2x[1];
   a2y[0]        = CondExpEq(a2x[0], a2one, a2p, a2zero);
   ADFun<a1type> a1f(a2x, a2y);

   // Optimization will check to see if we can skip part of conditional
   // expression that is not used.
   a1f.optimize();

   // f(x) = x_0 * x_1 if x[0] == 1
   //        0.0       otherwise
   vector<a1type> a1x(n), a1y(m);
   for (size_t j = 0; j < n; j++)
      a1x[j] = a1type(x[j]);
   Independent(a1x);
   a1y = a1f.Forward(0, a1x);
   CppAD::ADFun<double> f(a1x, a1y);

   // check case where x[0] == 1
   x[0] = 1.0;
   x[1] = 2.0;
   y = f.Forward(0, x);
   ok &= y[0] == x[1];

   // check case where x[0] != 1
   x[0] = 3.0;
   x[1] = 2.0;
   y = f.Forward(0, x);
   ok &= y[0] == 0.0;

   return ok;
}