File: test_djb.c

package info (click to toggle)
libm4ri 20200125-1
  • links: PTS, VCS
  • area: main
  • in suites: bullseye, sid
  • size: 2,560 kB
  • sloc: ansic: 12,633; sh: 4,304; makefile: 137
file content (86 lines) | stat: -rw-r--r-- 2,230 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
#include <m4ri/config.h>
#include <stdlib.h>
#include <m4ri/m4ri.h>
#include <m4ri/djb.h>
#include "testing.h"

/**
 * Check that the results of the implementation of Dan Bernstein's "Optimizing
 * linear maps mod 2" matches the matrix multiplication algorithms.
 *
 * \param m Number of rows of A
 * \param l Number of columns of A/number of rows of B
 * \param n Number of columns of B
 */
int mul_test_equality(rci_t m, rci_t l, rci_t n) {
  int ret  = 0;
  printf("   mul: m: %4d, l: %4d, n: %4d", m, l, n);

  /* we create two random matrices */
  mzd_t *A = mzd_init(m, l);
  mzd_t *B = mzd_init(l, n);
  mzd_randomize(A);
  mzd_randomize(B);

  /* C = A*B via Strassen */
  mzd_t *C = mzd_mul(NULL, A, B, __M4RI_STRASSEN_MUL_CUTOFF);

  /* C = A*B via DJB */
  djb_t *djb_A = djb_compile(A);
  mzd_t *djb_C = mzd_init(m, n);
  djb_apply_mzd(djb_A, djb_C, B);

  if (mzd_equal(C, djb_C) != TRUE) {
    printf(" Strassen != DJB");
    ret -=1;
  }

  mzd_free(djb_C);
  djb_free(djb_A);

  mzd_free(C);
  mzd_free(B);
  mzd_free(A);

  if(ret==0) {
    printf(" ... passed\n");
  } else {
    printf(" ... FAILED\n");
  }

  return ret;

}

int main() {
  int status = 0;

  srandom(17);

  status += mul_test_equality(   1,    1,    1);
  status += mul_test_equality(   1,  128,  128);
  status += mul_test_equality(   3,  131,  257);
  status += mul_test_equality(  64,   64,   64);
  status += mul_test_equality( 128,  128,  128);
  status += mul_test_equality(  21,  171,   31); 
  status += mul_test_equality(  21,  171,   31); 
  status += mul_test_equality( 193,   65,   65);
  status += mul_test_equality(1025, 1025, 1025);
  status += mul_test_equality(2048, 2048, 4096);
  status += mul_test_equality(4096, 3528, 4096);
  status += mul_test_equality(1024, 1025,    1);
  status += mul_test_equality(1000, 1000, 1000);
  status += mul_test_equality(1000,   10,   20);
  status += mul_test_equality(1710, 1290, 1000);
  status += mul_test_equality(1290, 1710,  200);
  status += mul_test_equality(1290, 1710, 2000);
  status += mul_test_equality(1290, 1290, 2000);
  status += mul_test_equality(1000,  210,  200);

  if (status == 0) {
    printf("All tests passed.\n");
    return 0;
  } else {
    return -1;
  }
}