Skip to content

Commit 70404a9

Browse files
Inverse matrix (<4x4)
1 parent c58b6a2 commit 70404a9

4 files changed

Lines changed: 49 additions & 0 deletions

File tree

include/array.hh

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1443,6 +1443,43 @@ public:
14431443
}
14441444
}
14451445

1446+
/*static Array det(const Array& a) {
1447+
1448+
}*/
1449+
1450+
//Square matrix inverse
1451+
static Array inv(const Array& a) {
1452+
switch (a.shape().size()) {
1453+
case 0: //scalar
1454+
Throw() << "Cannot inverse a scalar" << "\n";
1455+
break;
1456+
case 1: //vector
1457+
{
1458+
Throw() << "Cannot inverse a vector" << "\n";
1459+
}
1460+
case 2: //matrix
1461+
{
1462+
if (a.shape()[0] != a.shape()[1]) {
1463+
Throw() << "Cannot inverse non-square matrix" << "\n";
1464+
}
1465+
WrappedArray m_a(wrap_array(a));
1466+
switch (a.shape()[0]) {
1467+
case 1:
1468+
return unwrap_array(Eigen::Ref<Eigen::Matrix<details::ScalarWrapper,1,1>>(m_a).inverse());
1469+
case 2:
1470+
return unwrap_array(Eigen::Ref<Eigen::Matrix2<details::ScalarWrapper>>(m_a).inverse());
1471+
case 3:
1472+
return unwrap_array(Eigen::Ref<Eigen::Matrix3<details::ScalarWrapper>>(m_a).inverse());
1473+
case 4:
1474+
return unwrap_array(Eigen::Ref<Eigen::Matrix4<details::ScalarWrapper>>(m_a).inverse());
1475+
}
1476+
//return unwrap_array(m_a.inverse());
1477+
}
1478+
default:
1479+
Throw() << "Cannot inverse ND tensors" << "\n";
1480+
}
1481+
}
1482+
14461483

14471484
static Array flatten(const Array &tensor) {
14481485
uint n_elements = shape_size(tensor.shape());

include/grammar.impl.hh

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,8 @@ struct grammar : qi::grammar<Iterator, ast::operand(), ascii::space_type> {
193193
FN("cross" , &Array::cross)
194194
FN("sym" , &Array::sym)
195195
FN("dev" , &Array::dev)
196+
//FN("det" , &Array::det)
197+
FN("inv" , &Array::inv)
196198
;
197199

198200
unary_op.add

include/scalar_wrapper.hh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
#include "scalar_node.hh"
1414
#include <Eigen/Core>
1515
#include <Eigen/Geometry>
16+
#include <Eigen/LU>
1617
//#include <Eigen/Eigenvalues> //impossible
1718

1819
namespace bparser {

test/test_parser.cc

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -275,6 +275,15 @@ void test_expression() {
275275
BP_ASSERT(test_expr("dev([[1,2],[3,4]])", { 1-2.5, 2, 3, 4-2.5 }, {2,2}));
276276
BP_ASSERT(test_expr("tr(dev([[1,2],[3,4]]))", { 0 }, {}));
277277

278+
//BP_ASSERT(test_expr("det()", {}, {}));
279+
280+
BP_ASSERT(test_expr("inv([[1,2],[3,4]])", { -2., 1., 1.5, -0.5 }, { 2,2 }));
281+
BP_ASSERT(test_expr("a=[[1]]; inv(a)", { 1 }, { 1,1 }));
282+
BP_ASSERT(test_expr("a=[[1,2],[3,4]]; a @ inv(a)", { 1,0,0,1 }, { 2,2 }));
283+
BP_ASSERT(test_expr("a=[[1,2],[3,4]]; inv(a) @ a", { 1,0,0,1 }, { 2,2 }));
284+
BP_ASSERT(test_expr("a=[[1,2,3],[4,5,6],[7,8,9]]; inv(a) @ a", { 1,0,0, 0,1,0, 0,0,1 }, { 3,3 }));
285+
BP_ASSERT(test_expr("a=[[1,2,3,4],[5,6,7,8],[9,10,11,12],[13,14,15,16]]; inv(a) @ a", { 1,0,0,0, 0,1,0,0, 0,0,1,0, 0,0,0,1 }, { 4,4 }));
286+
278287
BP_ASSERT(test_expr("abs(-1)+abs(0)+abs(1)", {2}));
279288
BP_ASSERT(test_expr("floor(-3.5)", {-4}, {}));
280289
BP_ASSERT(test_expr("ceil(-3.5)", {-3}, {}));

0 commit comments

Comments
 (0)