Skip to content

Commit 9e51757

Browse files
authored
Merge pull request #2 from Imperial-MATH50009/main
better equality operator
2 parents a0a2532 + d404979 commit 9e51757

File tree

1 file changed

+101
-14
lines changed

1 file changed

+101
-14
lines changed

tests/test_exercise_9_8.py

Lines changed: 101 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,97 @@
11
import pytest
2+
from functools import singledispatch
3+
from collections.abc import Sequence
4+
from example_code.expression_tools import postvisitor, evaluate
5+
from expressions.expressions import Symbol, Number, \
6+
Operator, Add, Mul
27
try:
38
from expressions.expressions import differentiate
49
except ImportError:
510
pass
611

712

13+
class ExpressionsError(Exception):
14+
pass
15+
16+
17+
_test_expr = Symbol("x") / Number(1)
18+
_operands_attr = None
19+
for attr_name in dir(_test_expr):
20+
attr = getattr(_test_expr, attr_name)
21+
if (
22+
isinstance(attr, Sequence) and len(attr) == 2
23+
and type(attr[0]) is Symbol and str(attr[0]) == "x"
24+
and type(attr[1]) is Number and float(str(attr[1])) == 1.0
25+
):
26+
_operands_attr = attr_name
27+
break
28+
29+
30+
def try_eval(expression):
31+
"""Evaluate an expression if it doesn't have any symbols."""
32+
try:
33+
return postvisitor(expression, evaluate, symbol_map={})
34+
except KeyError:
35+
return expression
36+
37+
38+
def operands(expression):
39+
"""Return the operands tuple of an expression."""
40+
if _operands_attr is None:
41+
raise ExpressionsError("Could not find operands tuple on expression.")
42+
return tuple(getattr(expression, _operands_attr))
43+
44+
45+
@singledispatch
46+
def expressions_equal(t1, t2):
47+
"""Return true if two expressions are equal.
48+
49+
This function takes into account commutative operators, but not the full
50+
class of equivalences between expressions.
51+
"""
52+
try:
53+
if isinstance(t1, tuple) and isinstance(t2, tuple):
54+
return len(t1) == len(t2)\
55+
and all(expressions_equal(e1, e2) for e1, e2 in zip(t1, t2))
56+
except:
57+
return False
58+
return False # By default, expressions don't match.
59+
60+
61+
@expressions_equal.register(Number)
62+
def _(e1, e2):
63+
#  Not the nicest way to do it, but `str` is the only way we have to
64+
#  extract the value.
65+
return float(str(e1)) == float(str(e2))
66+
67+
68+
@expressions_equal.register(Symbol)
69+
def _(e1, e2):
70+
return str(e1) == str(e2)
71+
72+
73+
@expressions_equal.register(Operator)
74+
def _(e1, e2):
75+
return type(e1) is type(e2)\
76+
and expressions_equal(operands(e1), operands(e2))
77+
78+
79+
@expressions_equal.register(Add)
80+
@expressions_equal.register(Mul)
81+
def _(e1, e2):
82+
return type(e1) is type(e2)\
83+
and (
84+
expressions_equal(operands(e1), operands(e2))
85+
or expressions_equal(operands(e1), tuple(reversed(operands(e2))))
86+
) or try_eval(e1) == try_eval(e2)
87+
88+
89+
expressions_equal.register(tuple)
90+
def _(t1, t2):
91+
return type(t1) is type(t2) and len(t1) == len(t2)\
92+
and all(expressions_equal(e1, e2) for e1, e2 in zip(t1, t2))
93+
94+
895
def test_diff_import():
996
from expressions.expressions import differentiate
1097

@@ -14,30 +101,30 @@ def sample_diff_set():
14101
from expressions.expressions import Symbol
15102
x = Symbol('x')
16103
y = Symbol('y')
17-
tests = [(2 * x + 1, 'x', 1.5, 10, 2, '0.0 * x + 1.0 * 2 + 0.0'),
18-
(3 * x + y, 'x', 1.5, 10, 3, '0.0 * x + 1.0 * 3 + 0.0'),
104+
tests = [(2 * x + 1, 'x', 1.5, 10, 2, 0.0 * x + 1.0 * 2 + 0.0),
105+
(3 * x + y, 'x', 1.5, 10, 3, 0.0 * x + 1.0 * 3 + 0.0),
19106
(x * y + x / y, 'y', 3, 2, 2.25,
20-
'0.0 * y + 1.0 * x + (0.0 * y - x * 1.0) / y ^ 2'),
107+
0.0 * y + 1.0 * x + (0.0 * y - x * 1.0) / (y ** 2)),
21108
(2 * x**3 + x**2 * y, 'x', 2, 3, 36,
22-
'0.0 * x ^ 3 + 3 * x ^ (3 - 1) * 1.0 * 2 + 2 * x ^'
23-
' (2 - 1) * 1.0 * y + 0.0 * x ^ 2')
109+
0.0 * x**3 + 3 * x**(3 - 1) * 1.0 * 2 + 2 * x**(2 - 1) * 1.0 *
110+
y + 0.0 * x**2)
24111
]
25112
return tests
26113

27114

28115
@pytest.mark.parametrize("idx", [
29116
(0),
30117
(1),
31-
(2),
32-
(3)
118+
(2)
33119
])
34120
def test_diff_expr_recursive(sample_diff_set, idx):
35121
from expressions.expressions import differentiate
36122
from example_code.expression_tools import postvisitor
37123
expr, dvar, _, _, _, diff_expr = sample_diff_set[idx]
38-
assert str(postvisitor(expr, differentiate, var=dvar)) == diff_expr, \
39-
f"expected an expression of {diff_expr}"\
40-
f" for expression d/d{dvar}({expr})"
124+
derivative = postvisitor(expr, differentiate, var=dvar)
125+
assert expressions_equal(derivative, diff_expr), \
126+
f"Computing expression d/d{dvar}({expr}). Expected: \n {diff_expr}"\
127+
f"\n got: \n {derivative}"
41128

42129

43130
@pytest.mark.parametrize("idx", [
@@ -59,15 +146,15 @@ def test_diff_val_recursive(sample_diff_set, idx):
59146
@pytest.mark.parametrize("idx", [
60147
(0),
61148
(1),
62-
(2),
63-
(3)
149+
(2)
64150
])
65151
def test_diff_expr(sample_diff_set, idx):
66152
from expressions.expressions import postvisitor, differentiate
67153
expr, dvar, _, _, _, diff_expr = sample_diff_set[idx]
68-
assert str(postvisitor(expr, differentiate, var=dvar)) == diff_expr, \
154+
derivative = postvisitor(expr, differentiate, var=dvar)
155+
assert expressions_equal(derivative, diff_expr), \
69156
f"expected an expression of {diff_expr}"\
70-
f" for expression d/d{dvar}({expr})"
157+
f" for expression d/d{dvar}({expr}), got {derivative}"
71158

72159

73160
@pytest.mark.parametrize("idx", [

0 commit comments

Comments
 (0)