switching to high quality piper tts and added label translations
This commit is contained in:
@@ -0,0 +1,15 @@
|
||||
""" Unification in SymPy
|
||||
|
||||
See sympy.unify.core docstring for algorithmic details
|
||||
|
||||
See http://matthewrocklin.com/blog/work/2012/11/01/Unification/ for discussion
|
||||
"""
|
||||
|
||||
from .usympy import unify, rebuild
|
||||
from .rewrite import rewriterule
|
||||
|
||||
__all__ = [
|
||||
'unify', 'rebuild',
|
||||
|
||||
'rewriterule',
|
||||
]
|
||||
@@ -0,0 +1,234 @@
|
||||
""" Generic Unification algorithm for expression trees with lists of children
|
||||
|
||||
This implementation is a direct translation of
|
||||
|
||||
Artificial Intelligence: A Modern Approach by Stuart Russel and Peter Norvig
|
||||
Second edition, section 9.2, page 276
|
||||
|
||||
It is modified in the following ways:
|
||||
|
||||
1. We allow associative and commutative Compound expressions. This results in
|
||||
combinatorial blowup.
|
||||
2. We explore the tree lazily.
|
||||
3. We provide generic interfaces to symbolic algebra libraries in Python.
|
||||
|
||||
A more traditional version can be found here
|
||||
http://aima.cs.berkeley.edu/python/logic.html
|
||||
"""
|
||||
|
||||
from sympy.utilities.iterables import kbins
|
||||
|
||||
class Compound:
|
||||
""" A little class to represent an interior node in the tree
|
||||
|
||||
This is analogous to SymPy.Basic for non-Atoms
|
||||
"""
|
||||
def __init__(self, op, args):
|
||||
self.op = op
|
||||
self.args = args
|
||||
|
||||
def __eq__(self, other):
|
||||
return (type(self) is type(other) and self.op == other.op and
|
||||
self.args == other.args)
|
||||
|
||||
def __hash__(self):
|
||||
return hash((type(self), self.op, self.args))
|
||||
|
||||
def __str__(self):
|
||||
return "%s[%s]" % (str(self.op), ', '.join(map(str, self.args)))
|
||||
|
||||
class Variable:
|
||||
""" A Wild token """
|
||||
def __init__(self, arg):
|
||||
self.arg = arg
|
||||
|
||||
def __eq__(self, other):
|
||||
return type(self) is type(other) and self.arg == other.arg
|
||||
|
||||
def __hash__(self):
|
||||
return hash((type(self), self.arg))
|
||||
|
||||
def __str__(self):
|
||||
return "Variable(%s)" % str(self.arg)
|
||||
|
||||
class CondVariable:
|
||||
""" A wild token that matches conditionally.
|
||||
|
||||
arg - a wild token.
|
||||
valid - an additional constraining function on a match.
|
||||
"""
|
||||
def __init__(self, arg, valid):
|
||||
self.arg = arg
|
||||
self.valid = valid
|
||||
|
||||
def __eq__(self, other):
|
||||
return (type(self) is type(other) and
|
||||
self.arg == other.arg and
|
||||
self.valid == other.valid)
|
||||
|
||||
def __hash__(self):
|
||||
return hash((type(self), self.arg, self.valid))
|
||||
|
||||
def __str__(self):
|
||||
return "CondVariable(%s)" % str(self.arg)
|
||||
|
||||
def unify(x, y, s=None, **fns):
|
||||
""" Unify two expressions.
|
||||
|
||||
Parameters
|
||||
==========
|
||||
|
||||
x, y - expression trees containing leaves, Compounds and Variables.
|
||||
s - a mapping of variables to subtrees.
|
||||
|
||||
Returns
|
||||
=======
|
||||
|
||||
lazy sequence of mappings {Variable: subtree}
|
||||
|
||||
Examples
|
||||
========
|
||||
|
||||
>>> from sympy.unify.core import unify, Compound, Variable
|
||||
>>> expr = Compound("Add", ("x", "y"))
|
||||
>>> pattern = Compound("Add", ("x", Variable("a")))
|
||||
>>> next(unify(expr, pattern, {}))
|
||||
{Variable(a): 'y'}
|
||||
"""
|
||||
s = s or {}
|
||||
|
||||
if x == y:
|
||||
yield s
|
||||
elif isinstance(x, (Variable, CondVariable)):
|
||||
yield from unify_var(x, y, s, **fns)
|
||||
elif isinstance(y, (Variable, CondVariable)):
|
||||
yield from unify_var(y, x, s, **fns)
|
||||
elif isinstance(x, Compound) and isinstance(y, Compound):
|
||||
is_commutative = fns.get('is_commutative', lambda x: False)
|
||||
is_associative = fns.get('is_associative', lambda x: False)
|
||||
for sop in unify(x.op, y.op, s, **fns):
|
||||
if is_associative(x) and is_associative(y):
|
||||
a, b = (x, y) if len(x.args) < len(y.args) else (y, x)
|
||||
if is_commutative(x) and is_commutative(y):
|
||||
combs = allcombinations(a.args, b.args, 'commutative')
|
||||
else:
|
||||
combs = allcombinations(a.args, b.args, 'associative')
|
||||
for aaargs, bbargs in combs:
|
||||
aa = [unpack(Compound(a.op, arg)) for arg in aaargs]
|
||||
bb = [unpack(Compound(b.op, arg)) for arg in bbargs]
|
||||
yield from unify(aa, bb, sop, **fns)
|
||||
elif len(x.args) == len(y.args):
|
||||
yield from unify(x.args, y.args, sop, **fns)
|
||||
|
||||
elif is_args(x) and is_args(y) and len(x) == len(y):
|
||||
if len(x) == 0:
|
||||
yield s
|
||||
else:
|
||||
for shead in unify(x[0], y[0], s, **fns):
|
||||
yield from unify(x[1:], y[1:], shead, **fns)
|
||||
|
||||
def unify_var(var, x, s, **fns):
|
||||
if var in s:
|
||||
yield from unify(s[var], x, s, **fns)
|
||||
elif occur_check(var, x):
|
||||
pass
|
||||
elif isinstance(var, CondVariable) and var.valid(x):
|
||||
yield assoc(s, var, x)
|
||||
elif isinstance(var, Variable):
|
||||
yield assoc(s, var, x)
|
||||
|
||||
def occur_check(var, x):
|
||||
""" var occurs in subtree owned by x? """
|
||||
if var == x:
|
||||
return True
|
||||
elif isinstance(x, Compound):
|
||||
return occur_check(var, x.args)
|
||||
elif is_args(x):
|
||||
if any(occur_check(var, xi) for xi in x): return True
|
||||
return False
|
||||
|
||||
def assoc(d, key, val):
|
||||
""" Return copy of d with key associated to val """
|
||||
d = d.copy()
|
||||
d[key] = val
|
||||
return d
|
||||
|
||||
def is_args(x):
|
||||
""" Is x a traditional iterable? """
|
||||
return type(x) in (tuple, list, set)
|
||||
|
||||
def unpack(x):
|
||||
if isinstance(x, Compound) and len(x.args) == 1:
|
||||
return x.args[0]
|
||||
else:
|
||||
return x
|
||||
|
||||
def allcombinations(A, B, ordered):
|
||||
"""
|
||||
Restructure A and B to have the same number of elements.
|
||||
|
||||
Parameters
|
||||
==========
|
||||
|
||||
ordered must be either 'commutative' or 'associative'.
|
||||
|
||||
A and B can be rearranged so that the larger of the two lists is
|
||||
reorganized into smaller sublists.
|
||||
|
||||
Examples
|
||||
========
|
||||
|
||||
>>> from sympy.unify.core import allcombinations
|
||||
>>> for x in allcombinations((1, 2, 3), (5, 6), 'associative'): print(x)
|
||||
(((1,), (2, 3)), ((5,), (6,)))
|
||||
(((1, 2), (3,)), ((5,), (6,)))
|
||||
|
||||
>>> for x in allcombinations((1, 2, 3), (5, 6), 'commutative'): print(x)
|
||||
(((1,), (2, 3)), ((5,), (6,)))
|
||||
(((1, 2), (3,)), ((5,), (6,)))
|
||||
(((1,), (3, 2)), ((5,), (6,)))
|
||||
(((1, 3), (2,)), ((5,), (6,)))
|
||||
(((2,), (1, 3)), ((5,), (6,)))
|
||||
(((2, 1), (3,)), ((5,), (6,)))
|
||||
(((2,), (3, 1)), ((5,), (6,)))
|
||||
(((2, 3), (1,)), ((5,), (6,)))
|
||||
(((3,), (1, 2)), ((5,), (6,)))
|
||||
(((3, 1), (2,)), ((5,), (6,)))
|
||||
(((3,), (2, 1)), ((5,), (6,)))
|
||||
(((3, 2), (1,)), ((5,), (6,)))
|
||||
"""
|
||||
|
||||
if ordered == "commutative":
|
||||
ordered = 11
|
||||
if ordered == "associative":
|
||||
ordered = None
|
||||
sm, bg = (A, B) if len(A) < len(B) else (B, A)
|
||||
for part in kbins(list(range(len(bg))), len(sm), ordered=ordered):
|
||||
if bg == B:
|
||||
yield tuple((a,) for a in A), partition(B, part)
|
||||
else:
|
||||
yield partition(A, part), tuple((b,) for b in B)
|
||||
|
||||
def partition(it, part):
|
||||
""" Partition a tuple/list into pieces defined by indices.
|
||||
|
||||
Examples
|
||||
========
|
||||
|
||||
>>> from sympy.unify.core import partition
|
||||
>>> partition((10, 20, 30, 40), [[0, 1, 2], [3]])
|
||||
((10, 20, 30), (40,))
|
||||
"""
|
||||
return type(it)([index(it, ind) for ind in part])
|
||||
|
||||
def index(it, ind):
|
||||
""" Fancy indexing into an indexable iterable (tuple, list).
|
||||
|
||||
Examples
|
||||
========
|
||||
|
||||
>>> from sympy.unify.core import index
|
||||
>>> index([10, 20, 30], (1, 2, 0))
|
||||
[20, 30, 10]
|
||||
"""
|
||||
return type(it)([it[i] for i in ind])
|
||||
@@ -0,0 +1,55 @@
|
||||
""" Functions to support rewriting of SymPy expressions """
|
||||
|
||||
from sympy.core.expr import Expr
|
||||
from sympy.assumptions import ask
|
||||
from sympy.strategies.tools import subs
|
||||
from sympy.unify.usympy import rebuild, unify
|
||||
|
||||
def rewriterule(source, target, variables=(), condition=None, assume=None):
|
||||
""" Rewrite rule.
|
||||
|
||||
Transform expressions that match source into expressions that match target
|
||||
treating all ``variables`` as wilds.
|
||||
|
||||
Examples
|
||||
========
|
||||
|
||||
>>> from sympy.abc import w, x, y, z
|
||||
>>> from sympy.unify.rewrite import rewriterule
|
||||
>>> from sympy import default_sort_key
|
||||
>>> rl = rewriterule(x + y, x**y, [x, y])
|
||||
>>> sorted(rl(z + 3), key=default_sort_key)
|
||||
[3**z, z**3]
|
||||
|
||||
Use ``condition`` to specify additional requirements. Inputs are taken in
|
||||
the same order as is found in variables.
|
||||
|
||||
>>> rl = rewriterule(x + y, x**y, [x, y], lambda x, y: x.is_integer)
|
||||
>>> list(rl(z + 3))
|
||||
[3**z]
|
||||
|
||||
Use ``assume`` to specify additional requirements using new assumptions.
|
||||
|
||||
>>> from sympy.assumptions import Q
|
||||
>>> rl = rewriterule(x + y, x**y, [x, y], assume=Q.integer(x))
|
||||
>>> list(rl(z + 3))
|
||||
[3**z]
|
||||
|
||||
Assumptions for the local context are provided at rule runtime
|
||||
|
||||
>>> list(rl(w + z, Q.integer(z)))
|
||||
[z**w]
|
||||
"""
|
||||
|
||||
def rewrite_rl(expr, assumptions=True):
|
||||
for match in unify(source, expr, {}, variables=variables):
|
||||
if (condition and
|
||||
not condition(*[match.get(var, var) for var in variables])):
|
||||
continue
|
||||
if (assume and not ask(assume.xreplace(match), assumptions)):
|
||||
continue
|
||||
expr2 = subs(match)(target)
|
||||
if isinstance(expr2, Expr):
|
||||
expr2 = rebuild(expr2)
|
||||
yield expr2
|
||||
return rewrite_rl
|
||||
@@ -0,0 +1,74 @@
|
||||
from sympy.unify.rewrite import rewriterule
|
||||
from sympy.core.basic import Basic
|
||||
from sympy.core.singleton import S
|
||||
from sympy.core.symbol import Symbol
|
||||
from sympy.functions.elementary.trigonometric import sin
|
||||
from sympy.abc import x, y
|
||||
from sympy.strategies.rl import rebuild
|
||||
from sympy.assumptions import Q
|
||||
|
||||
p, q = Symbol('p'), Symbol('q')
|
||||
|
||||
def test_simple():
|
||||
rl = rewriterule(Basic(p, S(1)), Basic(p, S(2)), variables=(p,))
|
||||
assert list(rl(Basic(S(3), S(1)))) == [Basic(S(3), S(2))]
|
||||
|
||||
p1 = p**2
|
||||
p2 = p**3
|
||||
rl = rewriterule(p1, p2, variables=(p,))
|
||||
|
||||
expr = x**2
|
||||
assert list(rl(expr)) == [x**3]
|
||||
|
||||
def test_simple_variables():
|
||||
rl = rewriterule(Basic(x, S(1)), Basic(x, S(2)), variables=(x,))
|
||||
assert list(rl(Basic(S(3), S(1)))) == [Basic(S(3), S(2))]
|
||||
|
||||
rl = rewriterule(x**2, x**3, variables=(x,))
|
||||
assert list(rl(y**2)) == [y**3]
|
||||
|
||||
def test_moderate():
|
||||
p1 = p**2 + q**3
|
||||
p2 = (p*q)**4
|
||||
rl = rewriterule(p1, p2, (p, q))
|
||||
|
||||
expr = x**2 + y**3
|
||||
assert list(rl(expr)) == [(x*y)**4]
|
||||
|
||||
def test_sincos():
|
||||
p1 = sin(p)**2 + sin(p)**2
|
||||
p2 = 1
|
||||
rl = rewriterule(p1, p2, (p, q))
|
||||
|
||||
assert list(rl(sin(x)**2 + sin(x)**2)) == [1]
|
||||
assert list(rl(sin(y)**2 + sin(y)**2)) == [1]
|
||||
|
||||
def test_Exprs_ok():
|
||||
rl = rewriterule(p+q, q+p, (p, q))
|
||||
next(rl(x+y)).is_commutative
|
||||
str(next(rl(x+y)))
|
||||
|
||||
def test_condition_simple():
|
||||
rl = rewriterule(x, x+1, [x], lambda x: x < 10)
|
||||
assert not list(rl(S(15)))
|
||||
assert rebuild(next(rl(S(5)))) == 6
|
||||
|
||||
|
||||
def test_condition_multiple():
|
||||
rl = rewriterule(x + y, x**y, [x,y], lambda x, y: x.is_integer)
|
||||
|
||||
a = Symbol('a')
|
||||
b = Symbol('b', integer=True)
|
||||
expr = a + b
|
||||
assert list(rl(expr)) == [b**a]
|
||||
|
||||
c = Symbol('c', integer=True)
|
||||
d = Symbol('d', integer=True)
|
||||
assert set(rl(c + d)) == {c**d, d**c}
|
||||
|
||||
def test_assumptions():
|
||||
rl = rewriterule(x + y, x**y, [x, y], assume=Q.integer(x))
|
||||
|
||||
a, b = map(Symbol, 'ab')
|
||||
expr = a + b
|
||||
assert list(rl(expr, Q.integer(b))) == [b**a]
|
||||
@@ -0,0 +1,162 @@
|
||||
from sympy.core.add import Add
|
||||
from sympy.core.basic import Basic
|
||||
from sympy.core.containers import Tuple
|
||||
from sympy.core.singleton import S
|
||||
from sympy.core.symbol import (Symbol, symbols)
|
||||
from sympy.logic.boolalg import And
|
||||
from sympy.core.symbol import Str
|
||||
from sympy.unify.core import Compound, Variable
|
||||
from sympy.unify.usympy import (deconstruct, construct, unify, is_associative,
|
||||
is_commutative)
|
||||
from sympy.abc import x, y, z, n
|
||||
|
||||
def test_deconstruct():
|
||||
expr = Basic(S(1), S(2), S(3))
|
||||
expected = Compound(Basic, (1, 2, 3))
|
||||
assert deconstruct(expr) == expected
|
||||
|
||||
assert deconstruct(1) == 1
|
||||
assert deconstruct(x) == x
|
||||
assert deconstruct(x, variables=(x,)) == Variable(x)
|
||||
assert deconstruct(Add(1, x, evaluate=False)) == Compound(Add, (1, x))
|
||||
assert deconstruct(Add(1, x, evaluate=False), variables=(x,)) == \
|
||||
Compound(Add, (1, Variable(x)))
|
||||
|
||||
def test_construct():
|
||||
expr = Compound(Basic, (S(1), S(2), S(3)))
|
||||
expected = Basic(S(1), S(2), S(3))
|
||||
assert construct(expr) == expected
|
||||
|
||||
def test_nested():
|
||||
expr = Basic(S(1), Basic(S(2)), S(3))
|
||||
cmpd = Compound(Basic, (S(1), Compound(Basic, Tuple(2)), S(3)))
|
||||
assert deconstruct(expr) == cmpd
|
||||
assert construct(cmpd) == expr
|
||||
|
||||
def test_unify():
|
||||
expr = Basic(S(1), S(2), S(3))
|
||||
a, b, c = map(Symbol, 'abc')
|
||||
pattern = Basic(a, b, c)
|
||||
assert list(unify(expr, pattern, {}, (a, b, c))) == [{a: 1, b: 2, c: 3}]
|
||||
assert list(unify(expr, pattern, variables=(a, b, c))) == \
|
||||
[{a: 1, b: 2, c: 3}]
|
||||
|
||||
def test_unify_variables():
|
||||
assert list(unify(Basic(S(1), S(2)), Basic(S(1), x), {}, variables=(x,))) == [{x: 2}]
|
||||
|
||||
def test_s_input():
|
||||
expr = Basic(S(1), S(2))
|
||||
a, b = map(Symbol, 'ab')
|
||||
pattern = Basic(a, b)
|
||||
assert list(unify(expr, pattern, {}, (a, b))) == [{a: 1, b: 2}]
|
||||
assert list(unify(expr, pattern, {a: 5}, (a, b))) == []
|
||||
|
||||
def iterdicteq(a, b):
|
||||
a = tuple(a)
|
||||
b = tuple(b)
|
||||
return len(a) == len(b) and all(x in b for x in a)
|
||||
|
||||
def test_unify_commutative():
|
||||
expr = Add(1, 2, 3, evaluate=False)
|
||||
a, b, c = map(Symbol, 'abc')
|
||||
pattern = Add(a, b, c, evaluate=False)
|
||||
|
||||
result = tuple(unify(expr, pattern, {}, (a, b, c)))
|
||||
expected = ({a: 1, b: 2, c: 3},
|
||||
{a: 1, b: 3, c: 2},
|
||||
{a: 2, b: 1, c: 3},
|
||||
{a: 2, b: 3, c: 1},
|
||||
{a: 3, b: 1, c: 2},
|
||||
{a: 3, b: 2, c: 1})
|
||||
|
||||
assert iterdicteq(result, expected)
|
||||
|
||||
def test_unify_iter():
|
||||
expr = Add(1, 2, 3, evaluate=False)
|
||||
a, b, c = map(Symbol, 'abc')
|
||||
pattern = Add(a, c, evaluate=False)
|
||||
assert is_associative(deconstruct(pattern))
|
||||
assert is_commutative(deconstruct(pattern))
|
||||
|
||||
result = list(unify(expr, pattern, {}, (a, c)))
|
||||
expected = [{a: 1, c: Add(2, 3, evaluate=False)},
|
||||
{a: 1, c: Add(3, 2, evaluate=False)},
|
||||
{a: 2, c: Add(1, 3, evaluate=False)},
|
||||
{a: 2, c: Add(3, 1, evaluate=False)},
|
||||
{a: 3, c: Add(1, 2, evaluate=False)},
|
||||
{a: 3, c: Add(2, 1, evaluate=False)},
|
||||
{a: Add(1, 2, evaluate=False), c: 3},
|
||||
{a: Add(2, 1, evaluate=False), c: 3},
|
||||
{a: Add(1, 3, evaluate=False), c: 2},
|
||||
{a: Add(3, 1, evaluate=False), c: 2},
|
||||
{a: Add(2, 3, evaluate=False), c: 1},
|
||||
{a: Add(3, 2, evaluate=False), c: 1}]
|
||||
|
||||
assert iterdicteq(result, expected)
|
||||
|
||||
def test_hard_match():
|
||||
from sympy.functions.elementary.trigonometric import (cos, sin)
|
||||
expr = sin(x) + cos(x)**2
|
||||
p, q = map(Symbol, 'pq')
|
||||
pattern = sin(p) + cos(p)**2
|
||||
assert list(unify(expr, pattern, {}, (p, q))) == [{p: x}]
|
||||
|
||||
def test_matrix():
|
||||
from sympy.matrices.expressions.matexpr import MatrixSymbol
|
||||
X = MatrixSymbol('X', n, n)
|
||||
Y = MatrixSymbol('Y', 2, 2)
|
||||
Z = MatrixSymbol('Z', 2, 3)
|
||||
assert list(unify(X, Y, {}, variables=[n, Str('X')])) == [{Str('X'): Str('Y'), n: 2}]
|
||||
assert list(unify(X, Z, {}, variables=[n, Str('X')])) == []
|
||||
|
||||
def test_non_frankenAdds():
|
||||
# the is_commutative property used to fail because of Basic.__new__
|
||||
# This caused is_commutative and str calls to fail
|
||||
expr = x+y*2
|
||||
rebuilt = construct(deconstruct(expr))
|
||||
# Ensure that we can run these commands without causing an error
|
||||
str(rebuilt)
|
||||
rebuilt.is_commutative
|
||||
|
||||
def test_FiniteSet_commutivity():
|
||||
from sympy.sets.sets import FiniteSet
|
||||
a, b, c, x, y = symbols('a,b,c,x,y')
|
||||
s = FiniteSet(a, b, c)
|
||||
t = FiniteSet(x, y)
|
||||
variables = (x, y)
|
||||
assert {x: FiniteSet(a, c), y: b} in tuple(unify(s, t, variables=variables))
|
||||
|
||||
def test_FiniteSet_complex():
|
||||
from sympy.sets.sets import FiniteSet
|
||||
a, b, c, x, y, z = symbols('a,b,c,x,y,z')
|
||||
expr = FiniteSet(Basic(S(1), x), y, Basic(x, z))
|
||||
pattern = FiniteSet(a, Basic(x, b))
|
||||
variables = a, b
|
||||
expected = ({b: 1, a: FiniteSet(y, Basic(x, z))},
|
||||
{b: z, a: FiniteSet(y, Basic(S(1), x))})
|
||||
assert iterdicteq(unify(expr, pattern, variables=variables), expected)
|
||||
|
||||
|
||||
def test_and():
|
||||
variables = x, y
|
||||
expected = ({x: z > 0, y: n < 3},)
|
||||
assert iterdicteq(unify((z>0) & (n<3), And(x, y), variables=variables),
|
||||
expected)
|
||||
|
||||
def test_Union():
|
||||
from sympy.sets.sets import Interval
|
||||
assert list(unify(Interval(0, 1) + Interval(10, 11),
|
||||
Interval(0, 1) + Interval(12, 13),
|
||||
variables=(Interval(12, 13),)))
|
||||
|
||||
def test_is_commutative():
|
||||
assert is_commutative(deconstruct(x+y))
|
||||
assert is_commutative(deconstruct(x*y))
|
||||
assert not is_commutative(deconstruct(x**y))
|
||||
|
||||
def test_commutative_in_commutative():
|
||||
from sympy.abc import a,b,c,d
|
||||
from sympy.functions.elementary.trigonometric import (cos, sin)
|
||||
eq = sin(3)*sin(4)*sin(5) + 4*cos(3)*cos(4)
|
||||
pat = a*cos(b)*cos(c) + d*sin(b)*sin(c)
|
||||
assert next(unify(eq, pat, variables=(a,b,c,d)))
|
||||
@@ -0,0 +1,88 @@
|
||||
from sympy.unify.core import Compound, Variable, CondVariable, allcombinations
|
||||
from sympy.unify import core
|
||||
|
||||
a,b,c = 'a', 'b', 'c'
|
||||
w,x,y,z = map(Variable, 'wxyz')
|
||||
|
||||
C = Compound
|
||||
|
||||
def is_associative(x):
|
||||
return isinstance(x, Compound) and (x.op in ('Add', 'Mul', 'CAdd', 'CMul'))
|
||||
def is_commutative(x):
|
||||
return isinstance(x, Compound) and (x.op in ('CAdd', 'CMul'))
|
||||
|
||||
|
||||
def unify(a, b, s={}):
|
||||
return core.unify(a, b, s=s, is_associative=is_associative,
|
||||
is_commutative=is_commutative)
|
||||
|
||||
def test_basic():
|
||||
assert list(unify(a, x, {})) == [{x: a}]
|
||||
assert list(unify(a, x, {x: 10})) == []
|
||||
assert list(unify(1, x, {})) == [{x: 1}]
|
||||
assert list(unify(a, a, {})) == [{}]
|
||||
assert list(unify((w, x), (y, z), {})) == [{w: y, x: z}]
|
||||
assert list(unify(x, (a, b), {})) == [{x: (a, b)}]
|
||||
|
||||
assert list(unify((a, b), (x, x), {})) == []
|
||||
assert list(unify((y, z), (x, x), {}))!= []
|
||||
assert list(unify((a, (b, c)), (a, (x, y)), {})) == [{x: b, y: c}]
|
||||
|
||||
def test_ops():
|
||||
assert list(unify(C('Add', (a,b,c)), C('Add', (a,x,y)), {})) == \
|
||||
[{x:b, y:c}]
|
||||
assert list(unify(C('Add', (C('Mul', (1,2)), b,c)), C('Add', (x,y,c)), {})) == \
|
||||
[{x: C('Mul', (1,2)), y:b}]
|
||||
|
||||
def test_associative():
|
||||
c1 = C('Add', (1,2,3))
|
||||
c2 = C('Add', (x,y))
|
||||
assert tuple(unify(c1, c2, {})) == ({x: 1, y: C('Add', (2, 3))},
|
||||
{x: C('Add', (1, 2)), y: 3})
|
||||
|
||||
def test_commutative():
|
||||
c1 = C('CAdd', (1,2,3))
|
||||
c2 = C('CAdd', (x,y))
|
||||
result = list(unify(c1, c2, {}))
|
||||
assert {x: 1, y: C('CAdd', (2, 3))} in result
|
||||
assert ({x: 2, y: C('CAdd', (1, 3))} in result or
|
||||
{x: 2, y: C('CAdd', (3, 1))} in result)
|
||||
|
||||
def _test_combinations_assoc():
|
||||
assert set(allcombinations((1,2,3), (a,b), True)) == \
|
||||
{(((1, 2), (3,)), (a, b)), (((1,), (2, 3)), (a, b))}
|
||||
|
||||
def _test_combinations_comm():
|
||||
assert set(allcombinations((1,2,3), (a,b), None)) == \
|
||||
{(((1,), (2, 3)), ('a', 'b')), (((2,), (3, 1)), ('a', 'b')),
|
||||
(((3,), (1, 2)), ('a', 'b')), (((1, 2), (3,)), ('a', 'b')),
|
||||
(((2, 3), (1,)), ('a', 'b')), (((3, 1), (2,)), ('a', 'b'))}
|
||||
|
||||
def test_allcombinations():
|
||||
assert set(allcombinations((1,2), (1,2), 'commutative')) ==\
|
||||
{(((1,),(2,)), ((1,),(2,))), (((1,),(2,)), ((2,),(1,)))}
|
||||
|
||||
|
||||
def test_commutativity():
|
||||
c1 = Compound('CAdd', (a, b))
|
||||
c2 = Compound('CAdd', (x, y))
|
||||
assert is_commutative(c1) and is_commutative(c2)
|
||||
assert len(list(unify(c1, c2, {}))) == 2
|
||||
|
||||
|
||||
def test_CondVariable():
|
||||
expr = C('CAdd', (1, 2))
|
||||
x = Variable('x')
|
||||
y = CondVariable('y', lambda a: a % 2 == 0)
|
||||
z = CondVariable('z', lambda a: a > 3)
|
||||
pattern = C('CAdd', (x, y))
|
||||
assert list(unify(expr, pattern, {})) == \
|
||||
[{x: 1, y: 2}]
|
||||
|
||||
z = CondVariable('z', lambda a: a > 3)
|
||||
pattern = C('CAdd', (z, y))
|
||||
|
||||
assert list(unify(expr, pattern, {})) == []
|
||||
|
||||
def test_defaultdict():
|
||||
assert next(unify(Variable('x'), 'foo')) == {Variable('x'): 'foo'}
|
||||
@@ -0,0 +1,124 @@
|
||||
""" SymPy interface to Unification engine
|
||||
|
||||
See sympy.unify for module level docstring
|
||||
See sympy.unify.core for algorithmic docstring """
|
||||
|
||||
from sympy.core import Basic, Add, Mul, Pow
|
||||
from sympy.core.operations import AssocOp, LatticeOp
|
||||
from sympy.matrices import MatAdd, MatMul, MatrixExpr
|
||||
from sympy.sets.sets import Union, Intersection, FiniteSet
|
||||
from sympy.unify.core import Compound, Variable, CondVariable
|
||||
from sympy.unify import core
|
||||
|
||||
basic_new_legal = [MatrixExpr]
|
||||
eval_false_legal = [AssocOp, Pow, FiniteSet]
|
||||
illegal = [LatticeOp]
|
||||
|
||||
def sympy_associative(op):
|
||||
assoc_ops = (AssocOp, MatAdd, MatMul, Union, Intersection, FiniteSet)
|
||||
return any(issubclass(op, aop) for aop in assoc_ops)
|
||||
|
||||
def sympy_commutative(op):
|
||||
comm_ops = (Add, MatAdd, Union, Intersection, FiniteSet)
|
||||
return any(issubclass(op, cop) for cop in comm_ops)
|
||||
|
||||
def is_associative(x):
|
||||
return isinstance(x, Compound) and sympy_associative(x.op)
|
||||
|
||||
def is_commutative(x):
|
||||
if not isinstance(x, Compound):
|
||||
return False
|
||||
if sympy_commutative(x.op):
|
||||
return True
|
||||
if issubclass(x.op, Mul):
|
||||
return all(construct(arg).is_commutative for arg in x.args)
|
||||
|
||||
def mk_matchtype(typ):
|
||||
def matchtype(x):
|
||||
return (isinstance(x, typ) or
|
||||
isinstance(x, Compound) and issubclass(x.op, typ))
|
||||
return matchtype
|
||||
|
||||
def deconstruct(s, variables=()):
|
||||
""" Turn a SymPy object into a Compound """
|
||||
if s in variables:
|
||||
return Variable(s)
|
||||
if isinstance(s, (Variable, CondVariable)):
|
||||
return s
|
||||
if not isinstance(s, Basic) or s.is_Atom:
|
||||
return s
|
||||
return Compound(s.__class__,
|
||||
tuple(deconstruct(arg, variables) for arg in s.args))
|
||||
|
||||
def construct(t):
|
||||
""" Turn a Compound into a SymPy object """
|
||||
if isinstance(t, (Variable, CondVariable)):
|
||||
return t.arg
|
||||
if not isinstance(t, Compound):
|
||||
return t
|
||||
if any(issubclass(t.op, cls) for cls in eval_false_legal):
|
||||
return t.op(*map(construct, t.args), evaluate=False)
|
||||
elif any(issubclass(t.op, cls) for cls in basic_new_legal):
|
||||
return Basic.__new__(t.op, *map(construct, t.args))
|
||||
else:
|
||||
return t.op(*map(construct, t.args))
|
||||
|
||||
def rebuild(s):
|
||||
""" Rebuild a SymPy expression.
|
||||
|
||||
This removes harm caused by Expr-Rules interactions.
|
||||
"""
|
||||
return construct(deconstruct(s))
|
||||
|
||||
def unify(x, y, s=None, variables=(), **kwargs):
|
||||
""" Structural unification of two expressions/patterns.
|
||||
|
||||
Examples
|
||||
========
|
||||
|
||||
>>> from sympy.unify.usympy import unify
|
||||
>>> from sympy import Basic, S
|
||||
>>> from sympy.abc import x, y, z, p, q
|
||||
|
||||
>>> next(unify(Basic(S(1), S(2)), Basic(S(1), x), variables=[x]))
|
||||
{x: 2}
|
||||
|
||||
>>> expr = 2*x + y + z
|
||||
>>> pattern = 2*p + q
|
||||
>>> next(unify(expr, pattern, {}, variables=(p, q)))
|
||||
{p: x, q: y + z}
|
||||
|
||||
Unification supports commutative and associative matching
|
||||
|
||||
>>> expr = x + y + z
|
||||
>>> pattern = p + q
|
||||
>>> len(list(unify(expr, pattern, {}, variables=(p, q))))
|
||||
12
|
||||
|
||||
Symbols not indicated to be variables are treated as literal,
|
||||
else they are wild-like and match anything in a sub-expression.
|
||||
|
||||
>>> expr = x*y*z + 3
|
||||
>>> pattern = x*y + 3
|
||||
>>> next(unify(expr, pattern, {}, variables=[x, y]))
|
||||
{x: y, y: x*z}
|
||||
|
||||
The x and y of the pattern above were in a Mul and matched factors
|
||||
in the Mul of expr. Here, a single symbol matches an entire term:
|
||||
|
||||
>>> expr = x*y + 3
|
||||
>>> pattern = p + 3
|
||||
>>> next(unify(expr, pattern, {}, variables=[p]))
|
||||
{p: x*y}
|
||||
|
||||
"""
|
||||
decons = lambda x: deconstruct(x, variables)
|
||||
s = s or {}
|
||||
s = {decons(k): decons(v) for k, v in s.items()}
|
||||
|
||||
ds = core.unify(decons(x), decons(y), s,
|
||||
is_associative=is_associative,
|
||||
is_commutative=is_commutative,
|
||||
**kwargs)
|
||||
for d in ds:
|
||||
yield {construct(k): construct(v) for k, v in d.items()}
|
||||
Reference in New Issue
Block a user