11import pytest
2+ from functools import singledispatch
3+ from collections .abc import Sequence
4+ from example_code .expression_tools import postvisitor , evaluate
5+ from expressions .expressions import Symbol , Number , \
6+ Operator , Add , Mul
27try :
38 from expressions .expressions import differentiate
49except ImportError :
510 pass
611
712
13+ class ExpressionsError (Exception ):
14+ pass
15+
16+
17+ _test_expr = Symbol ("x" ) / Number (1 )
18+ _operands_attr = None
19+ for attr_name in dir (_test_expr ):
20+ attr = getattr (_test_expr , attr_name )
21+ if (
22+ isinstance (attr , Sequence ) and len (attr ) == 2
23+ and type (attr [0 ]) is Symbol and str (attr [0 ]) == "x"
24+ and type (attr [1 ]) is Number and float (str (attr [1 ])) == 1.0
25+ ):
26+ _operands_attr = attr_name
27+ break
28+
29+
30+ def try_eval (expression ):
31+ """Evaluate an expression if it doesn't have any symbols."""
32+ try :
33+ return postvisitor (expression , evaluate , symbol_map = {})
34+ except KeyError :
35+ return expression
36+
37+
38+ def operands (expression ):
39+ """Return the operands tuple of an expression."""
40+ if _operands_attr is None :
41+ raise ExpressionsError ("Could not find operands tuple on expression." )
42+ return tuple (getattr (expression , _operands_attr ))
43+
44+
45+ @singledispatch
46+ def expressions_equal (t1 , t2 ):
47+ """Return true if two expressions are equal.
48+
49+ This function takes into account commutative operators, but not the full
50+ class of equivalences between expressions.
51+ """
52+ try :
53+ if isinstance (t1 , tuple ) and isinstance (t2 , tuple ):
54+ return len (t1 ) == len (t2 )\
55+ and all (expressions_equal (e1 , e2 ) for e1 , e2 in zip (t1 , t2 ))
56+ except :
57+ return False
58+ return False # By default, expressions don't match.
59+
60+
61+ @expressions_equal .register (Number )
62+ def _ (e1 , e2 ):
63+ # Not the nicest way to do it, but `str` is the only way we have to
64+ # extract the value.
65+ return float (str (e1 )) == float (str (e2 ))
66+
67+
68+ @expressions_equal .register (Symbol )
69+ def _ (e1 , e2 ):
70+ return str (e1 ) == str (e2 )
71+
72+
73+ @expressions_equal .register (Operator )
74+ def _ (e1 , e2 ):
75+ return type (e1 ) is type (e2 )\
76+ and expressions_equal (operands (e1 ), operands (e2 ))
77+
78+
79+ @expressions_equal .register (Add )
80+ @expressions_equal .register (Mul )
81+ def _ (e1 , e2 ):
82+ return type (e1 ) is type (e2 )\
83+ and (
84+ expressions_equal (operands (e1 ), operands (e2 ))
85+ or expressions_equal (operands (e1 ), tuple (reversed (operands (e2 ))))
86+ ) or try_eval (e1 ) == try_eval (e2 )
87+
88+
89+ expressions_equal .register (tuple )
90+ def _ (t1 , t2 ):
91+ return type (t1 ) is type (t2 ) and len (t1 ) == len (t2 )\
92+ and all (expressions_equal (e1 , e2 ) for e1 , e2 in zip (t1 , t2 ))
93+
94+
895def test_diff_import ():
996 from expressions .expressions import differentiate
1097
@@ -14,30 +101,30 @@ def sample_diff_set():
14101 from expressions .expressions import Symbol
15102 x = Symbol ('x' )
16103 y = Symbol ('y' )
17- tests = [(2 * x + 1 , 'x' , 1.5 , 10 , 2 , ' 0.0 * x + 1.0 * 2 + 0.0' ),
18- (3 * x + y , 'x' , 1.5 , 10 , 3 , ' 0.0 * x + 1.0 * 3 + 0.0' ),
104+ tests = [(2 * x + 1 , 'x' , 1.5 , 10 , 2 , 0.0 * x + 1.0 * 2 + 0.0 ),
105+ (3 * x + y , 'x' , 1.5 , 10 , 3 , 0.0 * x + 1.0 * 3 + 0.0 ),
19106 (x * y + x / y , 'y' , 3 , 2 , 2.25 ,
20- ' 0.0 * y + 1.0 * x + (0.0 * y - x * 1.0) / y ^ 2' ),
107+ 0.0 * y + 1.0 * x + (0.0 * y - x * 1.0 ) / ( y ** 2 ) ),
21108 (2 * x ** 3 + x ** 2 * y , 'x' , 2 , 3 , 36 ,
22- ' 0.0 * x ^ 3 + 3 * x ^ (3 - 1) * 1.0 * 2 + 2 * x ^'
23- ' (2 - 1) * 1.0 * y + 0.0 * x ^ 2' )
109+ 0.0 * x ** 3 + 3 * x ** (3 - 1 ) * 1.0 * 2 + 2 * x ** ( 2 - 1 ) * 1.0 *
110+ y + 0.0 * x ** 2 )
24111 ]
25112 return tests
26113
27114
28115@pytest .mark .parametrize ("idx" , [
29116 (0 ),
30117 (1 ),
31- (2 ),
32- (3 )
118+ (2 )
33119])
34120def test_diff_expr_recursive (sample_diff_set , idx ):
35121 from expressions .expressions import differentiate
36122 from example_code .expression_tools import postvisitor
37123 expr , dvar , _ , _ , _ , diff_expr = sample_diff_set [idx ]
38- assert str (postvisitor (expr , differentiate , var = dvar )) == diff_expr , \
39- f"expected an expression of { diff_expr } " \
40- f" for expression d/d{ dvar } ({ expr } )"
124+ derivative = postvisitor (expr , differentiate , var = dvar )
125+ assert expressions_equal (derivative , diff_expr ), \
126+ f"Computing expression d/d{ dvar } ({ expr } ). Expected: \n { diff_expr } " \
127+ f"\n got: \n { derivative } "
41128
42129
43130@pytest .mark .parametrize ("idx" , [
@@ -59,15 +146,15 @@ def test_diff_val_recursive(sample_diff_set, idx):
59146@pytest .mark .parametrize ("idx" , [
60147 (0 ),
61148 (1 ),
62- (2 ),
63- (3 )
149+ (2 )
64150])
65151def test_diff_expr (sample_diff_set , idx ):
66152 from expressions .expressions import postvisitor , differentiate
67153 expr , dvar , _ , _ , _ , diff_expr = sample_diff_set [idx ]
68- assert str (postvisitor (expr , differentiate , var = dvar )) == diff_expr , \
154+ derivative = postvisitor (expr , differentiate , var = dvar )
155+ assert expressions_equal (derivative , diff_expr ), \
69156 f"expected an expression of { diff_expr } " \
70- f" for expression d/d{ dvar } ({ expr } )"
157+ f" for expression d/d{ dvar } ({ expr } ), got { derivative } "
71158
72159
73160@pytest .mark .parametrize ("idx" , [
0 commit comments