Skip to content

Commit c6ba90e

Browse files
P_SET_CONSTANT macro added, does not fully work
1 parent 955e006 commit c6ba90e

3 files changed

Lines changed: 29 additions & 13 deletions

File tree

nitpick/nitpick_def.cc

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,9 @@ int main() {
4040
}
4141
constexpr int vres_size = vec_size * 3;
4242
double vres[vres_size];
43+
for (uint i = 0; i < vres_size; ++i) {
44+
vres[i] = NAN;
45+
}
4346

4447
// Create parser, give the size of the value spaces.
4548
// That is maximal allocated space. Actual values and
@@ -51,11 +54,10 @@ int main() {
5154
// parse an expression.
5255
p.parse("1 * v1 + cs1 * v2");
5356

54-
//TODO: Create the SET_CONSTANT macro
5557
// "cs1" constant with shape {}, i.e. scalar and values {2}.
56-
p.set_constant("cs1", {}, {2});
58+
P_SET_CONSTANT(cs1, {}, {2});
5759
// "cv1" vector constant with shape {3}
58-
p.set_constant("cv1", {3}, {1, 2, 3});
60+
P_SET_CONSTANT(cv1, {3}, ARG({1, 2, 3}));
5961
// "v1" variable with shape {3}; v1 is pointer to the value space
6062
P_SET_VARIABLE(v1, { 3 }, v1);
6163
P_SET_VARIABLE(v2, { 3 }, v2);

nitpick/nitpick_generate.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ int main() {
1010
p.compile();
1111

1212
ExpressionDAG dag(p.result_array().elements());
13+
//dag.print_in_dot2();
1314

1415
std::ofstream file(NITPICK_AUTOGENERATED_FILE_NAME);
1516
file << dag.print_in_cxx(inv_map);

nitpick/nitpick_include.hh

Lines changed: 23 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,40 +3,53 @@
33

44
#define NITPICK_AUTOGENERATED_FILE_NAME "autogenerated.cc" //Keep the same in .gitignore
55

6+
#define ARG(...) __VA_ARGS__
7+
68
#define P_SET_(TYPE, NAME, SHAPE, POINTER) \
7-
Array NAME##_array = Array::value(POINTER, max_vec_size, SHAPE); \
9+
Array NAME##_array = Array::value(POINTER, max_vec_size, SHAPE); \
810
\
911
for (MultiIdx idx(NAME##_array.range()); idx.valid(); idx.inc_src()) { \
1012
std::string var_name = get_var_name(#NAME, idx.indices()); \
11-
inv_map[NAME##_array[idx]->values_] = var_name; \
13+
inv_map[NAME##_array[idx]->values_] = var_name; \
1214
node_map[var_name] = NAME##_array[idx]->values_; \
1315
} \
1416
p.set_##TYPE(#NAME, SHAPE, POINTER);
1517

16-
#define P_SET_VARIABLE(NAME, SHAPE, POINTER) P_SET_(variable, NAME, SHAPE, POINTER)
17-
//#define P_SET_CONSTANT(NAME, SHAPE, POINTER) P_SET_(constant, NAME, SHAPE, POINTER)
18-
#define P_SET_VAR_COPY(NAME, SHAPE, POINTER) P_SET_(var_copy, NAME, SHAPE, POINTER)
18+
#define P_SET_C(TYPE, NAME, SHAPE, VALUES) \
19+
Array NAME##_array = Array::constant(VALUES, SHAPE); \
20+
\
21+
for (MultiIdx idx(NAME##_array.range()); idx.valid(); idx.inc_src()) { \
22+
std::string var_name = get_var_name(#NAME, idx.indices()); \
23+
inv_map[NAME##_array[idx]->values_] = var_name; \
24+
node_map[var_name] = NAME##_array[idx]->values_; \
25+
} \
26+
p.set_##TYPE(#NAME, SHAPE, VALUES);
27+
28+
//this has different pointers than the parser's Array::constant, so it does not work
29+
30+
#define P_SET_VARIABLE(NAME, SHAPE, POINTER) P_SET_(variable, NAME, ARG(SHAPE), POINTER)
31+
#define P_SET_CONSTANT(NAME, SHAPE, VALUES) P_SET_C(constant, NAME, ARG(SHAPE), ARG(VALUES))
32+
#define P_SET_VAR_COPY(NAME, SHAPE, POINTER) P_SET_(var_copy, NAME, ARG(SHAPE), POINTER)
1933

2034
#include "test_tools.hh"
2135
#include "parser.hh"
2236
#include <iostream>
2337
#include <fstream>
2438

25-
std::string get_var_name(const std::string & name, const bparser::MultiIdx::VecUint & indices);
26-
27-
// v1, (1,2,3) => "v1__1_2_3"
39+
// v1, (1,2,3) => "v1[1,2,3]"
2840
std::string get_var_name(const std::string& name, const bparser::MultiIdx::VecUint& indices) {
2941
bparser::MultiIdx::VecUint::size_type size(indices.size());
3042

3143
std::ostringstream result;
3244
result << name;
33-
result << "__";
45+
result << "[";
3446
for (bparser::MultiIdx::VecUint::size_type i = 0; i < size; i++) {
3547
result << indices.at(i);
3648
if (i != size - 1) {
37-
result << '_';
49+
result << ',';
3850
}
3951
}
52+
result << "]";
4053
return result.str();
4154
}
4255

0 commit comments

Comments
 (0)