Skip to content

Commit 7728bec

Browse files
Nitpick case for matrix inversion
1 parent 8524b45 commit 7728bec

7 files changed

Lines changed: 180 additions & 3 deletions

File tree

CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -319,6 +319,7 @@ endmacro()
319319

320320
define_nit(basic_expr)
321321
define_nit(norm2)
322+
define_nit(inv)
322323

323324

324325
#set(src_name "nitpick_generate")

include/dag_printer.hh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -102,11 +102,11 @@ namespace details {
102102
std::cout.flush();
103103
}
104104

105-
105+
//TODO: Figure out why the print gets more and more empty lines with longer prints
106106
std::string print_in_cxx(const CXXVarMap& map) {
107107
std::ostringstream result;
108108
ExpressionDAG::NodeVec result_nodes;
109-
//TODO: Make a function returning se and getting node_map
109+
110110
result << "//AUTOGENERATED This file has been autogenerated by bparser::details::DagPrinter::print_in_cxx" << "\n";
111111
result << "// " << __DATE__ << " " << __TIME__ << "\n";
112112
result << "\n";

include/parser.hh

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,10 @@ public:
135135
return keys;
136136
}
137137

138+
const std::map<std::string, Array> get_raw_symbols() const {
139+
return symbols_;
140+
}
141+
138142
/**
139143
* Set given name to be a variable of given shape with values at
140144
* given address 'variable_space'.

test/cases/inv_def.hh

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
#ifndef NITPICK_IDE_IGNORE
2+
#include "nitpick_include.hh"
3+
#endif //NITPICK_IDE_IGNORE
4+
5+
void def(ExprCase& c) {
6+
7+
// parse an expression.
8+
c.parse("inv(m3)");
9+
10+
c.set_variable("m3", { 3,3 });
11+
12+
// Set the result variable shape
13+
c.set_result_shape({3,3});
14+
15+
16+
}

test/cases/inv_gen_edit.hh

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
//AUTOGENERATED This file has been autogenerated by bparser::details::DagPrinter::print_in_cxx
2+
// Mar 22 2026 21:51:21
3+
4+
#ifndef NITPICK_IDE_IGNORE
5+
#include "nitpick_common.hh"
6+
#include "nitpick_include.hh"
7+
using namespace bparser;
8+
using namespace bparser::details;
9+
#endif //NITPICK_IDE_IGNORE //This is here only to stop any IDE warnings
10+
ExpressionDAG gen(const NodeMap& node_map){ //Do not rename this function, it is used later in the nitpick_run.cc file
11+
12+
ScalarNodePtr Const_51__1 = ScalarNode::create_const(1);
13+
ScalarNodePtr Value_50__2 = create_variable_node(node_map.at("m3[1,1]"),Variable);
14+
ScalarNodePtr Value_49__2 = create_variable_node(node_map.at("m3[2,2]"),Variable);
15+
ScalarNodePtr mul_48__3 = ScalarNode::create<_mul_>(Value_50__2, Value_49__2);
16+
ScalarNodePtr Value_47__2 = create_variable_node(node_map.at("m3[1,2]"),Variable);
17+
ScalarNodePtr Value_46__2 = create_variable_node(node_map.at("m3[2,1]"),Variable);
18+
ScalarNodePtr mul_45__3 = ScalarNode::create<_mul_>(Value_47__2, Value_46__2);
19+
ScalarNodePtr sub_44__3 = ScalarNode::create<_sub_>(mul_48__3, mul_45__3);
20+
ScalarNodePtr Value_43__2 = create_variable_node(node_map.at("m3[0,0]"),Variable);
21+
ScalarNodePtr mul_42__3 = ScalarNode::create<_mul_>(sub_44__3, Value_43__2);
22+
ScalarNodePtr Value_41__2 = create_variable_node(node_map.at("m3[0,2]"),Variable);
23+
ScalarNodePtr mul_40__3 = ScalarNode::create<_mul_>(Value_46__2, Value_41__2);
24+
ScalarNodePtr Value_39__2 = create_variable_node(node_map.at("m3[0,1]"),Variable);
25+
ScalarNodePtr mul_38__3 = ScalarNode::create<_mul_>(Value_49__2, Value_39__2);
26+
ScalarNodePtr sub_37__3 = ScalarNode::create<_sub_>(mul_40__3, mul_38__3);
27+
ScalarNodePtr Value_36__2 = create_variable_node(node_map.at("m3[1,0]"),Variable);
28+
ScalarNodePtr mul_35__3 = ScalarNode::create<_mul_>(sub_37__3, Value_36__2);
29+
ScalarNodePtr add_34__3 = ScalarNode::create<_add_>(mul_42__3, mul_35__3);
30+
ScalarNodePtr mul_33__3 = ScalarNode::create<_mul_>(Value_39__2, Value_47__2);
31+
ScalarNodePtr mul_32__3 = ScalarNode::create<_mul_>(Value_41__2, Value_50__2);
32+
ScalarNodePtr sub_31__3 = ScalarNode::create<_sub_>(mul_33__3, mul_32__3);
33+
ScalarNodePtr Value_30__2 = create_variable_node(node_map.at("m3[2,0]"),Variable);
34+
ScalarNodePtr mul_29__3 = ScalarNode::create<_mul_>(sub_31__3, Value_30__2);
35+
ScalarNodePtr add_28__3 = ScalarNode::create<_add_>(add_34__3, mul_29__3);
36+
ScalarNodePtr div_27__3 = ScalarNode::create<_div_>(Const_51__1, add_28__3);
37+
ScalarNodePtr mul_26__4 = ScalarNode::create<_mul_>(sub_44__3, div_27__3);
38+
ScalarNodePtr mul_25__4 = ScalarNode::create<_mul_>(sub_37__3, div_27__3);
39+
ScalarNodePtr mul_24__4 = ScalarNode::create<_mul_>(sub_31__3, div_27__3);
40+
ScalarNodePtr mul_23__3 = ScalarNode::create<_mul_>(Value_47__2, Value_30__2);
41+
ScalarNodePtr mul_22__3 = ScalarNode::create<_mul_>(Value_36__2, Value_49__2);
42+
ScalarNodePtr sub_21__3 = ScalarNode::create<_sub_>(mul_23__3, mul_22__3);
43+
ScalarNodePtr mul_20__4 = ScalarNode::create<_mul_>(sub_21__3, div_27__3);
44+
ScalarNodePtr mul_19__3 = ScalarNode::create<_mul_>(Value_49__2, Value_43__2);
45+
ScalarNodePtr mul_18__3 = ScalarNode::create<_mul_>(Value_30__2, Value_41__2);
46+
ScalarNodePtr sub_17__3 = ScalarNode::create<_sub_>(mul_19__3, mul_18__3);
47+
ScalarNodePtr mul_16__4 = ScalarNode::create<_mul_>(sub_17__3, div_27__3);
48+
ScalarNodePtr mul_15__3 = ScalarNode::create<_mul_>(Value_41__2, Value_36__2);
49+
ScalarNodePtr mul_14__3 = ScalarNode::create<_mul_>(Value_43__2, Value_47__2);
50+
ScalarNodePtr sub_13__3 = ScalarNode::create<_sub_>(mul_15__3, mul_14__3);
51+
ScalarNodePtr mul_12__4 = ScalarNode::create<_mul_>(sub_13__3, div_27__3);
52+
ScalarNodePtr mul_11__3 = ScalarNode::create<_mul_>(Value_36__2, Value_46__2);
53+
ScalarNodePtr mul_10__3 = ScalarNode::create<_mul_>(Value_50__2, Value_30__2);
54+
ScalarNodePtr sub_9__3 = ScalarNode::create<_sub_>(mul_11__3, mul_10__3);
55+
ScalarNodePtr mul_8__4 = ScalarNode::create<_mul_>(sub_9__3, div_27__3);
56+
ScalarNodePtr mul_7__3 = ScalarNode::create<_mul_>(Value_30__2, Value_39__2);
57+
ScalarNodePtr mul_6__3 = ScalarNode::create<_mul_>(Value_46__2, Value_43__2);
58+
ScalarNodePtr sub_5__3 = ScalarNode::create<_sub_>(mul_7__3, mul_6__3);
59+
ScalarNodePtr mul_4__4 = ScalarNode::create<_mul_>(sub_5__3, div_27__3);
60+
ScalarNodePtr mul_3__3 = ScalarNode::create<_mul_>(Value_43__2, Value_50__2);
61+
ScalarNodePtr mul_2__3 = ScalarNode::create<_mul_>(Value_39__2, Value_36__2);
62+
ScalarNodePtr sub_1__3 = ScalarNode::create<_sub_>(mul_3__3, mul_2__3);
63+
ScalarNodePtr mul_0__4 = ScalarNode::create<_mul_>(sub_1__3, div_27__3);
64+
65+
66+
ScalarNodePtr r0 = ScalarNode::create_result(mul_26__4, node_map.at("_result_[0,0]"));
67+
ScalarNodePtr r1 = ScalarNode::create_result(mul_25__4, node_map.at("_result_[0,1]"));
68+
ScalarNodePtr r2 = ScalarNode::create_result(mul_24__4, node_map.at("_result_[0,2]"));
69+
ScalarNodePtr r3 = ScalarNode::create_result(mul_20__4, node_map.at("_result_[1,0]"));
70+
ScalarNodePtr r4 = ScalarNode::create_result(mul_16__4, node_map.at("_result_[1,1]"));
71+
ScalarNodePtr r5 = ScalarNode::create_result(mul_12__4, node_map.at("_result_[1,2]"));
72+
ScalarNodePtr r6 = ScalarNode::create_result(mul_8__4, node_map.at("_result_[2,0]"));
73+
ScalarNodePtr r7 = ScalarNode::create_result(mul_4__4, node_map.at("_result_[2,1]"));
74+
ScalarNodePtr r8 = ScalarNode::create_result(mul_0__4, node_map.at("_result_[2,2]"));
75+
ExpressionDAG se({r0, r1, r2, r3, r4, r5, r6, r7, r8});
76+
77+
return se;
78+
} //gen

test/cases/inv_gen_ref.hh

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
//AUTOGENERATED This file has been autogenerated by bparser::details::DagPrinter::print_in_cxx
2+
// Mar 22 2026 21:51:21
3+
4+
#ifndef NITPICK_IDE_IGNORE
5+
#include "nitpick_common.hh"
6+
#include "nitpick_include.hh"
7+
using namespace bparser;
8+
using namespace bparser::details;
9+
#endif //NITPICK_IDE_IGNORE //This is here only to stop any IDE warnings
10+
ExpressionDAG gen(const NodeMap& node_map){ //Do not rename this function, it is used later in the nitpick_run.cc file
11+
12+
ScalarNodePtr Const_51__1 = ScalarNode::create_const(1);
13+
ScalarNodePtr Value_50__2 = create_variable_node(node_map.at("m3[1,1]"),Variable);
14+
ScalarNodePtr Value_49__2 = create_variable_node(node_map.at("m3[2,2]"),Variable);
15+
ScalarNodePtr mul_48__3 = ScalarNode::create<_mul_>(Value_50__2, Value_49__2);
16+
ScalarNodePtr Value_47__2 = create_variable_node(node_map.at("m3[1,2]"),Variable);
17+
ScalarNodePtr Value_46__2 = create_variable_node(node_map.at("m3[2,1]"),Variable);
18+
ScalarNodePtr mul_45__3 = ScalarNode::create<_mul_>(Value_47__2, Value_46__2);
19+
ScalarNodePtr sub_44__3 = ScalarNode::create<_sub_>(mul_48__3, mul_45__3);
20+
ScalarNodePtr Value_43__2 = create_variable_node(node_map.at("m3[0,0]"),Variable);
21+
ScalarNodePtr mul_42__3 = ScalarNode::create<_mul_>(sub_44__3, Value_43__2);
22+
ScalarNodePtr Value_41__2 = create_variable_node(node_map.at("m3[0,2]"),Variable);
23+
ScalarNodePtr mul_40__3 = ScalarNode::create<_mul_>(Value_46__2, Value_41__2);
24+
ScalarNodePtr Value_39__2 = create_variable_node(node_map.at("m3[0,1]"),Variable);
25+
ScalarNodePtr mul_38__3 = ScalarNode::create<_mul_>(Value_49__2, Value_39__2);
26+
ScalarNodePtr sub_37__3 = ScalarNode::create<_sub_>(mul_40__3, mul_38__3);
27+
ScalarNodePtr Value_36__2 = create_variable_node(node_map.at("m3[1,0]"),Variable);
28+
ScalarNodePtr mul_35__3 = ScalarNode::create<_mul_>(sub_37__3, Value_36__2);
29+
ScalarNodePtr add_34__3 = ScalarNode::create<_add_>(mul_42__3, mul_35__3);
30+
ScalarNodePtr mul_33__3 = ScalarNode::create<_mul_>(Value_39__2, Value_47__2);
31+
ScalarNodePtr mul_32__3 = ScalarNode::create<_mul_>(Value_41__2, Value_50__2);
32+
ScalarNodePtr sub_31__3 = ScalarNode::create<_sub_>(mul_33__3, mul_32__3);
33+
ScalarNodePtr Value_30__2 = create_variable_node(node_map.at("m3[2,0]"),Variable);
34+
ScalarNodePtr mul_29__3 = ScalarNode::create<_mul_>(sub_31__3, Value_30__2);
35+
ScalarNodePtr add_28__3 = ScalarNode::create<_add_>(add_34__3, mul_29__3);
36+
ScalarNodePtr div_27__3 = ScalarNode::create<_div_>(Const_51__1, add_28__3);
37+
ScalarNodePtr mul_26__4 = ScalarNode::create<_mul_>(sub_44__3, div_27__3);
38+
ScalarNodePtr mul_25__4 = ScalarNode::create<_mul_>(sub_37__3, div_27__3);
39+
ScalarNodePtr mul_24__4 = ScalarNode::create<_mul_>(sub_31__3, div_27__3);
40+
ScalarNodePtr mul_23__3 = ScalarNode::create<_mul_>(Value_47__2, Value_30__2);
41+
ScalarNodePtr mul_22__3 = ScalarNode::create<_mul_>(Value_36__2, Value_49__2);
42+
ScalarNodePtr sub_21__3 = ScalarNode::create<_sub_>(mul_23__3, mul_22__3);
43+
ScalarNodePtr mul_20__4 = ScalarNode::create<_mul_>(sub_21__3, div_27__3);
44+
ScalarNodePtr mul_19__3 = ScalarNode::create<_mul_>(Value_49__2, Value_43__2);
45+
ScalarNodePtr mul_18__3 = ScalarNode::create<_mul_>(Value_30__2, Value_41__2);
46+
ScalarNodePtr sub_17__3 = ScalarNode::create<_sub_>(mul_19__3, mul_18__3);
47+
ScalarNodePtr mul_16__4 = ScalarNode::create<_mul_>(sub_17__3, div_27__3);
48+
ScalarNodePtr mul_15__3 = ScalarNode::create<_mul_>(Value_41__2, Value_36__2);
49+
ScalarNodePtr mul_14__3 = ScalarNode::create<_mul_>(Value_43__2, Value_47__2);
50+
ScalarNodePtr sub_13__3 = ScalarNode::create<_sub_>(mul_15__3, mul_14__3);
51+
ScalarNodePtr mul_12__4 = ScalarNode::create<_mul_>(sub_13__3, div_27__3);
52+
ScalarNodePtr mul_11__3 = ScalarNode::create<_mul_>(Value_36__2, Value_46__2);
53+
ScalarNodePtr mul_10__3 = ScalarNode::create<_mul_>(Value_50__2, Value_30__2);
54+
ScalarNodePtr sub_9__3 = ScalarNode::create<_sub_>(mul_11__3, mul_10__3);
55+
ScalarNodePtr mul_8__4 = ScalarNode::create<_mul_>(sub_9__3, div_27__3);
56+
ScalarNodePtr mul_7__3 = ScalarNode::create<_mul_>(Value_30__2, Value_39__2);
57+
ScalarNodePtr mul_6__3 = ScalarNode::create<_mul_>(Value_46__2, Value_43__2);
58+
ScalarNodePtr sub_5__3 = ScalarNode::create<_sub_>(mul_7__3, mul_6__3);
59+
ScalarNodePtr mul_4__4 = ScalarNode::create<_mul_>(sub_5__3, div_27__3);
60+
ScalarNodePtr mul_3__3 = ScalarNode::create<_mul_>(Value_43__2, Value_50__2);
61+
ScalarNodePtr mul_2__3 = ScalarNode::create<_mul_>(Value_39__2, Value_36__2);
62+
ScalarNodePtr sub_1__3 = ScalarNode::create<_sub_>(mul_3__3, mul_2__3);
63+
ScalarNodePtr mul_0__4 = ScalarNode::create<_mul_>(sub_1__3, div_27__3);
64+
65+
66+
ScalarNodePtr r0 = ScalarNode::create_result(mul_26__4, node_map.at("_result_[0,0]"));
67+
ScalarNodePtr r1 = ScalarNode::create_result(mul_25__4, node_map.at("_result_[0,1]"));
68+
ScalarNodePtr r2 = ScalarNode::create_result(mul_24__4, node_map.at("_result_[0,2]"));
69+
ScalarNodePtr r3 = ScalarNode::create_result(mul_20__4, node_map.at("_result_[1,0]"));
70+
ScalarNodePtr r4 = ScalarNode::create_result(mul_16__4, node_map.at("_result_[1,1]"));
71+
ScalarNodePtr r5 = ScalarNode::create_result(mul_12__4, node_map.at("_result_[1,2]"));
72+
ScalarNodePtr r6 = ScalarNode::create_result(mul_8__4, node_map.at("_result_[2,0]"));
73+
ScalarNodePtr r7 = ScalarNode::create_result(mul_4__4, node_map.at("_result_[2,1]"));
74+
ScalarNodePtr r8 = ScalarNode::create_result(mul_0__4, node_map.at("_result_[2,2]"));
75+
ExpressionDAG se({r0, r1, r2, r3, r4, r5, r6, r7, r8});
76+
77+
return se;
78+
} //gen

test/nitpick/nitpick_generate.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ int main() {
3333
//p.compile(exprcase.get_patch_arena()); //Add arena from ExprCase
3434

3535
ExpressionDAG dag(p.result_array().elements());
36-
//dag.print_in_dot2();
36+
DagPrinter(dag).print_in_dot2(p.get_raw_symbols());
3737

3838
std::ofstream file(NITPICK_GEN_FILE);
3939
file << DagPrinter(dag).print_in_cxx(exprcase.get_inv_map());

0 commit comments

Comments
 (0)