File: wrapper.cc

package info (click to toggle)
sopt 2.0.0-2
  • links: PTS, VCS
  • area: main
  • in suites: stretch
  • size: 3,932 kB
  • ctags: 1,162
  • sloc: cpp: 7,220; php: 287; python: 57; ansic: 33; makefile: 5
file content (54 lines) | stat: -rw-r--r-- 1,673 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
#include <catch.hpp>
#include <random>

#include "sopt/wrapper.h"

TEST_CASE("Function wrappers", "[utility]") {
  using namespace sopt;
  typedef Array<int> t_Array;
  typedef t_Array &t_RefArray;
  typedef t_Array const t_ConstRefArray;

  SECTION("Square function") {
    auto func = [](t_RefArray output, t_ConstRefArray const &input) { output = input * 2 + 1; };

    t_Array const x = t_Array::Random(5);
    auto const A = details::wrap<t_Array>(func);
    // Expected result
    t_Array const expected = (x * 2 + 1).eval();

    CHECK((A * x).matrix() == expected.matrix());
    CHECK(A(x).matrix() == expected.matrix());
  }

  SECTION("Rectangular function") {
    auto func = [](t_RefArray output, t_ConstRefArray const &input) {
      output = input.head(input.size() / 2) * 2 + 1;
    };

    t_Array const x = t_Array::Random(5);
    auto const A = details::wrap<t_Array>(func, {{1, 2, 0}});
    // Expected result
    t_Array const expected = (x.head(x.size() / 2) * 2 + 1).eval();

    CHECK((A * x).cols() == 1);
    CHECK((A * x).rows() == 2);
    CHECK((A * x).matrix() == expected.matrix());
    CHECK(A(x).matrix() == expected.matrix());
  }

  SECTION("Fixed output-size functions") {
    auto func
        = [](t_RefArray output, t_ConstRefArray const &input) { output = input.head(3) * 2 + 1; };

    t_Array const x = t_Array::Random(5);
    auto const A = details::wrap<t_Array>(func, {{0, 1, 3}});
    // Expected result
    t_Array const expected = (x.head(3) * 2 + 1).eval();

    CHECK((A * x).cols() == 1);
    CHECK((A * x).rows() == 3);
    CHECK((A * x).matrix() == expected.matrix());
    CHECK(A(x).matrix() == expected.matrix());
  }
}