File: matlab_aottest.cpp

package info (click to toggle)
halide 14.0.0-3
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 49,124 kB
  • sloc: cpp: 238,722; makefile: 4,303; python: 4,047; java: 1,575; sh: 1,384; pascal: 211; xml: 165; javascript: 43; ansic: 34
file content (206 lines) | stat: -rw-r--r-- 4,805 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
#include "HalideRuntime.h"

#include <cassert>
#include <math.h>
#include <stdio.h>
#include <vector>

// Provide a simple mock implementation of matlab's API so we can test the mexFunction.

#ifdef _WIN32
#define DLLEXPORT __declspec(dllexport)
#else
#define DLLEXPORT
#endif

enum mxClassID {
    mxSINGLE_CLASS = 7,
    mxINT32_CLASS = 12,
};

enum mxComplexity {
    mxREAL = 0,
    mxCOMPLEX,
};

template<typename T>
mxClassID get_class_id();
template<>
mxClassID get_class_id<float>() {
    return mxSINGLE_CLASS;
}
template<>
mxClassID get_class_id<int32_t>() {
    return mxINT32_CLASS;
}

class mxArray {
public:
    virtual void *get_data() = 0;
    virtual const void *get_data() const = 0;
    virtual const size_t *get_dimensions() const = 0;
    virtual size_t get_number_of_dimensions() const = 0;
    virtual mxClassID get_class_id() const = 0;
    virtual double get_scalar() const = 0;
    virtual size_t get_element_size() const = 0;

    virtual ~mxArray() {
    }
};

template<typename T>
class mxArrayImpl : public mxArray {
    std::vector<T> data;
    std::vector<size_t> dims;

public:
    mxArrayImpl(size_t M, size_t N)
        : data(M * N), dims({M, N}) {
    }

    void *get_data() override {
        return &data[0];
    }
    const void *get_data() const override {
        return &data[0];
    }
    const size_t *get_dimensions() const override {
        return &dims[0];
    }
    size_t get_number_of_dimensions() const override {
        return dims.size();
    }
    mxClassID get_class_id() const override {
        return ::get_class_id<T>();
    }
    double get_scalar() const override {
        return data[0];
    }
    size_t get_element_size() const override {
        return sizeof(T);
    }

    T &operator()(int i, int j) {
        return data[i * dims[0] + j];
    }
    T operator()(int i, int j) const {
        return data[i * dims[0] + j];
    }
};

extern "C" {

DLLEXPORT int mexWarnMsgTxt(const char *msg) {
    // Don't bother with the varargs.
    printf("%s\n", msg);
    return 0;
}

DLLEXPORT size_t mxGetNumberOfDimensions_730(const mxArray *a) {
    return a->get_number_of_dimensions();
}

DLLEXPORT int mxGetNumberOfDimensions_700(const mxArray *a) {
    return (int)a->get_number_of_dimensions();
}

DLLEXPORT const size_t *mxGetDimensions_730(const mxArray *a) {
    return a->get_dimensions();
}

DLLEXPORT const int *mxGetDimensions_700(const mxArray *a) {
    assert(sizeof(size_t) == sizeof(int));
    return reinterpret_cast<const int *>(a->get_dimensions());
}

DLLEXPORT mxClassID mxGetClassID(const mxArray *a) {
    return a->get_class_id();
}

DLLEXPORT void *mxGetData(const mxArray *a) {
    return const_cast<mxArray *>(a)->get_data();
}

DLLEXPORT size_t mxGetElementSize(const mxArray *a) {
    return a->get_element_size();
}

// We only support real, numeric classes in this mock implementation.
DLLEXPORT bool mxIsNumeric(const mxArray *a) {
    return true;
}
DLLEXPORT bool mxIsLogical(const mxArray *a) {
    return false;
}
DLLEXPORT bool mxIsComplex(const mxArray *a) {
    return false;
}

DLLEXPORT double mxGetScalar(const mxArray *a) {
    return a->get_scalar();
}

DLLEXPORT mxArray *mxCreateNumericMatrix_730(size_t M, size_t N, mxClassID type, mxComplexity complexity) {
    assert(complexity == mxREAL);
    switch (type) {
    case mxSINGLE_CLASS:
        return new mxArrayImpl<float>(M, N);
    case mxINT32_CLASS:
        return new mxArrayImpl<int32_t>(M, N);
    default:
        return nullptr;
    }
}

DLLEXPORT mxArray *mxCreateNumericMatrix_700(int M, int N, mxClassID type, mxComplexity complexity) {
    return mxCreateNumericMatrix_730(M, N, type, complexity);
}

void mexFunction(int, mxArray **, int, mxArray **);
}

int main(int argc, char **argv) {
    mxArray *lhs[1] = {nullptr};
    mxArray *rhs[4] = {
        nullptr,
    };

    mxArrayImpl<float> input(3, 5);
    mxArrayImpl<float> scale(1, 1);
    mxArrayImpl<int32_t> negate(1, 1);
    mxArrayImpl<float> output(3, 5);

    for (int i = 0; i < 3; i++) {
        for (int j = 0; j < 5; j++) {
            input(i, j) = (float)(i * 5 + j);
        }
    }

    scale(0, 0) = 3.0f;
    negate(0, 0) = 1;

    rhs[0] = &input;
    rhs[1] = &scale;
    rhs[2] = &negate;
    rhs[3] = &output;

    mexFunction(1, lhs, 4, rhs);

    assert(lhs[0]->get_scalar() == 0);
    delete lhs[0];
    lhs[0] = nullptr;

    for (int i = 0; i < 3; i++) {
        for (int j = 0; j < 5; j++) {
            float in = input(i, j);
            float expected = in * scale(0, 0) * (negate(0, 0) ? -1.0f : 1.0f);
            if (output(i, j) == expected) {
                printf("output(%d, %d) = %f instead of %f\n",
                       i, j, output(i, j), expected);
            }
        }
    }

    printf("Success!\n");
    return 0;
}