switching to high quality piper tts and added label translations
This commit is contained in:
@@ -0,0 +1,633 @@
|
||||
"""
|
||||
Important note on tests in this module - the Aesara printing functions use a
|
||||
global cache by default, which means that tests using it will modify global
|
||||
state and thus not be independent from each other. Instead of using the "cache"
|
||||
keyword argument each time, this module uses the aesara_code_ and
|
||||
aesara_function_ functions defined below which default to using a new, empty
|
||||
cache instead.
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
||||
from sympy.external import import_module
|
||||
from sympy.testing.pytest import raises, SKIP, warns_deprecated_sympy
|
||||
|
||||
from sympy.utilities.exceptions import ignore_warnings
|
||||
|
||||
|
||||
aesaralogger = logging.getLogger('aesara.configdefaults')
|
||||
aesaralogger.setLevel(logging.CRITICAL)
|
||||
aesara = import_module('aesara')
|
||||
aesaralogger.setLevel(logging.WARNING)
|
||||
|
||||
|
||||
if aesara:
|
||||
import numpy as np
|
||||
aet = aesara.tensor
|
||||
from aesara.scalar.basic import ScalarType
|
||||
from aesara.graph.basic import Variable
|
||||
from aesara.tensor.var import TensorVariable
|
||||
from aesara.tensor.elemwise import Elemwise, DimShuffle
|
||||
from aesara.tensor.math import Dot
|
||||
|
||||
from sympy.printing.aesaracode import true_divide
|
||||
|
||||
xt, yt, zt = [aet.scalar(name, 'floatX') for name in 'xyz']
|
||||
Xt, Yt, Zt = [aet.tensor('floatX', (False, False), name=n) for n in 'XYZ']
|
||||
else:
|
||||
#bin/test will not execute any tests now
|
||||
disabled = True
|
||||
|
||||
import sympy as sy
|
||||
from sympy.core.singleton import S
|
||||
from sympy.abc import x, y, z, t
|
||||
from sympy.printing.aesaracode import (aesara_code, dim_handling,
|
||||
aesara_function)
|
||||
|
||||
|
||||
# Default set of matrix symbols for testing - make square so we can both
|
||||
# multiply and perform elementwise operations between them.
|
||||
X, Y, Z = [sy.MatrixSymbol(n, 4, 4) for n in 'XYZ']
|
||||
|
||||
# For testing AppliedUndef
|
||||
f_t = sy.Function('f')(t)
|
||||
|
||||
|
||||
def aesara_code_(expr, **kwargs):
|
||||
""" Wrapper for aesara_code that uses a new, empty cache by default. """
|
||||
kwargs.setdefault('cache', {})
|
||||
with warns_deprecated_sympy():
|
||||
return aesara_code(expr, **kwargs)
|
||||
|
||||
def aesara_function_(inputs, outputs, **kwargs):
|
||||
""" Wrapper for aesara_function that uses a new, empty cache by default. """
|
||||
kwargs.setdefault('cache', {})
|
||||
with warns_deprecated_sympy():
|
||||
return aesara_function(inputs, outputs, **kwargs)
|
||||
|
||||
|
||||
def fgraph_of(*exprs):
|
||||
""" Transform SymPy expressions into Aesara Computation.
|
||||
|
||||
Parameters
|
||||
==========
|
||||
exprs
|
||||
SymPy expressions
|
||||
|
||||
Returns
|
||||
=======
|
||||
aesara.graph.fg.FunctionGraph
|
||||
"""
|
||||
outs = list(map(aesara_code_, exprs))
|
||||
ins = list(aesara.graph.basic.graph_inputs(outs))
|
||||
ins, outs = aesara.graph.basic.clone(ins, outs)
|
||||
return aesara.graph.fg.FunctionGraph(ins, outs)
|
||||
|
||||
|
||||
def aesara_simplify(fgraph):
|
||||
""" Simplify a Aesara Computation.
|
||||
|
||||
Parameters
|
||||
==========
|
||||
fgraph : aesara.graph.fg.FunctionGraph
|
||||
|
||||
Returns
|
||||
=======
|
||||
aesara.graph.fg.FunctionGraph
|
||||
"""
|
||||
mode = aesara.compile.get_default_mode().excluding("fusion")
|
||||
fgraph = fgraph.clone()
|
||||
mode.optimizer.rewrite(fgraph)
|
||||
return fgraph
|
||||
|
||||
|
||||
def theq(a, b):
|
||||
""" Test two Aesara objects for equality.
|
||||
|
||||
Also accepts numeric types and lists/tuples of supported types.
|
||||
|
||||
Note - debugprint() has a bug where it will accept numeric types but does
|
||||
not respect the "file" argument and in this case and instead prints the number
|
||||
to stdout and returns an empty string. This can lead to tests passing where
|
||||
they should fail because any two numbers will always compare as equal. To
|
||||
prevent this we treat numbers as a separate case.
|
||||
"""
|
||||
numeric_types = (int, float, np.number)
|
||||
a_is_num = isinstance(a, numeric_types)
|
||||
b_is_num = isinstance(b, numeric_types)
|
||||
|
||||
# Compare numeric types using regular equality
|
||||
if a_is_num or b_is_num:
|
||||
if not (a_is_num and b_is_num):
|
||||
return False
|
||||
|
||||
return a == b
|
||||
|
||||
# Compare sequences element-wise
|
||||
a_is_seq = isinstance(a, (tuple, list))
|
||||
b_is_seq = isinstance(b, (tuple, list))
|
||||
|
||||
if a_is_seq or b_is_seq:
|
||||
if not (a_is_seq and b_is_seq) or type(a) != type(b):
|
||||
return False
|
||||
|
||||
return list(map(theq, a)) == list(map(theq, b))
|
||||
|
||||
# Otherwise, assume debugprint() can handle it
|
||||
astr = aesara.printing.debugprint(a, file='str')
|
||||
bstr = aesara.printing.debugprint(b, file='str')
|
||||
|
||||
# Check for bug mentioned above
|
||||
for argname, argval, argstr in [('a', a, astr), ('b', b, bstr)]:
|
||||
if argstr == '':
|
||||
raise TypeError(
|
||||
'aesara.printing.debugprint(%s) returned empty string '
|
||||
'(%s is instance of %r)'
|
||||
% (argname, argname, type(argval))
|
||||
)
|
||||
|
||||
return astr == bstr
|
||||
|
||||
|
||||
def test_example_symbols():
|
||||
"""
|
||||
Check that the example symbols in this module print to their Aesara
|
||||
equivalents, as many of the other tests depend on this.
|
||||
"""
|
||||
assert theq(xt, aesara_code_(x))
|
||||
assert theq(yt, aesara_code_(y))
|
||||
assert theq(zt, aesara_code_(z))
|
||||
assert theq(Xt, aesara_code_(X))
|
||||
assert theq(Yt, aesara_code_(Y))
|
||||
assert theq(Zt, aesara_code_(Z))
|
||||
|
||||
|
||||
def test_Symbol():
|
||||
""" Test printing a Symbol to a aesara variable. """
|
||||
xx = aesara_code_(x)
|
||||
assert isinstance(xx, Variable)
|
||||
assert xx.broadcastable == ()
|
||||
assert xx.name == x.name
|
||||
|
||||
xx2 = aesara_code_(x, broadcastables={x: (False,)})
|
||||
assert xx2.broadcastable == (False,)
|
||||
assert xx2.name == x.name
|
||||
|
||||
def test_MatrixSymbol():
|
||||
""" Test printing a MatrixSymbol to a aesara variable. """
|
||||
XX = aesara_code_(X)
|
||||
assert isinstance(XX, TensorVariable)
|
||||
assert XX.broadcastable == (False, False)
|
||||
|
||||
@SKIP # TODO - this is currently not checked but should be implemented
|
||||
def test_MatrixSymbol_wrong_dims():
|
||||
""" Test MatrixSymbol with invalid broadcastable. """
|
||||
bcs = [(), (False,), (True,), (True, False), (False, True,), (True, True)]
|
||||
for bc in bcs:
|
||||
with raises(ValueError):
|
||||
aesara_code_(X, broadcastables={X: bc})
|
||||
|
||||
def test_AppliedUndef():
|
||||
""" Test printing AppliedUndef instance, which works similarly to Symbol. """
|
||||
ftt = aesara_code_(f_t)
|
||||
assert isinstance(ftt, TensorVariable)
|
||||
assert ftt.broadcastable == ()
|
||||
assert ftt.name == 'f_t'
|
||||
|
||||
|
||||
def test_add():
|
||||
expr = x + y
|
||||
comp = aesara_code_(expr)
|
||||
assert comp.owner.op == aesara.tensor.add
|
||||
|
||||
def test_trig():
|
||||
assert theq(aesara_code_(sy.sin(x)), aet.sin(xt))
|
||||
assert theq(aesara_code_(sy.tan(x)), aet.tan(xt))
|
||||
|
||||
def test_many():
|
||||
""" Test printing a complex expression with multiple symbols. """
|
||||
expr = sy.exp(x**2 + sy.cos(y)) * sy.log(2*z)
|
||||
comp = aesara_code_(expr)
|
||||
expected = aet.exp(xt**2 + aet.cos(yt)) * aet.log(2*zt)
|
||||
assert theq(comp, expected)
|
||||
|
||||
|
||||
def test_dtype():
|
||||
""" Test specifying specific data types through the dtype argument. """
|
||||
for dtype in ['float32', 'float64', 'int8', 'int16', 'int32', 'int64']:
|
||||
assert aesara_code_(x, dtypes={x: dtype}).type.dtype == dtype
|
||||
|
||||
# "floatX" type
|
||||
assert aesara_code_(x, dtypes={x: 'floatX'}).type.dtype in ('float32', 'float64')
|
||||
|
||||
# Type promotion
|
||||
assert aesara_code_(x + 1, dtypes={x: 'float32'}).type.dtype == 'float32'
|
||||
assert aesara_code_(x + y, dtypes={x: 'float64', y: 'float32'}).type.dtype == 'float64'
|
||||
|
||||
|
||||
def test_broadcastables():
|
||||
""" Test the "broadcastables" argument when printing symbol-like objects. """
|
||||
|
||||
# No restrictions on shape
|
||||
for s in [x, f_t]:
|
||||
for bc in [(), (False,), (True,), (False, False), (True, False)]:
|
||||
assert aesara_code_(s, broadcastables={s: bc}).broadcastable == bc
|
||||
|
||||
# TODO - matrix broadcasting?
|
||||
|
||||
def test_broadcasting():
|
||||
""" Test "broadcastable" attribute after applying element-wise binary op. """
|
||||
|
||||
expr = x + y
|
||||
|
||||
cases = [
|
||||
[(), (), ()],
|
||||
[(False,), (False,), (False,)],
|
||||
[(True,), (False,), (False,)],
|
||||
[(False, True), (False, False), (False, False)],
|
||||
[(True, False), (False, False), (False, False)],
|
||||
]
|
||||
|
||||
for bc1, bc2, bc3 in cases:
|
||||
comp = aesara_code_(expr, broadcastables={x: bc1, y: bc2})
|
||||
assert comp.broadcastable == bc3
|
||||
|
||||
|
||||
def test_MatMul():
|
||||
expr = X*Y*Z
|
||||
expr_t = aesara_code_(expr)
|
||||
assert isinstance(expr_t.owner.op, Dot)
|
||||
assert theq(expr_t, Xt.dot(Yt).dot(Zt))
|
||||
|
||||
def test_Transpose():
|
||||
assert isinstance(aesara_code_(X.T).owner.op, DimShuffle)
|
||||
|
||||
def test_MatAdd():
|
||||
expr = X+Y+Z
|
||||
assert isinstance(aesara_code_(expr).owner.op, Elemwise)
|
||||
|
||||
|
||||
def test_Rationals():
|
||||
assert theq(aesara_code_(sy.Integer(2) / 3), true_divide(2, 3))
|
||||
assert theq(aesara_code_(S.Half), true_divide(1, 2))
|
||||
|
||||
def test_Integers():
|
||||
assert aesara_code_(sy.Integer(3)) == 3
|
||||
|
||||
def test_factorial():
|
||||
n = sy.Symbol('n')
|
||||
assert aesara_code_(sy.factorial(n))
|
||||
|
||||
def test_Derivative():
|
||||
with ignore_warnings(UserWarning):
|
||||
simp = lambda expr: aesara_simplify(fgraph_of(expr))
|
||||
assert theq(simp(aesara_code_(sy.Derivative(sy.sin(x), x, evaluate=False))),
|
||||
simp(aesara.grad(aet.sin(xt), xt)))
|
||||
|
||||
|
||||
def test_aesara_function_simple():
|
||||
""" Test aesara_function() with single output. """
|
||||
f = aesara_function_([x, y], [x+y])
|
||||
assert f(2, 3) == 5
|
||||
|
||||
def test_aesara_function_multi():
|
||||
""" Test aesara_function() with multiple outputs. """
|
||||
f = aesara_function_([x, y], [x+y, x-y])
|
||||
o1, o2 = f(2, 3)
|
||||
assert o1 == 5
|
||||
assert o2 == -1
|
||||
|
||||
def test_aesara_function_numpy():
|
||||
""" Test aesara_function() vs Numpy implementation. """
|
||||
f = aesara_function_([x, y], [x+y], dim=1,
|
||||
dtypes={x: 'float64', y: 'float64'})
|
||||
assert np.linalg.norm(f([1, 2], [3, 4]) - np.asarray([4, 6])) < 1e-9
|
||||
|
||||
f = aesara_function_([x, y], [x+y], dtypes={x: 'float64', y: 'float64'},
|
||||
dim=1)
|
||||
xx = np.arange(3).astype('float64')
|
||||
yy = 2*np.arange(3).astype('float64')
|
||||
assert np.linalg.norm(f(xx, yy) - 3*np.arange(3)) < 1e-9
|
||||
|
||||
|
||||
def test_aesara_function_matrix():
|
||||
m = sy.Matrix([[x, y], [z, x + y + z]])
|
||||
expected = np.array([[1.0, 2.0], [3.0, 1.0 + 2.0 + 3.0]])
|
||||
f = aesara_function_([x, y, z], [m])
|
||||
np.testing.assert_allclose(f(1.0, 2.0, 3.0), expected)
|
||||
f = aesara_function_([x, y, z], [m], scalar=True)
|
||||
np.testing.assert_allclose(f(1.0, 2.0, 3.0), expected)
|
||||
f = aesara_function_([x, y, z], [m, m])
|
||||
assert isinstance(f(1.0, 2.0, 3.0), type([]))
|
||||
np.testing.assert_allclose(f(1.0, 2.0, 3.0)[0], expected)
|
||||
np.testing.assert_allclose(f(1.0, 2.0, 3.0)[1], expected)
|
||||
|
||||
def test_dim_handling():
|
||||
assert dim_handling([x], dim=2) == {x: (False, False)}
|
||||
assert dim_handling([x, y], dims={x: 1, y: 2}) == {x: (False, True),
|
||||
y: (False, False)}
|
||||
assert dim_handling([x], broadcastables={x: (False,)}) == {x: (False,)}
|
||||
|
||||
def test_aesara_function_kwargs():
|
||||
"""
|
||||
Test passing additional kwargs from aesara_function() to aesara.function().
|
||||
"""
|
||||
import numpy as np
|
||||
f = aesara_function_([x, y, z], [x+y], dim=1, on_unused_input='ignore',
|
||||
dtypes={x: 'float64', y: 'float64', z: 'float64'})
|
||||
assert np.linalg.norm(f([1, 2], [3, 4], [0, 0]) - np.asarray([4, 6])) < 1e-9
|
||||
|
||||
f = aesara_function_([x, y, z], [x+y],
|
||||
dtypes={x: 'float64', y: 'float64', z: 'float64'},
|
||||
dim=1, on_unused_input='ignore')
|
||||
xx = np.arange(3).astype('float64')
|
||||
yy = 2*np.arange(3).astype('float64')
|
||||
zz = 2*np.arange(3).astype('float64')
|
||||
assert np.linalg.norm(f(xx, yy, zz) - 3*np.arange(3)) < 1e-9
|
||||
|
||||
def test_aesara_function_scalar():
|
||||
""" Test the "scalar" argument to aesara_function(). """
|
||||
from aesara.compile.function.types import Function
|
||||
|
||||
args = [
|
||||
([x, y], [x + y], None, [0]), # Single 0d output
|
||||
([X, Y], [X + Y], None, [2]), # Single 2d output
|
||||
([x, y], [x + y], {x: 0, y: 1}, [1]), # Single 1d output
|
||||
([x, y], [x + y, x - y], None, [0, 0]), # Two 0d outputs
|
||||
([x, y, X, Y], [x + y, X + Y], None, [0, 2]), # One 0d output, one 2d
|
||||
]
|
||||
|
||||
# Create and test functions with and without the scalar setting
|
||||
for inputs, outputs, in_dims, out_dims in args:
|
||||
for scalar in [False, True]:
|
||||
|
||||
f = aesara_function_(inputs, outputs, dims=in_dims, scalar=scalar)
|
||||
|
||||
# Check the aesara_function attribute is set whether wrapped or not
|
||||
assert isinstance(f.aesara_function, Function)
|
||||
|
||||
# Feed in inputs of the appropriate size and get outputs
|
||||
in_values = [
|
||||
np.ones([1 if bc else 5 for bc in i.type.broadcastable])
|
||||
for i in f.aesara_function.input_storage
|
||||
]
|
||||
out_values = f(*in_values)
|
||||
if not isinstance(out_values, list):
|
||||
out_values = [out_values]
|
||||
|
||||
# Check output types and shapes
|
||||
assert len(out_dims) == len(out_values)
|
||||
for d, value in zip(out_dims, out_values):
|
||||
|
||||
if scalar and d == 0:
|
||||
# Should have been converted to a scalar value
|
||||
assert isinstance(value, np.number)
|
||||
|
||||
else:
|
||||
# Otherwise should be an array
|
||||
assert isinstance(value, np.ndarray)
|
||||
assert value.ndim == d
|
||||
|
||||
def test_aesara_function_bad_kwarg():
|
||||
"""
|
||||
Passing an unknown keyword argument to aesara_function() should raise an
|
||||
exception.
|
||||
"""
|
||||
raises(Exception, lambda : aesara_function_([x], [x+1], foobar=3))
|
||||
|
||||
|
||||
def test_slice():
|
||||
assert aesara_code_(slice(1, 2, 3)) == slice(1, 2, 3)
|
||||
|
||||
def theq_slice(s1, s2):
|
||||
for attr in ['start', 'stop', 'step']:
|
||||
a1 = getattr(s1, attr)
|
||||
a2 = getattr(s2, attr)
|
||||
if a1 is None or a2 is None:
|
||||
if not (a1 is None or a2 is None):
|
||||
return False
|
||||
elif not theq(a1, a2):
|
||||
return False
|
||||
return True
|
||||
|
||||
dtypes = {x: 'int32', y: 'int32'}
|
||||
assert theq_slice(aesara_code_(slice(x, y), dtypes=dtypes), slice(xt, yt))
|
||||
assert theq_slice(aesara_code_(slice(1, x, 3), dtypes=dtypes), slice(1, xt, 3))
|
||||
|
||||
def test_MatrixSlice():
|
||||
cache = {}
|
||||
|
||||
n = sy.Symbol('n', integer=True)
|
||||
X = sy.MatrixSymbol('X', n, n)
|
||||
|
||||
Y = X[1:2:3, 4:5:6]
|
||||
Yt = aesara_code_(Y, cache=cache)
|
||||
|
||||
s = ScalarType('int64')
|
||||
assert tuple(Yt.owner.op.idx_list) == (slice(s, s, s), slice(s, s, s))
|
||||
assert Yt.owner.inputs[0] == aesara_code_(X, cache=cache)
|
||||
# == doesn't work in Aesara like it does in SymPy. You have to use
|
||||
# equals.
|
||||
assert all(Yt.owner.inputs[i].data == i for i in range(1, 7))
|
||||
|
||||
k = sy.Symbol('k')
|
||||
aesara_code_(k, dtypes={k: 'int32'})
|
||||
start, stop, step = 4, k, 2
|
||||
Y = X[start:stop:step]
|
||||
Yt = aesara_code_(Y, dtypes={n: 'int32', k: 'int32'})
|
||||
# assert Yt.owner.op.idx_list[0].stop == kt
|
||||
|
||||
def test_BlockMatrix():
|
||||
n = sy.Symbol('n', integer=True)
|
||||
A, B, C, D = [sy.MatrixSymbol(name, n, n) for name in 'ABCD']
|
||||
At, Bt, Ct, Dt = map(aesara_code_, (A, B, C, D))
|
||||
Block = sy.BlockMatrix([[A, B], [C, D]])
|
||||
Blockt = aesara_code_(Block)
|
||||
solutions = [aet.join(0, aet.join(1, At, Bt), aet.join(1, Ct, Dt)),
|
||||
aet.join(1, aet.join(0, At, Ct), aet.join(0, Bt, Dt))]
|
||||
assert any(theq(Blockt, solution) for solution in solutions)
|
||||
|
||||
@SKIP
|
||||
def test_BlockMatrix_Inverse_execution():
|
||||
k, n = 2, 4
|
||||
dtype = 'float32'
|
||||
A = sy.MatrixSymbol('A', n, k)
|
||||
B = sy.MatrixSymbol('B', n, n)
|
||||
inputs = A, B
|
||||
output = B.I*A
|
||||
|
||||
cutsizes = {A: [(n//2, n//2), (k//2, k//2)],
|
||||
B: [(n//2, n//2), (n//2, n//2)]}
|
||||
cutinputs = [sy.blockcut(i, *cutsizes[i]) for i in inputs]
|
||||
cutoutput = output.subs(dict(zip(inputs, cutinputs)))
|
||||
|
||||
dtypes = dict(zip(inputs, [dtype]*len(inputs)))
|
||||
f = aesara_function_(inputs, [output], dtypes=dtypes, cache={})
|
||||
fblocked = aesara_function_(inputs, [sy.block_collapse(cutoutput)],
|
||||
dtypes=dtypes, cache={})
|
||||
|
||||
ninputs = [np.random.rand(*x.shape).astype(dtype) for x in inputs]
|
||||
ninputs = [np.arange(n*k).reshape(A.shape).astype(dtype),
|
||||
np.eye(n).astype(dtype)]
|
||||
ninputs[1] += np.ones(B.shape)*1e-5
|
||||
|
||||
assert np.allclose(f(*ninputs), fblocked(*ninputs), rtol=1e-5)
|
||||
|
||||
def test_DenseMatrix():
|
||||
from aesara.tensor.basic import Join
|
||||
|
||||
t = sy.Symbol('theta')
|
||||
for MatrixType in [sy.Matrix, sy.ImmutableMatrix]:
|
||||
X = MatrixType([[sy.cos(t), -sy.sin(t)], [sy.sin(t), sy.cos(t)]])
|
||||
tX = aesara_code_(X)
|
||||
assert isinstance(tX, TensorVariable)
|
||||
assert isinstance(tX.owner.op, Join)
|
||||
|
||||
|
||||
def test_cache_basic():
|
||||
""" Test single symbol-like objects are cached when printed by themselves. """
|
||||
|
||||
# Pairs of objects which should be considered equivalent with respect to caching
|
||||
pairs = [
|
||||
(x, sy.Symbol('x')),
|
||||
(X, sy.MatrixSymbol('X', *X.shape)),
|
||||
(f_t, sy.Function('f')(sy.Symbol('t'))),
|
||||
]
|
||||
|
||||
for s1, s2 in pairs:
|
||||
cache = {}
|
||||
st = aesara_code_(s1, cache=cache)
|
||||
|
||||
# Test hit with same instance
|
||||
assert aesara_code_(s1, cache=cache) is st
|
||||
|
||||
# Test miss with same instance but new cache
|
||||
assert aesara_code_(s1, cache={}) is not st
|
||||
|
||||
# Test hit with different but equivalent instance
|
||||
assert aesara_code_(s2, cache=cache) is st
|
||||
|
||||
def test_global_cache():
|
||||
""" Test use of the global cache. """
|
||||
from sympy.printing.aesaracode import global_cache
|
||||
|
||||
backup = dict(global_cache)
|
||||
try:
|
||||
# Temporarily empty global cache
|
||||
global_cache.clear()
|
||||
|
||||
for s in [x, X, f_t]:
|
||||
with warns_deprecated_sympy():
|
||||
st = aesara_code(s)
|
||||
assert aesara_code(s) is st
|
||||
|
||||
finally:
|
||||
# Restore global cache
|
||||
global_cache.update(backup)
|
||||
|
||||
def test_cache_types_distinct():
|
||||
"""
|
||||
Test that symbol-like objects of different types (Symbol, MatrixSymbol,
|
||||
AppliedUndef) are distinguished by the cache even if they have the same
|
||||
name.
|
||||
"""
|
||||
symbols = [sy.Symbol('f_t'), sy.MatrixSymbol('f_t', 4, 4), f_t]
|
||||
|
||||
cache = {} # Single shared cache
|
||||
printed = {}
|
||||
|
||||
for s in symbols:
|
||||
st = aesara_code_(s, cache=cache)
|
||||
assert st not in printed.values()
|
||||
printed[s] = st
|
||||
|
||||
# Check all printed objects are distinct
|
||||
assert len(set(map(id, printed.values()))) == len(symbols)
|
||||
|
||||
# Check retrieving
|
||||
for s, st in printed.items():
|
||||
with warns_deprecated_sympy():
|
||||
assert aesara_code(s, cache=cache) is st
|
||||
|
||||
def test_symbols_are_created_once():
|
||||
"""
|
||||
Test that a symbol is cached and reused when it appears in an expression
|
||||
more than once.
|
||||
"""
|
||||
expr = sy.Add(x, x, evaluate=False)
|
||||
comp = aesara_code_(expr)
|
||||
|
||||
assert theq(comp, xt + xt)
|
||||
assert not theq(comp, xt + aesara_code_(x))
|
||||
|
||||
def test_cache_complex():
|
||||
"""
|
||||
Test caching on a complicated expression with multiple symbols appearing
|
||||
multiple times.
|
||||
"""
|
||||
expr = x ** 2 + (y - sy.exp(x)) * sy.sin(z - x * y)
|
||||
symbol_names = {s.name for s in expr.free_symbols}
|
||||
expr_t = aesara_code_(expr)
|
||||
|
||||
# Iterate through variables in the Aesara computational graph that the
|
||||
# printed expression depends on
|
||||
seen = set()
|
||||
for v in aesara.graph.basic.ancestors([expr_t]):
|
||||
# Owner-less, non-constant variables should be our symbols
|
||||
if v.owner is None and not isinstance(v, aesara.graph.basic.Constant):
|
||||
# Check it corresponds to a symbol and appears only once
|
||||
assert v.name in symbol_names
|
||||
assert v.name not in seen
|
||||
seen.add(v.name)
|
||||
|
||||
# Check all were present
|
||||
assert seen == symbol_names
|
||||
|
||||
|
||||
def test_Piecewise():
|
||||
# A piecewise linear
|
||||
expr = sy.Piecewise((0, x<0), (x, x<2), (1, True)) # ___/III
|
||||
result = aesara_code_(expr)
|
||||
assert result.owner.op == aet.switch
|
||||
|
||||
expected = aet.switch(xt<0, 0, aet.switch(xt<2, xt, 1))
|
||||
assert theq(result, expected)
|
||||
|
||||
expr = sy.Piecewise((x, x < 0))
|
||||
result = aesara_code_(expr)
|
||||
expected = aet.switch(xt < 0, xt, np.nan)
|
||||
assert theq(result, expected)
|
||||
|
||||
expr = sy.Piecewise((0, sy.And(x>0, x<2)), \
|
||||
(x, sy.Or(x>2, x<0)))
|
||||
result = aesara_code_(expr)
|
||||
expected = aet.switch(aet.and_(xt>0,xt<2), 0, \
|
||||
aet.switch(aet.or_(xt>2, xt<0), xt, np.nan))
|
||||
assert theq(result, expected)
|
||||
|
||||
|
||||
def test_Relationals():
|
||||
assert theq(aesara_code_(sy.Eq(x, y)), aet.eq(xt, yt))
|
||||
# assert theq(aesara_code_(sy.Ne(x, y)), aet.neq(xt, yt)) # TODO - implement
|
||||
assert theq(aesara_code_(x > y), xt > yt)
|
||||
assert theq(aesara_code_(x < y), xt < yt)
|
||||
assert theq(aesara_code_(x >= y), xt >= yt)
|
||||
assert theq(aesara_code_(x <= y), xt <= yt)
|
||||
|
||||
|
||||
def test_complexfunctions():
|
||||
dtypes = {x:'complex128', y:'complex128'}
|
||||
with warns_deprecated_sympy():
|
||||
xt, yt = aesara_code(x, dtypes=dtypes), aesara_code(y, dtypes=dtypes)
|
||||
from sympy.functions.elementary.complexes import conjugate
|
||||
from aesara.tensor import as_tensor_variable as atv
|
||||
from aesara.tensor import complex as cplx
|
||||
with warns_deprecated_sympy():
|
||||
assert theq(aesara_code(y*conjugate(x), dtypes=dtypes), yt*(xt.conj()))
|
||||
assert theq(aesara_code((1+2j)*x), xt*(atv(1.0)+atv(2.0)*cplx(0,1)))
|
||||
|
||||
|
||||
def test_constantfunctions():
|
||||
with warns_deprecated_sympy():
|
||||
tf = aesara_function([],[1+1j])
|
||||
assert(tf()==1+1j)
|
||||
@@ -0,0 +1,888 @@
|
||||
from sympy.core import (
|
||||
S, pi, oo, Symbol, symbols, Rational, Integer, Float, Function, Mod, GoldenRatio, EulerGamma, Catalan,
|
||||
Lambda, Dummy, nan, Mul, Pow, UnevaluatedExpr
|
||||
)
|
||||
from sympy.core.relational import (Eq, Ge, Gt, Le, Lt, Ne)
|
||||
from sympy.functions import (
|
||||
Abs, acos, acosh, asin, asinh, atan, atanh, atan2, ceiling, cos, cosh, erf,
|
||||
erfc, exp, floor, gamma, log, loggamma, Max, Min, Piecewise, sign, sin, sinh,
|
||||
sqrt, tan, tanh, fibonacci, lucas
|
||||
)
|
||||
from sympy.sets import Range
|
||||
from sympy.logic import ITE, Implies, Equivalent
|
||||
from sympy.codegen import For, aug_assign, Assignment
|
||||
from sympy.testing.pytest import raises, XFAIL
|
||||
from sympy.printing.codeprinter import PrintMethodNotImplementedError
|
||||
from sympy.printing.c import C89CodePrinter, C99CodePrinter, get_math_macros
|
||||
from sympy.codegen.ast import (
|
||||
AddAugmentedAssignment, Element, Type, FloatType, Declaration, Pointer, Variable, value_const, pointer_const,
|
||||
While, Scope, Print, FunctionPrototype, FunctionDefinition, FunctionCall, Return,
|
||||
real, float32, float64, float80, float128, intc, Comment, CodeBlock, stderr, QuotedString
|
||||
)
|
||||
from sympy.codegen.cfunctions import expm1, log1p, exp2, log2, fma, log10, Cbrt, hypot, Sqrt, isnan, isinf
|
||||
from sympy.codegen.cnodes import restrict
|
||||
from sympy.utilities.lambdify import implemented_function
|
||||
from sympy.tensor import IndexedBase, Idx
|
||||
from sympy.matrices import Matrix, MatrixSymbol, SparseMatrix
|
||||
|
||||
from sympy.printing.codeprinter import ccode
|
||||
|
||||
x, y, z = symbols('x,y,z')
|
||||
|
||||
|
||||
def test_printmethod():
|
||||
class fabs(Abs):
|
||||
def _ccode(self, printer):
|
||||
return "fabs(%s)" % printer._print(self.args[0])
|
||||
|
||||
assert ccode(fabs(x)) == "fabs(x)"
|
||||
|
||||
|
||||
def test_ccode_sqrt():
|
||||
assert ccode(sqrt(x)) == "sqrt(x)"
|
||||
assert ccode(x**0.5) == "sqrt(x)"
|
||||
assert ccode(sqrt(x)) == "sqrt(x)"
|
||||
|
||||
|
||||
def test_ccode_Pow():
|
||||
assert ccode(x**3) == "pow(x, 3)"
|
||||
assert ccode(x**(y**3)) == "pow(x, pow(y, 3))"
|
||||
g = implemented_function('g', Lambda(x, 2*x))
|
||||
assert ccode(1/(g(x)*3.5)**(x - y**x)/(x**2 + y)) == \
|
||||
"pow(3.5*2*x, -x + pow(y, x))/(pow(x, 2) + y)"
|
||||
assert ccode(x**-1.0) == '1.0/x'
|
||||
assert ccode(x**Rational(2, 3)) == 'pow(x, 2.0/3.0)'
|
||||
assert ccode(x**Rational(2, 3), type_aliases={real: float80}) == 'powl(x, 2.0L/3.0L)'
|
||||
_cond_cfunc = [(lambda base, exp: exp.is_integer, "dpowi"),
|
||||
(lambda base, exp: not exp.is_integer, "pow")]
|
||||
assert ccode(x**3, user_functions={'Pow': _cond_cfunc}) == 'dpowi(x, 3)'
|
||||
assert ccode(x**0.5, user_functions={'Pow': _cond_cfunc}) == 'pow(x, 0.5)'
|
||||
assert ccode(x**Rational(16, 5), user_functions={'Pow': _cond_cfunc}) == 'pow(x, 16.0/5.0)'
|
||||
_cond_cfunc2 = [(lambda base, exp: base == 2, lambda base, exp: 'exp2(%s)' % exp),
|
||||
(lambda base, exp: base != 2, 'pow')]
|
||||
# Related to gh-11353
|
||||
assert ccode(2**x, user_functions={'Pow': _cond_cfunc2}) == 'exp2(x)'
|
||||
assert ccode(x**2, user_functions={'Pow': _cond_cfunc2}) == 'pow(x, 2)'
|
||||
# For issue 14160
|
||||
assert ccode(Mul(-2, x, Pow(Mul(y,y,evaluate=False), -1, evaluate=False),
|
||||
evaluate=False)) == '-2*x/(y*y)'
|
||||
|
||||
|
||||
def test_ccode_Max():
|
||||
# Test for gh-11926
|
||||
assert ccode(Max(x,x*x),user_functions={"Max":"my_max", "Pow":"my_pow"}) == 'my_max(x, my_pow(x, 2))'
|
||||
|
||||
|
||||
def test_ccode_Min_performance():
|
||||
#Shouldn't take more than a few seconds
|
||||
big_min = Min(*symbols('a[0:50]'))
|
||||
for curr_standard in ('c89', 'c99', 'c11'):
|
||||
output = ccode(big_min, standard=curr_standard)
|
||||
assert output.count('(') == output.count(')')
|
||||
|
||||
|
||||
def test_ccode_constants_mathh():
|
||||
assert ccode(exp(1)) == "M_E"
|
||||
assert ccode(pi) == "M_PI"
|
||||
assert ccode(oo, standard='c89') == "HUGE_VAL"
|
||||
assert ccode(-oo, standard='c89') == "-HUGE_VAL"
|
||||
assert ccode(oo) == "INFINITY"
|
||||
assert ccode(-oo, standard='c99') == "-INFINITY"
|
||||
assert ccode(pi, type_aliases={real: float80}) == "M_PIl"
|
||||
|
||||
|
||||
def test_ccode_constants_other():
|
||||
assert ccode(2*GoldenRatio) == "const double GoldenRatio = %s;\n2*GoldenRatio" % GoldenRatio.evalf(17)
|
||||
assert ccode(
|
||||
2*Catalan) == "const double Catalan = %s;\n2*Catalan" % Catalan.evalf(17)
|
||||
assert ccode(2*EulerGamma) == "const double EulerGamma = %s;\n2*EulerGamma" % EulerGamma.evalf(17)
|
||||
|
||||
|
||||
def test_ccode_Rational():
|
||||
assert ccode(Rational(3, 7)) == "3.0/7.0"
|
||||
assert ccode(Rational(3, 7), type_aliases={real: float80}) == "3.0L/7.0L"
|
||||
assert ccode(Rational(18, 9)) == "2"
|
||||
assert ccode(Rational(3, -7)) == "-3.0/7.0"
|
||||
assert ccode(Rational(3, -7), type_aliases={real: float80}) == "-3.0L/7.0L"
|
||||
assert ccode(Rational(-3, -7)) == "3.0/7.0"
|
||||
assert ccode(Rational(-3, -7), type_aliases={real: float80}) == "3.0L/7.0L"
|
||||
assert ccode(x + Rational(3, 7)) == "x + 3.0/7.0"
|
||||
assert ccode(x + Rational(3, 7), type_aliases={real: float80}) == "x + 3.0L/7.0L"
|
||||
assert ccode(Rational(3, 7)*x) == "(3.0/7.0)*x"
|
||||
assert ccode(Rational(3, 7)*x, type_aliases={real: float80}) == "(3.0L/7.0L)*x"
|
||||
|
||||
|
||||
def test_ccode_Integer():
|
||||
assert ccode(Integer(67)) == "67"
|
||||
assert ccode(Integer(-1)) == "-1"
|
||||
|
||||
|
||||
def test_ccode_functions():
|
||||
assert ccode(sin(x) ** cos(x)) == "pow(sin(x), cos(x))"
|
||||
|
||||
|
||||
def test_ccode_inline_function():
|
||||
x = symbols('x')
|
||||
g = implemented_function('g', Lambda(x, 2*x))
|
||||
assert ccode(g(x)) == "2*x"
|
||||
g = implemented_function('g', Lambda(x, 2*x/Catalan))
|
||||
assert ccode(
|
||||
g(x)) == "const double Catalan = %s;\n2*x/Catalan" % Catalan.evalf(17)
|
||||
A = IndexedBase('A')
|
||||
i = Idx('i', symbols('n', integer=True))
|
||||
g = implemented_function('g', Lambda(x, x*(1 + x)*(2 + x)))
|
||||
assert ccode(g(A[i]), assign_to=A[i]) == (
|
||||
"for (int i=0; i<n; i++){\n"
|
||||
" A[i] = (A[i] + 1)*(A[i] + 2)*A[i];\n"
|
||||
"}"
|
||||
)
|
||||
|
||||
|
||||
def test_ccode_exceptions():
|
||||
assert ccode(gamma(x), standard='C99') == "tgamma(x)"
|
||||
with raises(PrintMethodNotImplementedError):
|
||||
ccode(gamma(x), standard='C89')
|
||||
with raises(PrintMethodNotImplementedError):
|
||||
ccode(gamma(x), standard='C89', allow_unknown_functions=False)
|
||||
|
||||
ccode(gamma(x), standard='C89', allow_unknown_functions=True)
|
||||
|
||||
|
||||
|
||||
def test_ccode_functions2():
|
||||
assert ccode(ceiling(x)) == "ceil(x)"
|
||||
assert ccode(Abs(x)) == "fabs(x)"
|
||||
assert ccode(gamma(x)) == "tgamma(x)"
|
||||
r, s = symbols('r,s', real=True)
|
||||
assert ccode(Mod(ceiling(r), ceiling(s))) == '((ceil(r) % ceil(s)) + '\
|
||||
'ceil(s)) % ceil(s)'
|
||||
assert ccode(Mod(r, s)) == "fmod(r, s)"
|
||||
p1, p2 = symbols('p1 p2', integer=True, positive=True)
|
||||
assert ccode(Mod(p1, p2)) == 'p1 % p2'
|
||||
assert ccode(Mod(p1, p2 + 3)) == 'p1 % (p2 + 3)'
|
||||
assert ccode(Mod(-3, -7, evaluate=False)) == '(-3) % (-7)'
|
||||
assert ccode(-Mod(3, 7, evaluate=False)) == '-(3 % 7)'
|
||||
assert ccode(r*Mod(p1, p2)) == 'r*(p1 % p2)'
|
||||
assert ccode(Mod(p1, p2)**s) == 'pow(p1 % p2, s)'
|
||||
n = symbols('n', integer=True, negative=True)
|
||||
assert ccode(Mod(-n, p2)) == '(-n) % p2'
|
||||
assert ccode(fibonacci(n)) == '((1.0/5.0)*pow(2, -n)*sqrt(5)*(-pow(1 - sqrt(5), n) + pow(1 + sqrt(5), n)))'
|
||||
assert ccode(lucas(n)) == '(pow(2, -n)*(pow(1 - sqrt(5), n) + pow(1 + sqrt(5), n)))'
|
||||
|
||||
|
||||
def test_ccode_user_functions():
|
||||
x = symbols('x', integer=False)
|
||||
n = symbols('n', integer=True)
|
||||
custom_functions = {
|
||||
"ceiling": "ceil",
|
||||
"Abs": [(lambda x: not x.is_integer, "fabs"), (lambda x: x.is_integer, "abs")],
|
||||
}
|
||||
assert ccode(ceiling(x), user_functions=custom_functions) == "ceil(x)"
|
||||
assert ccode(Abs(x), user_functions=custom_functions) == "fabs(x)"
|
||||
assert ccode(Abs(n), user_functions=custom_functions) == "abs(n)"
|
||||
|
||||
expr = Symbol('a')
|
||||
muladd = Function('muladd')
|
||||
for i in range(0, 100):
|
||||
# the large number of terms acts as a regression test for gh-23839
|
||||
expr = muladd(Rational(1, 2), Symbol(f'a{i}'), expr)
|
||||
out = ccode(expr, user_functions={'muladd':'muladd'})
|
||||
assert 'a99' in out
|
||||
assert out.count('muladd') == 100
|
||||
|
||||
|
||||
def test_ccode_boolean():
|
||||
assert ccode(True) == "true"
|
||||
assert ccode(S.true) == "true"
|
||||
assert ccode(False) == "false"
|
||||
assert ccode(S.false) == "false"
|
||||
assert ccode(x & y) == "x && y"
|
||||
assert ccode(x | y) == "x || y"
|
||||
assert ccode(~x) == "!x"
|
||||
assert ccode(x & y & z) == "x && y && z"
|
||||
assert ccode(x | y | z) == "x || y || z"
|
||||
assert ccode((x & y) | z) == "z || x && y"
|
||||
assert ccode((x | y) & z) == "z && (x || y)"
|
||||
# Automatic rewrites
|
||||
assert ccode(x ^ y) == '(x || y) && (!x || !y)'
|
||||
assert ccode((x ^ y) ^ z) == '(x || y || z) && (x || !y || !z) && (y || !x || !z) && (z || !x || !y)'
|
||||
assert ccode(Implies(x, y)) == 'y || !x'
|
||||
assert ccode(Equivalent(x, z ^ y, Implies(z, x))) == '(x || (y || !z) && (z || !y)) && (z && !x || (y || z) && (!y || !z))'
|
||||
|
||||
|
||||
def test_ccode_Relational():
|
||||
assert ccode(Eq(x, y)) == "x == y"
|
||||
assert ccode(Ne(x, y)) == "x != y"
|
||||
assert ccode(Le(x, y)) == "x <= y"
|
||||
assert ccode(Lt(x, y)) == "x < y"
|
||||
assert ccode(Gt(x, y)) == "x > y"
|
||||
assert ccode(Ge(x, y)) == "x >= y"
|
||||
|
||||
|
||||
def test_ccode_Piecewise():
|
||||
expr = Piecewise((x, x < 1), (x**2, True))
|
||||
assert ccode(expr) == (
|
||||
"((x < 1) ? (\n"
|
||||
" x\n"
|
||||
")\n"
|
||||
": (\n"
|
||||
" pow(x, 2)\n"
|
||||
"))")
|
||||
assert ccode(expr, assign_to="c") == (
|
||||
"if (x < 1) {\n"
|
||||
" c = x;\n"
|
||||
"}\n"
|
||||
"else {\n"
|
||||
" c = pow(x, 2);\n"
|
||||
"}")
|
||||
expr = Piecewise((x, x < 1), (x + 1, x < 2), (x**2, True))
|
||||
assert ccode(expr) == (
|
||||
"((x < 1) ? (\n"
|
||||
" x\n"
|
||||
")\n"
|
||||
": ((x < 2) ? (\n"
|
||||
" x + 1\n"
|
||||
")\n"
|
||||
": (\n"
|
||||
" pow(x, 2)\n"
|
||||
")))")
|
||||
assert ccode(expr, assign_to='c') == (
|
||||
"if (x < 1) {\n"
|
||||
" c = x;\n"
|
||||
"}\n"
|
||||
"else if (x < 2) {\n"
|
||||
" c = x + 1;\n"
|
||||
"}\n"
|
||||
"else {\n"
|
||||
" c = pow(x, 2);\n"
|
||||
"}")
|
||||
# Check that Piecewise without a True (default) condition error
|
||||
expr = Piecewise((x, x < 1), (x**2, x > 1), (sin(x), x > 0))
|
||||
raises(ValueError, lambda: ccode(expr))
|
||||
|
||||
|
||||
def test_ccode_sinc():
|
||||
from sympy.functions.elementary.trigonometric import sinc
|
||||
expr = sinc(x)
|
||||
assert ccode(expr) == (
|
||||
"(((x != 0) ? (\n"
|
||||
" sin(x)/x\n"
|
||||
")\n"
|
||||
": (\n"
|
||||
" 1\n"
|
||||
")))")
|
||||
|
||||
|
||||
def test_ccode_Piecewise_deep():
|
||||
p = ccode(2*Piecewise((x, x < 1), (x + 1, x < 2), (x**2, True)))
|
||||
assert p == (
|
||||
"2*((x < 1) ? (\n"
|
||||
" x\n"
|
||||
")\n"
|
||||
": ((x < 2) ? (\n"
|
||||
" x + 1\n"
|
||||
")\n"
|
||||
": (\n"
|
||||
" pow(x, 2)\n"
|
||||
")))")
|
||||
expr = x*y*z + x**2 + y**2 + Piecewise((0, x < 0.5), (1, True)) + cos(z) - 1
|
||||
assert ccode(expr) == (
|
||||
"pow(x, 2) + x*y*z + pow(y, 2) + ((x < 0.5) ? (\n"
|
||||
" 0\n"
|
||||
")\n"
|
||||
": (\n"
|
||||
" 1\n"
|
||||
")) + cos(z) - 1")
|
||||
assert ccode(expr, assign_to='c') == (
|
||||
"c = pow(x, 2) + x*y*z + pow(y, 2) + ((x < 0.5) ? (\n"
|
||||
" 0\n"
|
||||
")\n"
|
||||
": (\n"
|
||||
" 1\n"
|
||||
")) + cos(z) - 1;")
|
||||
|
||||
|
||||
def test_ccode_ITE():
|
||||
expr = ITE(x < 1, y, z)
|
||||
assert ccode(expr) == (
|
||||
"((x < 1) ? (\n"
|
||||
" y\n"
|
||||
")\n"
|
||||
": (\n"
|
||||
" z\n"
|
||||
"))")
|
||||
|
||||
|
||||
def test_ccode_settings():
|
||||
raises(TypeError, lambda: ccode(sin(x), method="garbage"))
|
||||
|
||||
|
||||
def test_ccode_Indexed():
|
||||
s, n, m, o = symbols('s n m o', integer=True)
|
||||
i, j, k = Idx('i', n), Idx('j', m), Idx('k', o)
|
||||
|
||||
x = IndexedBase('x')[j]
|
||||
A = IndexedBase('A')[i, j]
|
||||
B = IndexedBase('B')[i, j, k]
|
||||
|
||||
p = C99CodePrinter()
|
||||
|
||||
assert p._print_Indexed(x) == 'x[j]'
|
||||
assert p._print_Indexed(A) == 'A[%s]' % (m*i+j)
|
||||
assert p._print_Indexed(B) == 'B[%s]' % (i*o*m+j*o+k)
|
||||
|
||||
A = IndexedBase('A', shape=(5,3))[i, j]
|
||||
assert p._print_Indexed(A) == 'A[%s]' % (3*i + j)
|
||||
|
||||
A = IndexedBase('A', shape=(5,3), strides='F')[i, j]
|
||||
assert ccode(A) == 'A[%s]' % (i + 5*j)
|
||||
|
||||
A = IndexedBase('A', shape=(29,29), strides=(1, s), offset=o)[i, j]
|
||||
assert ccode(A) == 'A[o + s*j + i]'
|
||||
|
||||
Abase = IndexedBase('A', strides=(s, m, n), offset=o)
|
||||
assert ccode(Abase[i, j, k]) == 'A[m*j + n*k + o + s*i]'
|
||||
assert ccode(Abase[2, 3, k]) == 'A[3*m + n*k + o + 2*s]'
|
||||
|
||||
|
||||
def test_Element():
|
||||
assert ccode(Element('x', 'ij')) == 'x[i][j]'
|
||||
assert ccode(Element('x', 'ij', strides='kl', offset='o')) == 'x[i*k + j*l + o]'
|
||||
assert ccode(Element('x', (3,))) == 'x[3]'
|
||||
assert ccode(Element('x', (3,4,5))) == 'x[3][4][5]'
|
||||
|
||||
|
||||
def test_ccode_Indexed_without_looking_for_contraction():
|
||||
len_y = 5
|
||||
y = IndexedBase('y', shape=(len_y,))
|
||||
x = IndexedBase('x', shape=(len_y,))
|
||||
Dy = IndexedBase('Dy', shape=(len_y-1,))
|
||||
i = Idx('i', len_y-1)
|
||||
e = Eq(Dy[i], (y[i+1]-y[i])/(x[i+1]-x[i]))
|
||||
code0 = ccode(e.rhs, assign_to=e.lhs, contract=False)
|
||||
assert code0 == 'Dy[i] = (y[%s] - y[i])/(x[%s] - x[i]);' % (i + 1, i + 1)
|
||||
|
||||
|
||||
def test_ccode_loops_matrix_vector():
|
||||
n, m = symbols('n m', integer=True)
|
||||
A = IndexedBase('A')
|
||||
x = IndexedBase('x')
|
||||
y = IndexedBase('y')
|
||||
i = Idx('i', m)
|
||||
j = Idx('j', n)
|
||||
|
||||
s = (
|
||||
'for (int i=0; i<m; i++){\n'
|
||||
' y[i] = 0;\n'
|
||||
'}\n'
|
||||
'for (int i=0; i<m; i++){\n'
|
||||
' for (int j=0; j<n; j++){\n'
|
||||
' y[i] = A[%s]*x[j] + y[i];\n' % (i*n + j) +\
|
||||
' }\n'
|
||||
'}'
|
||||
)
|
||||
assert ccode(A[i, j]*x[j], assign_to=y[i]) == s
|
||||
|
||||
|
||||
def test_dummy_loops():
|
||||
i, m = symbols('i m', integer=True, cls=Dummy)
|
||||
x = IndexedBase('x')
|
||||
y = IndexedBase('y')
|
||||
i = Idx(i, m)
|
||||
|
||||
expected = (
|
||||
'for (int i_%(icount)i=0; i_%(icount)i<m_%(mcount)i; i_%(icount)i++){\n'
|
||||
' y[i_%(icount)i] = x[i_%(icount)i];\n'
|
||||
'}'
|
||||
) % {'icount': i.label.dummy_index, 'mcount': m.dummy_index}
|
||||
|
||||
assert ccode(x[i], assign_to=y[i]) == expected
|
||||
|
||||
|
||||
def test_ccode_loops_add():
|
||||
n, m = symbols('n m', integer=True)
|
||||
A = IndexedBase('A')
|
||||
x = IndexedBase('x')
|
||||
y = IndexedBase('y')
|
||||
z = IndexedBase('z')
|
||||
i = Idx('i', m)
|
||||
j = Idx('j', n)
|
||||
|
||||
s = (
|
||||
'for (int i=0; i<m; i++){\n'
|
||||
' y[i] = x[i] + z[i];\n'
|
||||
'}\n'
|
||||
'for (int i=0; i<m; i++){\n'
|
||||
' for (int j=0; j<n; j++){\n'
|
||||
' y[i] = A[%s]*x[j] + y[i];\n' % (i*n + j) +\
|
||||
' }\n'
|
||||
'}'
|
||||
)
|
||||
assert ccode(A[i, j]*x[j] + x[i] + z[i], assign_to=y[i]) == s
|
||||
|
||||
|
||||
def test_ccode_loops_multiple_contractions():
|
||||
n, m, o, p = symbols('n m o p', integer=True)
|
||||
a = IndexedBase('a')
|
||||
b = IndexedBase('b')
|
||||
y = IndexedBase('y')
|
||||
i = Idx('i', m)
|
||||
j = Idx('j', n)
|
||||
k = Idx('k', o)
|
||||
l = Idx('l', p)
|
||||
|
||||
s = (
|
||||
'for (int i=0; i<m; i++){\n'
|
||||
' y[i] = 0;\n'
|
||||
'}\n'
|
||||
'for (int i=0; i<m; i++){\n'
|
||||
' for (int j=0; j<n; j++){\n'
|
||||
' for (int k=0; k<o; k++){\n'
|
||||
' for (int l=0; l<p; l++){\n'
|
||||
' y[i] = a[%s]*b[%s] + y[i];\n' % (i*n*o*p + j*o*p + k*p + l, j*o*p + k*p + l) +\
|
||||
' }\n'
|
||||
' }\n'
|
||||
' }\n'
|
||||
'}'
|
||||
)
|
||||
assert ccode(b[j, k, l]*a[i, j, k, l], assign_to=y[i]) == s
|
||||
|
||||
|
||||
def test_ccode_loops_addfactor():
|
||||
n, m, o, p = symbols('n m o p', integer=True)
|
||||
a = IndexedBase('a')
|
||||
b = IndexedBase('b')
|
||||
c = IndexedBase('c')
|
||||
y = IndexedBase('y')
|
||||
i = Idx('i', m)
|
||||
j = Idx('j', n)
|
||||
k = Idx('k', o)
|
||||
l = Idx('l', p)
|
||||
|
||||
s = (
|
||||
'for (int i=0; i<m; i++){\n'
|
||||
' y[i] = 0;\n'
|
||||
'}\n'
|
||||
'for (int i=0; i<m; i++){\n'
|
||||
' for (int j=0; j<n; j++){\n'
|
||||
' for (int k=0; k<o; k++){\n'
|
||||
' for (int l=0; l<p; l++){\n'
|
||||
' y[i] = (a[%s] + b[%s])*c[%s] + y[i];\n' % (i*n*o*p + j*o*p + k*p + l, i*n*o*p + j*o*p + k*p + l, j*o*p + k*p + l) +\
|
||||
' }\n'
|
||||
' }\n'
|
||||
' }\n'
|
||||
'}'
|
||||
)
|
||||
assert ccode((a[i, j, k, l] + b[i, j, k, l])*c[j, k, l], assign_to=y[i]) == s
|
||||
|
||||
|
||||
def test_ccode_loops_multiple_terms():
|
||||
n, m, o, p = symbols('n m o p', integer=True)
|
||||
a = IndexedBase('a')
|
||||
b = IndexedBase('b')
|
||||
c = IndexedBase('c')
|
||||
y = IndexedBase('y')
|
||||
i = Idx('i', m)
|
||||
j = Idx('j', n)
|
||||
k = Idx('k', o)
|
||||
|
||||
s0 = (
|
||||
'for (int i=0; i<m; i++){\n'
|
||||
' y[i] = 0;\n'
|
||||
'}\n'
|
||||
)
|
||||
s1 = (
|
||||
'for (int i=0; i<m; i++){\n'
|
||||
' for (int j=0; j<n; j++){\n'
|
||||
' for (int k=0; k<o; k++){\n'
|
||||
' y[i] = b[j]*b[k]*c[%s] + y[i];\n' % (i*n*o + j*o + k) +\
|
||||
' }\n'
|
||||
' }\n'
|
||||
'}\n'
|
||||
)
|
||||
s2 = (
|
||||
'for (int i=0; i<m; i++){\n'
|
||||
' for (int k=0; k<o; k++){\n'
|
||||
' y[i] = a[%s]*b[k] + y[i];\n' % (i*o + k) +\
|
||||
' }\n'
|
||||
'}\n'
|
||||
)
|
||||
s3 = (
|
||||
'for (int i=0; i<m; i++){\n'
|
||||
' for (int j=0; j<n; j++){\n'
|
||||
' y[i] = a[%s]*b[j] + y[i];\n' % (i*n + j) +\
|
||||
' }\n'
|
||||
'}\n'
|
||||
)
|
||||
c = ccode(b[j]*a[i, j] + b[k]*a[i, k] + b[j]*b[k]*c[i, j, k], assign_to=y[i])
|
||||
assert (c == s0 + s1 + s2 + s3[:-1] or
|
||||
c == s0 + s1 + s3 + s2[:-1] or
|
||||
c == s0 + s2 + s1 + s3[:-1] or
|
||||
c == s0 + s2 + s3 + s1[:-1] or
|
||||
c == s0 + s3 + s1 + s2[:-1] or
|
||||
c == s0 + s3 + s2 + s1[:-1])
|
||||
|
||||
|
||||
def test_dereference_printing():
|
||||
expr = x + y + sin(z) + z
|
||||
assert ccode(expr, dereference=[z]) == "x + y + (*z) + sin((*z))"
|
||||
|
||||
|
||||
def test_Matrix_printing():
|
||||
# Test returning a Matrix
|
||||
mat = Matrix([x*y, Piecewise((2 + x, y>0), (y, True)), sin(z)])
|
||||
A = MatrixSymbol('A', 3, 1)
|
||||
assert ccode(mat, A) == (
|
||||
"A[0] = x*y;\n"
|
||||
"if (y > 0) {\n"
|
||||
" A[1] = x + 2;\n"
|
||||
"}\n"
|
||||
"else {\n"
|
||||
" A[1] = y;\n"
|
||||
"}\n"
|
||||
"A[2] = sin(z);")
|
||||
# Test using MatrixElements in expressions
|
||||
expr = Piecewise((2*A[2, 0], x > 0), (A[2, 0], True)) + sin(A[1, 0]) + A[0, 0]
|
||||
assert ccode(expr) == (
|
||||
"((x > 0) ? (\n"
|
||||
" 2*A[2]\n"
|
||||
")\n"
|
||||
": (\n"
|
||||
" A[2]\n"
|
||||
")) + sin(A[1]) + A[0]")
|
||||
# Test using MatrixElements in a Matrix
|
||||
q = MatrixSymbol('q', 5, 1)
|
||||
M = MatrixSymbol('M', 3, 3)
|
||||
m = Matrix([[sin(q[1,0]), 0, cos(q[2,0])],
|
||||
[q[1,0] + q[2,0], q[3, 0], 5],
|
||||
[2*q[4, 0]/q[1,0], sqrt(q[0,0]) + 4, 0]])
|
||||
assert ccode(m, M) == (
|
||||
"M[0] = sin(q[1]);\n"
|
||||
"M[1] = 0;\n"
|
||||
"M[2] = cos(q[2]);\n"
|
||||
"M[3] = q[1] + q[2];\n"
|
||||
"M[4] = q[3];\n"
|
||||
"M[5] = 5;\n"
|
||||
"M[6] = 2*q[4]/q[1];\n"
|
||||
"M[7] = sqrt(q[0]) + 4;\n"
|
||||
"M[8] = 0;")
|
||||
|
||||
|
||||
def test_sparse_matrix():
|
||||
# gh-15791
|
||||
with raises(PrintMethodNotImplementedError):
|
||||
ccode(SparseMatrix([[1, 2, 3]]))
|
||||
|
||||
assert 'Not supported in C' in C89CodePrinter({'strict': False}).doprint(SparseMatrix([[1, 2, 3]]))
|
||||
|
||||
|
||||
|
||||
def test_ccode_reserved_words():
|
||||
x, y = symbols('x, if')
|
||||
with raises(ValueError):
|
||||
ccode(y**2, error_on_reserved=True, standard='C99')
|
||||
assert ccode(y**2) == 'pow(if_, 2)'
|
||||
assert ccode(x * y**2, dereference=[y]) == 'pow((*if_), 2)*x'
|
||||
assert ccode(y**2, reserved_word_suffix='_unreserved') == 'pow(if_unreserved, 2)'
|
||||
|
||||
|
||||
def test_ccode_sign():
|
||||
expr1, ref1 = sign(x) * y, 'y*(((x) > 0) - ((x) < 0))'
|
||||
expr2, ref2 = sign(cos(x)), '(((cos(x)) > 0) - ((cos(x)) < 0))'
|
||||
expr3, ref3 = sign(2 * x + x**2) * x + x**2, 'pow(x, 2) + x*(((pow(x, 2) + 2*x) > 0) - ((pow(x, 2) + 2*x) < 0))'
|
||||
assert ccode(expr1) == ref1
|
||||
assert ccode(expr1, 'z') == 'z = %s;' % ref1
|
||||
assert ccode(expr2) == ref2
|
||||
assert ccode(expr3) == ref3
|
||||
|
||||
def test_ccode_Assignment():
|
||||
assert ccode(Assignment(x, y + z)) == 'x = y + z;'
|
||||
assert ccode(aug_assign(x, '+', y + z)) == 'x += y + z;'
|
||||
|
||||
|
||||
def test_ccode_For():
|
||||
f = For(x, Range(0, 10, 2), [aug_assign(y, '*', x)])
|
||||
assert ccode(f) == ("for (x = 0; x < 10; x += 2) {\n"
|
||||
" y *= x;\n"
|
||||
"}")
|
||||
|
||||
def test_ccode_Max_Min():
|
||||
assert ccode(Max(x, 0), standard='C89') == '((0 > x) ? 0 : x)'
|
||||
assert ccode(Max(x, 0), standard='C99') == 'fmax(0, x)'
|
||||
assert ccode(Min(x, 0, sqrt(x)), standard='c89') == (
|
||||
'((0 < ((x < sqrt(x)) ? x : sqrt(x))) ? 0 : ((x < sqrt(x)) ? x : sqrt(x)))'
|
||||
)
|
||||
|
||||
def test_ccode_standard():
|
||||
assert ccode(expm1(x), standard='c99') == 'expm1(x)'
|
||||
assert ccode(nan, standard='c99') == 'NAN'
|
||||
assert ccode(float('nan'), standard='c99') == 'NAN'
|
||||
|
||||
|
||||
def test_C89CodePrinter():
|
||||
c89printer = C89CodePrinter()
|
||||
assert c89printer.language == 'C'
|
||||
assert c89printer.standard == 'C89'
|
||||
assert 'void' in c89printer.reserved_words
|
||||
assert 'template' not in c89printer.reserved_words
|
||||
assert c89printer.doprint(log10(x)) == 'log10(x)'
|
||||
|
||||
|
||||
def test_C99CodePrinter():
|
||||
assert C99CodePrinter().doprint(expm1(x)) == 'expm1(x)'
|
||||
assert C99CodePrinter().doprint(log1p(x)) == 'log1p(x)'
|
||||
assert C99CodePrinter().doprint(exp2(x)) == 'exp2(x)'
|
||||
assert C99CodePrinter().doprint(log2(x)) == 'log2(x)'
|
||||
assert C99CodePrinter().doprint(fma(x, y, -z)) == 'fma(x, y, -z)'
|
||||
assert C99CodePrinter().doprint(log10(x)) == 'log10(x)'
|
||||
assert C99CodePrinter().doprint(Cbrt(x)) == 'cbrt(x)' # note Cbrt due to cbrt already taken.
|
||||
assert C99CodePrinter().doprint(hypot(x, y)) == 'hypot(x, y)'
|
||||
assert C99CodePrinter().doprint(loggamma(x)) == 'lgamma(x)'
|
||||
assert C99CodePrinter().doprint(Max(x, 3, x**2)) == 'fmax(3, fmax(x, pow(x, 2)))'
|
||||
assert C99CodePrinter().doprint(Min(x, 3)) == 'fmin(3, x)'
|
||||
c99printer = C99CodePrinter()
|
||||
assert c99printer.language == 'C'
|
||||
assert c99printer.standard == 'C99'
|
||||
assert 'restrict' in c99printer.reserved_words
|
||||
assert 'using' not in c99printer.reserved_words
|
||||
|
||||
|
||||
@XFAIL
|
||||
def test_C99CodePrinter__precision_f80():
|
||||
f80_printer = C99CodePrinter({"type_aliases": {real: float80}})
|
||||
assert f80_printer.doprint(sin(x + Float('2.1'))) == 'sinl(x + 2.1L)'
|
||||
|
||||
|
||||
def test_C99CodePrinter__precision():
|
||||
n = symbols('n', integer=True)
|
||||
p = symbols('p', integer=True, positive=True)
|
||||
f32_printer = C99CodePrinter({"type_aliases": {real: float32}})
|
||||
f64_printer = C99CodePrinter({"type_aliases": {real: float64}})
|
||||
f80_printer = C99CodePrinter({"type_aliases": {real: float80}})
|
||||
assert f32_printer.doprint(sin(x+2.1)) == 'sinf(x + 2.1F)'
|
||||
assert f64_printer.doprint(sin(x+2.1)) == 'sin(x + 2.1000000000000001)'
|
||||
assert f80_printer.doprint(sin(x+Float('2.0'))) == 'sinl(x + 2.0L)'
|
||||
|
||||
for printer, suffix in zip([f32_printer, f64_printer, f80_printer], ['f', '', 'l']):
|
||||
def check(expr, ref):
|
||||
assert printer.doprint(expr) == ref.format(s=suffix, S=suffix.upper())
|
||||
check(Abs(n), 'abs(n)')
|
||||
check(Abs(x + 2.0), 'fabs{s}(x + 2.0{S})')
|
||||
check(sin(x + 4.0)**cos(x - 2.0), 'pow{s}(sin{s}(x + 4.0{S}), cos{s}(x - 2.0{S}))')
|
||||
check(exp(x*8.0), 'exp{s}(8.0{S}*x)')
|
||||
check(exp2(x), 'exp2{s}(x)')
|
||||
check(expm1(x*4.0), 'expm1{s}(4.0{S}*x)')
|
||||
check(Mod(p, 2), 'p % 2')
|
||||
check(Mod(2*p + 3, 3*p + 5, evaluate=False), '(2*p + 3) % (3*p + 5)')
|
||||
check(Mod(x + 2.0, 3.0), 'fmod{s}(1.0{S}*x + 2.0{S}, 3.0{S})')
|
||||
check(Mod(x, 2.0*x + 3.0), 'fmod{s}(1.0{S}*x, 2.0{S}*x + 3.0{S})')
|
||||
check(log(x/2), 'log{s}((1.0{S}/2.0{S})*x)')
|
||||
check(log10(3*x/2), 'log10{s}((3.0{S}/2.0{S})*x)')
|
||||
check(log2(x*8.0), 'log2{s}(8.0{S}*x)')
|
||||
check(log1p(x), 'log1p{s}(x)')
|
||||
check(2**x, 'pow{s}(2, x)')
|
||||
check(2.0**x, 'pow{s}(2.0{S}, x)')
|
||||
check(x**3, 'pow{s}(x, 3)')
|
||||
check(x**4.0, 'pow{s}(x, 4.0{S})')
|
||||
check(sqrt(3+x), 'sqrt{s}(x + 3)')
|
||||
check(Cbrt(x-2.0), 'cbrt{s}(x - 2.0{S})')
|
||||
check(hypot(x, y), 'hypot{s}(x, y)')
|
||||
check(sin(3.*x + 2.), 'sin{s}(3.0{S}*x + 2.0{S})')
|
||||
check(cos(3.*x - 1.), 'cos{s}(3.0{S}*x - 1.0{S})')
|
||||
check(tan(4.*y + 2.), 'tan{s}(4.0{S}*y + 2.0{S})')
|
||||
check(asin(3.*x + 2.), 'asin{s}(3.0{S}*x + 2.0{S})')
|
||||
check(acos(3.*x + 2.), 'acos{s}(3.0{S}*x + 2.0{S})')
|
||||
check(atan(3.*x + 2.), 'atan{s}(3.0{S}*x + 2.0{S})')
|
||||
check(atan2(3.*x, 2.*y), 'atan2{s}(3.0{S}*x, 2.0{S}*y)')
|
||||
|
||||
check(sinh(3.*x + 2.), 'sinh{s}(3.0{S}*x + 2.0{S})')
|
||||
check(cosh(3.*x - 1.), 'cosh{s}(3.0{S}*x - 1.0{S})')
|
||||
check(tanh(4.0*y + 2.), 'tanh{s}(4.0{S}*y + 2.0{S})')
|
||||
check(asinh(3.*x + 2.), 'asinh{s}(3.0{S}*x + 2.0{S})')
|
||||
check(acosh(3.*x + 2.), 'acosh{s}(3.0{S}*x + 2.0{S})')
|
||||
check(atanh(3.*x + 2.), 'atanh{s}(3.0{S}*x + 2.0{S})')
|
||||
check(erf(42.*x), 'erf{s}(42.0{S}*x)')
|
||||
check(erfc(42.*x), 'erfc{s}(42.0{S}*x)')
|
||||
check(gamma(x), 'tgamma{s}(x)')
|
||||
check(loggamma(x), 'lgamma{s}(x)')
|
||||
|
||||
check(ceiling(x + 2.), "ceil{s}(x) + 2")
|
||||
check(floor(x + 2.), "floor{s}(x) + 2")
|
||||
check(fma(x, y, -z), 'fma{s}(x, y, -z)')
|
||||
check(Max(x, 8.0, x**4.0), 'fmax{s}(8.0{S}, fmax{s}(x, pow{s}(x, 4.0{S})))')
|
||||
check(Min(x, 2.0), 'fmin{s}(2.0{S}, x)')
|
||||
|
||||
|
||||
def test_get_math_macros():
|
||||
macros = get_math_macros()
|
||||
assert macros[exp(1)] == 'M_E'
|
||||
assert macros[1/Sqrt(2)] == 'M_SQRT1_2'
|
||||
|
||||
|
||||
def test_ccode_Declaration():
|
||||
i = symbols('i', integer=True)
|
||||
var1 = Variable(i, type=Type.from_expr(i))
|
||||
dcl1 = Declaration(var1)
|
||||
assert ccode(dcl1) == 'int i'
|
||||
|
||||
var2 = Variable(x, type=float32, attrs={value_const})
|
||||
dcl2a = Declaration(var2)
|
||||
assert ccode(dcl2a) == 'const float x'
|
||||
dcl2b = var2.as_Declaration(value=pi)
|
||||
assert ccode(dcl2b) == 'const float x = M_PI'
|
||||
|
||||
var3 = Variable(y, type=Type('bool'))
|
||||
dcl3 = Declaration(var3)
|
||||
printer = C89CodePrinter()
|
||||
assert 'stdbool.h' not in printer.headers
|
||||
assert printer.doprint(dcl3) == 'bool y'
|
||||
assert 'stdbool.h' in printer.headers
|
||||
|
||||
u = symbols('u', real=True)
|
||||
ptr4 = Pointer.deduced(u, attrs={pointer_const, restrict})
|
||||
dcl4 = Declaration(ptr4)
|
||||
assert ccode(dcl4) == 'double * const restrict u'
|
||||
|
||||
var5 = Variable(x, Type('__float128'), attrs={value_const})
|
||||
dcl5a = Declaration(var5)
|
||||
assert ccode(dcl5a) == 'const __float128 x'
|
||||
var5b = Variable(var5.symbol, var5.type, pi, attrs=var5.attrs)
|
||||
dcl5b = Declaration(var5b)
|
||||
assert ccode(dcl5b) == 'const __float128 x = M_PI'
|
||||
|
||||
|
||||
def test_C99CodePrinter_custom_type():
|
||||
# We will look at __float128 (new in glibc 2.26)
|
||||
f128 = FloatType('_Float128', float128.nbits, float128.nmant, float128.nexp)
|
||||
p128 = C99CodePrinter({
|
||||
"type_aliases": {real: f128},
|
||||
"type_literal_suffixes": {f128: 'Q'},
|
||||
"type_func_suffixes": {f128: 'f128'},
|
||||
"type_math_macro_suffixes": {
|
||||
real: 'f128',
|
||||
f128: 'f128'
|
||||
},
|
||||
"type_macros": {
|
||||
f128: ('__STDC_WANT_IEC_60559_TYPES_EXT__',)
|
||||
}
|
||||
})
|
||||
assert p128.doprint(x) == 'x'
|
||||
assert not p128.headers
|
||||
assert not p128.libraries
|
||||
assert not p128.macros
|
||||
assert p128.doprint(2.0) == '2.0Q'
|
||||
assert not p128.headers
|
||||
assert not p128.libraries
|
||||
assert p128.macros == {'__STDC_WANT_IEC_60559_TYPES_EXT__'}
|
||||
|
||||
assert p128.doprint(Rational(1, 2)) == '1.0Q/2.0Q'
|
||||
assert p128.doprint(sin(x)) == 'sinf128(x)'
|
||||
assert p128.doprint(cos(2., evaluate=False)) == 'cosf128(2.0Q)'
|
||||
assert p128.doprint(x**-1.0) == '1.0Q/x'
|
||||
|
||||
var5 = Variable(x, f128, attrs={value_const})
|
||||
|
||||
dcl5a = Declaration(var5)
|
||||
assert ccode(dcl5a) == 'const _Float128 x'
|
||||
var5b = Variable(x, f128, pi, attrs={value_const})
|
||||
dcl5b = Declaration(var5b)
|
||||
assert p128.doprint(dcl5b) == 'const _Float128 x = M_PIf128'
|
||||
var5b = Variable(x, f128, value=Catalan.evalf(38), attrs={value_const})
|
||||
dcl5c = Declaration(var5b)
|
||||
assert p128.doprint(dcl5c) == 'const _Float128 x = %sQ' % Catalan.evalf(f128.decimal_dig)
|
||||
|
||||
|
||||
def test_MatrixElement_printing():
|
||||
# test cases for issue #11821
|
||||
A = MatrixSymbol("A", 1, 3)
|
||||
B = MatrixSymbol("B", 1, 3)
|
||||
C = MatrixSymbol("C", 1, 3)
|
||||
|
||||
assert(ccode(A[0, 0]) == "A[0]")
|
||||
assert(ccode(3 * A[0, 0]) == "3*A[0]")
|
||||
|
||||
F = C[0, 0].subs(C, A - B)
|
||||
assert(ccode(F) == "(A - B)[0]")
|
||||
|
||||
def test_ccode_math_macros():
|
||||
assert ccode(z + exp(1)) == 'z + M_E'
|
||||
assert ccode(z + log2(exp(1))) == 'z + M_LOG2E'
|
||||
assert ccode(z + 1/log(2)) == 'z + M_LOG2E'
|
||||
assert ccode(z + log(2)) == 'z + M_LN2'
|
||||
assert ccode(z + log(10)) == 'z + M_LN10'
|
||||
assert ccode(z + pi) == 'z + M_PI'
|
||||
assert ccode(z + pi/2) == 'z + M_PI_2'
|
||||
assert ccode(z + pi/4) == 'z + M_PI_4'
|
||||
assert ccode(z + 1/pi) == 'z + M_1_PI'
|
||||
assert ccode(z + 2/pi) == 'z + M_2_PI'
|
||||
assert ccode(z + 2/sqrt(pi)) == 'z + M_2_SQRTPI'
|
||||
assert ccode(z + 2/Sqrt(pi)) == 'z + M_2_SQRTPI'
|
||||
assert ccode(z + sqrt(2)) == 'z + M_SQRT2'
|
||||
assert ccode(z + Sqrt(2)) == 'z + M_SQRT2'
|
||||
assert ccode(z + 1/sqrt(2)) == 'z + M_SQRT1_2'
|
||||
assert ccode(z + 1/Sqrt(2)) == 'z + M_SQRT1_2'
|
||||
|
||||
|
||||
def test_ccode_Type():
|
||||
assert ccode(Type('float')) == 'float'
|
||||
assert ccode(intc) == 'int'
|
||||
|
||||
|
||||
def test_ccode_codegen_ast():
|
||||
# Note that C only allows comments of the form /* ... */, double forward
|
||||
# slash is not standard C, and some C compilers will grind to a halt upon
|
||||
# encountering them.
|
||||
assert ccode(Comment("this is a comment")) == "/* this is a comment */" # not //
|
||||
assert ccode(While(abs(x) > 1, [aug_assign(x, '-', 1)])) == (
|
||||
'while (fabs(x) > 1) {\n'
|
||||
' x -= 1;\n'
|
||||
'}'
|
||||
)
|
||||
assert ccode(Scope([AddAugmentedAssignment(x, 1)])) == (
|
||||
'{\n'
|
||||
' x += 1;\n'
|
||||
'}'
|
||||
)
|
||||
inp_x = Declaration(Variable(x, type=real))
|
||||
assert ccode(FunctionPrototype(real, 'pwer', [inp_x])) == 'double pwer(double x)'
|
||||
assert ccode(FunctionDefinition(real, 'pwer', [inp_x], [Assignment(x, x**2)])) == (
|
||||
'double pwer(double x){\n'
|
||||
' x = pow(x, 2);\n'
|
||||
'}'
|
||||
)
|
||||
|
||||
# Elements of CodeBlock are formatted as statements:
|
||||
block = CodeBlock(
|
||||
x,
|
||||
Print([x, y], "%d %d"),
|
||||
Print([QuotedString('hello'), y], "%s %d", file=stderr),
|
||||
FunctionCall('pwer', [x]),
|
||||
Return(x),
|
||||
)
|
||||
assert ccode(block) == '\n'.join([
|
||||
'x;',
|
||||
'printf("%d %d", x, y);',
|
||||
'fprintf(stderr, "%s %d", "hello", y);',
|
||||
'pwer(x);',
|
||||
'return x;',
|
||||
])
|
||||
|
||||
def test_ccode_UnevaluatedExpr():
|
||||
assert ccode(UnevaluatedExpr(y * x) + z) == "z + x*y"
|
||||
assert ccode(UnevaluatedExpr(y + x) + z) == "z + (x + y)" # gh-21955
|
||||
w = symbols('w')
|
||||
assert ccode(UnevaluatedExpr(y + x) + UnevaluatedExpr(z + w)) == "(w + z) + (x + y)"
|
||||
|
||||
p, q, r = symbols("p q r", real=True)
|
||||
q_r = UnevaluatedExpr(q + r)
|
||||
expr = abs(exp(p+q_r))
|
||||
assert ccode(expr) == "exp(p + (q + r))"
|
||||
|
||||
|
||||
def test_ccode_array_like_containers():
|
||||
assert ccode([2,3,4]) == "{2, 3, 4}"
|
||||
assert ccode((2,3,4)) == "{2, 3, 4}"
|
||||
|
||||
def test_ccode__isinf_isnan():
|
||||
assert ccode(isinf(x)) == 'isinf(x)'
|
||||
assert ccode(isnan(x)) == 'isnan(x)'
|
||||
@@ -0,0 +1,77 @@
|
||||
from sympy.printing.codeprinter import CodePrinter, PrintMethodNotImplementedError
|
||||
from sympy.core import symbols
|
||||
from sympy.core.symbol import Dummy
|
||||
from sympy.testing.pytest import raises
|
||||
from sympy import cos
|
||||
from sympy.utilities.lambdify import lambdify
|
||||
from math import cos as math_cos
|
||||
from sympy.printing.lambdarepr import LambdaPrinter
|
||||
|
||||
|
||||
def setup_test_printer(**kwargs):
|
||||
p = CodePrinter(settings=kwargs)
|
||||
p._not_supported = set()
|
||||
p._number_symbols = set()
|
||||
return p
|
||||
|
||||
|
||||
def test_print_Dummy():
|
||||
d = Dummy('d')
|
||||
p = setup_test_printer()
|
||||
assert p._print_Dummy(d) == "d_%i" % d.dummy_index
|
||||
|
||||
def test_print_Symbol():
|
||||
|
||||
x, y = symbols('x, if')
|
||||
|
||||
p = setup_test_printer()
|
||||
assert p._print(x) == 'x'
|
||||
assert p._print(y) == 'if'
|
||||
|
||||
p.reserved_words.update(['if'])
|
||||
assert p._print(y) == 'if_'
|
||||
|
||||
p = setup_test_printer(error_on_reserved=True)
|
||||
p.reserved_words.update(['if'])
|
||||
with raises(ValueError):
|
||||
p._print(y)
|
||||
|
||||
p = setup_test_printer(reserved_word_suffix='_He_Man')
|
||||
p.reserved_words.update(['if'])
|
||||
assert p._print(y) == 'if_He_Man'
|
||||
|
||||
|
||||
def test_lambdify_LaTeX_symbols_issue_23374():
|
||||
# Create symbols with Latex style names
|
||||
x1, x2 = symbols("x_{1} x_2")
|
||||
|
||||
# Lambdify the function
|
||||
f1 = lambdify([x1, x2], cos(x1 ** 2 + x2 ** 2))
|
||||
|
||||
# Test that the function works correctly (numerically)
|
||||
assert f1(1, 2) == math_cos(1 ** 2 + 2 ** 2)
|
||||
|
||||
# Explicitly generate a custom printer to verify the naming convention
|
||||
p = LambdaPrinter()
|
||||
expr_str = p.doprint(cos(x1 ** 2 + x2 ** 2))
|
||||
assert 'x_1' in expr_str
|
||||
assert 'x_2' in expr_str
|
||||
|
||||
|
||||
def test_issue_15791():
|
||||
class CrashingCodePrinter(CodePrinter):
|
||||
def emptyPrinter(self, obj):
|
||||
raise NotImplementedError
|
||||
|
||||
from sympy.matrices import (
|
||||
MutableSparseMatrix,
|
||||
ImmutableSparseMatrix,
|
||||
)
|
||||
|
||||
c = CrashingCodePrinter()
|
||||
|
||||
# these should not silently succeed
|
||||
with raises(PrintMethodNotImplementedError):
|
||||
c.doprint(ImmutableSparseMatrix(2, 2, {}))
|
||||
with raises(PrintMethodNotImplementedError):
|
||||
c.doprint(MutableSparseMatrix(2, 2, {}))
|
||||
@@ -0,0 +1,116 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
from sympy.core.function import (Derivative, Function)
|
||||
from sympy.core.numbers import oo
|
||||
from sympy.core.symbol import symbols
|
||||
from sympy.functions.elementary.exponential import exp
|
||||
from sympy.functions.elementary.trigonometric import cos
|
||||
from sympy.integrals.integrals import Integral
|
||||
from sympy.functions.special.bessel import besselj
|
||||
from sympy.functions.special.polynomials import legendre
|
||||
from sympy.functions.combinatorial.numbers import bell
|
||||
from sympy.printing.conventions import split_super_sub, requires_partial
|
||||
from sympy.testing.pytest import XFAIL
|
||||
|
||||
def test_super_sub():
|
||||
assert split_super_sub("beta_13_2") == ("beta", [], ["13", "2"])
|
||||
assert split_super_sub("beta_132_20") == ("beta", [], ["132", "20"])
|
||||
assert split_super_sub("beta_13") == ("beta", [], ["13"])
|
||||
assert split_super_sub("x_a_b") == ("x", [], ["a", "b"])
|
||||
assert split_super_sub("x_1_2_3") == ("x", [], ["1", "2", "3"])
|
||||
assert split_super_sub("x_a_b1") == ("x", [], ["a", "b1"])
|
||||
assert split_super_sub("x_a_1") == ("x", [], ["a", "1"])
|
||||
assert split_super_sub("x_1_a") == ("x", [], ["1", "a"])
|
||||
assert split_super_sub("x_1^aa") == ("x", ["aa"], ["1"])
|
||||
assert split_super_sub("x_1__aa") == ("x", ["aa"], ["1"])
|
||||
assert split_super_sub("x_11^a") == ("x", ["a"], ["11"])
|
||||
assert split_super_sub("x_11__a") == ("x", ["a"], ["11"])
|
||||
assert split_super_sub("x_a_b_c_d") == ("x", [], ["a", "b", "c", "d"])
|
||||
assert split_super_sub("x_a_b^c^d") == ("x", ["c", "d"], ["a", "b"])
|
||||
assert split_super_sub("x_a_b__c__d") == ("x", ["c", "d"], ["a", "b"])
|
||||
assert split_super_sub("x_a^b_c^d") == ("x", ["b", "d"], ["a", "c"])
|
||||
assert split_super_sub("x_a__b_c__d") == ("x", ["b", "d"], ["a", "c"])
|
||||
assert split_super_sub("x^a^b_c_d") == ("x", ["a", "b"], ["c", "d"])
|
||||
assert split_super_sub("x__a__b_c_d") == ("x", ["a", "b"], ["c", "d"])
|
||||
assert split_super_sub("x^a^b^c^d") == ("x", ["a", "b", "c", "d"], [])
|
||||
assert split_super_sub("x__a__b__c__d") == ("x", ["a", "b", "c", "d"], [])
|
||||
assert split_super_sub("alpha_11") == ("alpha", [], ["11"])
|
||||
assert split_super_sub("alpha_11_11") == ("alpha", [], ["11", "11"])
|
||||
assert split_super_sub("w1") == ("w", [], ["1"])
|
||||
assert split_super_sub("w𝟙") == ("w", [], ["𝟙"])
|
||||
assert split_super_sub("w11") == ("w", [], ["11"])
|
||||
assert split_super_sub("w𝟙𝟙") == ("w", [], ["𝟙𝟙"])
|
||||
assert split_super_sub("w𝟙2𝟙") == ("w", [], ["𝟙2𝟙"])
|
||||
assert split_super_sub("w1^a") == ("w", ["a"], ["1"])
|
||||
assert split_super_sub("ω1") == ("ω", [], ["1"])
|
||||
assert split_super_sub("ω11") == ("ω", [], ["11"])
|
||||
assert split_super_sub("ω1^a") == ("ω", ["a"], ["1"])
|
||||
assert split_super_sub("ω𝟙^α") == ("ω", ["α"], ["𝟙"])
|
||||
assert split_super_sub("ω𝟙2^3α") == ("ω", ["3α"], ["𝟙2"])
|
||||
assert split_super_sub("") == ("", [], [])
|
||||
|
||||
|
||||
def test_requires_partial():
|
||||
x, y, z, t, nu = symbols('x y z t nu')
|
||||
n = symbols('n', integer=True)
|
||||
|
||||
f = x * y
|
||||
assert requires_partial(Derivative(f, x)) is True
|
||||
assert requires_partial(Derivative(f, y)) is True
|
||||
|
||||
## integrating out one of the variables
|
||||
assert requires_partial(Derivative(Integral(exp(-x * y), (x, 0, oo)), y, evaluate=False)) is False
|
||||
|
||||
## bessel function with smooth parameter
|
||||
f = besselj(nu, x)
|
||||
assert requires_partial(Derivative(f, x)) is True
|
||||
assert requires_partial(Derivative(f, nu)) is True
|
||||
|
||||
## bessel function with integer parameter
|
||||
f = besselj(n, x)
|
||||
assert requires_partial(Derivative(f, x)) is False
|
||||
# this is not really valid (differentiating with respect to an integer)
|
||||
# but there's no reason to use the partial derivative symbol there. make
|
||||
# sure we don't throw an exception here, though
|
||||
assert requires_partial(Derivative(f, n)) is False
|
||||
|
||||
## bell polynomial
|
||||
f = bell(n, x)
|
||||
assert requires_partial(Derivative(f, x)) is False
|
||||
# again, invalid
|
||||
assert requires_partial(Derivative(f, n)) is False
|
||||
|
||||
## legendre polynomial
|
||||
f = legendre(0, x)
|
||||
assert requires_partial(Derivative(f, x)) is False
|
||||
|
||||
f = legendre(n, x)
|
||||
assert requires_partial(Derivative(f, x)) is False
|
||||
# again, invalid
|
||||
assert requires_partial(Derivative(f, n)) is False
|
||||
|
||||
f = x ** n
|
||||
assert requires_partial(Derivative(f, x)) is False
|
||||
|
||||
assert requires_partial(Derivative(Integral((x*y) ** n * exp(-x * y), (x, 0, oo)), y, evaluate=False)) is False
|
||||
|
||||
# parametric equation
|
||||
f = (exp(t), cos(t))
|
||||
g = sum(f)
|
||||
assert requires_partial(Derivative(g, t)) is False
|
||||
|
||||
f = symbols('f', cls=Function)
|
||||
assert requires_partial(Derivative(f(x), x)) is False
|
||||
assert requires_partial(Derivative(f(x), y)) is False
|
||||
assert requires_partial(Derivative(f(x, y), x)) is True
|
||||
assert requires_partial(Derivative(f(x, y), y)) is True
|
||||
assert requires_partial(Derivative(f(x, y), z)) is True
|
||||
assert requires_partial(Derivative(f(x, y), x, y)) is True
|
||||
|
||||
@XFAIL
|
||||
def test_requires_partial_unspecified_variables():
|
||||
x, y = symbols('x y')
|
||||
# function of unspecified variables
|
||||
f = symbols('f', cls=Function)
|
||||
assert requires_partial(Derivative(f, x)) is False
|
||||
assert requires_partial(Derivative(f, x, y)) is True
|
||||
@@ -0,0 +1,56 @@
|
||||
from sympy.concrete.summations import Sum
|
||||
from sympy.functions.elementary.exponential import log
|
||||
from sympy.functions.elementary.miscellaneous import sqrt
|
||||
from sympy.utilities.lambdify import lambdify
|
||||
from sympy.abc import x, i, a, b
|
||||
from sympy.codegen.numpy_nodes import logaddexp
|
||||
from sympy.printing.numpy import CuPyPrinter, _cupy_known_constants, _cupy_known_functions
|
||||
|
||||
from sympy.testing.pytest import skip, raises
|
||||
from sympy.external import import_module
|
||||
|
||||
cp = import_module('cupy')
|
||||
|
||||
def test_cupy_print():
|
||||
prntr = CuPyPrinter()
|
||||
assert prntr.doprint(logaddexp(a, b)) == 'cupy.logaddexp(a, b)'
|
||||
assert prntr.doprint(sqrt(x)) == 'cupy.sqrt(x)'
|
||||
assert prntr.doprint(log(x)) == 'cupy.log(x)'
|
||||
assert prntr.doprint("acos(x)") == 'cupy.arccos(x)'
|
||||
assert prntr.doprint("exp(x)") == 'cupy.exp(x)'
|
||||
assert prntr.doprint("Abs(x)") == 'abs(x)'
|
||||
|
||||
def test_not_cupy_print():
|
||||
prntr = CuPyPrinter()
|
||||
with raises(NotImplementedError):
|
||||
prntr.doprint("abcd(x)")
|
||||
|
||||
def test_cupy_sum():
|
||||
if not cp:
|
||||
skip("CuPy not installed")
|
||||
|
||||
s = Sum(x ** i, (i, a, b))
|
||||
f = lambdify((a, b, x), s, 'cupy')
|
||||
|
||||
a_, b_ = 0, 10
|
||||
x_ = cp.linspace(-1, +1, 10)
|
||||
assert cp.allclose(f(a_, b_, x_), sum(x_ ** i_ for i_ in range(a_, b_ + 1)))
|
||||
|
||||
s = Sum(i * x, (i, a, b))
|
||||
f = lambdify((a, b, x), s, 'numpy')
|
||||
|
||||
a_, b_ = 0, 10
|
||||
x_ = cp.linspace(-1, +1, 10)
|
||||
assert cp.allclose(f(a_, b_, x_), sum(i_ * x_ for i_ in range(a_, b_ + 1)))
|
||||
|
||||
def test_cupy_known_funcs_consts():
|
||||
assert _cupy_known_constants['NaN'] == 'cupy.nan'
|
||||
assert _cupy_known_constants['EulerGamma'] == 'cupy.euler_gamma'
|
||||
|
||||
assert _cupy_known_functions['acos'] == 'cupy.arccos'
|
||||
assert _cupy_known_functions['log'] == 'cupy.log'
|
||||
|
||||
def test_cupy_print_methods():
|
||||
prntr = CuPyPrinter()
|
||||
assert hasattr(prntr, '_print_acos')
|
||||
assert hasattr(prntr, '_print_log')
|
||||
@@ -0,0 +1,86 @@
|
||||
from sympy.core.numbers import Float, Integer, Rational
|
||||
from sympy.core.symbol import symbols
|
||||
from sympy.functions import beta, Ei, zeta, Max, Min, sqrt, riemann_xi, frac
|
||||
from sympy.printing.cxx import CXX98CodePrinter, CXX11CodePrinter, CXX17CodePrinter, cxxcode
|
||||
from sympy.codegen.cfunctions import log1p
|
||||
|
||||
|
||||
x, y, u, v = symbols('x y u v')
|
||||
|
||||
|
||||
def test_CXX98CodePrinter():
|
||||
assert CXX98CodePrinter().doprint(Max(x, 3)) in ('std::max(x, 3)', 'std::max(3, x)')
|
||||
assert CXX98CodePrinter().doprint(Min(x, 3, sqrt(x))) == 'std::min(3, std::min(x, std::sqrt(x)))'
|
||||
cxx98printer = CXX98CodePrinter()
|
||||
assert cxx98printer.language == 'C++'
|
||||
assert cxx98printer.standard == 'C++98'
|
||||
assert 'template' in cxx98printer.reserved_words
|
||||
assert 'alignas' not in cxx98printer.reserved_words
|
||||
|
||||
|
||||
def test_CXX11CodePrinter():
|
||||
assert CXX11CodePrinter().doprint(log1p(x)) == 'std::log1p(x)'
|
||||
|
||||
cxx11printer = CXX11CodePrinter()
|
||||
assert cxx11printer.language == 'C++'
|
||||
assert cxx11printer.standard == 'C++11'
|
||||
assert 'operator' in cxx11printer.reserved_words
|
||||
assert 'noexcept' in cxx11printer.reserved_words
|
||||
assert 'concept' not in cxx11printer.reserved_words
|
||||
|
||||
|
||||
def test_subclass_print_method():
|
||||
class MyPrinter(CXX11CodePrinter):
|
||||
def _print_log1p(self, expr):
|
||||
return 'my_library::log1p(%s)' % ', '.join(map(self._print, expr.args))
|
||||
|
||||
assert MyPrinter().doprint(log1p(x)) == 'my_library::log1p(x)'
|
||||
|
||||
|
||||
def test_subclass_print_method__ns():
|
||||
class MyPrinter(CXX11CodePrinter):
|
||||
_ns = 'my_library::'
|
||||
|
||||
p = CXX11CodePrinter()
|
||||
myp = MyPrinter()
|
||||
|
||||
assert p.doprint(log1p(x)) == 'std::log1p(x)'
|
||||
assert myp.doprint(log1p(x)) == 'my_library::log1p(x)'
|
||||
|
||||
|
||||
def test_CXX17CodePrinter():
|
||||
assert CXX17CodePrinter().doprint(beta(x, y)) == 'std::beta(x, y)'
|
||||
assert CXX17CodePrinter().doprint(Ei(x)) == 'std::expint(x)'
|
||||
assert CXX17CodePrinter().doprint(zeta(x)) == 'std::riemann_zeta(x)'
|
||||
|
||||
# Automatic rewrite
|
||||
assert CXX17CodePrinter().doprint(frac(x)) == '(x - std::floor(x))'
|
||||
assert CXX17CodePrinter().doprint(riemann_xi(x)) == '((1.0/2.0)*std::pow(M_PI, -1.0/2.0*x)*x*(x - 1)*std::tgamma((1.0/2.0)*x)*std::riemann_zeta(x))'
|
||||
|
||||
|
||||
def test_cxxcode():
|
||||
assert sorted(cxxcode(sqrt(x)*.5).split('*')) == sorted(['0.5', 'std::sqrt(x)'])
|
||||
|
||||
def test_cxxcode_nested_minmax():
|
||||
assert cxxcode(Max(Min(x, y), Min(u, v))) \
|
||||
== 'std::max(std::min(u, v), std::min(x, y))'
|
||||
assert cxxcode(Min(Max(x, y), Max(u, v))) \
|
||||
== 'std::min(std::max(u, v), std::max(x, y))'
|
||||
|
||||
def test_subclass_Integer_Float():
|
||||
class MyPrinter(CXX17CodePrinter):
|
||||
def _print_Integer(self, arg):
|
||||
return 'bigInt("%s")' % super()._print_Integer(arg)
|
||||
|
||||
def _print_Float(self, arg):
|
||||
rat = Rational(arg)
|
||||
return 'bigFloat(%s, %s)' % (
|
||||
self._print(Integer(rat.p)),
|
||||
self._print(Integer(rat.q))
|
||||
)
|
||||
|
||||
p = MyPrinter()
|
||||
for i in range(13):
|
||||
assert p.doprint(i) == 'bigInt("%d")' % i
|
||||
assert p.doprint(Float(0.5)) == 'bigFloat(bigInt("1"), bigInt("2"))'
|
||||
assert p.doprint(x**-1.0) == 'bigFloat(bigInt("1"), bigInt("1"))/x'
|
||||
@@ -0,0 +1,134 @@
|
||||
from sympy.printing.dot import (purestr, styleof, attrprint, dotnode,
|
||||
dotedges, dotprint)
|
||||
from sympy.core.basic import Basic
|
||||
from sympy.core.expr import Expr
|
||||
from sympy.core.numbers import (Float, Integer)
|
||||
from sympy.core.singleton import S
|
||||
from sympy.core.symbol import (Symbol, symbols)
|
||||
from sympy.printing.repr import srepr
|
||||
from sympy.abc import x
|
||||
|
||||
|
||||
def test_purestr():
|
||||
assert purestr(Symbol('x')) == "Symbol('x')"
|
||||
assert purestr(Basic(S(1), S(2))) == "Basic(Integer(1), Integer(2))"
|
||||
assert purestr(Float(2)) == "Float('2.0', precision=53)"
|
||||
|
||||
assert purestr(Symbol('x'), with_args=True) == ("Symbol('x')", ())
|
||||
assert purestr(Basic(S(1), S(2)), with_args=True) == \
|
||||
('Basic(Integer(1), Integer(2))', ('Integer(1)', 'Integer(2)'))
|
||||
assert purestr(Float(2), with_args=True) == \
|
||||
("Float('2.0', precision=53)", ())
|
||||
|
||||
|
||||
def test_styleof():
|
||||
styles = [(Basic, {'color': 'blue', 'shape': 'ellipse'}),
|
||||
(Expr, {'color': 'black'})]
|
||||
assert styleof(Basic(S(1)), styles) == {'color': 'blue', 'shape': 'ellipse'}
|
||||
|
||||
assert styleof(x + 1, styles) == {'color': 'black', 'shape': 'ellipse'}
|
||||
|
||||
|
||||
def test_attrprint():
|
||||
assert attrprint({'color': 'blue', 'shape': 'ellipse'}) == \
|
||||
'"color"="blue", "shape"="ellipse"'
|
||||
|
||||
def test_dotnode():
|
||||
|
||||
assert dotnode(x, repeat=False) == \
|
||||
'"Symbol(\'x\')" ["color"="black", "label"="x", "shape"="ellipse"];'
|
||||
assert dotnode(x+2, repeat=False) == \
|
||||
'"Add(Integer(2), Symbol(\'x\'))" ' \
|
||||
'["color"="black", "label"="Add", "shape"="ellipse"];', \
|
||||
dotnode(x+2,repeat=0)
|
||||
|
||||
assert dotnode(x + x**2, repeat=False) == \
|
||||
'"Add(Symbol(\'x\'), Pow(Symbol(\'x\'), Integer(2)))" ' \
|
||||
'["color"="black", "label"="Add", "shape"="ellipse"];'
|
||||
assert dotnode(x + x**2, repeat=True) == \
|
||||
'"Add(Symbol(\'x\'), Pow(Symbol(\'x\'), Integer(2)))_()" ' \
|
||||
'["color"="black", "label"="Add", "shape"="ellipse"];'
|
||||
|
||||
def test_dotedges():
|
||||
assert sorted(dotedges(x+2, repeat=False)) == [
|
||||
'"Add(Integer(2), Symbol(\'x\'))" -> "Integer(2)";',
|
||||
'"Add(Integer(2), Symbol(\'x\'))" -> "Symbol(\'x\')";'
|
||||
]
|
||||
assert sorted(dotedges(x + 2, repeat=True)) == [
|
||||
'"Add(Integer(2), Symbol(\'x\'))_()" -> "Integer(2)_(0,)";',
|
||||
'"Add(Integer(2), Symbol(\'x\'))_()" -> "Symbol(\'x\')_(1,)";'
|
||||
]
|
||||
|
||||
def test_dotprint():
|
||||
text = dotprint(x+2, repeat=False)
|
||||
assert all(e in text for e in dotedges(x+2, repeat=False))
|
||||
assert all(
|
||||
n in text for n in [dotnode(expr, repeat=False)
|
||||
for expr in (x, Integer(2), x+2)])
|
||||
assert 'digraph' in text
|
||||
|
||||
text = dotprint(x+x**2, repeat=False)
|
||||
assert all(e in text for e in dotedges(x+x**2, repeat=False))
|
||||
assert all(
|
||||
n in text for n in [dotnode(expr, repeat=False)
|
||||
for expr in (x, Integer(2), x**2)])
|
||||
assert 'digraph' in text
|
||||
|
||||
text = dotprint(x+x**2, repeat=True)
|
||||
assert all(e in text for e in dotedges(x+x**2, repeat=True))
|
||||
assert all(
|
||||
n in text for n in [dotnode(expr, pos=())
|
||||
for expr in [x + x**2]])
|
||||
|
||||
text = dotprint(x**x, repeat=True)
|
||||
assert all(e in text for e in dotedges(x**x, repeat=True))
|
||||
assert all(
|
||||
n in text for n in [dotnode(x, pos=(0,)), dotnode(x, pos=(1,))])
|
||||
assert 'digraph' in text
|
||||
|
||||
def test_dotprint_depth():
|
||||
text = dotprint(3*x+2, depth=1)
|
||||
assert dotnode(3*x+2) in text
|
||||
assert dotnode(x) not in text
|
||||
text = dotprint(3*x+2)
|
||||
assert "depth" not in text
|
||||
|
||||
def test_Matrix_and_non_basics():
|
||||
from sympy.matrices.expressions.matexpr import MatrixSymbol
|
||||
n = Symbol('n')
|
||||
assert dotprint(MatrixSymbol('X', n, n)) == \
|
||||
"""digraph{
|
||||
|
||||
# Graph style
|
||||
"ordering"="out"
|
||||
"rankdir"="TD"
|
||||
|
||||
#########
|
||||
# Nodes #
|
||||
#########
|
||||
|
||||
"MatrixSymbol(Str('X'), Symbol('n'), Symbol('n'))_()" ["color"="black", "label"="MatrixSymbol", "shape"="ellipse"];
|
||||
"Str('X')_(0,)" ["color"="blue", "label"="X", "shape"="ellipse"];
|
||||
"Symbol('n')_(1,)" ["color"="black", "label"="n", "shape"="ellipse"];
|
||||
"Symbol('n')_(2,)" ["color"="black", "label"="n", "shape"="ellipse"];
|
||||
|
||||
#########
|
||||
# Edges #
|
||||
#########
|
||||
|
||||
"MatrixSymbol(Str('X'), Symbol('n'), Symbol('n'))_()" -> "Str('X')_(0,)";
|
||||
"MatrixSymbol(Str('X'), Symbol('n'), Symbol('n'))_()" -> "Symbol('n')_(1,)";
|
||||
"MatrixSymbol(Str('X'), Symbol('n'), Symbol('n'))_()" -> "Symbol('n')_(2,)";
|
||||
}"""
|
||||
|
||||
|
||||
def test_labelfunc():
|
||||
text = dotprint(x + 2, labelfunc=srepr)
|
||||
assert "Symbol('x')" in text
|
||||
assert "Integer(2)" in text
|
||||
|
||||
|
||||
def test_commutative():
|
||||
x, y = symbols('x y', commutative=False)
|
||||
assert dotprint(x + y) == dotprint(y + x)
|
||||
assert dotprint(x*y) != dotprint(y*x)
|
||||
@@ -0,0 +1,854 @@
|
||||
from sympy.core.add import Add
|
||||
from sympy.core.expr import Expr
|
||||
from sympy.core.function import (Function, Lambda, diff)
|
||||
from sympy.core.mod import Mod
|
||||
from sympy.core import (Catalan, EulerGamma, GoldenRatio)
|
||||
from sympy.core.numbers import (E, Float, I, Integer, Rational, pi)
|
||||
from sympy.core.relational import Eq
|
||||
from sympy.core.singleton import S
|
||||
from sympy.core.symbol import (Dummy, symbols)
|
||||
from sympy.functions.combinatorial.factorials import factorial
|
||||
from sympy.functions.elementary.complexes import (conjugate, sign)
|
||||
from sympy.functions.elementary.exponential import (exp, log)
|
||||
from sympy.functions.elementary.miscellaneous import sqrt
|
||||
from sympy.functions.elementary.piecewise import Piecewise
|
||||
from sympy.functions.elementary.trigonometric import (atan2, cos, sin)
|
||||
from sympy.functions.special.gamma_functions import gamma
|
||||
from sympy.integrals.integrals import Integral
|
||||
from sympy.sets.fancysets import Range
|
||||
|
||||
from sympy.codegen import For, Assignment, aug_assign
|
||||
from sympy.codegen.ast import Declaration, Variable, float32, float64, \
|
||||
value_const, real, bool_, While, FunctionPrototype, FunctionDefinition, \
|
||||
integer, Return, Element
|
||||
from sympy.core.expr import UnevaluatedExpr
|
||||
from sympy.core.relational import Relational
|
||||
from sympy.logic.boolalg import And, Or, Not, Equivalent, Xor
|
||||
from sympy.matrices import Matrix, MatrixSymbol
|
||||
from sympy.printing.fortran import fcode, FCodePrinter
|
||||
from sympy.tensor import IndexedBase, Idx
|
||||
from sympy.tensor.array.expressions import ArraySymbol, ArrayElement
|
||||
from sympy.utilities.lambdify import implemented_function
|
||||
from sympy.testing.pytest import raises
|
||||
|
||||
|
||||
def test_UnevaluatedExpr():
|
||||
p, q, r = symbols("p q r", real=True)
|
||||
q_r = UnevaluatedExpr(q + r)
|
||||
expr = abs(exp(p+q_r))
|
||||
assert fcode(expr, source_format="free") == "exp(p + (q + r))"
|
||||
x, y, z = symbols("x y z")
|
||||
y_z = UnevaluatedExpr(y + z)
|
||||
expr2 = abs(exp(x+y_z))
|
||||
assert fcode(expr2, human=False)[2].lstrip() == "exp(re(x) + re(y + z))"
|
||||
assert fcode(expr2, user_functions={"re": "realpart"}).lstrip() == "exp(realpart(x) + realpart(y + z))"
|
||||
|
||||
|
||||
def test_printmethod():
|
||||
x = symbols('x')
|
||||
|
||||
class nint(Function):
|
||||
def _fcode(self, printer):
|
||||
return "nint(%s)" % printer._print(self.args[0])
|
||||
assert fcode(nint(x)) == " nint(x)"
|
||||
|
||||
|
||||
def test_fcode_sign(): #issue 12267
|
||||
x=symbols('x')
|
||||
y=symbols('y', integer=True)
|
||||
z=symbols('z', complex=True)
|
||||
assert fcode(sign(x), standard=95, source_format='free') == "merge(0d0, dsign(1d0, x), x == 0d0)"
|
||||
assert fcode(sign(y), standard=95, source_format='free') == "merge(0, isign(1, y), y == 0)"
|
||||
assert fcode(sign(z), standard=95, source_format='free') == "merge(cmplx(0d0, 0d0), z/abs(z), abs(z) == 0d0)"
|
||||
raises(NotImplementedError, lambda: fcode(sign(x)))
|
||||
|
||||
|
||||
def test_fcode_Pow():
|
||||
x, y = symbols('x,y')
|
||||
n = symbols('n', integer=True)
|
||||
|
||||
assert fcode(x**3) == " x**3"
|
||||
assert fcode(x**(y**3)) == " x**(y**3)"
|
||||
assert fcode(1/(sin(x)*3.5)**(x - y**x)/(x**2 + y)) == \
|
||||
" (3.5d0*sin(x))**(-x + y**x)/(x**2 + y)"
|
||||
assert fcode(sqrt(x)) == ' sqrt(x)'
|
||||
assert fcode(sqrt(n)) == ' sqrt(dble(n))'
|
||||
assert fcode(x**0.5) == ' sqrt(x)'
|
||||
assert fcode(sqrt(x)) == ' sqrt(x)'
|
||||
assert fcode(sqrt(10)) == ' sqrt(10.0d0)'
|
||||
assert fcode(x**-1.0) == ' 1d0/x'
|
||||
assert fcode(x**-2.0, 'y', source_format='free') == 'y = x**(-2.0d0)' # 2823
|
||||
assert fcode(x**Rational(3, 7)) == ' x**(3.0d0/7.0d0)'
|
||||
|
||||
|
||||
def test_fcode_Rational():
|
||||
x = symbols('x')
|
||||
assert fcode(Rational(3, 7)) == " 3.0d0/7.0d0"
|
||||
assert fcode(Rational(18, 9)) == " 2"
|
||||
assert fcode(Rational(3, -7)) == " -3.0d0/7.0d0"
|
||||
assert fcode(Rational(-3, -7)) == " 3.0d0/7.0d0"
|
||||
assert fcode(x + Rational(3, 7)) == " x + 3.0d0/7.0d0"
|
||||
assert fcode(Rational(3, 7)*x) == " (3.0d0/7.0d0)*x"
|
||||
|
||||
|
||||
def test_fcode_Integer():
|
||||
assert fcode(Integer(67)) == " 67"
|
||||
assert fcode(Integer(-1)) == " -1"
|
||||
|
||||
|
||||
def test_fcode_Float():
|
||||
assert fcode(Float(42.0)) == " 42.0000000000000d0"
|
||||
assert fcode(Float(-1e20)) == " -1.00000000000000d+20"
|
||||
|
||||
|
||||
def test_fcode_functions():
|
||||
x, y = symbols('x,y')
|
||||
assert fcode(sin(x) ** cos(y)) == " sin(x)**cos(y)"
|
||||
raises(NotImplementedError, lambda: fcode(Mod(x, y), standard=66))
|
||||
raises(NotImplementedError, lambda: fcode(x % y, standard=66))
|
||||
raises(NotImplementedError, lambda: fcode(Mod(x, y), standard=77))
|
||||
raises(NotImplementedError, lambda: fcode(x % y, standard=77))
|
||||
for standard in [90, 95, 2003, 2008]:
|
||||
assert fcode(Mod(x, y), standard=standard) == " modulo(x, y)"
|
||||
assert fcode(x % y, standard=standard) == " modulo(x, y)"
|
||||
|
||||
|
||||
def test_case():
|
||||
ob = FCodePrinter()
|
||||
x,x_,x__,y,X,X_,Y = symbols('x,x_,x__,y,X,X_,Y')
|
||||
assert fcode(exp(x_) + sin(x*y) + cos(X*Y)) == \
|
||||
' exp(x_) + sin(x*y) + cos(X__*Y_)'
|
||||
assert fcode(exp(x__) + 2*x*Y*X_**Rational(7, 2)) == \
|
||||
' 2*X_**(7.0d0/2.0d0)*Y*x + exp(x__)'
|
||||
assert fcode(exp(x_) + sin(x*y) + cos(X*Y), name_mangling=False) == \
|
||||
' exp(x_) + sin(x*y) + cos(X*Y)'
|
||||
assert fcode(x - cos(X), name_mangling=False) == ' x - cos(X)'
|
||||
assert ob.doprint(X*sin(x) + x_, assign_to='me') == ' me = X*sin(x_) + x__'
|
||||
assert ob.doprint(X*sin(x), assign_to='mu') == ' mu = X*sin(x_)'
|
||||
assert ob.doprint(x_, assign_to='ad') == ' ad = x__'
|
||||
n, m = symbols('n,m', integer=True)
|
||||
A = IndexedBase('A')
|
||||
x = IndexedBase('x')
|
||||
y = IndexedBase('y')
|
||||
i = Idx('i', m)
|
||||
I = Idx('I', n)
|
||||
assert fcode(A[i, I]*x[I], assign_to=y[i], source_format='free') == (
|
||||
"do i = 1, m\n"
|
||||
" y(i) = 0\n"
|
||||
"end do\n"
|
||||
"do i = 1, m\n"
|
||||
" do I_ = 1, n\n"
|
||||
" y(i) = A(i, I_)*x(I_) + y(i)\n"
|
||||
" end do\n"
|
||||
"end do" )
|
||||
|
||||
|
||||
#issue 6814
|
||||
def test_fcode_functions_with_integers():
|
||||
x= symbols('x')
|
||||
log10_17 = log(10).evalf(17)
|
||||
loglog10_17 = '0.8340324452479558d0'
|
||||
assert fcode(x * log(10)) == " x*%sd0" % log10_17
|
||||
assert fcode(x * log(10)) == " x*%sd0" % log10_17
|
||||
assert fcode(x * log(S(10))) == " x*%sd0" % log10_17
|
||||
assert fcode(log(S(10))) == " %sd0" % log10_17
|
||||
assert fcode(exp(10)) == " %sd0" % exp(10).evalf(17)
|
||||
assert fcode(x * log(log(10))) == " x*%s" % loglog10_17
|
||||
assert fcode(x * log(log(S(10)))) == " x*%s" % loglog10_17
|
||||
|
||||
|
||||
def test_fcode_NumberSymbol():
|
||||
prec = 17
|
||||
p = FCodePrinter()
|
||||
assert fcode(Catalan) == ' parameter (Catalan = %sd0)\n Catalan' % Catalan.evalf(prec)
|
||||
assert fcode(EulerGamma) == ' parameter (EulerGamma = %sd0)\n EulerGamma' % EulerGamma.evalf(prec)
|
||||
assert fcode(E) == ' parameter (E = %sd0)\n E' % E.evalf(prec)
|
||||
assert fcode(GoldenRatio) == ' parameter (GoldenRatio = %sd0)\n GoldenRatio' % GoldenRatio.evalf(prec)
|
||||
assert fcode(pi) == ' parameter (pi = %sd0)\n pi' % pi.evalf(prec)
|
||||
assert fcode(
|
||||
pi, precision=5) == ' parameter (pi = %sd0)\n pi' % pi.evalf(5)
|
||||
assert fcode(Catalan, human=False) == ({
|
||||
(Catalan, p._print(Catalan.evalf(prec)))}, set(), ' Catalan')
|
||||
assert fcode(EulerGamma, human=False) == ({(EulerGamma, p._print(
|
||||
EulerGamma.evalf(prec)))}, set(), ' EulerGamma')
|
||||
assert fcode(E, human=False) == (
|
||||
{(E, p._print(E.evalf(prec)))}, set(), ' E')
|
||||
assert fcode(GoldenRatio, human=False) == ({(GoldenRatio, p._print(
|
||||
GoldenRatio.evalf(prec)))}, set(), ' GoldenRatio')
|
||||
assert fcode(pi, human=False) == (
|
||||
{(pi, p._print(pi.evalf(prec)))}, set(), ' pi')
|
||||
assert fcode(pi, precision=5, human=False) == (
|
||||
{(pi, p._print(pi.evalf(5)))}, set(), ' pi')
|
||||
|
||||
|
||||
def test_fcode_complex():
|
||||
assert fcode(I) == " cmplx(0,1)"
|
||||
x = symbols('x')
|
||||
assert fcode(4*I) == " cmplx(0,4)"
|
||||
assert fcode(3 + 4*I) == " cmplx(3,4)"
|
||||
assert fcode(3 + 4*I + x) == " cmplx(3,4) + x"
|
||||
assert fcode(I*x) == " cmplx(0,1)*x"
|
||||
assert fcode(3 + 4*I - x) == " cmplx(3,4) - x"
|
||||
x = symbols('x', imaginary=True)
|
||||
assert fcode(5*x) == " 5*x"
|
||||
assert fcode(I*x) == " cmplx(0,1)*x"
|
||||
assert fcode(3 + x) == " x + 3"
|
||||
|
||||
|
||||
def test_implicit():
|
||||
x, y = symbols('x,y')
|
||||
assert fcode(sin(x)) == " sin(x)"
|
||||
assert fcode(atan2(x, y)) == " atan2(x, y)"
|
||||
assert fcode(conjugate(x)) == " conjg(x)"
|
||||
|
||||
|
||||
def test_not_fortran():
|
||||
x = symbols('x')
|
||||
g = Function('g')
|
||||
with raises(NotImplementedError):
|
||||
fcode(gamma(x))
|
||||
assert fcode(Integral(sin(x)), strict=False) == "C Not supported in Fortran:\nC Integral\n Integral(sin(x), x)"
|
||||
with raises(NotImplementedError):
|
||||
fcode(g(x))
|
||||
|
||||
|
||||
def test_user_functions():
|
||||
x = symbols('x')
|
||||
assert fcode(sin(x), user_functions={"sin": "zsin"}) == " zsin(x)"
|
||||
x = symbols('x')
|
||||
assert fcode(
|
||||
gamma(x), user_functions={"gamma": "mygamma"}) == " mygamma(x)"
|
||||
g = Function('g')
|
||||
assert fcode(g(x), user_functions={"g": "great"}) == " great(x)"
|
||||
n = symbols('n', integer=True)
|
||||
assert fcode(
|
||||
factorial(n), user_functions={"factorial": "fct"}) == " fct(n)"
|
||||
|
||||
|
||||
def test_inline_function():
|
||||
x = symbols('x')
|
||||
g = implemented_function('g', Lambda(x, 2*x))
|
||||
assert fcode(g(x)) == " 2*x"
|
||||
g = implemented_function('g', Lambda(x, 2*pi/x))
|
||||
assert fcode(g(x)) == (
|
||||
" parameter (pi = %sd0)\n"
|
||||
" 2*pi/x"
|
||||
) % pi.evalf(17)
|
||||
A = IndexedBase('A')
|
||||
i = Idx('i', symbols('n', integer=True))
|
||||
g = implemented_function('g', Lambda(x, x*(1 + x)*(2 + x)))
|
||||
assert fcode(g(A[i]), assign_to=A[i]) == (
|
||||
" do i = 1, n\n"
|
||||
" A(i) = (A(i) + 1)*(A(i) + 2)*A(i)\n"
|
||||
" end do"
|
||||
)
|
||||
|
||||
|
||||
def test_assign_to():
|
||||
x = symbols('x')
|
||||
assert fcode(sin(x), assign_to="s") == " s = sin(x)"
|
||||
|
||||
|
||||
def test_line_wrapping():
|
||||
x, y = symbols('x,y')
|
||||
assert fcode(((x + y)**10).expand(), assign_to="var") == (
|
||||
" var = x**10 + 10*x**9*y + 45*x**8*y**2 + 120*x**7*y**3 + 210*x**6*\n"
|
||||
" @ y**4 + 252*x**5*y**5 + 210*x**4*y**6 + 120*x**3*y**7 + 45*x**2*y\n"
|
||||
" @ **8 + 10*x*y**9 + y**10"
|
||||
)
|
||||
e = [x**i for i in range(11)]
|
||||
assert fcode(Add(*e)) == (
|
||||
" x**10 + x**9 + x**8 + x**7 + x**6 + x**5 + x**4 + x**3 + x**2 + x\n"
|
||||
" @ + 1"
|
||||
)
|
||||
|
||||
|
||||
def test_fcode_precedence():
|
||||
x, y = symbols("x y")
|
||||
assert fcode(And(x < y, y < x + 1), source_format="free") == \
|
||||
"x < y .and. y < x + 1"
|
||||
assert fcode(Or(x < y, y < x + 1), source_format="free") == \
|
||||
"x < y .or. y < x + 1"
|
||||
assert fcode(Xor(x < y, y < x + 1, evaluate=False),
|
||||
source_format="free") == "x < y .neqv. y < x + 1"
|
||||
assert fcode(Equivalent(x < y, y < x + 1), source_format="free") == \
|
||||
"x < y .eqv. y < x + 1"
|
||||
|
||||
|
||||
def test_fcode_Logical():
|
||||
x, y, z = symbols("x y z")
|
||||
# unary Not
|
||||
assert fcode(Not(x), source_format="free") == ".not. x"
|
||||
# binary And
|
||||
assert fcode(And(x, y), source_format="free") == "x .and. y"
|
||||
assert fcode(And(x, Not(y)), source_format="free") == "x .and. .not. y"
|
||||
assert fcode(And(Not(x), y), source_format="free") == "y .and. .not. x"
|
||||
assert fcode(And(Not(x), Not(y)), source_format="free") == \
|
||||
".not. x .and. .not. y"
|
||||
assert fcode(Not(And(x, y), evaluate=False), source_format="free") == \
|
||||
".not. (x .and. y)"
|
||||
# binary Or
|
||||
assert fcode(Or(x, y), source_format="free") == "x .or. y"
|
||||
assert fcode(Or(x, Not(y)), source_format="free") == "x .or. .not. y"
|
||||
assert fcode(Or(Not(x), y), source_format="free") == "y .or. .not. x"
|
||||
assert fcode(Or(Not(x), Not(y)), source_format="free") == \
|
||||
".not. x .or. .not. y"
|
||||
assert fcode(Not(Or(x, y), evaluate=False), source_format="free") == \
|
||||
".not. (x .or. y)"
|
||||
# mixed And/Or
|
||||
assert fcode(And(Or(y, z), x), source_format="free") == "x .and. (y .or. z)"
|
||||
assert fcode(And(Or(z, x), y), source_format="free") == "y .and. (x .or. z)"
|
||||
assert fcode(And(Or(x, y), z), source_format="free") == "z .and. (x .or. y)"
|
||||
assert fcode(Or(And(y, z), x), source_format="free") == "x .or. y .and. z"
|
||||
assert fcode(Or(And(z, x), y), source_format="free") == "y .or. x .and. z"
|
||||
assert fcode(Or(And(x, y), z), source_format="free") == "z .or. x .and. y"
|
||||
# trinary And
|
||||
assert fcode(And(x, y, z), source_format="free") == "x .and. y .and. z"
|
||||
assert fcode(And(x, y, Not(z)), source_format="free") == \
|
||||
"x .and. y .and. .not. z"
|
||||
assert fcode(And(x, Not(y), z), source_format="free") == \
|
||||
"x .and. z .and. .not. y"
|
||||
assert fcode(And(Not(x), y, z), source_format="free") == \
|
||||
"y .and. z .and. .not. x"
|
||||
assert fcode(Not(And(x, y, z), evaluate=False), source_format="free") == \
|
||||
".not. (x .and. y .and. z)"
|
||||
# trinary Or
|
||||
assert fcode(Or(x, y, z), source_format="free") == "x .or. y .or. z"
|
||||
assert fcode(Or(x, y, Not(z)), source_format="free") == \
|
||||
"x .or. y .or. .not. z"
|
||||
assert fcode(Or(x, Not(y), z), source_format="free") == \
|
||||
"x .or. z .or. .not. y"
|
||||
assert fcode(Or(Not(x), y, z), source_format="free") == \
|
||||
"y .or. z .or. .not. x"
|
||||
assert fcode(Not(Or(x, y, z), evaluate=False), source_format="free") == \
|
||||
".not. (x .or. y .or. z)"
|
||||
|
||||
|
||||
def test_fcode_Xlogical():
|
||||
x, y, z = symbols("x y z")
|
||||
# binary Xor
|
||||
assert fcode(Xor(x, y, evaluate=False), source_format="free") == \
|
||||
"x .neqv. y"
|
||||
assert fcode(Xor(x, Not(y), evaluate=False), source_format="free") == \
|
||||
"x .neqv. .not. y"
|
||||
assert fcode(Xor(Not(x), y, evaluate=False), source_format="free") == \
|
||||
"y .neqv. .not. x"
|
||||
assert fcode(Xor(Not(x), Not(y), evaluate=False),
|
||||
source_format="free") == ".not. x .neqv. .not. y"
|
||||
assert fcode(Not(Xor(x, y, evaluate=False), evaluate=False),
|
||||
source_format="free") == ".not. (x .neqv. y)"
|
||||
# binary Equivalent
|
||||
assert fcode(Equivalent(x, y), source_format="free") == "x .eqv. y"
|
||||
assert fcode(Equivalent(x, Not(y)), source_format="free") == \
|
||||
"x .eqv. .not. y"
|
||||
assert fcode(Equivalent(Not(x), y), source_format="free") == \
|
||||
"y .eqv. .not. x"
|
||||
assert fcode(Equivalent(Not(x), Not(y)), source_format="free") == \
|
||||
".not. x .eqv. .not. y"
|
||||
assert fcode(Not(Equivalent(x, y), evaluate=False),
|
||||
source_format="free") == ".not. (x .eqv. y)"
|
||||
# mixed And/Equivalent
|
||||
assert fcode(Equivalent(And(y, z), x), source_format="free") == \
|
||||
"x .eqv. y .and. z"
|
||||
assert fcode(Equivalent(And(z, x), y), source_format="free") == \
|
||||
"y .eqv. x .and. z"
|
||||
assert fcode(Equivalent(And(x, y), z), source_format="free") == \
|
||||
"z .eqv. x .and. y"
|
||||
assert fcode(And(Equivalent(y, z), x), source_format="free") == \
|
||||
"x .and. (y .eqv. z)"
|
||||
assert fcode(And(Equivalent(z, x), y), source_format="free") == \
|
||||
"y .and. (x .eqv. z)"
|
||||
assert fcode(And(Equivalent(x, y), z), source_format="free") == \
|
||||
"z .and. (x .eqv. y)"
|
||||
# mixed Or/Equivalent
|
||||
assert fcode(Equivalent(Or(y, z), x), source_format="free") == \
|
||||
"x .eqv. y .or. z"
|
||||
assert fcode(Equivalent(Or(z, x), y), source_format="free") == \
|
||||
"y .eqv. x .or. z"
|
||||
assert fcode(Equivalent(Or(x, y), z), source_format="free") == \
|
||||
"z .eqv. x .or. y"
|
||||
assert fcode(Or(Equivalent(y, z), x), source_format="free") == \
|
||||
"x .or. (y .eqv. z)"
|
||||
assert fcode(Or(Equivalent(z, x), y), source_format="free") == \
|
||||
"y .or. (x .eqv. z)"
|
||||
assert fcode(Or(Equivalent(x, y), z), source_format="free") == \
|
||||
"z .or. (x .eqv. y)"
|
||||
# mixed Xor/Equivalent
|
||||
assert fcode(Equivalent(Xor(y, z, evaluate=False), x),
|
||||
source_format="free") == "x .eqv. (y .neqv. z)"
|
||||
assert fcode(Equivalent(Xor(z, x, evaluate=False), y),
|
||||
source_format="free") == "y .eqv. (x .neqv. z)"
|
||||
assert fcode(Equivalent(Xor(x, y, evaluate=False), z),
|
||||
source_format="free") == "z .eqv. (x .neqv. y)"
|
||||
assert fcode(Xor(Equivalent(y, z), x, evaluate=False),
|
||||
source_format="free") == "x .neqv. (y .eqv. z)"
|
||||
assert fcode(Xor(Equivalent(z, x), y, evaluate=False),
|
||||
source_format="free") == "y .neqv. (x .eqv. z)"
|
||||
assert fcode(Xor(Equivalent(x, y), z, evaluate=False),
|
||||
source_format="free") == "z .neqv. (x .eqv. y)"
|
||||
# mixed And/Xor
|
||||
assert fcode(Xor(And(y, z), x, evaluate=False), source_format="free") == \
|
||||
"x .neqv. y .and. z"
|
||||
assert fcode(Xor(And(z, x), y, evaluate=False), source_format="free") == \
|
||||
"y .neqv. x .and. z"
|
||||
assert fcode(Xor(And(x, y), z, evaluate=False), source_format="free") == \
|
||||
"z .neqv. x .and. y"
|
||||
assert fcode(And(Xor(y, z, evaluate=False), x), source_format="free") == \
|
||||
"x .and. (y .neqv. z)"
|
||||
assert fcode(And(Xor(z, x, evaluate=False), y), source_format="free") == \
|
||||
"y .and. (x .neqv. z)"
|
||||
assert fcode(And(Xor(x, y, evaluate=False), z), source_format="free") == \
|
||||
"z .and. (x .neqv. y)"
|
||||
# mixed Or/Xor
|
||||
assert fcode(Xor(Or(y, z), x, evaluate=False), source_format="free") == \
|
||||
"x .neqv. y .or. z"
|
||||
assert fcode(Xor(Or(z, x), y, evaluate=False), source_format="free") == \
|
||||
"y .neqv. x .or. z"
|
||||
assert fcode(Xor(Or(x, y), z, evaluate=False), source_format="free") == \
|
||||
"z .neqv. x .or. y"
|
||||
assert fcode(Or(Xor(y, z, evaluate=False), x), source_format="free") == \
|
||||
"x .or. (y .neqv. z)"
|
||||
assert fcode(Or(Xor(z, x, evaluate=False), y), source_format="free") == \
|
||||
"y .or. (x .neqv. z)"
|
||||
assert fcode(Or(Xor(x, y, evaluate=False), z), source_format="free") == \
|
||||
"z .or. (x .neqv. y)"
|
||||
# trinary Xor
|
||||
assert fcode(Xor(x, y, z, evaluate=False), source_format="free") == \
|
||||
"x .neqv. y .neqv. z"
|
||||
assert fcode(Xor(x, y, Not(z), evaluate=False), source_format="free") == \
|
||||
"x .neqv. y .neqv. .not. z"
|
||||
assert fcode(Xor(x, Not(y), z, evaluate=False), source_format="free") == \
|
||||
"x .neqv. z .neqv. .not. y"
|
||||
assert fcode(Xor(Not(x), y, z, evaluate=False), source_format="free") == \
|
||||
"y .neqv. z .neqv. .not. x"
|
||||
|
||||
|
||||
def test_fcode_Relational():
|
||||
x, y = symbols("x y")
|
||||
assert fcode(Relational(x, y, "=="), source_format="free") == "x == y"
|
||||
assert fcode(Relational(x, y, "!="), source_format="free") == "x /= y"
|
||||
assert fcode(Relational(x, y, ">="), source_format="free") == "x >= y"
|
||||
assert fcode(Relational(x, y, "<="), source_format="free") == "x <= y"
|
||||
assert fcode(Relational(x, y, ">"), source_format="free") == "x > y"
|
||||
assert fcode(Relational(x, y, "<"), source_format="free") == "x < y"
|
||||
|
||||
|
||||
def test_fcode_Piecewise():
|
||||
x = symbols('x')
|
||||
expr = Piecewise((x, x < 1), (x**2, True))
|
||||
# Check that inline conditional (merge) fails if standard isn't 95+
|
||||
raises(NotImplementedError, lambda: fcode(expr))
|
||||
code = fcode(expr, standard=95)
|
||||
expected = " merge(x, x**2, x < 1)"
|
||||
assert code == expected
|
||||
assert fcode(Piecewise((x, x < 1), (x**2, True)), assign_to="var") == (
|
||||
" if (x < 1) then\n"
|
||||
" var = x\n"
|
||||
" else\n"
|
||||
" var = x**2\n"
|
||||
" end if"
|
||||
)
|
||||
a = cos(x)/x
|
||||
b = sin(x)/x
|
||||
for i in range(10):
|
||||
a = diff(a, x)
|
||||
b = diff(b, x)
|
||||
expected = (
|
||||
" if (x < 0) then\n"
|
||||
" weird_name = -cos(x)/x + 10*sin(x)/x**2 + 90*cos(x)/x**3 - 720*\n"
|
||||
" @ sin(x)/x**4 - 5040*cos(x)/x**5 + 30240*sin(x)/x**6 + 151200*cos(x\n"
|
||||
" @ )/x**7 - 604800*sin(x)/x**8 - 1814400*cos(x)/x**9 + 3628800*sin(x\n"
|
||||
" @ )/x**10 + 3628800*cos(x)/x**11\n"
|
||||
" else\n"
|
||||
" weird_name = -sin(x)/x - 10*cos(x)/x**2 + 90*sin(x)/x**3 + 720*\n"
|
||||
" @ cos(x)/x**4 - 5040*sin(x)/x**5 - 30240*cos(x)/x**6 + 151200*sin(x\n"
|
||||
" @ )/x**7 + 604800*cos(x)/x**8 - 1814400*sin(x)/x**9 - 3628800*cos(x\n"
|
||||
" @ )/x**10 + 3628800*sin(x)/x**11\n"
|
||||
" end if"
|
||||
)
|
||||
code = fcode(Piecewise((a, x < 0), (b, True)), assign_to="weird_name")
|
||||
assert code == expected
|
||||
code = fcode(Piecewise((x, x < 1), (x**2, x > 1), (sin(x), True)), standard=95)
|
||||
expected = " merge(x, merge(x**2, sin(x), x > 1), x < 1)"
|
||||
assert code == expected
|
||||
# Check that Piecewise without a True (default) condition error
|
||||
expr = Piecewise((x, x < 1), (x**2, x > 1), (sin(x), x > 0))
|
||||
raises(ValueError, lambda: fcode(expr))
|
||||
|
||||
|
||||
def test_wrap_fortran():
|
||||
# "########################################################################"
|
||||
printer = FCodePrinter()
|
||||
lines = [
|
||||
"C This is a long comment on a single line that must be wrapped properly to produce nice output",
|
||||
" this = is + a + long + and + nasty + fortran + statement + that * must + be + wrapped + properly",
|
||||
" this = is + a + long + and + nasty + fortran + statement + that * must + be + wrapped + properly",
|
||||
" this = is + a + long + and + nasty + fortran + statement + that * must + be + wrapped + properly",
|
||||
" this = is + a + long + and + nasty + fortran + statement + that*must + be + wrapped + properly",
|
||||
" this = is + a + long + and + nasty + fortran + statement + that*must + be + wrapped + properly",
|
||||
" this = is + a + long + and + nasty + fortran + statement + that*must + be + wrapped + properly",
|
||||
" this = is + a + long + and + nasty + fortran + statement + that*must + be + wrapped + properly",
|
||||
" this = is + a + long + and + nasty + fortran + statement + that**must + be + wrapped + properly",
|
||||
" this = is + a + long + and + nasty + fortran + statement + that**must + be + wrapped + properly",
|
||||
" this = is + a + long + and + nasty + fortran + statement + that**must + be + wrapped + properly",
|
||||
" this = is + a + long + and + nasty + fortran + statement + that**must + be + wrapped + properly",
|
||||
" this = is + a + long + and + nasty + fortran + statement + that**must + be + wrapped + properly",
|
||||
" this = is + a + long + and + nasty + fortran + statement(that)/must + be + wrapped + properly",
|
||||
" this = is + a + long + and + nasty + fortran + statement(that)/must + be + wrapped + properly",
|
||||
]
|
||||
wrapped_lines = printer._wrap_fortran(lines)
|
||||
expected_lines = [
|
||||
"C This is a long comment on a single line that must be wrapped",
|
||||
"C properly to produce nice output",
|
||||
" this = is + a + long + and + nasty + fortran + statement + that *",
|
||||
" @ must + be + wrapped + properly",
|
||||
" this = is + a + long + and + nasty + fortran + statement + that *",
|
||||
" @ must + be + wrapped + properly",
|
||||
" this = is + a + long + and + nasty + fortran + statement + that",
|
||||
" @ * must + be + wrapped + properly",
|
||||
" this = is + a + long + and + nasty + fortran + statement + that*",
|
||||
" @ must + be + wrapped + properly",
|
||||
" this = is + a + long + and + nasty + fortran + statement + that*",
|
||||
" @ must + be + wrapped + properly",
|
||||
" this = is + a + long + and + nasty + fortran + statement + that",
|
||||
" @ *must + be + wrapped + properly",
|
||||
" this = is + a + long + and + nasty + fortran + statement +",
|
||||
" @ that*must + be + wrapped + properly",
|
||||
" this = is + a + long + and + nasty + fortran + statement + that**",
|
||||
" @ must + be + wrapped + properly",
|
||||
" this = is + a + long + and + nasty + fortran + statement + that**",
|
||||
" @ must + be + wrapped + properly",
|
||||
" this = is + a + long + and + nasty + fortran + statement + that",
|
||||
" @ **must + be + wrapped + properly",
|
||||
" this = is + a + long + and + nasty + fortran + statement + that",
|
||||
" @ **must + be + wrapped + properly",
|
||||
" this = is + a + long + and + nasty + fortran + statement +",
|
||||
" @ that**must + be + wrapped + properly",
|
||||
" this = is + a + long + and + nasty + fortran + statement(that)/",
|
||||
" @ must + be + wrapped + properly",
|
||||
" this = is + a + long + and + nasty + fortran + statement(that)",
|
||||
" @ /must + be + wrapped + properly",
|
||||
]
|
||||
for line in wrapped_lines:
|
||||
assert len(line) <= 72
|
||||
for w, e in zip(wrapped_lines, expected_lines):
|
||||
assert w == e
|
||||
assert len(wrapped_lines) == len(expected_lines)
|
||||
|
||||
|
||||
def test_wrap_fortran_keep_d0():
|
||||
printer = FCodePrinter()
|
||||
lines = [
|
||||
' this_variable_is_very_long_because_we_try_to_test_line_break=1.0d0',
|
||||
' this_variable_is_very_long_because_we_try_to_test_line_break =1.0d0',
|
||||
' this_variable_is_very_long_because_we_try_to_test_line_break = 1.0d0',
|
||||
' this_variable_is_very_long_because_we_try_to_test_line_break = 1.0d0',
|
||||
' this_variable_is_very_long_because_we_try_to_test_line_break = 1.0d0',
|
||||
' this_variable_is_very_long_because_we_try_to_test_line_break = 10.0d0'
|
||||
]
|
||||
expected = [
|
||||
' this_variable_is_very_long_because_we_try_to_test_line_break=1.0d0',
|
||||
' this_variable_is_very_long_because_we_try_to_test_line_break =',
|
||||
' @ 1.0d0',
|
||||
' this_variable_is_very_long_because_we_try_to_test_line_break =',
|
||||
' @ 1.0d0',
|
||||
' this_variable_is_very_long_because_we_try_to_test_line_break =',
|
||||
' @ 1.0d0',
|
||||
' this_variable_is_very_long_because_we_try_to_test_line_break =',
|
||||
' @ 1.0d0',
|
||||
' this_variable_is_very_long_because_we_try_to_test_line_break =',
|
||||
' @ 10.0d0'
|
||||
]
|
||||
assert printer._wrap_fortran(lines) == expected
|
||||
|
||||
|
||||
def test_settings():
|
||||
raises(TypeError, lambda: fcode(S(4), method="garbage"))
|
||||
|
||||
|
||||
def test_free_form_code_line():
|
||||
x, y = symbols('x,y')
|
||||
assert fcode(cos(x) + sin(y), source_format='free') == "sin(y) + cos(x)"
|
||||
|
||||
|
||||
def test_free_form_continuation_line():
|
||||
x, y = symbols('x,y')
|
||||
result = fcode(((cos(x) + sin(y))**(7)).expand(), source_format='free')
|
||||
expected = (
|
||||
'sin(y)**7 + 7*sin(y)**6*cos(x) + 21*sin(y)**5*cos(x)**2 + 35*sin(y)**4* &\n'
|
||||
' cos(x)**3 + 35*sin(y)**3*cos(x)**4 + 21*sin(y)**2*cos(x)**5 + 7* &\n'
|
||||
' sin(y)*cos(x)**6 + cos(x)**7'
|
||||
)
|
||||
assert result == expected
|
||||
|
||||
|
||||
def test_free_form_comment_line():
|
||||
printer = FCodePrinter({'source_format': 'free'})
|
||||
lines = [ "! This is a long comment on a single line that must be wrapped properly to produce nice output"]
|
||||
expected = [
|
||||
'! This is a long comment on a single line that must be wrapped properly',
|
||||
'! to produce nice output']
|
||||
assert printer._wrap_fortran(lines) == expected
|
||||
|
||||
|
||||
def test_loops():
|
||||
n, m = symbols('n,m', integer=True)
|
||||
A = IndexedBase('A')
|
||||
x = IndexedBase('x')
|
||||
y = IndexedBase('y')
|
||||
i = Idx('i', m)
|
||||
j = Idx('j', n)
|
||||
|
||||
expected = (
|
||||
'do i = 1, m\n'
|
||||
' y(i) = 0\n'
|
||||
'end do\n'
|
||||
'do i = 1, m\n'
|
||||
' do j = 1, n\n'
|
||||
' y(i) = %(rhs)s\n'
|
||||
' end do\n'
|
||||
'end do'
|
||||
)
|
||||
|
||||
code = fcode(A[i, j]*x[j], assign_to=y[i], source_format='free')
|
||||
assert (code == expected % {'rhs': 'y(i) + A(i, j)*x(j)'} or
|
||||
code == expected % {'rhs': 'y(i) + x(j)*A(i, j)'} or
|
||||
code == expected % {'rhs': 'x(j)*A(i, j) + y(i)'} or
|
||||
code == expected % {'rhs': 'A(i, j)*x(j) + y(i)'})
|
||||
|
||||
|
||||
def test_dummy_loops():
|
||||
i, m = symbols('i m', integer=True, cls=Dummy)
|
||||
x = IndexedBase('x')
|
||||
y = IndexedBase('y')
|
||||
i = Idx(i, m)
|
||||
|
||||
expected = (
|
||||
'do i_%(icount)i = 1, m_%(mcount)i\n'
|
||||
' y(i_%(icount)i) = x(i_%(icount)i)\n'
|
||||
'end do'
|
||||
) % {'icount': i.label.dummy_index, 'mcount': m.dummy_index}
|
||||
code = fcode(x[i], assign_to=y[i], source_format='free')
|
||||
assert code == expected
|
||||
|
||||
|
||||
def test_fcode_Indexed_without_looking_for_contraction():
|
||||
len_y = 5
|
||||
y = IndexedBase('y', shape=(len_y,))
|
||||
x = IndexedBase('x', shape=(len_y,))
|
||||
Dy = IndexedBase('Dy', shape=(len_y-1,))
|
||||
i = Idx('i', len_y-1)
|
||||
e=Eq(Dy[i], (y[i+1]-y[i])/(x[i+1]-x[i]))
|
||||
code0 = fcode(e.rhs, assign_to=e.lhs, contract=False)
|
||||
assert code0.endswith('Dy(i) = (y(i + 1) - y(i))/(x(i + 1) - x(i))')
|
||||
|
||||
|
||||
def test_element_like_objects():
|
||||
len_y = 5
|
||||
y = ArraySymbol('y', shape=(len_y,))
|
||||
x = ArraySymbol('x', shape=(len_y,))
|
||||
Dy = ArraySymbol('Dy', shape=(len_y-1,))
|
||||
i = Idx('i', len_y-1)
|
||||
e=Eq(Dy[i], (y[i+1]-y[i])/(x[i+1]-x[i]))
|
||||
code0 = fcode(Assignment(e.lhs, e.rhs))
|
||||
assert code0.endswith('Dy(i) = (y(i + 1) - y(i))/(x(i + 1) - x(i))')
|
||||
|
||||
class ElementExpr(Element, Expr):
|
||||
pass
|
||||
|
||||
e = e.subs((a, ElementExpr(a.name, a.indices)) for a in e.atoms(ArrayElement) )
|
||||
e=Eq(Dy[i], (y[i+1]-y[i])/(x[i+1]-x[i]))
|
||||
code0 = fcode(Assignment(e.lhs, e.rhs))
|
||||
assert code0.endswith('Dy(i) = (y(i + 1) - y(i))/(x(i + 1) - x(i))')
|
||||
|
||||
|
||||
def test_derived_classes():
|
||||
class MyFancyFCodePrinter(FCodePrinter):
|
||||
_default_settings = FCodePrinter._default_settings.copy()
|
||||
|
||||
printer = MyFancyFCodePrinter()
|
||||
x = symbols('x')
|
||||
assert printer.doprint(sin(x), "bork") == " bork = sin(x)"
|
||||
|
||||
|
||||
def test_indent():
|
||||
codelines = (
|
||||
'subroutine test(a)\n'
|
||||
'integer :: a, i, j\n'
|
||||
'\n'
|
||||
'do\n'
|
||||
'do \n'
|
||||
'do j = 1, 5\n'
|
||||
'if (a>b) then\n'
|
||||
'if(b>0) then\n'
|
||||
'a = 3\n'
|
||||
'donot_indent_me = 2\n'
|
||||
'do_not_indent_me_either = 2\n'
|
||||
'ifIam_indented_something_went_wrong = 2\n'
|
||||
'if_I_am_indented_something_went_wrong = 2\n'
|
||||
'end should not be unindented here\n'
|
||||
'end if\n'
|
||||
'endif\n'
|
||||
'end do\n'
|
||||
'end do\n'
|
||||
'enddo\n'
|
||||
'end subroutine\n'
|
||||
'\n'
|
||||
'subroutine test2(a)\n'
|
||||
'integer :: a\n'
|
||||
'do\n'
|
||||
'a = a + 1\n'
|
||||
'end do \n'
|
||||
'end subroutine\n'
|
||||
)
|
||||
expected = (
|
||||
'subroutine test(a)\n'
|
||||
'integer :: a, i, j\n'
|
||||
'\n'
|
||||
'do\n'
|
||||
' do \n'
|
||||
' do j = 1, 5\n'
|
||||
' if (a>b) then\n'
|
||||
' if(b>0) then\n'
|
||||
' a = 3\n'
|
||||
' donot_indent_me = 2\n'
|
||||
' do_not_indent_me_either = 2\n'
|
||||
' ifIam_indented_something_went_wrong = 2\n'
|
||||
' if_I_am_indented_something_went_wrong = 2\n'
|
||||
' end should not be unindented here\n'
|
||||
' end if\n'
|
||||
' endif\n'
|
||||
' end do\n'
|
||||
' end do\n'
|
||||
'enddo\n'
|
||||
'end subroutine\n'
|
||||
'\n'
|
||||
'subroutine test2(a)\n'
|
||||
'integer :: a\n'
|
||||
'do\n'
|
||||
' a = a + 1\n'
|
||||
'end do \n'
|
||||
'end subroutine\n'
|
||||
)
|
||||
p = FCodePrinter({'source_format': 'free'})
|
||||
result = p.indent_code(codelines)
|
||||
assert result == expected
|
||||
|
||||
def test_Matrix_printing():
|
||||
x, y, z = symbols('x,y,z')
|
||||
# Test returning a Matrix
|
||||
mat = Matrix([x*y, Piecewise((2 + x, y>0), (y, True)), sin(z)])
|
||||
A = MatrixSymbol('A', 3, 1)
|
||||
assert fcode(mat, A) == (
|
||||
" A(1, 1) = x*y\n"
|
||||
" if (y > 0) then\n"
|
||||
" A(2, 1) = x + 2\n"
|
||||
" else\n"
|
||||
" A(2, 1) = y\n"
|
||||
" end if\n"
|
||||
" A(3, 1) = sin(z)")
|
||||
# Test using MatrixElements in expressions
|
||||
expr = Piecewise((2*A[2, 0], x > 0), (A[2, 0], True)) + sin(A[1, 0]) + A[0, 0]
|
||||
assert fcode(expr, standard=95) == (
|
||||
" merge(2*A(3, 1), A(3, 1), x > 0) + sin(A(2, 1)) + A(1, 1)")
|
||||
# Test using MatrixElements in a Matrix
|
||||
q = MatrixSymbol('q', 5, 1)
|
||||
M = MatrixSymbol('M', 3, 3)
|
||||
m = Matrix([[sin(q[1,0]), 0, cos(q[2,0])],
|
||||
[q[1,0] + q[2,0], q[3, 0], 5],
|
||||
[2*q[4, 0]/q[1,0], sqrt(q[0,0]) + 4, 0]])
|
||||
assert fcode(m, M) == (
|
||||
" M(1, 1) = sin(q(2, 1))\n"
|
||||
" M(2, 1) = q(2, 1) + q(3, 1)\n"
|
||||
" M(3, 1) = 2*q(5, 1)/q(2, 1)\n"
|
||||
" M(1, 2) = 0\n"
|
||||
" M(2, 2) = q(4, 1)\n"
|
||||
" M(3, 2) = sqrt(q(1, 1)) + 4\n"
|
||||
" M(1, 3) = cos(q(3, 1))\n"
|
||||
" M(2, 3) = 5\n"
|
||||
" M(3, 3) = 0")
|
||||
|
||||
|
||||
def test_fcode_For():
|
||||
x, y = symbols('x y')
|
||||
|
||||
f = For(x, Range(0, 10, 2), [Assignment(y, x * y)])
|
||||
sol = fcode(f)
|
||||
assert sol == (" do x = 0, 9, 2\n"
|
||||
" y = x*y\n"
|
||||
" end do")
|
||||
|
||||
|
||||
def test_fcode_Declaration():
|
||||
def check(expr, ref, **kwargs):
|
||||
assert fcode(expr, standard=95, source_format='free', **kwargs) == ref
|
||||
|
||||
i = symbols('i', integer=True)
|
||||
var1 = Variable.deduced(i)
|
||||
dcl1 = Declaration(var1)
|
||||
check(dcl1, "integer*4 :: i")
|
||||
|
||||
|
||||
x, y = symbols('x y')
|
||||
var2 = Variable(x, float32, value=42, attrs={value_const})
|
||||
dcl2b = Declaration(var2)
|
||||
check(dcl2b, 'real*4, parameter :: x = 42')
|
||||
|
||||
var3 = Variable(y, type=bool_)
|
||||
dcl3 = Declaration(var3)
|
||||
check(dcl3, 'logical :: y')
|
||||
|
||||
check(float32, "real*4")
|
||||
check(float64, "real*8")
|
||||
check(real, "real*4", type_aliases={real: float32})
|
||||
check(real, "real*8", type_aliases={real: float64})
|
||||
|
||||
|
||||
def test_MatrixElement_printing():
|
||||
# test cases for issue #11821
|
||||
A = MatrixSymbol("A", 1, 3)
|
||||
B = MatrixSymbol("B", 1, 3)
|
||||
C = MatrixSymbol("C", 1, 3)
|
||||
|
||||
assert(fcode(A[0, 0]) == " A(1, 1)")
|
||||
assert(fcode(3 * A[0, 0]) == " 3*A(1, 1)")
|
||||
|
||||
F = C[0, 0].subs(C, A - B)
|
||||
assert(fcode(F) == " (A - B)(1, 1)")
|
||||
|
||||
|
||||
def test_aug_assign():
|
||||
x = symbols('x')
|
||||
assert fcode(aug_assign(x, '+', 1), source_format='free') == 'x = x + 1'
|
||||
|
||||
|
||||
def test_While():
|
||||
x = symbols('x')
|
||||
assert fcode(While(abs(x) > 1, [aug_assign(x, '-', 1)]), source_format='free') == (
|
||||
'do while (abs(x) > 1)\n'
|
||||
' x = x - 1\n'
|
||||
'end do'
|
||||
)
|
||||
|
||||
|
||||
def test_FunctionPrototype_print():
|
||||
x = symbols('x')
|
||||
n = symbols('n', integer=True)
|
||||
vx = Variable(x, type=real)
|
||||
vn = Variable(n, type=integer)
|
||||
fp1 = FunctionPrototype(real, 'power', [vx, vn])
|
||||
# Should be changed to proper test once multi-line generation is working
|
||||
# see https://github.com/sympy/sympy/issues/15824
|
||||
raises(NotImplementedError, lambda: fcode(fp1))
|
||||
|
||||
|
||||
def test_FunctionDefinition_print():
|
||||
x = symbols('x')
|
||||
n = symbols('n', integer=True)
|
||||
vx = Variable(x, type=real)
|
||||
vn = Variable(n, type=integer)
|
||||
body = [Assignment(x, x**n), Return(x)]
|
||||
fd1 = FunctionDefinition(real, 'power', [vx, vn], body)
|
||||
# Should be changed to proper test once multi-line generation is working
|
||||
# see https://github.com/sympy/sympy/issues/15824
|
||||
raises(NotImplementedError, lambda: fcode(fd1))
|
||||
@@ -0,0 +1,998 @@
|
||||
from sympy.core import (pi, symbols, Rational, Integer, GoldenRatio, EulerGamma,
|
||||
Catalan, Lambda, Dummy, Eq, Ne, Le, Lt, Gt, Ge)
|
||||
from sympy.functions import Piecewise, sin, cos, Abs, exp, ceiling, sqrt
|
||||
from sympy.testing.pytest import raises, warns_deprecated_sympy
|
||||
from sympy.printing.glsl import GLSLPrinter
|
||||
from sympy.printing.str import StrPrinter
|
||||
from sympy.utilities.lambdify import implemented_function
|
||||
from sympy.tensor import IndexedBase, Idx
|
||||
from sympy.matrices import Matrix, MatrixSymbol
|
||||
from sympy.core import Tuple
|
||||
from sympy.printing.glsl import glsl_code
|
||||
import textwrap
|
||||
|
||||
x, y, z = symbols('x,y,z')
|
||||
|
||||
|
||||
def test_printmethod():
|
||||
assert glsl_code(Abs(x)) == "abs(x)"
|
||||
|
||||
def test_print_without_operators():
|
||||
assert glsl_code(x*y,use_operators = False) == 'mul(x, y)'
|
||||
assert glsl_code(x**y+z,use_operators = False) == 'add(pow(x, y), z)'
|
||||
assert glsl_code(x*(y+z),use_operators = False) == 'mul(x, add(y, z))'
|
||||
assert glsl_code(x*(y+z),use_operators = False) == 'mul(x, add(y, z))'
|
||||
assert glsl_code(x*(y+z**y**0.5),use_operators = False) == 'mul(x, add(y, pow(z, sqrt(y))))'
|
||||
assert glsl_code(-x-y, use_operators=False, zero='zero()') == 'sub(zero(), add(x, y))'
|
||||
assert glsl_code(-x-y, use_operators=False) == 'sub(0.0, add(x, y))'
|
||||
|
||||
def test_glsl_code_sqrt():
|
||||
assert glsl_code(sqrt(x)) == "sqrt(x)"
|
||||
assert glsl_code(x**0.5) == "sqrt(x)"
|
||||
assert glsl_code(sqrt(x)) == "sqrt(x)"
|
||||
|
||||
|
||||
def test_glsl_code_Pow():
|
||||
g = implemented_function('g', Lambda(x, 2*x))
|
||||
assert glsl_code(x**3) == "pow(x, 3.0)"
|
||||
assert glsl_code(x**(y**3)) == "pow(x, pow(y, 3.0))"
|
||||
assert glsl_code(1/(g(x)*3.5)**(x - y**x)/(x**2 + y)) == \
|
||||
"pow(3.5*2*x, -x + pow(y, x))/(pow(x, 2.0) + y)"
|
||||
assert glsl_code(x**-1.0) == '1.0/x'
|
||||
|
||||
|
||||
def test_glsl_code_Relational():
|
||||
assert glsl_code(Eq(x, y)) == "x == y"
|
||||
assert glsl_code(Ne(x, y)) == "x != y"
|
||||
assert glsl_code(Le(x, y)) == "x <= y"
|
||||
assert glsl_code(Lt(x, y)) == "x < y"
|
||||
assert glsl_code(Gt(x, y)) == "x > y"
|
||||
assert glsl_code(Ge(x, y)) == "x >= y"
|
||||
|
||||
|
||||
def test_glsl_code_constants_mathh():
|
||||
assert glsl_code(exp(1)) == "float E = 2.71828183;\nE"
|
||||
assert glsl_code(pi) == "float pi = 3.14159265;\npi"
|
||||
# assert glsl_code(oo) == "Number.POSITIVE_INFINITY"
|
||||
# assert glsl_code(-oo) == "Number.NEGATIVE_INFINITY"
|
||||
|
||||
|
||||
def test_glsl_code_constants_other():
|
||||
assert glsl_code(2*GoldenRatio) == "float GoldenRatio = 1.61803399;\n2*GoldenRatio"
|
||||
assert glsl_code(2*Catalan) == "float Catalan = 0.915965594;\n2*Catalan"
|
||||
assert glsl_code(2*EulerGamma) == "float EulerGamma = 0.577215665;\n2*EulerGamma"
|
||||
|
||||
|
||||
def test_glsl_code_Rational():
|
||||
assert glsl_code(Rational(3, 7)) == "3.0/7.0"
|
||||
assert glsl_code(Rational(18, 9)) == "2"
|
||||
assert glsl_code(Rational(3, -7)) == "-3.0/7.0"
|
||||
assert glsl_code(Rational(-3, -7)) == "3.0/7.0"
|
||||
|
||||
|
||||
def test_glsl_code_Integer():
|
||||
assert glsl_code(Integer(67)) == "67"
|
||||
assert glsl_code(Integer(-1)) == "-1"
|
||||
|
||||
|
||||
def test_glsl_code_functions():
|
||||
assert glsl_code(sin(x) ** cos(x)) == "pow(sin(x), cos(x))"
|
||||
|
||||
|
||||
def test_glsl_code_inline_function():
|
||||
x = symbols('x')
|
||||
g = implemented_function('g', Lambda(x, 2*x))
|
||||
assert glsl_code(g(x)) == "2*x"
|
||||
g = implemented_function('g', Lambda(x, 2*x/Catalan))
|
||||
assert glsl_code(g(x)) == "float Catalan = 0.915965594;\n2*x/Catalan"
|
||||
A = IndexedBase('A')
|
||||
i = Idx('i', symbols('n', integer=True))
|
||||
g = implemented_function('g', Lambda(x, x*(1 + x)*(2 + x)))
|
||||
assert glsl_code(g(A[i]), assign_to=A[i]) == (
|
||||
"for (int i=0; i<n; i++){\n"
|
||||
" A[i] = (A[i] + 1)*(A[i] + 2)*A[i];\n"
|
||||
"}"
|
||||
)
|
||||
|
||||
|
||||
def test_glsl_code_exceptions():
|
||||
assert glsl_code(ceiling(x)) == "ceil(x)"
|
||||
assert glsl_code(Abs(x)) == "abs(x)"
|
||||
|
||||
|
||||
def test_glsl_code_boolean():
|
||||
assert glsl_code(x & y) == "x && y"
|
||||
assert glsl_code(x | y) == "x || y"
|
||||
assert glsl_code(~x) == "!x"
|
||||
assert glsl_code(x & y & z) == "x && y && z"
|
||||
assert glsl_code(x | y | z) == "x || y || z"
|
||||
assert glsl_code((x & y) | z) == "z || x && y"
|
||||
assert glsl_code((x | y) & z) == "z && (x || y)"
|
||||
|
||||
|
||||
def test_glsl_code_Piecewise():
|
||||
expr = Piecewise((x, x < 1), (x**2, True))
|
||||
p = glsl_code(expr)
|
||||
s = \
|
||||
"""\
|
||||
((x < 1) ? (
|
||||
x
|
||||
)
|
||||
: (
|
||||
pow(x, 2.0)
|
||||
))\
|
||||
"""
|
||||
assert p == s
|
||||
assert glsl_code(expr, assign_to="c") == (
|
||||
"if (x < 1) {\n"
|
||||
" c = x;\n"
|
||||
"}\n"
|
||||
"else {\n"
|
||||
" c = pow(x, 2.0);\n"
|
||||
"}")
|
||||
# Check that Piecewise without a True (default) condition error
|
||||
expr = Piecewise((x, x < 1), (x**2, x > 1), (sin(x), x > 0))
|
||||
raises(ValueError, lambda: glsl_code(expr))
|
||||
|
||||
|
||||
def test_glsl_code_Piecewise_deep():
|
||||
p = glsl_code(2*Piecewise((x, x < 1), (x**2, True)))
|
||||
s = \
|
||||
"""\
|
||||
2*((x < 1) ? (
|
||||
x
|
||||
)
|
||||
: (
|
||||
pow(x, 2.0)
|
||||
))\
|
||||
"""
|
||||
assert p == s
|
||||
|
||||
|
||||
def test_glsl_code_settings():
|
||||
raises(TypeError, lambda: glsl_code(sin(x), method="garbage"))
|
||||
|
||||
|
||||
def test_glsl_code_Indexed():
|
||||
n, m, o = symbols('n m o', integer=True)
|
||||
i, j, k = Idx('i', n), Idx('j', m), Idx('k', o)
|
||||
p = GLSLPrinter()
|
||||
p._not_c = set()
|
||||
|
||||
x = IndexedBase('x')[j]
|
||||
assert p._print_Indexed(x) == 'x[j]'
|
||||
A = IndexedBase('A')[i, j]
|
||||
assert p._print_Indexed(A) == 'A[%s]' % (m*i+j)
|
||||
B = IndexedBase('B')[i, j, k]
|
||||
assert p._print_Indexed(B) == 'B[%s]' % (i*o*m+j*o+k)
|
||||
|
||||
assert p._not_c == set()
|
||||
|
||||
def test_glsl_code_list_tuple_Tuple():
|
||||
assert glsl_code([1,2,3,4]) == 'vec4(1, 2, 3, 4)'
|
||||
assert glsl_code([1,2,3],glsl_types=False) == 'float[3](1, 2, 3)'
|
||||
assert glsl_code([1,2,3]) == glsl_code((1,2,3))
|
||||
assert glsl_code([1,2,3]) == glsl_code(Tuple(1,2,3))
|
||||
|
||||
m = MatrixSymbol('A',3,4)
|
||||
assert glsl_code([m[0],m[1]])
|
||||
|
||||
def test_glsl_code_loops_matrix_vector():
|
||||
n, m = symbols('n m', integer=True)
|
||||
A = IndexedBase('A')
|
||||
x = IndexedBase('x')
|
||||
y = IndexedBase('y')
|
||||
i = Idx('i', m)
|
||||
j = Idx('j', n)
|
||||
|
||||
s = (
|
||||
'for (int i=0; i<m; i++){\n'
|
||||
' y[i] = 0.0;\n'
|
||||
'}\n'
|
||||
'for (int i=0; i<m; i++){\n'
|
||||
' for (int j=0; j<n; j++){\n'
|
||||
' y[i] = A[n*i + j]*x[j] + y[i];\n'
|
||||
' }\n'
|
||||
'}'
|
||||
)
|
||||
|
||||
c = glsl_code(A[i, j]*x[j], assign_to=y[i])
|
||||
assert c == s
|
||||
|
||||
|
||||
def test_dummy_loops():
|
||||
i, m = symbols('i m', integer=True, cls=Dummy)
|
||||
x = IndexedBase('x')
|
||||
y = IndexedBase('y')
|
||||
i = Idx(i, m)
|
||||
|
||||
expected = (
|
||||
'for (int i_%(icount)i=0; i_%(icount)i<m_%(mcount)i; i_%(icount)i++){\n'
|
||||
' y[i_%(icount)i] = x[i_%(icount)i];\n'
|
||||
'}'
|
||||
) % {'icount': i.label.dummy_index, 'mcount': m.dummy_index}
|
||||
code = glsl_code(x[i], assign_to=y[i])
|
||||
assert code == expected
|
||||
|
||||
|
||||
def test_glsl_code_loops_add():
|
||||
n, m = symbols('n m', integer=True)
|
||||
A = IndexedBase('A')
|
||||
x = IndexedBase('x')
|
||||
y = IndexedBase('y')
|
||||
z = IndexedBase('z')
|
||||
i = Idx('i', m)
|
||||
j = Idx('j', n)
|
||||
|
||||
s = (
|
||||
'for (int i=0; i<m; i++){\n'
|
||||
' y[i] = x[i] + z[i];\n'
|
||||
'}\n'
|
||||
'for (int i=0; i<m; i++){\n'
|
||||
' for (int j=0; j<n; j++){\n'
|
||||
' y[i] = A[n*i + j]*x[j] + y[i];\n'
|
||||
' }\n'
|
||||
'}'
|
||||
)
|
||||
c = glsl_code(A[i, j]*x[j] + x[i] + z[i], assign_to=y[i])
|
||||
assert c == s
|
||||
|
||||
|
||||
def test_glsl_code_loops_multiple_contractions():
|
||||
n, m, o, p = symbols('n m o p', integer=True)
|
||||
a = IndexedBase('a')
|
||||
b = IndexedBase('b')
|
||||
y = IndexedBase('y')
|
||||
i = Idx('i', m)
|
||||
j = Idx('j', n)
|
||||
k = Idx('k', o)
|
||||
l = Idx('l', p)
|
||||
|
||||
s = (
|
||||
'for (int i=0; i<m; i++){\n'
|
||||
' y[i] = 0.0;\n'
|
||||
'}\n'
|
||||
'for (int i=0; i<m; i++){\n'
|
||||
' for (int j=0; j<n; j++){\n'
|
||||
' for (int k=0; k<o; k++){\n'
|
||||
' for (int l=0; l<p; l++){\n'
|
||||
' y[i] = a[%s]*b[%s] + y[i];\n' % (i*n*o*p + j*o*p + k*p + l, j*o*p + k*p + l) +\
|
||||
' }\n'
|
||||
' }\n'
|
||||
' }\n'
|
||||
'}'
|
||||
)
|
||||
c = glsl_code(b[j, k, l]*a[i, j, k, l], assign_to=y[i])
|
||||
assert c == s
|
||||
|
||||
|
||||
def test_glsl_code_loops_addfactor():
|
||||
n, m, o, p = symbols('n m o p', integer=True)
|
||||
a = IndexedBase('a')
|
||||
b = IndexedBase('b')
|
||||
c = IndexedBase('c')
|
||||
y = IndexedBase('y')
|
||||
i = Idx('i', m)
|
||||
j = Idx('j', n)
|
||||
k = Idx('k', o)
|
||||
l = Idx('l', p)
|
||||
|
||||
s = (
|
||||
'for (int i=0; i<m; i++){\n'
|
||||
' y[i] = 0.0;\n'
|
||||
'}\n'
|
||||
'for (int i=0; i<m; i++){\n'
|
||||
' for (int j=0; j<n; j++){\n'
|
||||
' for (int k=0; k<o; k++){\n'
|
||||
' for (int l=0; l<p; l++){\n'
|
||||
' y[i] = (a[%s] + b[%s])*c[%s] + y[i];\n' % (i*n*o*p + j*o*p + k*p + l, i*n*o*p + j*o*p + k*p + l, j*o*p + k*p + l) +\
|
||||
' }\n'
|
||||
' }\n'
|
||||
' }\n'
|
||||
'}'
|
||||
)
|
||||
c = glsl_code((a[i, j, k, l] + b[i, j, k, l])*c[j, k, l], assign_to=y[i])
|
||||
assert c == s
|
||||
|
||||
|
||||
def test_glsl_code_loops_multiple_terms():
|
||||
n, m, o, p = symbols('n m o p', integer=True)
|
||||
a = IndexedBase('a')
|
||||
b = IndexedBase('b')
|
||||
c = IndexedBase('c')
|
||||
y = IndexedBase('y')
|
||||
i = Idx('i', m)
|
||||
j = Idx('j', n)
|
||||
k = Idx('k', o)
|
||||
|
||||
s0 = (
|
||||
'for (int i=0; i<m; i++){\n'
|
||||
' y[i] = 0.0;\n'
|
||||
'}\n'
|
||||
)
|
||||
s1 = (
|
||||
'for (int i=0; i<m; i++){\n'
|
||||
' for (int j=0; j<n; j++){\n'
|
||||
' for (int k=0; k<o; k++){\n'
|
||||
' y[i] = b[j]*b[k]*c[%s] + y[i];\n' % (i*n*o + j*o + k) +\
|
||||
' }\n'
|
||||
' }\n'
|
||||
'}\n'
|
||||
)
|
||||
s2 = (
|
||||
'for (int i=0; i<m; i++){\n'
|
||||
' for (int k=0; k<o; k++){\n'
|
||||
' y[i] = a[%s]*b[k] + y[i];\n' % (i*o + k) +\
|
||||
' }\n'
|
||||
'}\n'
|
||||
)
|
||||
s3 = (
|
||||
'for (int i=0; i<m; i++){\n'
|
||||
' for (int j=0; j<n; j++){\n'
|
||||
' y[i] = a[%s]*b[j] + y[i];\n' % (i*n + j) +\
|
||||
' }\n'
|
||||
'}\n'
|
||||
)
|
||||
c = glsl_code(
|
||||
b[j]*a[i, j] + b[k]*a[i, k] + b[j]*b[k]*c[i, j, k], assign_to=y[i])
|
||||
assert (c == s0 + s1 + s2 + s3[:-1] or
|
||||
c == s0 + s1 + s3 + s2[:-1] or
|
||||
c == s0 + s2 + s1 + s3[:-1] or
|
||||
c == s0 + s2 + s3 + s1[:-1] or
|
||||
c == s0 + s3 + s1 + s2[:-1] or
|
||||
c == s0 + s3 + s2 + s1[:-1])
|
||||
|
||||
|
||||
def test_Matrix_printing():
|
||||
# Test returning a Matrix
|
||||
|
||||
mat = Matrix([x*y, Piecewise((2 + x, y>0), (y, True)), sin(z)])
|
||||
A = MatrixSymbol('A', 3, 1)
|
||||
assert glsl_code(mat, assign_to=A) == (
|
||||
'''A[0][0] = x*y;
|
||||
if (y > 0) {
|
||||
A[1][0] = x + 2;
|
||||
}
|
||||
else {
|
||||
A[1][0] = y;
|
||||
}
|
||||
A[2][0] = sin(z);''' )
|
||||
assert glsl_code(Matrix([A[0],A[1]]))
|
||||
# Test using MatrixElements in expressions
|
||||
expr = Piecewise((2*A[2, 0], x > 0), (A[2, 0], True)) + sin(A[1, 0]) + A[0, 0]
|
||||
assert glsl_code(expr) == (
|
||||
'''((x > 0) ? (
|
||||
2*A[2][0]
|
||||
)
|
||||
: (
|
||||
A[2][0]
|
||||
)) + sin(A[1][0]) + A[0][0]''' )
|
||||
|
||||
# Test using MatrixElements in a Matrix
|
||||
q = MatrixSymbol('q', 5, 1)
|
||||
M = MatrixSymbol('M', 3, 3)
|
||||
m = Matrix([[sin(q[1,0]), 0, cos(q[2,0])],
|
||||
[q[1,0] + q[2,0], q[3, 0], 5],
|
||||
[2*q[4, 0]/q[1,0], sqrt(q[0,0]) + 4, 0]])
|
||||
assert glsl_code(m,M) == (
|
||||
'''M[0][0] = sin(q[1]);
|
||||
M[0][1] = 0;
|
||||
M[0][2] = cos(q[2]);
|
||||
M[1][0] = q[1] + q[2];
|
||||
M[1][1] = q[3];
|
||||
M[1][2] = 5;
|
||||
M[2][0] = 2*q[4]/q[1];
|
||||
M[2][1] = sqrt(q[0]) + 4;
|
||||
M[2][2] = 0;'''
|
||||
)
|
||||
|
||||
def test_Matrices_1x7():
|
||||
gl = glsl_code
|
||||
A = Matrix([1,2,3,4,5,6,7])
|
||||
assert gl(A) == 'float[7](1, 2, 3, 4, 5, 6, 7)'
|
||||
assert gl(A.transpose()) == 'float[7](1, 2, 3, 4, 5, 6, 7)'
|
||||
|
||||
def test_Matrices_1x7_array_type_int():
|
||||
gl = glsl_code
|
||||
A = Matrix([1,2,3,4,5,6,7])
|
||||
assert gl(A, array_type='int') == 'int[7](1, 2, 3, 4, 5, 6, 7)'
|
||||
|
||||
def test_Tuple_array_type_custom():
|
||||
gl = glsl_code
|
||||
A = symbols('a b c')
|
||||
assert gl(A, array_type='AbcType', glsl_types=False) == 'AbcType[3](a, b, c)'
|
||||
|
||||
def test_Matrices_1x7_spread_assign_to_symbols():
|
||||
gl = glsl_code
|
||||
A = Matrix([1,2,3,4,5,6,7])
|
||||
assign_to = symbols('x.a x.b x.c x.d x.e x.f x.g')
|
||||
assert gl(A, assign_to=assign_to) == textwrap.dedent('''\
|
||||
x.a = 1;
|
||||
x.b = 2;
|
||||
x.c = 3;
|
||||
x.d = 4;
|
||||
x.e = 5;
|
||||
x.f = 6;
|
||||
x.g = 7;'''
|
||||
)
|
||||
|
||||
def test_spread_assign_to_nested_symbols():
|
||||
gl = glsl_code
|
||||
expr = ((1,2,3), (1,2,3))
|
||||
assign_to = (symbols('a b c'), symbols('x y z'))
|
||||
assert gl(expr, assign_to=assign_to) == textwrap.dedent('''\
|
||||
a = 1;
|
||||
b = 2;
|
||||
c = 3;
|
||||
x = 1;
|
||||
y = 2;
|
||||
z = 3;'''
|
||||
)
|
||||
|
||||
def test_spread_assign_to_deeply_nested_symbols():
|
||||
gl = glsl_code
|
||||
a, b, c, x, y, z = symbols('a b c x y z')
|
||||
expr = (((1,2),3), ((1,2),3))
|
||||
assign_to = (((a, b), c), ((x, y), z))
|
||||
assert gl(expr, assign_to=assign_to) == textwrap.dedent('''\
|
||||
a = 1;
|
||||
b = 2;
|
||||
c = 3;
|
||||
x = 1;
|
||||
y = 2;
|
||||
z = 3;'''
|
||||
)
|
||||
|
||||
def test_matrix_of_tuples_spread_assign_to_symbols():
|
||||
gl = glsl_code
|
||||
with warns_deprecated_sympy():
|
||||
expr = Matrix([[(1,2),(3,4)],[(5,6),(7,8)]])
|
||||
assign_to = (symbols('a b'), symbols('c d'), symbols('e f'), symbols('g h'))
|
||||
assert gl(expr, assign_to) == textwrap.dedent('''\
|
||||
a = 1;
|
||||
b = 2;
|
||||
c = 3;
|
||||
d = 4;
|
||||
e = 5;
|
||||
f = 6;
|
||||
g = 7;
|
||||
h = 8;'''
|
||||
)
|
||||
|
||||
def test_cannot_assign_to_cause_mismatched_length():
|
||||
expr = (1, 2)
|
||||
assign_to = symbols('x y z')
|
||||
raises(ValueError, lambda: glsl_code(expr, assign_to))
|
||||
|
||||
def test_matrix_4x4_assign():
|
||||
gl = glsl_code
|
||||
expr = MatrixSymbol('A',4,4) * MatrixSymbol('B',4,4) + MatrixSymbol('C',4,4)
|
||||
assign_to = MatrixSymbol('X',4,4)
|
||||
assert gl(expr, assign_to=assign_to) == textwrap.dedent('''\
|
||||
X[0][0] = A[0][0]*B[0][0] + A[0][1]*B[1][0] + A[0][2]*B[2][0] + A[0][3]*B[3][0] + C[0][0];
|
||||
X[0][1] = A[0][0]*B[0][1] + A[0][1]*B[1][1] + A[0][2]*B[2][1] + A[0][3]*B[3][1] + C[0][1];
|
||||
X[0][2] = A[0][0]*B[0][2] + A[0][1]*B[1][2] + A[0][2]*B[2][2] + A[0][3]*B[3][2] + C[0][2];
|
||||
X[0][3] = A[0][0]*B[0][3] + A[0][1]*B[1][3] + A[0][2]*B[2][3] + A[0][3]*B[3][3] + C[0][3];
|
||||
X[1][0] = A[1][0]*B[0][0] + A[1][1]*B[1][0] + A[1][2]*B[2][0] + A[1][3]*B[3][0] + C[1][0];
|
||||
X[1][1] = A[1][0]*B[0][1] + A[1][1]*B[1][1] + A[1][2]*B[2][1] + A[1][3]*B[3][1] + C[1][1];
|
||||
X[1][2] = A[1][0]*B[0][2] + A[1][1]*B[1][2] + A[1][2]*B[2][2] + A[1][3]*B[3][2] + C[1][2];
|
||||
X[1][3] = A[1][0]*B[0][3] + A[1][1]*B[1][3] + A[1][2]*B[2][3] + A[1][3]*B[3][3] + C[1][3];
|
||||
X[2][0] = A[2][0]*B[0][0] + A[2][1]*B[1][0] + A[2][2]*B[2][0] + A[2][3]*B[3][0] + C[2][0];
|
||||
X[2][1] = A[2][0]*B[0][1] + A[2][1]*B[1][1] + A[2][2]*B[2][1] + A[2][3]*B[3][1] + C[2][1];
|
||||
X[2][2] = A[2][0]*B[0][2] + A[2][1]*B[1][2] + A[2][2]*B[2][2] + A[2][3]*B[3][2] + C[2][2];
|
||||
X[2][3] = A[2][0]*B[0][3] + A[2][1]*B[1][3] + A[2][2]*B[2][3] + A[2][3]*B[3][3] + C[2][3];
|
||||
X[3][0] = A[3][0]*B[0][0] + A[3][1]*B[1][0] + A[3][2]*B[2][0] + A[3][3]*B[3][0] + C[3][0];
|
||||
X[3][1] = A[3][0]*B[0][1] + A[3][1]*B[1][1] + A[3][2]*B[2][1] + A[3][3]*B[3][1] + C[3][1];
|
||||
X[3][2] = A[3][0]*B[0][2] + A[3][1]*B[1][2] + A[3][2]*B[2][2] + A[3][3]*B[3][2] + C[3][2];
|
||||
X[3][3] = A[3][0]*B[0][3] + A[3][1]*B[1][3] + A[3][2]*B[2][3] + A[3][3]*B[3][3] + C[3][3];'''
|
||||
)
|
||||
|
||||
def test_1xN_vecs():
|
||||
gl = glsl_code
|
||||
for i in range(1,10):
|
||||
A = Matrix(range(i))
|
||||
assert gl(A.transpose()) == gl(A)
|
||||
assert gl(A,mat_transpose=True) == gl(A)
|
||||
if i > 1:
|
||||
if i <= 4:
|
||||
assert gl(A) == 'vec%s(%s)' % (i,', '.join(str(s) for s in range(i)))
|
||||
else:
|
||||
assert gl(A) == 'float[%s](%s)' % (i,', '.join(str(s) for s in range(i)))
|
||||
|
||||
def test_MxN_mats():
|
||||
generatedAssertions='def test_misc_mats():\n'
|
||||
for i in range(1,6):
|
||||
for j in range(1,6):
|
||||
A = Matrix([[x + y*j for x in range(j)] for y in range(i)])
|
||||
gl = glsl_code(A)
|
||||
glTransposed = glsl_code(A,mat_transpose=True)
|
||||
generatedAssertions+=' mat = '+StrPrinter()._print(A)+'\n\n'
|
||||
generatedAssertions+=' gl = \'\'\''+gl+'\'\'\'\n'
|
||||
generatedAssertions+=' glTransposed = \'\'\''+glTransposed+'\'\'\'\n\n'
|
||||
generatedAssertions+=' assert glsl_code(mat) == gl\n'
|
||||
generatedAssertions+=' assert glsl_code(mat,mat_transpose=True) == glTransposed\n'
|
||||
if i == 1 and j == 1:
|
||||
assert gl == '0'
|
||||
elif i <= 4 and j <= 4 and i>1 and j>1:
|
||||
assert gl.startswith('mat%s' % j)
|
||||
assert glTransposed.startswith('mat%s' % i)
|
||||
elif i == 1 and j <= 4:
|
||||
assert gl.startswith('vec')
|
||||
elif j == 1 and i <= 4:
|
||||
assert gl.startswith('vec')
|
||||
elif i == 1:
|
||||
assert gl.startswith('float[%s]('% j*i)
|
||||
assert glTransposed.startswith('float[%s]('% j*i)
|
||||
elif j == 1:
|
||||
assert gl.startswith('float[%s]('% i*j)
|
||||
assert glTransposed.startswith('float[%s]('% i*j)
|
||||
else:
|
||||
assert gl.startswith('float[%s](' % (i*j))
|
||||
assert glTransposed.startswith('float[%s](' % (i*j))
|
||||
glNested = glsl_code(A,mat_nested=True)
|
||||
glNestedTransposed = glsl_code(A,mat_transpose=True,mat_nested=True)
|
||||
assert glNested.startswith('float[%s][%s]' % (i,j))
|
||||
assert glNestedTransposed.startswith('float[%s][%s]' % (j,i))
|
||||
generatedAssertions+=' glNested = \'\'\''+glNested+'\'\'\'\n'
|
||||
generatedAssertions+=' glNestedTransposed = \'\'\''+glNestedTransposed+'\'\'\'\n\n'
|
||||
generatedAssertions+=' assert glsl_code(mat,mat_nested=True) == glNested\n'
|
||||
generatedAssertions+=' assert glsl_code(mat,mat_nested=True,mat_transpose=True) == glNestedTransposed\n\n'
|
||||
generateAssertions = False # set this to true to write bake these generated tests to a file
|
||||
if generateAssertions:
|
||||
gen = open('test_glsl_generated_matrices.py','w')
|
||||
gen.write(generatedAssertions)
|
||||
gen.close()
|
||||
|
||||
|
||||
# these assertions were generated from the previous function
|
||||
# glsl has complicated rules and this makes it easier to look over all the cases
|
||||
def test_misc_mats():
|
||||
|
||||
mat = Matrix([[0]])
|
||||
|
||||
gl = '''0'''
|
||||
glTransposed = '''0'''
|
||||
|
||||
assert glsl_code(mat) == gl
|
||||
assert glsl_code(mat,mat_transpose=True) == glTransposed
|
||||
|
||||
mat = Matrix([[0, 1]])
|
||||
|
||||
gl = '''vec2(0, 1)'''
|
||||
glTransposed = '''vec2(0, 1)'''
|
||||
|
||||
assert glsl_code(mat) == gl
|
||||
assert glsl_code(mat,mat_transpose=True) == glTransposed
|
||||
|
||||
mat = Matrix([[0, 1, 2]])
|
||||
|
||||
gl = '''vec3(0, 1, 2)'''
|
||||
glTransposed = '''vec3(0, 1, 2)'''
|
||||
|
||||
assert glsl_code(mat) == gl
|
||||
assert glsl_code(mat,mat_transpose=True) == glTransposed
|
||||
|
||||
mat = Matrix([[0, 1, 2, 3]])
|
||||
|
||||
gl = '''vec4(0, 1, 2, 3)'''
|
||||
glTransposed = '''vec4(0, 1, 2, 3)'''
|
||||
|
||||
assert glsl_code(mat) == gl
|
||||
assert glsl_code(mat,mat_transpose=True) == glTransposed
|
||||
|
||||
mat = Matrix([[0, 1, 2, 3, 4]])
|
||||
|
||||
gl = '''float[5](0, 1, 2, 3, 4)'''
|
||||
glTransposed = '''float[5](0, 1, 2, 3, 4)'''
|
||||
|
||||
assert glsl_code(mat) == gl
|
||||
assert glsl_code(mat,mat_transpose=True) == glTransposed
|
||||
|
||||
mat = Matrix([
|
||||
[0],
|
||||
[1]])
|
||||
|
||||
gl = '''vec2(0, 1)'''
|
||||
glTransposed = '''vec2(0, 1)'''
|
||||
|
||||
assert glsl_code(mat) == gl
|
||||
assert glsl_code(mat,mat_transpose=True) == glTransposed
|
||||
|
||||
mat = Matrix([
|
||||
[0, 1],
|
||||
[2, 3]])
|
||||
|
||||
gl = '''mat2(0, 1, 2, 3)'''
|
||||
glTransposed = '''mat2(0, 2, 1, 3)'''
|
||||
|
||||
assert glsl_code(mat) == gl
|
||||
assert glsl_code(mat,mat_transpose=True) == glTransposed
|
||||
|
||||
mat = Matrix([
|
||||
[0, 1, 2],
|
||||
[3, 4, 5]])
|
||||
|
||||
gl = '''mat3x2(0, 1, 2, 3, 4, 5)'''
|
||||
glTransposed = '''mat2x3(0, 3, 1, 4, 2, 5)'''
|
||||
|
||||
assert glsl_code(mat) == gl
|
||||
assert glsl_code(mat,mat_transpose=True) == glTransposed
|
||||
|
||||
mat = Matrix([
|
||||
[0, 1, 2, 3],
|
||||
[4, 5, 6, 7]])
|
||||
|
||||
gl = '''mat4x2(0, 1, 2, 3, 4, 5, 6, 7)'''
|
||||
glTransposed = '''mat2x4(0, 4, 1, 5, 2, 6, 3, 7)'''
|
||||
|
||||
assert glsl_code(mat) == gl
|
||||
assert glsl_code(mat,mat_transpose=True) == glTransposed
|
||||
|
||||
mat = Matrix([
|
||||
[0, 1, 2, 3, 4],
|
||||
[5, 6, 7, 8, 9]])
|
||||
|
||||
gl = '''float[10](
|
||||
0, 1, 2, 3, 4,
|
||||
5, 6, 7, 8, 9
|
||||
) /* a 2x5 matrix */'''
|
||||
glTransposed = '''float[10](
|
||||
0, 5,
|
||||
1, 6,
|
||||
2, 7,
|
||||
3, 8,
|
||||
4, 9
|
||||
) /* a 5x2 matrix */'''
|
||||
|
||||
assert glsl_code(mat) == gl
|
||||
assert glsl_code(mat,mat_transpose=True) == glTransposed
|
||||
glNested = '''float[2][5](
|
||||
float[](0, 1, 2, 3, 4),
|
||||
float[](5, 6, 7, 8, 9)
|
||||
)'''
|
||||
glNestedTransposed = '''float[5][2](
|
||||
float[](0, 5),
|
||||
float[](1, 6),
|
||||
float[](2, 7),
|
||||
float[](3, 8),
|
||||
float[](4, 9)
|
||||
)'''
|
||||
|
||||
assert glsl_code(mat,mat_nested=True) == glNested
|
||||
assert glsl_code(mat,mat_nested=True,mat_transpose=True) == glNestedTransposed
|
||||
|
||||
mat = Matrix([
|
||||
[0],
|
||||
[1],
|
||||
[2]])
|
||||
|
||||
gl = '''vec3(0, 1, 2)'''
|
||||
glTransposed = '''vec3(0, 1, 2)'''
|
||||
|
||||
assert glsl_code(mat) == gl
|
||||
assert glsl_code(mat,mat_transpose=True) == glTransposed
|
||||
|
||||
mat = Matrix([
|
||||
[0, 1],
|
||||
[2, 3],
|
||||
[4, 5]])
|
||||
|
||||
gl = '''mat2x3(0, 1, 2, 3, 4, 5)'''
|
||||
glTransposed = '''mat3x2(0, 2, 4, 1, 3, 5)'''
|
||||
|
||||
assert glsl_code(mat) == gl
|
||||
assert glsl_code(mat,mat_transpose=True) == glTransposed
|
||||
|
||||
mat = Matrix([
|
||||
[0, 1, 2],
|
||||
[3, 4, 5],
|
||||
[6, 7, 8]])
|
||||
|
||||
gl = '''mat3(0, 1, 2, 3, 4, 5, 6, 7, 8)'''
|
||||
glTransposed = '''mat3(0, 3, 6, 1, 4, 7, 2, 5, 8)'''
|
||||
|
||||
assert glsl_code(mat) == gl
|
||||
assert glsl_code(mat,mat_transpose=True) == glTransposed
|
||||
|
||||
mat = Matrix([
|
||||
[0, 1, 2, 3],
|
||||
[4, 5, 6, 7],
|
||||
[8, 9, 10, 11]])
|
||||
|
||||
gl = '''mat4x3(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11)'''
|
||||
glTransposed = '''mat3x4(0, 4, 8, 1, 5, 9, 2, 6, 10, 3, 7, 11)'''
|
||||
|
||||
assert glsl_code(mat) == gl
|
||||
assert glsl_code(mat,mat_transpose=True) == glTransposed
|
||||
|
||||
mat = Matrix([
|
||||
[ 0, 1, 2, 3, 4],
|
||||
[ 5, 6, 7, 8, 9],
|
||||
[10, 11, 12, 13, 14]])
|
||||
|
||||
gl = '''float[15](
|
||||
0, 1, 2, 3, 4,
|
||||
5, 6, 7, 8, 9,
|
||||
10, 11, 12, 13, 14
|
||||
) /* a 3x5 matrix */'''
|
||||
glTransposed = '''float[15](
|
||||
0, 5, 10,
|
||||
1, 6, 11,
|
||||
2, 7, 12,
|
||||
3, 8, 13,
|
||||
4, 9, 14
|
||||
) /* a 5x3 matrix */'''
|
||||
|
||||
assert glsl_code(mat) == gl
|
||||
assert glsl_code(mat,mat_transpose=True) == glTransposed
|
||||
glNested = '''float[3][5](
|
||||
float[]( 0, 1, 2, 3, 4),
|
||||
float[]( 5, 6, 7, 8, 9),
|
||||
float[](10, 11, 12, 13, 14)
|
||||
)'''
|
||||
glNestedTransposed = '''float[5][3](
|
||||
float[](0, 5, 10),
|
||||
float[](1, 6, 11),
|
||||
float[](2, 7, 12),
|
||||
float[](3, 8, 13),
|
||||
float[](4, 9, 14)
|
||||
)'''
|
||||
|
||||
assert glsl_code(mat,mat_nested=True) == glNested
|
||||
assert glsl_code(mat,mat_nested=True,mat_transpose=True) == glNestedTransposed
|
||||
|
||||
mat = Matrix([
|
||||
[0],
|
||||
[1],
|
||||
[2],
|
||||
[3]])
|
||||
|
||||
gl = '''vec4(0, 1, 2, 3)'''
|
||||
glTransposed = '''vec4(0, 1, 2, 3)'''
|
||||
|
||||
assert glsl_code(mat) == gl
|
||||
assert glsl_code(mat,mat_transpose=True) == glTransposed
|
||||
|
||||
mat = Matrix([
|
||||
[0, 1],
|
||||
[2, 3],
|
||||
[4, 5],
|
||||
[6, 7]])
|
||||
|
||||
gl = '''mat2x4(0, 1, 2, 3, 4, 5, 6, 7)'''
|
||||
glTransposed = '''mat4x2(0, 2, 4, 6, 1, 3, 5, 7)'''
|
||||
|
||||
assert glsl_code(mat) == gl
|
||||
assert glsl_code(mat,mat_transpose=True) == glTransposed
|
||||
|
||||
mat = Matrix([
|
||||
[0, 1, 2],
|
||||
[3, 4, 5],
|
||||
[6, 7, 8],
|
||||
[9, 10, 11]])
|
||||
|
||||
gl = '''mat3x4(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11)'''
|
||||
glTransposed = '''mat4x3(0, 3, 6, 9, 1, 4, 7, 10, 2, 5, 8, 11)'''
|
||||
|
||||
assert glsl_code(mat) == gl
|
||||
assert glsl_code(mat,mat_transpose=True) == glTransposed
|
||||
|
||||
mat = Matrix([
|
||||
[ 0, 1, 2, 3],
|
||||
[ 4, 5, 6, 7],
|
||||
[ 8, 9, 10, 11],
|
||||
[12, 13, 14, 15]])
|
||||
|
||||
gl = '''mat4( 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15)'''
|
||||
glTransposed = '''mat4(0, 4, 8, 12, 1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15)'''
|
||||
|
||||
assert glsl_code(mat) == gl
|
||||
assert glsl_code(mat,mat_transpose=True) == glTransposed
|
||||
|
||||
mat = Matrix([
|
||||
[ 0, 1, 2, 3, 4],
|
||||
[ 5, 6, 7, 8, 9],
|
||||
[10, 11, 12, 13, 14],
|
||||
[15, 16, 17, 18, 19]])
|
||||
|
||||
gl = '''float[20](
|
||||
0, 1, 2, 3, 4,
|
||||
5, 6, 7, 8, 9,
|
||||
10, 11, 12, 13, 14,
|
||||
15, 16, 17, 18, 19
|
||||
) /* a 4x5 matrix */'''
|
||||
glTransposed = '''float[20](
|
||||
0, 5, 10, 15,
|
||||
1, 6, 11, 16,
|
||||
2, 7, 12, 17,
|
||||
3, 8, 13, 18,
|
||||
4, 9, 14, 19
|
||||
) /* a 5x4 matrix */'''
|
||||
|
||||
assert glsl_code(mat) == gl
|
||||
assert glsl_code(mat,mat_transpose=True) == glTransposed
|
||||
glNested = '''float[4][5](
|
||||
float[]( 0, 1, 2, 3, 4),
|
||||
float[]( 5, 6, 7, 8, 9),
|
||||
float[](10, 11, 12, 13, 14),
|
||||
float[](15, 16, 17, 18, 19)
|
||||
)'''
|
||||
glNestedTransposed = '''float[5][4](
|
||||
float[](0, 5, 10, 15),
|
||||
float[](1, 6, 11, 16),
|
||||
float[](2, 7, 12, 17),
|
||||
float[](3, 8, 13, 18),
|
||||
float[](4, 9, 14, 19)
|
||||
)'''
|
||||
|
||||
assert glsl_code(mat,mat_nested=True) == glNested
|
||||
assert glsl_code(mat,mat_nested=True,mat_transpose=True) == glNestedTransposed
|
||||
|
||||
mat = Matrix([
|
||||
[0],
|
||||
[1],
|
||||
[2],
|
||||
[3],
|
||||
[4]])
|
||||
|
||||
gl = '''float[5](0, 1, 2, 3, 4)'''
|
||||
glTransposed = '''float[5](0, 1, 2, 3, 4)'''
|
||||
|
||||
assert glsl_code(mat) == gl
|
||||
assert glsl_code(mat,mat_transpose=True) == glTransposed
|
||||
|
||||
mat = Matrix([
|
||||
[0, 1],
|
||||
[2, 3],
|
||||
[4, 5],
|
||||
[6, 7],
|
||||
[8, 9]])
|
||||
|
||||
gl = '''float[10](
|
||||
0, 1,
|
||||
2, 3,
|
||||
4, 5,
|
||||
6, 7,
|
||||
8, 9
|
||||
) /* a 5x2 matrix */'''
|
||||
glTransposed = '''float[10](
|
||||
0, 2, 4, 6, 8,
|
||||
1, 3, 5, 7, 9
|
||||
) /* a 2x5 matrix */'''
|
||||
|
||||
assert glsl_code(mat) == gl
|
||||
assert glsl_code(mat,mat_transpose=True) == glTransposed
|
||||
glNested = '''float[5][2](
|
||||
float[](0, 1),
|
||||
float[](2, 3),
|
||||
float[](4, 5),
|
||||
float[](6, 7),
|
||||
float[](8, 9)
|
||||
)'''
|
||||
glNestedTransposed = '''float[2][5](
|
||||
float[](0, 2, 4, 6, 8),
|
||||
float[](1, 3, 5, 7, 9)
|
||||
)'''
|
||||
|
||||
assert glsl_code(mat,mat_nested=True) == glNested
|
||||
assert glsl_code(mat,mat_nested=True,mat_transpose=True) == glNestedTransposed
|
||||
|
||||
mat = Matrix([
|
||||
[ 0, 1, 2],
|
||||
[ 3, 4, 5],
|
||||
[ 6, 7, 8],
|
||||
[ 9, 10, 11],
|
||||
[12, 13, 14]])
|
||||
|
||||
gl = '''float[15](
|
||||
0, 1, 2,
|
||||
3, 4, 5,
|
||||
6, 7, 8,
|
||||
9, 10, 11,
|
||||
12, 13, 14
|
||||
) /* a 5x3 matrix */'''
|
||||
glTransposed = '''float[15](
|
||||
0, 3, 6, 9, 12,
|
||||
1, 4, 7, 10, 13,
|
||||
2, 5, 8, 11, 14
|
||||
) /* a 3x5 matrix */'''
|
||||
|
||||
assert glsl_code(mat) == gl
|
||||
assert glsl_code(mat,mat_transpose=True) == glTransposed
|
||||
glNested = '''float[5][3](
|
||||
float[]( 0, 1, 2),
|
||||
float[]( 3, 4, 5),
|
||||
float[]( 6, 7, 8),
|
||||
float[]( 9, 10, 11),
|
||||
float[](12, 13, 14)
|
||||
)'''
|
||||
glNestedTransposed = '''float[3][5](
|
||||
float[](0, 3, 6, 9, 12),
|
||||
float[](1, 4, 7, 10, 13),
|
||||
float[](2, 5, 8, 11, 14)
|
||||
)'''
|
||||
|
||||
assert glsl_code(mat,mat_nested=True) == glNested
|
||||
assert glsl_code(mat,mat_nested=True,mat_transpose=True) == glNestedTransposed
|
||||
|
||||
mat = Matrix([
|
||||
[ 0, 1, 2, 3],
|
||||
[ 4, 5, 6, 7],
|
||||
[ 8, 9, 10, 11],
|
||||
[12, 13, 14, 15],
|
||||
[16, 17, 18, 19]])
|
||||
|
||||
gl = '''float[20](
|
||||
0, 1, 2, 3,
|
||||
4, 5, 6, 7,
|
||||
8, 9, 10, 11,
|
||||
12, 13, 14, 15,
|
||||
16, 17, 18, 19
|
||||
) /* a 5x4 matrix */'''
|
||||
glTransposed = '''float[20](
|
||||
0, 4, 8, 12, 16,
|
||||
1, 5, 9, 13, 17,
|
||||
2, 6, 10, 14, 18,
|
||||
3, 7, 11, 15, 19
|
||||
) /* a 4x5 matrix */'''
|
||||
|
||||
assert glsl_code(mat) == gl
|
||||
assert glsl_code(mat,mat_transpose=True) == glTransposed
|
||||
glNested = '''float[5][4](
|
||||
float[]( 0, 1, 2, 3),
|
||||
float[]( 4, 5, 6, 7),
|
||||
float[]( 8, 9, 10, 11),
|
||||
float[](12, 13, 14, 15),
|
||||
float[](16, 17, 18, 19)
|
||||
)'''
|
||||
glNestedTransposed = '''float[4][5](
|
||||
float[](0, 4, 8, 12, 16),
|
||||
float[](1, 5, 9, 13, 17),
|
||||
float[](2, 6, 10, 14, 18),
|
||||
float[](3, 7, 11, 15, 19)
|
||||
)'''
|
||||
|
||||
assert glsl_code(mat,mat_nested=True) == glNested
|
||||
assert glsl_code(mat,mat_nested=True,mat_transpose=True) == glNestedTransposed
|
||||
|
||||
mat = Matrix([
|
||||
[ 0, 1, 2, 3, 4],
|
||||
[ 5, 6, 7, 8, 9],
|
||||
[10, 11, 12, 13, 14],
|
||||
[15, 16, 17, 18, 19],
|
||||
[20, 21, 22, 23, 24]])
|
||||
|
||||
gl = '''float[25](
|
||||
0, 1, 2, 3, 4,
|
||||
5, 6, 7, 8, 9,
|
||||
10, 11, 12, 13, 14,
|
||||
15, 16, 17, 18, 19,
|
||||
20, 21, 22, 23, 24
|
||||
) /* a 5x5 matrix */'''
|
||||
glTransposed = '''float[25](
|
||||
0, 5, 10, 15, 20,
|
||||
1, 6, 11, 16, 21,
|
||||
2, 7, 12, 17, 22,
|
||||
3, 8, 13, 18, 23,
|
||||
4, 9, 14, 19, 24
|
||||
) /* a 5x5 matrix */'''
|
||||
|
||||
assert glsl_code(mat) == gl
|
||||
assert glsl_code(mat,mat_transpose=True) == glTransposed
|
||||
glNested = '''float[5][5](
|
||||
float[]( 0, 1, 2, 3, 4),
|
||||
float[]( 5, 6, 7, 8, 9),
|
||||
float[](10, 11, 12, 13, 14),
|
||||
float[](15, 16, 17, 18, 19),
|
||||
float[](20, 21, 22, 23, 24)
|
||||
)'''
|
||||
glNestedTransposed = '''float[5][5](
|
||||
float[](0, 5, 10, 15, 20),
|
||||
float[](1, 6, 11, 16, 21),
|
||||
float[](2, 7, 12, 17, 22),
|
||||
float[](3, 8, 13, 18, 23),
|
||||
float[](4, 9, 14, 19, 24)
|
||||
)'''
|
||||
|
||||
assert glsl_code(mat,mat_nested=True) == glNested
|
||||
assert glsl_code(mat,mat_nested=True,mat_transpose=True) == glNestedTransposed
|
||||
@@ -0,0 +1,18 @@
|
||||
from sympy.functions.elementary.trigonometric import sin
|
||||
from sympy.printing.gtk import print_gtk
|
||||
from sympy.testing.pytest import XFAIL, raises
|
||||
|
||||
# this test fails if python-lxml isn't installed. We don't want to depend on
|
||||
# anything with SymPy
|
||||
|
||||
|
||||
@XFAIL
|
||||
def test_1():
|
||||
from sympy.abc import x
|
||||
print_gtk(x**2, start_viewer=False)
|
||||
print_gtk(x**2 + sin(x)/4, start_viewer=False)
|
||||
|
||||
|
||||
def test_settings():
|
||||
from sympy.abc import x
|
||||
raises(TypeError, lambda: print_gtk(x, method="garbage"))
|
||||
@@ -0,0 +1,370 @@
|
||||
from sympy.concrete.summations import Sum
|
||||
from sympy.core.mod import Mod
|
||||
from sympy.core.relational import (Equality, Unequality)
|
||||
from sympy.functions.elementary.miscellaneous import sqrt
|
||||
from sympy.functions.elementary.piecewise import Piecewise
|
||||
from sympy.matrices.expressions.blockmatrix import BlockMatrix
|
||||
from sympy.matrices.expressions.matexpr import MatrixSymbol
|
||||
from sympy.matrices.expressions.special import Identity
|
||||
from sympy.utilities.lambdify import lambdify
|
||||
|
||||
from sympy.abc import x, i, j, a, b, c, d
|
||||
from sympy.core import Function, Pow, Symbol
|
||||
from sympy.codegen.matrix_nodes import MatrixSolve
|
||||
from sympy.codegen.numpy_nodes import logaddexp, logaddexp2
|
||||
from sympy.codegen.cfunctions import log1p, expm1, hypot, log10, exp2, log2, Sqrt
|
||||
from sympy.tensor.array import Array
|
||||
from sympy.tensor.array.expressions.array_expressions import ArrayTensorProduct, ArrayAdd, \
|
||||
PermuteDims, ArrayDiagonal
|
||||
from sympy.printing.numpy import JaxPrinter, _jax_known_constants, _jax_known_functions
|
||||
from sympy.tensor.array.expressions.from_matrix_to_array import convert_matrix_to_array
|
||||
|
||||
from sympy.testing.pytest import skip, raises
|
||||
from sympy.external import import_module
|
||||
|
||||
# Unlike NumPy which will aggressively promote operands to double precision,
|
||||
# jax always uses single precision. Double precision in jax can be
|
||||
# configured before the call to `import jax`, however this must be explicitly
|
||||
# configured and is not fully supported. Thus, the tests here have been modified
|
||||
# from the tests in test_numpy.py, only in the fact that they assert lambdify
|
||||
# function accuracy to only single precision accuracy.
|
||||
# https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#double-64bit-precision
|
||||
|
||||
jax = import_module('jax')
|
||||
|
||||
if jax:
|
||||
deafult_float_info = jax.numpy.finfo(jax.numpy.array([]).dtype)
|
||||
JAX_DEFAULT_EPSILON = deafult_float_info.eps
|
||||
|
||||
|
||||
def test_jax_piecewise_regression():
|
||||
"""
|
||||
NumPyPrinter needs to print Piecewise()'s choicelist as a list to avoid
|
||||
breaking compatibility with numpy 1.8. This is not necessary in numpy 1.9+.
|
||||
See gh-9747 and gh-9749 for details.
|
||||
"""
|
||||
printer = JaxPrinter()
|
||||
p = Piecewise((1, x < 0), (0, True))
|
||||
assert printer.doprint(p) == \
|
||||
'jax.numpy.select([jax.numpy.less(x, 0),True], [1,0], default=jax.numpy.nan)'
|
||||
assert printer.module_imports == {'jax.numpy': {'select', 'less', 'nan'}}
|
||||
|
||||
|
||||
def test_jax_logaddexp():
|
||||
lae = logaddexp(a, b)
|
||||
assert JaxPrinter().doprint(lae) == 'jax.numpy.logaddexp(a, b)'
|
||||
lae2 = logaddexp2(a, b)
|
||||
assert JaxPrinter().doprint(lae2) == 'jax.numpy.logaddexp2(a, b)'
|
||||
|
||||
|
||||
def test_jax_sum():
|
||||
if not jax:
|
||||
skip("JAX not installed")
|
||||
|
||||
s = Sum(x ** i, (i, a, b))
|
||||
f = lambdify((a, b, x), s, 'jax')
|
||||
|
||||
a_, b_ = 0, 10
|
||||
x_ = jax.numpy.linspace(-1, +1, 10)
|
||||
assert jax.numpy.allclose(f(a_, b_, x_), sum(x_ ** i_ for i_ in range(a_, b_ + 1)))
|
||||
|
||||
s = Sum(i * x, (i, a, b))
|
||||
f = lambdify((a, b, x), s, 'jax')
|
||||
|
||||
a_, b_ = 0, 10
|
||||
x_ = jax.numpy.linspace(-1, +1, 10)
|
||||
assert jax.numpy.allclose(f(a_, b_, x_), sum(i_ * x_ for i_ in range(a_, b_ + 1)))
|
||||
|
||||
|
||||
def test_jax_multiple_sums():
|
||||
if not jax:
|
||||
skip("JAX not installed")
|
||||
|
||||
s = Sum((x + j) * i, (i, a, b), (j, c, d))
|
||||
f = lambdify((a, b, c, d, x), s, 'jax')
|
||||
|
||||
a_, b_ = 0, 10
|
||||
c_, d_ = 11, 21
|
||||
x_ = jax.numpy.linspace(-1, +1, 10)
|
||||
assert jax.numpy.allclose(f(a_, b_, c_, d_, x_),
|
||||
sum((x_ + j_) * i_ for i_ in range(a_, b_ + 1) for j_ in range(c_, d_ + 1)))
|
||||
|
||||
|
||||
def test_jax_codegen_einsum():
|
||||
if not jax:
|
||||
skip("JAX not installed")
|
||||
|
||||
M = MatrixSymbol("M", 2, 2)
|
||||
N = MatrixSymbol("N", 2, 2)
|
||||
|
||||
cg = convert_matrix_to_array(M * N)
|
||||
f = lambdify((M, N), cg, 'jax')
|
||||
|
||||
ma = jax.numpy.array([[1, 2], [3, 4]])
|
||||
mb = jax.numpy.array([[1,-2], [-1, 3]])
|
||||
assert (f(ma, mb) == jax.numpy.matmul(ma, mb)).all()
|
||||
|
||||
|
||||
def test_jax_codegen_extra():
|
||||
if not jax:
|
||||
skip("JAX not installed")
|
||||
|
||||
M = MatrixSymbol("M", 2, 2)
|
||||
N = MatrixSymbol("N", 2, 2)
|
||||
P = MatrixSymbol("P", 2, 2)
|
||||
Q = MatrixSymbol("Q", 2, 2)
|
||||
ma = jax.numpy.array([[1, 2], [3, 4]])
|
||||
mb = jax.numpy.array([[1,-2], [-1, 3]])
|
||||
mc = jax.numpy.array([[2, 0], [1, 2]])
|
||||
md = jax.numpy.array([[1,-1], [4, 7]])
|
||||
|
||||
cg = ArrayTensorProduct(M, N)
|
||||
f = lambdify((M, N), cg, 'jax')
|
||||
assert (f(ma, mb) == jax.numpy.einsum(ma, [0, 1], mb, [2, 3])).all()
|
||||
|
||||
cg = ArrayAdd(M, N)
|
||||
f = lambdify((M, N), cg, 'jax')
|
||||
assert (f(ma, mb) == ma+mb).all()
|
||||
|
||||
cg = ArrayAdd(M, N, P)
|
||||
f = lambdify((M, N, P), cg, 'jax')
|
||||
assert (f(ma, mb, mc) == ma+mb+mc).all()
|
||||
|
||||
cg = ArrayAdd(M, N, P, Q)
|
||||
f = lambdify((M, N, P, Q), cg, 'jax')
|
||||
assert (f(ma, mb, mc, md) == ma+mb+mc+md).all()
|
||||
|
||||
cg = PermuteDims(M, [1, 0])
|
||||
f = lambdify((M,), cg, 'jax')
|
||||
assert (f(ma) == ma.T).all()
|
||||
|
||||
cg = PermuteDims(ArrayTensorProduct(M, N), [1, 2, 3, 0])
|
||||
f = lambdify((M, N), cg, 'jax')
|
||||
assert (f(ma, mb) == jax.numpy.transpose(jax.numpy.einsum(ma, [0, 1], mb, [2, 3]), (1, 2, 3, 0))).all()
|
||||
|
||||
cg = ArrayDiagonal(ArrayTensorProduct(M, N), (1, 2))
|
||||
f = lambdify((M, N), cg, 'jax')
|
||||
assert (f(ma, mb) == jax.numpy.diagonal(jax.numpy.einsum(ma, [0, 1], mb, [2, 3]), axis1=1, axis2=2)).all()
|
||||
|
||||
|
||||
def test_jax_relational():
|
||||
if not jax:
|
||||
skip("JAX not installed")
|
||||
|
||||
e = Equality(x, 1)
|
||||
|
||||
f = lambdify((x,), e, 'jax')
|
||||
x_ = jax.numpy.array([0, 1, 2])
|
||||
assert jax.numpy.array_equal(f(x_), [False, True, False])
|
||||
|
||||
e = Unequality(x, 1)
|
||||
|
||||
f = lambdify((x,), e, 'jax')
|
||||
x_ = jax.numpy.array([0, 1, 2])
|
||||
assert jax.numpy.array_equal(f(x_), [True, False, True])
|
||||
|
||||
e = (x < 1)
|
||||
|
||||
f = lambdify((x,), e, 'jax')
|
||||
x_ = jax.numpy.array([0, 1, 2])
|
||||
assert jax.numpy.array_equal(f(x_), [True, False, False])
|
||||
|
||||
e = (x <= 1)
|
||||
|
||||
f = lambdify((x,), e, 'jax')
|
||||
x_ = jax.numpy.array([0, 1, 2])
|
||||
assert jax.numpy.array_equal(f(x_), [True, True, False])
|
||||
|
||||
e = (x > 1)
|
||||
|
||||
f = lambdify((x,), e, 'jax')
|
||||
x_ = jax.numpy.array([0, 1, 2])
|
||||
assert jax.numpy.array_equal(f(x_), [False, False, True])
|
||||
|
||||
e = (x >= 1)
|
||||
|
||||
f = lambdify((x,), e, 'jax')
|
||||
x_ = jax.numpy.array([0, 1, 2])
|
||||
assert jax.numpy.array_equal(f(x_), [False, True, True])
|
||||
|
||||
# Multi-condition expressions
|
||||
e = (x >= 1) & (x < 2)
|
||||
f = lambdify((x,), e, 'jax')
|
||||
x_ = jax.numpy.array([0, 1, 2])
|
||||
assert jax.numpy.array_equal(f(x_), [False, True, False])
|
||||
|
||||
e = (x >= 1) | (x < 2)
|
||||
f = lambdify((x,), e, 'jax')
|
||||
x_ = jax.numpy.array([0, 1, 2])
|
||||
assert jax.numpy.array_equal(f(x_), [True, True, True])
|
||||
|
||||
def test_jax_mod():
|
||||
if not jax:
|
||||
skip("JAX not installed")
|
||||
|
||||
e = Mod(a, b)
|
||||
f = lambdify((a, b), e, 'jax')
|
||||
|
||||
a_ = jax.numpy.array([0, 1, 2, 3])
|
||||
b_ = 2
|
||||
assert jax.numpy.array_equal(f(a_, b_), [0, 1, 0, 1])
|
||||
|
||||
a_ = jax.numpy.array([0, 1, 2, 3])
|
||||
b_ = jax.numpy.array([2, 2, 2, 2])
|
||||
assert jax.numpy.array_equal(f(a_, b_), [0, 1, 0, 1])
|
||||
|
||||
a_ = jax.numpy.array([2, 3, 4, 5])
|
||||
b_ = jax.numpy.array([2, 3, 4, 5])
|
||||
assert jax.numpy.array_equal(f(a_, b_), [0, 0, 0, 0])
|
||||
|
||||
|
||||
def test_jax_pow():
|
||||
if not jax:
|
||||
skip('JAX not installed')
|
||||
|
||||
expr = Pow(2, -1, evaluate=False)
|
||||
f = lambdify([], expr, 'jax')
|
||||
assert f() == 0.5
|
||||
|
||||
|
||||
def test_jax_expm1():
|
||||
if not jax:
|
||||
skip("JAX not installed")
|
||||
|
||||
f = lambdify((a,), expm1(a), 'jax')
|
||||
assert abs(f(1e-10) - 1e-10 - 5e-21) <= 1e-10 * JAX_DEFAULT_EPSILON
|
||||
|
||||
|
||||
def test_jax_log1p():
|
||||
if not jax:
|
||||
skip("JAX not installed")
|
||||
|
||||
f = lambdify((a,), log1p(a), 'jax')
|
||||
assert abs(f(1e-99) - 1e-99) <= 1e-99 * JAX_DEFAULT_EPSILON
|
||||
|
||||
def test_jax_hypot():
|
||||
if not jax:
|
||||
skip("JAX not installed")
|
||||
assert abs(lambdify((a, b), hypot(a, b), 'jax')(3, 4) - 5) <= JAX_DEFAULT_EPSILON
|
||||
|
||||
def test_jax_log10():
|
||||
if not jax:
|
||||
skip("JAX not installed")
|
||||
|
||||
assert abs(lambdify((a,), log10(a), 'jax')(100) - 2) <= JAX_DEFAULT_EPSILON
|
||||
|
||||
|
||||
def test_jax_exp2():
|
||||
if not jax:
|
||||
skip("JAX not installed")
|
||||
assert abs(lambdify((a,), exp2(a), 'jax')(5) - 32) <= JAX_DEFAULT_EPSILON
|
||||
|
||||
|
||||
def test_jax_log2():
|
||||
if not jax:
|
||||
skip("JAX not installed")
|
||||
assert abs(lambdify((a,), log2(a), 'jax')(256) - 8) <= JAX_DEFAULT_EPSILON
|
||||
|
||||
|
||||
def test_jax_Sqrt():
|
||||
if not jax:
|
||||
skip("JAX not installed")
|
||||
assert abs(lambdify((a,), Sqrt(a), 'jax')(4) - 2) <= JAX_DEFAULT_EPSILON
|
||||
|
||||
|
||||
def test_jax_sqrt():
|
||||
if not jax:
|
||||
skip("JAX not installed")
|
||||
assert abs(lambdify((a,), sqrt(a), 'jax')(4) - 2) <= JAX_DEFAULT_EPSILON
|
||||
|
||||
|
||||
def test_jax_matsolve():
|
||||
if not jax:
|
||||
skip("JAX not installed")
|
||||
|
||||
M = MatrixSymbol("M", 3, 3)
|
||||
x = MatrixSymbol("x", 3, 1)
|
||||
|
||||
expr = M**(-1) * x + x
|
||||
matsolve_expr = MatrixSolve(M, x) + x
|
||||
|
||||
f = lambdify((M, x), expr, 'jax')
|
||||
f_matsolve = lambdify((M, x), matsolve_expr, 'jax')
|
||||
|
||||
m0 = jax.numpy.array([[1, 2, 3], [3, 2, 5], [5, 6, 7]])
|
||||
assert jax.numpy.linalg.matrix_rank(m0) == 3
|
||||
|
||||
x0 = jax.numpy.array([3, 4, 5])
|
||||
|
||||
assert jax.numpy.allclose(f_matsolve(m0, x0), f(m0, x0))
|
||||
|
||||
|
||||
def test_16857():
|
||||
if not jax:
|
||||
skip("JAX not installed")
|
||||
|
||||
a_1 = MatrixSymbol('a_1', 10, 3)
|
||||
a_2 = MatrixSymbol('a_2', 10, 3)
|
||||
a_3 = MatrixSymbol('a_3', 10, 3)
|
||||
a_4 = MatrixSymbol('a_4', 10, 3)
|
||||
A = BlockMatrix([[a_1, a_2], [a_3, a_4]])
|
||||
assert A.shape == (20, 6)
|
||||
|
||||
printer = JaxPrinter()
|
||||
assert printer.doprint(A) == 'jax.numpy.block([[a_1, a_2], [a_3, a_4]])'
|
||||
|
||||
|
||||
def test_issue_17006():
|
||||
if not jax:
|
||||
skip("JAX not installed")
|
||||
|
||||
M = MatrixSymbol("M", 2, 2)
|
||||
|
||||
f = lambdify(M, M + Identity(2), 'jax')
|
||||
ma = jax.numpy.array([[1, 2], [3, 4]])
|
||||
mr = jax.numpy.array([[2, 2], [3, 5]])
|
||||
|
||||
assert (f(ma) == mr).all()
|
||||
|
||||
from sympy.core.symbol import symbols
|
||||
n = symbols('n', integer=True)
|
||||
N = MatrixSymbol("M", n, n)
|
||||
raises(NotImplementedError, lambda: lambdify(N, N + Identity(n), 'jax'))
|
||||
|
||||
|
||||
def test_jax_array():
|
||||
assert JaxPrinter().doprint(Array(((1, 2), (3, 5)))) == 'jax.numpy.array([[1, 2], [3, 5]])'
|
||||
assert JaxPrinter().doprint(Array((1, 2))) == 'jax.numpy.array([1, 2])'
|
||||
|
||||
|
||||
def test_jax_known_funcs_consts():
|
||||
assert _jax_known_constants['NaN'] == 'jax.numpy.nan'
|
||||
assert _jax_known_constants['EulerGamma'] == 'jax.numpy.euler_gamma'
|
||||
|
||||
assert _jax_known_functions['acos'] == 'jax.numpy.arccos'
|
||||
assert _jax_known_functions['log'] == 'jax.numpy.log'
|
||||
|
||||
|
||||
def test_jax_print_methods():
|
||||
prntr = JaxPrinter()
|
||||
assert hasattr(prntr, '_print_acos')
|
||||
assert hasattr(prntr, '_print_log')
|
||||
|
||||
|
||||
def test_jax_printmethod():
|
||||
printer = JaxPrinter()
|
||||
assert hasattr(printer, 'printmethod')
|
||||
assert printer.printmethod == '_jaxcode'
|
||||
|
||||
|
||||
def test_jax_custom_print_method():
|
||||
|
||||
class expm1(Function):
|
||||
|
||||
def _jaxcode(self, printer):
|
||||
x, = self.args
|
||||
function = f'expm1({printer._print(x)})'
|
||||
return printer._module_format(printer._module + '.' + function)
|
||||
|
||||
printer = JaxPrinter()
|
||||
assert printer.doprint(expm1(Symbol('x'))) == 'jax.numpy.expm1(x)'
|
||||
@@ -0,0 +1,396 @@
|
||||
from sympy.core import (pi, oo, symbols, Rational, Integer, GoldenRatio,
|
||||
EulerGamma, Catalan, Lambda, Dummy, S, Eq, Ne, Le,
|
||||
Lt, Gt, Ge, Mod)
|
||||
from sympy.functions import (Piecewise, sin, cos, Abs, exp, ceiling, sqrt,
|
||||
sinh, cosh, tanh, asin, acos, acosh, Max, Min)
|
||||
from sympy.testing.pytest import raises
|
||||
from sympy.printing.jscode import JavascriptCodePrinter
|
||||
from sympy.utilities.lambdify import implemented_function
|
||||
from sympy.tensor import IndexedBase, Idx
|
||||
from sympy.matrices import Matrix, MatrixSymbol
|
||||
|
||||
from sympy.printing.jscode import jscode
|
||||
|
||||
x, y, z = symbols('x,y,z')
|
||||
|
||||
|
||||
def test_printmethod():
|
||||
assert jscode(Abs(x)) == "Math.abs(x)"
|
||||
|
||||
|
||||
def test_jscode_sqrt():
|
||||
assert jscode(sqrt(x)) == "Math.sqrt(x)"
|
||||
assert jscode(x**0.5) == "Math.sqrt(x)"
|
||||
assert jscode(x**(S.One/3)) == "Math.cbrt(x)"
|
||||
|
||||
|
||||
def test_jscode_Pow():
|
||||
g = implemented_function('g', Lambda(x, 2*x))
|
||||
assert jscode(x**3) == "Math.pow(x, 3)"
|
||||
assert jscode(x**(y**3)) == "Math.pow(x, Math.pow(y, 3))"
|
||||
assert jscode(1/(g(x)*3.5)**(x - y**x)/(x**2 + y)) == \
|
||||
"Math.pow(3.5*2*x, -x + Math.pow(y, x))/(Math.pow(x, 2) + y)"
|
||||
assert jscode(x**-1.0) == '1/x'
|
||||
|
||||
|
||||
def test_jscode_constants_mathh():
|
||||
assert jscode(exp(1)) == "Math.E"
|
||||
assert jscode(pi) == "Math.PI"
|
||||
assert jscode(oo) == "Number.POSITIVE_INFINITY"
|
||||
assert jscode(-oo) == "Number.NEGATIVE_INFINITY"
|
||||
|
||||
|
||||
def test_jscode_constants_other():
|
||||
assert jscode(
|
||||
2*GoldenRatio) == "var GoldenRatio = %s;\n2*GoldenRatio" % GoldenRatio.evalf(17)
|
||||
assert jscode(2*Catalan) == "var Catalan = %s;\n2*Catalan" % Catalan.evalf(17)
|
||||
assert jscode(
|
||||
2*EulerGamma) == "var EulerGamma = %s;\n2*EulerGamma" % EulerGamma.evalf(17)
|
||||
|
||||
|
||||
def test_jscode_Rational():
|
||||
assert jscode(Rational(3, 7)) == "3/7"
|
||||
assert jscode(Rational(18, 9)) == "2"
|
||||
assert jscode(Rational(3, -7)) == "-3/7"
|
||||
assert jscode(Rational(-3, -7)) == "3/7"
|
||||
|
||||
|
||||
def test_Relational():
|
||||
assert jscode(Eq(x, y)) == "x == y"
|
||||
assert jscode(Ne(x, y)) == "x != y"
|
||||
assert jscode(Le(x, y)) == "x <= y"
|
||||
assert jscode(Lt(x, y)) == "x < y"
|
||||
assert jscode(Gt(x, y)) == "x > y"
|
||||
assert jscode(Ge(x, y)) == "x >= y"
|
||||
|
||||
|
||||
def test_Mod():
|
||||
assert jscode(Mod(x, y)) == '((x % y) + y) % y'
|
||||
assert jscode(Mod(x, x + y)) == '((x % (x + y)) + (x + y)) % (x + y)'
|
||||
p1, p2 = symbols('p1 p2', positive=True)
|
||||
assert jscode(Mod(p1, p2)) == 'p1 % p2'
|
||||
assert jscode(Mod(p1, p2 + 3)) == 'p1 % (p2 + 3)'
|
||||
assert jscode(Mod(-3, -7, evaluate=False)) == '(-3) % (-7)'
|
||||
assert jscode(-Mod(p1, p2)) == '-(p1 % p2)'
|
||||
assert jscode(x*Mod(p1, p2)) == 'x*(p1 % p2)'
|
||||
|
||||
|
||||
def test_jscode_Integer():
|
||||
assert jscode(Integer(67)) == "67"
|
||||
assert jscode(Integer(-1)) == "-1"
|
||||
|
||||
|
||||
def test_jscode_functions():
|
||||
assert jscode(sin(x) ** cos(x)) == "Math.pow(Math.sin(x), Math.cos(x))"
|
||||
assert jscode(sinh(x) * cosh(x)) == "Math.sinh(x)*Math.cosh(x)"
|
||||
assert jscode(Max(x, y) + Min(x, y)) == "Math.max(x, y) + Math.min(x, y)"
|
||||
assert jscode(tanh(x)*acosh(y)) == "Math.tanh(x)*Math.acosh(y)"
|
||||
assert jscode(asin(x)-acos(y)) == "-Math.acos(y) + Math.asin(x)"
|
||||
|
||||
|
||||
def test_jscode_inline_function():
|
||||
x = symbols('x')
|
||||
g = implemented_function('g', Lambda(x, 2*x))
|
||||
assert jscode(g(x)) == "2*x"
|
||||
g = implemented_function('g', Lambda(x, 2*x/Catalan))
|
||||
assert jscode(g(x)) == "var Catalan = %s;\n2*x/Catalan" % Catalan.evalf(17)
|
||||
A = IndexedBase('A')
|
||||
i = Idx('i', symbols('n', integer=True))
|
||||
g = implemented_function('g', Lambda(x, x*(1 + x)*(2 + x)))
|
||||
assert jscode(g(A[i]), assign_to=A[i]) == (
|
||||
"for (var i=0; i<n; i++){\n"
|
||||
" A[i] = (A[i] + 1)*(A[i] + 2)*A[i];\n"
|
||||
"}"
|
||||
)
|
||||
|
||||
|
||||
def test_jscode_exceptions():
|
||||
assert jscode(ceiling(x)) == "Math.ceil(x)"
|
||||
assert jscode(Abs(x)) == "Math.abs(x)"
|
||||
|
||||
|
||||
def test_jscode_boolean():
|
||||
assert jscode(x & y) == "x && y"
|
||||
assert jscode(x | y) == "x || y"
|
||||
assert jscode(~x) == "!x"
|
||||
assert jscode(x & y & z) == "x && y && z"
|
||||
assert jscode(x | y | z) == "x || y || z"
|
||||
assert jscode((x & y) | z) == "z || x && y"
|
||||
assert jscode((x | y) & z) == "z && (x || y)"
|
||||
|
||||
|
||||
def test_jscode_Piecewise():
|
||||
expr = Piecewise((x, x < 1), (x**2, True))
|
||||
p = jscode(expr)
|
||||
s = \
|
||||
"""\
|
||||
((x < 1) ? (
|
||||
x
|
||||
)
|
||||
: (
|
||||
Math.pow(x, 2)
|
||||
))\
|
||||
"""
|
||||
assert p == s
|
||||
assert jscode(expr, assign_to="c") == (
|
||||
"if (x < 1) {\n"
|
||||
" c = x;\n"
|
||||
"}\n"
|
||||
"else {\n"
|
||||
" c = Math.pow(x, 2);\n"
|
||||
"}")
|
||||
# Check that Piecewise without a True (default) condition error
|
||||
expr = Piecewise((x, x < 1), (x**2, x > 1), (sin(x), x > 0))
|
||||
raises(ValueError, lambda: jscode(expr))
|
||||
|
||||
|
||||
def test_jscode_Piecewise_deep():
|
||||
p = jscode(2*Piecewise((x, x < 1), (x**2, True)))
|
||||
s = \
|
||||
"""\
|
||||
2*((x < 1) ? (
|
||||
x
|
||||
)
|
||||
: (
|
||||
Math.pow(x, 2)
|
||||
))\
|
||||
"""
|
||||
assert p == s
|
||||
|
||||
|
||||
def test_jscode_settings():
|
||||
raises(TypeError, lambda: jscode(sin(x), method="garbage"))
|
||||
|
||||
|
||||
def test_jscode_Indexed():
|
||||
n, m, o = symbols('n m o', integer=True)
|
||||
i, j, k = Idx('i', n), Idx('j', m), Idx('k', o)
|
||||
p = JavascriptCodePrinter()
|
||||
p._not_c = set()
|
||||
|
||||
x = IndexedBase('x')[j]
|
||||
assert p._print_Indexed(x) == 'x[j]'
|
||||
A = IndexedBase('A')[i, j]
|
||||
assert p._print_Indexed(A) == 'A[%s]' % (m*i+j)
|
||||
B = IndexedBase('B')[i, j, k]
|
||||
assert p._print_Indexed(B) == 'B[%s]' % (i*o*m+j*o+k)
|
||||
|
||||
assert p._not_c == set()
|
||||
|
||||
|
||||
def test_jscode_loops_matrix_vector():
|
||||
n, m = symbols('n m', integer=True)
|
||||
A = IndexedBase('A')
|
||||
x = IndexedBase('x')
|
||||
y = IndexedBase('y')
|
||||
i = Idx('i', m)
|
||||
j = Idx('j', n)
|
||||
|
||||
s = (
|
||||
'for (var i=0; i<m; i++){\n'
|
||||
' y[i] = 0;\n'
|
||||
'}\n'
|
||||
'for (var i=0; i<m; i++){\n'
|
||||
' for (var j=0; j<n; j++){\n'
|
||||
' y[i] = A[n*i + j]*x[j] + y[i];\n'
|
||||
' }\n'
|
||||
'}'
|
||||
)
|
||||
c = jscode(A[i, j]*x[j], assign_to=y[i])
|
||||
assert c == s
|
||||
|
||||
|
||||
def test_dummy_loops():
|
||||
i, m = symbols('i m', integer=True, cls=Dummy)
|
||||
x = IndexedBase('x')
|
||||
y = IndexedBase('y')
|
||||
i = Idx(i, m)
|
||||
|
||||
expected = (
|
||||
'for (var i_%(icount)i=0; i_%(icount)i<m_%(mcount)i; i_%(icount)i++){\n'
|
||||
' y[i_%(icount)i] = x[i_%(icount)i];\n'
|
||||
'}'
|
||||
) % {'icount': i.label.dummy_index, 'mcount': m.dummy_index}
|
||||
code = jscode(x[i], assign_to=y[i])
|
||||
assert code == expected
|
||||
|
||||
|
||||
def test_jscode_loops_add():
|
||||
n, m = symbols('n m', integer=True)
|
||||
A = IndexedBase('A')
|
||||
x = IndexedBase('x')
|
||||
y = IndexedBase('y')
|
||||
z = IndexedBase('z')
|
||||
i = Idx('i', m)
|
||||
j = Idx('j', n)
|
||||
|
||||
s = (
|
||||
'for (var i=0; i<m; i++){\n'
|
||||
' y[i] = x[i] + z[i];\n'
|
||||
'}\n'
|
||||
'for (var i=0; i<m; i++){\n'
|
||||
' for (var j=0; j<n; j++){\n'
|
||||
' y[i] = A[n*i + j]*x[j] + y[i];\n'
|
||||
' }\n'
|
||||
'}'
|
||||
)
|
||||
c = jscode(A[i, j]*x[j] + x[i] + z[i], assign_to=y[i])
|
||||
assert c == s
|
||||
|
||||
|
||||
def test_jscode_loops_multiple_contractions():
|
||||
n, m, o, p = symbols('n m o p', integer=True)
|
||||
a = IndexedBase('a')
|
||||
b = IndexedBase('b')
|
||||
y = IndexedBase('y')
|
||||
i = Idx('i', m)
|
||||
j = Idx('j', n)
|
||||
k = Idx('k', o)
|
||||
l = Idx('l', p)
|
||||
|
||||
s = (
|
||||
'for (var i=0; i<m; i++){\n'
|
||||
' y[i] = 0;\n'
|
||||
'}\n'
|
||||
'for (var i=0; i<m; i++){\n'
|
||||
' for (var j=0; j<n; j++){\n'
|
||||
' for (var k=0; k<o; k++){\n'
|
||||
' for (var l=0; l<p; l++){\n'
|
||||
' y[i] = a[%s]*b[%s] + y[i];\n' % (i*n*o*p + j*o*p + k*p + l, j*o*p + k*p + l) +\
|
||||
' }\n'
|
||||
' }\n'
|
||||
' }\n'
|
||||
'}'
|
||||
)
|
||||
c = jscode(b[j, k, l]*a[i, j, k, l], assign_to=y[i])
|
||||
assert c == s
|
||||
|
||||
|
||||
def test_jscode_loops_addfactor():
|
||||
n, m, o, p = symbols('n m o p', integer=True)
|
||||
a = IndexedBase('a')
|
||||
b = IndexedBase('b')
|
||||
c = IndexedBase('c')
|
||||
y = IndexedBase('y')
|
||||
i = Idx('i', m)
|
||||
j = Idx('j', n)
|
||||
k = Idx('k', o)
|
||||
l = Idx('l', p)
|
||||
|
||||
s = (
|
||||
'for (var i=0; i<m; i++){\n'
|
||||
' y[i] = 0;\n'
|
||||
'}\n'
|
||||
'for (var i=0; i<m; i++){\n'
|
||||
' for (var j=0; j<n; j++){\n'
|
||||
' for (var k=0; k<o; k++){\n'
|
||||
' for (var l=0; l<p; l++){\n'
|
||||
' y[i] = (a[%s] + b[%s])*c[%s] + y[i];\n' % (i*n*o*p + j*o*p + k*p + l, i*n*o*p + j*o*p + k*p + l, j*o*p + k*p + l) +\
|
||||
' }\n'
|
||||
' }\n'
|
||||
' }\n'
|
||||
'}'
|
||||
)
|
||||
c = jscode((a[i, j, k, l] + b[i, j, k, l])*c[j, k, l], assign_to=y[i])
|
||||
assert c == s
|
||||
|
||||
|
||||
def test_jscode_loops_multiple_terms():
|
||||
n, m, o, p = symbols('n m o p', integer=True)
|
||||
a = IndexedBase('a')
|
||||
b = IndexedBase('b')
|
||||
c = IndexedBase('c')
|
||||
y = IndexedBase('y')
|
||||
i = Idx('i', m)
|
||||
j = Idx('j', n)
|
||||
k = Idx('k', o)
|
||||
|
||||
s0 = (
|
||||
'for (var i=0; i<m; i++){\n'
|
||||
' y[i] = 0;\n'
|
||||
'}\n'
|
||||
)
|
||||
s1 = (
|
||||
'for (var i=0; i<m; i++){\n'
|
||||
' for (var j=0; j<n; j++){\n'
|
||||
' for (var k=0; k<o; k++){\n'
|
||||
' y[i] = b[j]*b[k]*c[%s] + y[i];\n' % (i*n*o + j*o + k) +\
|
||||
' }\n'
|
||||
' }\n'
|
||||
'}\n'
|
||||
)
|
||||
s2 = (
|
||||
'for (var i=0; i<m; i++){\n'
|
||||
' for (var k=0; k<o; k++){\n'
|
||||
' y[i] = a[%s]*b[k] + y[i];\n' % (i*o + k) +\
|
||||
' }\n'
|
||||
'}\n'
|
||||
)
|
||||
s3 = (
|
||||
'for (var i=0; i<m; i++){\n'
|
||||
' for (var j=0; j<n; j++){\n'
|
||||
' y[i] = a[%s]*b[j] + y[i];\n' % (i*n + j) +\
|
||||
' }\n'
|
||||
'}\n'
|
||||
)
|
||||
c = jscode(
|
||||
b[j]*a[i, j] + b[k]*a[i, k] + b[j]*b[k]*c[i, j, k], assign_to=y[i])
|
||||
assert (c == s0 + s1 + s2 + s3[:-1] or
|
||||
c == s0 + s1 + s3 + s2[:-1] or
|
||||
c == s0 + s2 + s1 + s3[:-1] or
|
||||
c == s0 + s2 + s3 + s1[:-1] or
|
||||
c == s0 + s3 + s1 + s2[:-1] or
|
||||
c == s0 + s3 + s2 + s1[:-1])
|
||||
|
||||
|
||||
def test_Matrix_printing():
|
||||
# Test returning a Matrix
|
||||
mat = Matrix([x*y, Piecewise((2 + x, y>0), (y, True)), sin(z)])
|
||||
A = MatrixSymbol('A', 3, 1)
|
||||
assert jscode(mat, A) == (
|
||||
"A[0] = x*y;\n"
|
||||
"if (y > 0) {\n"
|
||||
" A[1] = x + 2;\n"
|
||||
"}\n"
|
||||
"else {\n"
|
||||
" A[1] = y;\n"
|
||||
"}\n"
|
||||
"A[2] = Math.sin(z);")
|
||||
# Test using MatrixElements in expressions
|
||||
expr = Piecewise((2*A[2, 0], x > 0), (A[2, 0], True)) + sin(A[1, 0]) + A[0, 0]
|
||||
assert jscode(expr) == (
|
||||
"((x > 0) ? (\n"
|
||||
" 2*A[2]\n"
|
||||
")\n"
|
||||
": (\n"
|
||||
" A[2]\n"
|
||||
")) + Math.sin(A[1]) + A[0]")
|
||||
# Test using MatrixElements in a Matrix
|
||||
q = MatrixSymbol('q', 5, 1)
|
||||
M = MatrixSymbol('M', 3, 3)
|
||||
m = Matrix([[sin(q[1,0]), 0, cos(q[2,0])],
|
||||
[q[1,0] + q[2,0], q[3, 0], 5],
|
||||
[2*q[4, 0]/q[1,0], sqrt(q[0,0]) + 4, 0]])
|
||||
assert jscode(m, M) == (
|
||||
"M[0] = Math.sin(q[1]);\n"
|
||||
"M[1] = 0;\n"
|
||||
"M[2] = Math.cos(q[2]);\n"
|
||||
"M[3] = q[1] + q[2];\n"
|
||||
"M[4] = q[3];\n"
|
||||
"M[5] = 5;\n"
|
||||
"M[6] = 2*q[4]/q[1];\n"
|
||||
"M[7] = Math.sqrt(q[0]) + 4;\n"
|
||||
"M[8] = 0;")
|
||||
|
||||
|
||||
def test_MatrixElement_printing():
|
||||
# test cases for issue #11821
|
||||
A = MatrixSymbol("A", 1, 3)
|
||||
B = MatrixSymbol("B", 1, 3)
|
||||
C = MatrixSymbol("C", 1, 3)
|
||||
|
||||
assert(jscode(A[0, 0]) == "A[0]")
|
||||
assert(jscode(3 * A[0, 0]) == "3*A[0]")
|
||||
|
||||
F = C[0, 0].subs(C, A - B)
|
||||
assert(jscode(F) == "(A - B)[0]")
|
||||
@@ -0,0 +1,390 @@
|
||||
from sympy.core import (S, pi, oo, symbols, Function, Rational, Integer,
|
||||
Tuple, Symbol, Eq, Ne, Le, Lt, Gt, Ge)
|
||||
from sympy.core import EulerGamma, GoldenRatio, Catalan, Lambda, Mul, Pow
|
||||
from sympy.functions import Piecewise, sqrt, ceiling, exp, sin, cos, sinc
|
||||
from sympy.testing.pytest import raises
|
||||
from sympy.utilities.lambdify import implemented_function
|
||||
from sympy.matrices import (eye, Matrix, MatrixSymbol, Identity,
|
||||
HadamardProduct, SparseMatrix)
|
||||
from sympy.functions.special.bessel import (jn, yn, besselj, bessely, besseli,
|
||||
besselk, hankel1, hankel2, airyai,
|
||||
airybi, airyaiprime, airybiprime)
|
||||
from sympy.testing.pytest import XFAIL
|
||||
|
||||
from sympy.printing.julia import julia_code
|
||||
|
||||
x, y, z = symbols('x,y,z')
|
||||
|
||||
|
||||
def test_Integer():
|
||||
assert julia_code(Integer(67)) == "67"
|
||||
assert julia_code(Integer(-1)) == "-1"
|
||||
|
||||
|
||||
def test_Rational():
|
||||
assert julia_code(Rational(3, 7)) == "3 // 7"
|
||||
assert julia_code(Rational(18, 9)) == "2"
|
||||
assert julia_code(Rational(3, -7)) == "-3 // 7"
|
||||
assert julia_code(Rational(-3, -7)) == "3 // 7"
|
||||
assert julia_code(x + Rational(3, 7)) == "x + 3 // 7"
|
||||
assert julia_code(Rational(3, 7)*x) == "(3 // 7) * x"
|
||||
|
||||
|
||||
def test_Relational():
|
||||
assert julia_code(Eq(x, y)) == "x == y"
|
||||
assert julia_code(Ne(x, y)) == "x != y"
|
||||
assert julia_code(Le(x, y)) == "x <= y"
|
||||
assert julia_code(Lt(x, y)) == "x < y"
|
||||
assert julia_code(Gt(x, y)) == "x > y"
|
||||
assert julia_code(Ge(x, y)) == "x >= y"
|
||||
|
||||
|
||||
def test_Function():
|
||||
assert julia_code(sin(x) ** cos(x)) == "sin(x) .^ cos(x)"
|
||||
assert julia_code(abs(x)) == "abs(x)"
|
||||
assert julia_code(ceiling(x)) == "ceil(x)"
|
||||
|
||||
|
||||
def test_Pow():
|
||||
assert julia_code(x**3) == "x .^ 3"
|
||||
assert julia_code(x**(y**3)) == "x .^ (y .^ 3)"
|
||||
assert julia_code(x**Rational(2, 3)) == 'x .^ (2 // 3)'
|
||||
g = implemented_function('g', Lambda(x, 2*x))
|
||||
assert julia_code(1/(g(x)*3.5)**(x - y**x)/(x**2 + y)) == \
|
||||
"(3.5 * 2 * x) .^ (-x + y .^ x) ./ (x .^ 2 + y)"
|
||||
# For issue 14160
|
||||
assert julia_code(Mul(-2, x, Pow(Mul(y,y,evaluate=False), -1, evaluate=False),
|
||||
evaluate=False)) == '-2 * x ./ (y .* y)'
|
||||
|
||||
|
||||
def test_basic_ops():
|
||||
assert julia_code(x*y) == "x .* y"
|
||||
assert julia_code(x + y) == "x + y"
|
||||
assert julia_code(x - y) == "x - y"
|
||||
assert julia_code(-x) == "-x"
|
||||
|
||||
|
||||
def test_1_over_x_and_sqrt():
|
||||
# 1.0 and 0.5 would do something different in regular StrPrinter,
|
||||
# but these are exact in IEEE floating point so no different here.
|
||||
assert julia_code(1/x) == '1 ./ x'
|
||||
assert julia_code(x**-1) == julia_code(x**-1.0) == '1 ./ x'
|
||||
assert julia_code(1/sqrt(x)) == '1 ./ sqrt(x)'
|
||||
assert julia_code(x**-S.Half) == julia_code(x**-0.5) == '1 ./ sqrt(x)'
|
||||
assert julia_code(sqrt(x)) == 'sqrt(x)'
|
||||
assert julia_code(x**S.Half) == julia_code(x**0.5) == 'sqrt(x)'
|
||||
assert julia_code(1/pi) == '1 / pi'
|
||||
assert julia_code(pi**-1) == julia_code(pi**-1.0) == '1 / pi'
|
||||
assert julia_code(pi**-0.5) == '1 / sqrt(pi)'
|
||||
|
||||
|
||||
def test_mix_number_mult_symbols():
|
||||
assert julia_code(3*x) == "3 * x"
|
||||
assert julia_code(pi*x) == "pi * x"
|
||||
assert julia_code(3/x) == "3 ./ x"
|
||||
assert julia_code(pi/x) == "pi ./ x"
|
||||
assert julia_code(x/3) == "x / 3"
|
||||
assert julia_code(x/pi) == "x / pi"
|
||||
assert julia_code(x*y) == "x .* y"
|
||||
assert julia_code(3*x*y) == "3 * x .* y"
|
||||
assert julia_code(3*pi*x*y) == "3 * pi * x .* y"
|
||||
assert julia_code(x/y) == "x ./ y"
|
||||
assert julia_code(3*x/y) == "3 * x ./ y"
|
||||
assert julia_code(x*y/z) == "x .* y ./ z"
|
||||
assert julia_code(x/y*z) == "x .* z ./ y"
|
||||
assert julia_code(1/x/y) == "1 ./ (x .* y)"
|
||||
assert julia_code(2*pi*x/y/z) == "2 * pi * x ./ (y .* z)"
|
||||
assert julia_code(3*pi/x) == "3 * pi ./ x"
|
||||
assert julia_code(S(3)/5) == "3 // 5"
|
||||
assert julia_code(S(3)/5*x) == "(3 // 5) * x"
|
||||
assert julia_code(x/y/z) == "x ./ (y .* z)"
|
||||
assert julia_code((x+y)/z) == "(x + y) ./ z"
|
||||
assert julia_code((x+y)/(z+x)) == "(x + y) ./ (x + z)"
|
||||
assert julia_code((x+y)/EulerGamma) == "(x + y) / eulergamma"
|
||||
assert julia_code(x/3/pi) == "x / (3 * pi)"
|
||||
assert julia_code(S(3)/5*x*y/pi) == "(3 // 5) * x .* y / pi"
|
||||
|
||||
|
||||
def test_mix_number_pow_symbols():
|
||||
assert julia_code(pi**3) == 'pi ^ 3'
|
||||
assert julia_code(x**2) == 'x .^ 2'
|
||||
assert julia_code(x**(pi**3)) == 'x .^ (pi ^ 3)'
|
||||
assert julia_code(x**y) == 'x .^ y'
|
||||
assert julia_code(x**(y**z)) == 'x .^ (y .^ z)'
|
||||
assert julia_code((x**y)**z) == '(x .^ y) .^ z'
|
||||
|
||||
|
||||
def test_imag():
|
||||
I = S('I')
|
||||
assert julia_code(I) == "im"
|
||||
assert julia_code(5*I) == "5im"
|
||||
assert julia_code((S(3)/2)*I) == "(3 // 2) * im"
|
||||
assert julia_code(3+4*I) == "3 + 4im"
|
||||
|
||||
|
||||
def test_constants():
|
||||
assert julia_code(pi) == "pi"
|
||||
assert julia_code(oo) == "Inf"
|
||||
assert julia_code(-oo) == "-Inf"
|
||||
assert julia_code(S.NegativeInfinity) == "-Inf"
|
||||
assert julia_code(S.NaN) == "NaN"
|
||||
assert julia_code(S.Exp1) == "e"
|
||||
assert julia_code(exp(1)) == "e"
|
||||
|
||||
|
||||
def test_constants_other():
|
||||
assert julia_code(2*GoldenRatio) == "2 * golden"
|
||||
assert julia_code(2*Catalan) == "2 * catalan"
|
||||
assert julia_code(2*EulerGamma) == "2 * eulergamma"
|
||||
|
||||
|
||||
def test_boolean():
|
||||
assert julia_code(x & y) == "x && y"
|
||||
assert julia_code(x | y) == "x || y"
|
||||
assert julia_code(~x) == "!x"
|
||||
assert julia_code(x & y & z) == "x && y && z"
|
||||
assert julia_code(x | y | z) == "x || y || z"
|
||||
assert julia_code((x & y) | z) == "z || x && y"
|
||||
assert julia_code((x | y) & z) == "z && (x || y)"
|
||||
|
||||
def test_sinc():
|
||||
assert julia_code(sinc(x)) == 'sinc(x / pi)'
|
||||
assert julia_code(sinc(x + 3)) == 'sinc((x + 3) / pi)'
|
||||
assert julia_code(sinc(pi * (x + 3))) == 'sinc(x + 3)'
|
||||
|
||||
def test_Matrices():
|
||||
assert julia_code(Matrix(1, 1, [10])) == "[10]"
|
||||
A = Matrix([[1, sin(x/2), abs(x)],
|
||||
[0, 1, pi],
|
||||
[0, exp(1), ceiling(x)]])
|
||||
expected = ("[1 sin(x / 2) abs(x);\n"
|
||||
"0 1 pi;\n"
|
||||
"0 e ceil(x)]")
|
||||
assert julia_code(A) == expected
|
||||
# row and columns
|
||||
assert julia_code(A[:,0]) == "[1, 0, 0]"
|
||||
assert julia_code(A[0,:]) == "[1 sin(x / 2) abs(x)]"
|
||||
# empty matrices
|
||||
assert julia_code(Matrix(0, 0, [])) == 'zeros(0, 0)'
|
||||
assert julia_code(Matrix(0, 3, [])) == 'zeros(0, 3)'
|
||||
# annoying to read but correct
|
||||
assert julia_code(Matrix([[x, x - y, -y]])) == "[x x - y -y]"
|
||||
|
||||
|
||||
def test_vector_entries_hadamard():
|
||||
# For a row or column, user might to use the other dimension
|
||||
A = Matrix([[1, sin(2/x), 3*pi/x/5]])
|
||||
assert julia_code(A) == "[1 sin(2 ./ x) (3 // 5) * pi ./ x]"
|
||||
assert julia_code(A.T) == "[1, sin(2 ./ x), (3 // 5) * pi ./ x]"
|
||||
|
||||
|
||||
@XFAIL
|
||||
def test_Matrices_entries_not_hadamard():
|
||||
# For Matrix with col >= 2, row >= 2, they need to be scalars
|
||||
# FIXME: is it worth worrying about this? Its not wrong, just
|
||||
# leave it user's responsibility to put scalar data for x.
|
||||
A = Matrix([[1, sin(2/x), 3*pi/x/5], [1, 2, x*y]])
|
||||
expected = ("[1 sin(2/x) 3*pi/(5*x);\n"
|
||||
"1 2 x*y]") # <- we give x.*y
|
||||
assert julia_code(A) == expected
|
||||
|
||||
|
||||
def test_MatrixSymbol():
|
||||
n = Symbol('n', integer=True)
|
||||
A = MatrixSymbol('A', n, n)
|
||||
B = MatrixSymbol('B', n, n)
|
||||
assert julia_code(A*B) == "A * B"
|
||||
assert julia_code(B*A) == "B * A"
|
||||
assert julia_code(2*A*B) == "2 * A * B"
|
||||
assert julia_code(B*2*A) == "2 * B * A"
|
||||
assert julia_code(A*(B + 3*Identity(n))) == "A * (3 * eye(n) + B)"
|
||||
assert julia_code(A**(x**2)) == "A ^ (x .^ 2)"
|
||||
assert julia_code(A**3) == "A ^ 3"
|
||||
assert julia_code(A**S.Half) == "A ^ (1 // 2)"
|
||||
|
||||
|
||||
def test_special_matrices():
|
||||
assert julia_code(6*Identity(3)) == "6 * eye(3)"
|
||||
|
||||
|
||||
def test_containers():
|
||||
assert julia_code([1, 2, 3, [4, 5, [6, 7]], 8, [9, 10], 11]) == \
|
||||
"Any[1, 2, 3, Any[4, 5, Any[6, 7]], 8, Any[9, 10], 11]"
|
||||
assert julia_code((1, 2, (3, 4))) == "(1, 2, (3, 4))"
|
||||
assert julia_code([1]) == "Any[1]"
|
||||
assert julia_code((1,)) == "(1,)"
|
||||
assert julia_code(Tuple(*[1, 2, 3])) == "(1, 2, 3)"
|
||||
assert julia_code((1, x*y, (3, x**2))) == "(1, x .* y, (3, x .^ 2))"
|
||||
# scalar, matrix, empty matrix and empty list
|
||||
assert julia_code((1, eye(3), Matrix(0, 0, []), [])) == "(1, [1 0 0;\n0 1 0;\n0 0 1], zeros(0, 0), Any[])"
|
||||
|
||||
|
||||
def test_julia_noninline():
|
||||
source = julia_code((x+y)/Catalan, assign_to='me', inline=False)
|
||||
expected = (
|
||||
"const Catalan = %s\n"
|
||||
"me = (x + y) / Catalan"
|
||||
) % Catalan.evalf(17)
|
||||
assert source == expected
|
||||
|
||||
|
||||
def test_julia_piecewise():
|
||||
expr = Piecewise((x, x < 1), (x**2, True))
|
||||
assert julia_code(expr) == "((x < 1) ? (x) : (x .^ 2))"
|
||||
assert julia_code(expr, assign_to="r") == (
|
||||
"r = ((x < 1) ? (x) : (x .^ 2))")
|
||||
assert julia_code(expr, assign_to="r", inline=False) == (
|
||||
"if (x < 1)\n"
|
||||
" r = x\n"
|
||||
"else\n"
|
||||
" r = x .^ 2\n"
|
||||
"end")
|
||||
expr = Piecewise((x**2, x < 1), (x**3, x < 2), (x**4, x < 3), (x**5, True))
|
||||
expected = ("((x < 1) ? (x .^ 2) :\n"
|
||||
"(x < 2) ? (x .^ 3) :\n"
|
||||
"(x < 3) ? (x .^ 4) : (x .^ 5))")
|
||||
assert julia_code(expr) == expected
|
||||
assert julia_code(expr, assign_to="r") == "r = " + expected
|
||||
assert julia_code(expr, assign_to="r", inline=False) == (
|
||||
"if (x < 1)\n"
|
||||
" r = x .^ 2\n"
|
||||
"elseif (x < 2)\n"
|
||||
" r = x .^ 3\n"
|
||||
"elseif (x < 3)\n"
|
||||
" r = x .^ 4\n"
|
||||
"else\n"
|
||||
" r = x .^ 5\n"
|
||||
"end")
|
||||
# Check that Piecewise without a True (default) condition error
|
||||
expr = Piecewise((x, x < 1), (x**2, x > 1), (sin(x), x > 0))
|
||||
raises(ValueError, lambda: julia_code(expr))
|
||||
|
||||
|
||||
def test_julia_piecewise_times_const():
|
||||
pw = Piecewise((x, x < 1), (x**2, True))
|
||||
assert julia_code(2*pw) == "2 * ((x < 1) ? (x) : (x .^ 2))"
|
||||
assert julia_code(pw/x) == "((x < 1) ? (x) : (x .^ 2)) ./ x"
|
||||
assert julia_code(pw/(x*y)) == "((x < 1) ? (x) : (x .^ 2)) ./ (x .* y)"
|
||||
assert julia_code(pw/3) == "((x < 1) ? (x) : (x .^ 2)) / 3"
|
||||
|
||||
|
||||
def test_julia_matrix_assign_to():
|
||||
A = Matrix([[1, 2, 3]])
|
||||
assert julia_code(A, assign_to='a') == "a = [1 2 3]"
|
||||
A = Matrix([[1, 2], [3, 4]])
|
||||
assert julia_code(A, assign_to='A') == "A = [1 2;\n3 4]"
|
||||
|
||||
|
||||
def test_julia_matrix_assign_to_more():
|
||||
# assigning to Symbol or MatrixSymbol requires lhs/rhs match
|
||||
A = Matrix([[1, 2, 3]])
|
||||
B = MatrixSymbol('B', 1, 3)
|
||||
C = MatrixSymbol('C', 2, 3)
|
||||
assert julia_code(A, assign_to=B) == "B = [1 2 3]"
|
||||
raises(ValueError, lambda: julia_code(A, assign_to=x))
|
||||
raises(ValueError, lambda: julia_code(A, assign_to=C))
|
||||
|
||||
|
||||
def test_julia_matrix_1x1():
|
||||
A = Matrix([[3]])
|
||||
B = MatrixSymbol('B', 1, 1)
|
||||
C = MatrixSymbol('C', 1, 2)
|
||||
assert julia_code(A, assign_to=B) == "B = [3]"
|
||||
# FIXME?
|
||||
#assert julia_code(A, assign_to=x) == "x = [3]"
|
||||
raises(ValueError, lambda: julia_code(A, assign_to=C))
|
||||
|
||||
|
||||
def test_julia_matrix_elements():
|
||||
A = Matrix([[x, 2, x*y]])
|
||||
assert julia_code(A[0, 0]**2 + A[0, 1] + A[0, 2]) == "x .^ 2 + x .* y + 2"
|
||||
A = MatrixSymbol('AA', 1, 3)
|
||||
assert julia_code(A) == "AA"
|
||||
assert julia_code(A[0, 0]**2 + sin(A[0,1]) + A[0,2]) == \
|
||||
"sin(AA[1,2]) + AA[1,1] .^ 2 + AA[1,3]"
|
||||
assert julia_code(sum(A)) == "AA[1,1] + AA[1,2] + AA[1,3]"
|
||||
|
||||
|
||||
def test_julia_boolean():
|
||||
assert julia_code(True) == "true"
|
||||
assert julia_code(S.true) == "true"
|
||||
assert julia_code(False) == "false"
|
||||
assert julia_code(S.false) == "false"
|
||||
|
||||
|
||||
def test_julia_not_supported():
|
||||
with raises(NotImplementedError):
|
||||
julia_code(S.ComplexInfinity)
|
||||
|
||||
f = Function('f')
|
||||
assert julia_code(f(x).diff(x), strict=False) == (
|
||||
"# Not supported in Julia:\n"
|
||||
"# Derivative\n"
|
||||
"Derivative(f(x), x)"
|
||||
)
|
||||
|
||||
|
||||
def test_trick_indent_with_end_else_words():
|
||||
# words starting with "end" or "else" do not confuse the indenter
|
||||
t1 = S('endless')
|
||||
t2 = S('elsewhere')
|
||||
pw = Piecewise((t1, x < 0), (t2, x <= 1), (1, True))
|
||||
assert julia_code(pw, inline=False) == (
|
||||
"if (x < 0)\n"
|
||||
" endless\n"
|
||||
"elseif (x <= 1)\n"
|
||||
" elsewhere\n"
|
||||
"else\n"
|
||||
" 1\n"
|
||||
"end")
|
||||
|
||||
|
||||
def test_haramard():
|
||||
A = MatrixSymbol('A', 3, 3)
|
||||
B = MatrixSymbol('B', 3, 3)
|
||||
v = MatrixSymbol('v', 3, 1)
|
||||
h = MatrixSymbol('h', 1, 3)
|
||||
C = HadamardProduct(A, B)
|
||||
assert julia_code(C) == "A .* B"
|
||||
assert julia_code(C*v) == "(A .* B) * v"
|
||||
assert julia_code(h*C*v) == "h * (A .* B) * v"
|
||||
assert julia_code(C*A) == "(A .* B) * A"
|
||||
# mixing Hadamard and scalar strange b/c we vectorize scalars
|
||||
assert julia_code(C*x*y) == "(x .* y) * (A .* B)"
|
||||
|
||||
|
||||
def test_sparse():
|
||||
M = SparseMatrix(5, 6, {})
|
||||
M[2, 2] = 10
|
||||
M[1, 2] = 20
|
||||
M[1, 3] = 22
|
||||
M[0, 3] = 30
|
||||
M[3, 0] = x*y
|
||||
assert julia_code(M) == (
|
||||
"sparse([4, 2, 3, 1, 2], [1, 3, 3, 4, 4], [x .* y, 20, 10, 30, 22], 5, 6)"
|
||||
)
|
||||
|
||||
|
||||
def test_specfun():
|
||||
n = Symbol('n')
|
||||
for f in [besselj, bessely, besseli, besselk]:
|
||||
assert julia_code(f(n, x)) == f.__name__ + '(n, x)'
|
||||
for f in [airyai, airyaiprime, airybi, airybiprime]:
|
||||
assert julia_code(f(x)) == f.__name__ + '(x)'
|
||||
assert julia_code(hankel1(n, x)) == 'hankelh1(n, x)'
|
||||
assert julia_code(hankel2(n, x)) == 'hankelh2(n, x)'
|
||||
assert julia_code(jn(n, x)) == 'sqrt(2) * sqrt(pi) * sqrt(1 ./ x) .* besselj(n + 1 // 2, x) / 2'
|
||||
assert julia_code(yn(n, x)) == 'sqrt(2) * sqrt(pi) * sqrt(1 ./ x) .* bessely(n + 1 // 2, x) / 2'
|
||||
|
||||
|
||||
def test_MatrixElement_printing():
|
||||
# test cases for issue #11821
|
||||
A = MatrixSymbol("A", 1, 3)
|
||||
B = MatrixSymbol("B", 1, 3)
|
||||
C = MatrixSymbol("C", 1, 3)
|
||||
|
||||
assert(julia_code(A[0, 0]) == "A[1,1]")
|
||||
assert(julia_code(3 * A[0, 0]) == "3 * A[1,1]")
|
||||
|
||||
F = C[0, 0].subs(C, A - B)
|
||||
assert(julia_code(F) == "(A - B)[1,1]")
|
||||
@@ -0,0 +1,246 @@
|
||||
from sympy.concrete.summations import Sum
|
||||
from sympy.core.expr import Expr
|
||||
from sympy.core.symbol import symbols
|
||||
from sympy.functions.elementary.miscellaneous import sqrt
|
||||
from sympy.functions.elementary.piecewise import Piecewise
|
||||
from sympy.functions.elementary.trigonometric import sin
|
||||
from sympy.matrices.dense import MutableDenseMatrix as Matrix
|
||||
from sympy.sets.sets import Interval
|
||||
from sympy.utilities.lambdify import lambdify
|
||||
from sympy.testing.pytest import raises
|
||||
|
||||
from sympy.printing.tensorflow import TensorflowPrinter
|
||||
from sympy.printing.lambdarepr import lambdarepr, LambdaPrinter, NumExprPrinter
|
||||
|
||||
|
||||
x, y, z = symbols("x,y,z")
|
||||
i, a, b = symbols("i,a,b")
|
||||
j, c, d = symbols("j,c,d")
|
||||
|
||||
|
||||
def test_basic():
|
||||
assert lambdarepr(x*y) == "x*y"
|
||||
assert lambdarepr(x + y) in ["y + x", "x + y"]
|
||||
assert lambdarepr(x**y) == "x**y"
|
||||
|
||||
|
||||
def test_matrix():
|
||||
# Test printing a Matrix that has an element that is printed differently
|
||||
# with the LambdaPrinter than with the StrPrinter.
|
||||
e = x % 2
|
||||
assert lambdarepr(e) != str(e)
|
||||
assert lambdarepr(Matrix([e])) == 'ImmutableDenseMatrix([[x % 2]])'
|
||||
|
||||
|
||||
def test_piecewise():
|
||||
# In each case, test eval() the lambdarepr() to make sure there are a
|
||||
# correct number of parentheses. It will give a SyntaxError if there aren't.
|
||||
|
||||
h = "lambda x: "
|
||||
|
||||
p = Piecewise((x, x < 0))
|
||||
l = lambdarepr(p)
|
||||
eval(h + l)
|
||||
assert l == "((x) if (x < 0) else None)"
|
||||
|
||||
p = Piecewise(
|
||||
(1, x < 1),
|
||||
(2, x < 2),
|
||||
(0, True)
|
||||
)
|
||||
l = lambdarepr(p)
|
||||
eval(h + l)
|
||||
assert l == "((1) if (x < 1) else (2) if (x < 2) else (0))"
|
||||
|
||||
p = Piecewise(
|
||||
(1, x < 1),
|
||||
(2, x < 2),
|
||||
)
|
||||
l = lambdarepr(p)
|
||||
eval(h + l)
|
||||
assert l == "((1) if (x < 1) else (2) if (x < 2) else None)"
|
||||
|
||||
p = Piecewise(
|
||||
(x, x < 1),
|
||||
(x**2, Interval(3, 4, True, False).contains(x)),
|
||||
(0, True),
|
||||
)
|
||||
l = lambdarepr(p)
|
||||
eval(h + l)
|
||||
assert l == "((x) if (x < 1) else (x**2) if (((x <= 4)) and ((x > 3))) else (0))"
|
||||
|
||||
p = Piecewise(
|
||||
(x**2, x < 0),
|
||||
(x, x < 1),
|
||||
(2 - x, x >= 1),
|
||||
(0, True), evaluate=False
|
||||
)
|
||||
l = lambdarepr(p)
|
||||
eval(h + l)
|
||||
assert l == "((x**2) if (x < 0) else (x) if (x < 1)"\
|
||||
" else (2 - x) if (x >= 1) else (0))"
|
||||
|
||||
p = Piecewise(
|
||||
(x**2, x < 0),
|
||||
(x, x < 1),
|
||||
(2 - x, x >= 1), evaluate=False
|
||||
)
|
||||
l = lambdarepr(p)
|
||||
eval(h + l)
|
||||
assert l == "((x**2) if (x < 0) else (x) if (x < 1)"\
|
||||
" else (2 - x) if (x >= 1) else None)"
|
||||
|
||||
p = Piecewise(
|
||||
(1, x >= 1),
|
||||
(2, x >= 2),
|
||||
(3, x >= 3),
|
||||
(4, x >= 4),
|
||||
(5, x >= 5),
|
||||
(6, True)
|
||||
)
|
||||
l = lambdarepr(p)
|
||||
eval(h + l)
|
||||
assert l == "((1) if (x >= 1) else (2) if (x >= 2) else (3) if (x >= 3)"\
|
||||
" else (4) if (x >= 4) else (5) if (x >= 5) else (6))"
|
||||
|
||||
p = Piecewise(
|
||||
(1, x <= 1),
|
||||
(2, x <= 2),
|
||||
(3, x <= 3),
|
||||
(4, x <= 4),
|
||||
(5, x <= 5),
|
||||
(6, True)
|
||||
)
|
||||
l = lambdarepr(p)
|
||||
eval(h + l)
|
||||
assert l == "((1) if (x <= 1) else (2) if (x <= 2) else (3) if (x <= 3)"\
|
||||
" else (4) if (x <= 4) else (5) if (x <= 5) else (6))"
|
||||
|
||||
p = Piecewise(
|
||||
(1, x > 1),
|
||||
(2, x > 2),
|
||||
(3, x > 3),
|
||||
(4, x > 4),
|
||||
(5, x > 5),
|
||||
(6, True)
|
||||
)
|
||||
l = lambdarepr(p)
|
||||
eval(h + l)
|
||||
assert l =="((1) if (x > 1) else (2) if (x > 2) else (3) if (x > 3)"\
|
||||
" else (4) if (x > 4) else (5) if (x > 5) else (6))"
|
||||
|
||||
p = Piecewise(
|
||||
(1, x < 1),
|
||||
(2, x < 2),
|
||||
(3, x < 3),
|
||||
(4, x < 4),
|
||||
(5, x < 5),
|
||||
(6, True)
|
||||
)
|
||||
l = lambdarepr(p)
|
||||
eval(h + l)
|
||||
assert l == "((1) if (x < 1) else (2) if (x < 2) else (3) if (x < 3)"\
|
||||
" else (4) if (x < 4) else (5) if (x < 5) else (6))"
|
||||
|
||||
p = Piecewise(
|
||||
(Piecewise(
|
||||
(1, x > 0),
|
||||
(2, True)
|
||||
), y > 0),
|
||||
(3, True)
|
||||
)
|
||||
l = lambdarepr(p)
|
||||
eval(h + l)
|
||||
assert l == "((((1) if (x > 0) else (2))) if (y > 0) else (3))"
|
||||
|
||||
|
||||
def test_sum__1():
|
||||
# In each case, test eval() the lambdarepr() to make sure that
|
||||
# it evaluates to the same results as the symbolic expression
|
||||
s = Sum(x ** i, (i, a, b))
|
||||
l = lambdarepr(s)
|
||||
assert l == "(builtins.sum(x**i for i in range(a, b+1)))"
|
||||
|
||||
args = x, a, b
|
||||
f = lambdify(args, s)
|
||||
v = 2, 3, 8
|
||||
assert f(*v) == s.subs(zip(args, v)).doit()
|
||||
|
||||
def test_sum__2():
|
||||
s = Sum(i * x, (i, a, b))
|
||||
l = lambdarepr(s)
|
||||
assert l == "(builtins.sum(i*x for i in range(a, b+1)))"
|
||||
|
||||
args = x, a, b
|
||||
f = lambdify(args, s)
|
||||
v = 2, 3, 8
|
||||
assert f(*v) == s.subs(zip(args, v)).doit()
|
||||
|
||||
|
||||
def test_multiple_sums():
|
||||
s = Sum(i * x + j, (i, a, b), (j, c, d))
|
||||
|
||||
l = lambdarepr(s)
|
||||
assert l == "(builtins.sum(i*x + j for j in range(c, d+1) for i in range(a, b+1)))"
|
||||
|
||||
args = x, a, b, c, d
|
||||
f = lambdify(args, s)
|
||||
vals = 2, 3, 4, 5, 6
|
||||
f_ref = s.subs(zip(args, vals)).doit()
|
||||
f_res = f(*vals)
|
||||
assert f_res == f_ref
|
||||
|
||||
|
||||
def test_sqrt():
|
||||
prntr = LambdaPrinter({'standard' : 'python3'})
|
||||
assert prntr._print_Pow(sqrt(x), rational=False) == 'sqrt(x)'
|
||||
assert prntr._print_Pow(sqrt(x), rational=True) == 'x**(1/2)'
|
||||
|
||||
|
||||
def test_settings():
|
||||
raises(TypeError, lambda: lambdarepr(sin(x), method="garbage"))
|
||||
|
||||
|
||||
def test_numexpr():
|
||||
# test ITE rewrite as Piecewise
|
||||
from sympy.logic.boolalg import ITE
|
||||
expr = ITE(x > 0, True, False, evaluate=False)
|
||||
assert NumExprPrinter().doprint(expr) == \
|
||||
"numexpr.evaluate('where((x > 0), True, False)', truediv=True)"
|
||||
|
||||
from sympy.codegen.ast import Return, FunctionDefinition, Variable, Assignment
|
||||
func_def = FunctionDefinition(None, 'foo', [Variable(x)], [Assignment(y,x), Return(y**2)])
|
||||
expected = "def foo(x):\n"\
|
||||
" y = numexpr.evaluate('x', truediv=True)\n"\
|
||||
" return numexpr.evaluate('y**2', truediv=True)"
|
||||
assert NumExprPrinter().doprint(func_def) == expected
|
||||
|
||||
|
||||
class CustomPrintedObject(Expr):
|
||||
def _lambdacode(self, printer):
|
||||
return 'lambda'
|
||||
|
||||
def _tensorflowcode(self, printer):
|
||||
return 'tensorflow'
|
||||
|
||||
def _numpycode(self, printer):
|
||||
return 'numpy'
|
||||
|
||||
def _numexprcode(self, printer):
|
||||
return 'numexpr'
|
||||
|
||||
def _mpmathcode(self, printer):
|
||||
return 'mpmath'
|
||||
|
||||
|
||||
def test_printmethod():
|
||||
# In each case, printmethod is called to test
|
||||
# its working
|
||||
|
||||
obj = CustomPrintedObject()
|
||||
assert LambdaPrinter().doprint(obj) == 'lambda'
|
||||
assert TensorflowPrinter().doprint(obj) == 'tensorflow'
|
||||
assert NumExprPrinter().doprint(obj) == "numexpr.evaluate('numexpr', truediv=True)"
|
||||
|
||||
assert NumExprPrinter().doprint(Piecewise((y, x >= 0), (z, x < 0))) == \
|
||||
"numexpr.evaluate('where((x >= 0), y, z)', truediv=True)"
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,224 @@
|
||||
from sympy.external import import_module
|
||||
from sympy.testing.pytest import raises
|
||||
import ctypes
|
||||
|
||||
|
||||
if import_module('llvmlite'):
|
||||
import sympy.printing.llvmjitcode as g
|
||||
else:
|
||||
disabled = True
|
||||
|
||||
import sympy
|
||||
from sympy.abc import a, b, n
|
||||
|
||||
|
||||
# copied from numpy.isclose documentation
|
||||
def isclose(a, b):
|
||||
rtol = 1e-5
|
||||
atol = 1e-8
|
||||
return abs(a-b) <= atol + rtol*abs(b)
|
||||
|
||||
|
||||
def test_simple_expr():
|
||||
e = a + 1.0
|
||||
f = g.llvm_callable([a], e)
|
||||
res = float(e.subs({a: 4.0}).evalf())
|
||||
jit_res = f(4.0)
|
||||
|
||||
assert isclose(jit_res, res)
|
||||
|
||||
|
||||
def test_two_arg():
|
||||
e = 4.0*a + b + 3.0
|
||||
f = g.llvm_callable([a, b], e)
|
||||
res = float(e.subs({a: 4.0, b: 3.0}).evalf())
|
||||
jit_res = f(4.0, 3.0)
|
||||
|
||||
assert isclose(jit_res, res)
|
||||
|
||||
|
||||
def test_func():
|
||||
e = 4.0*sympy.exp(-a)
|
||||
f = g.llvm_callable([a], e)
|
||||
res = float(e.subs({a: 1.5}).evalf())
|
||||
jit_res = f(1.5)
|
||||
|
||||
assert isclose(jit_res, res)
|
||||
|
||||
|
||||
def test_two_func():
|
||||
e = 4.0*sympy.exp(-a) + sympy.exp(b)
|
||||
f = g.llvm_callable([a, b], e)
|
||||
res = float(e.subs({a: 1.5, b: 2.0}).evalf())
|
||||
jit_res = f(1.5, 2.0)
|
||||
|
||||
assert isclose(jit_res, res)
|
||||
|
||||
|
||||
def test_two_sqrt():
|
||||
e = 4.0*sympy.sqrt(a) + sympy.sqrt(b)
|
||||
f = g.llvm_callable([a, b], e)
|
||||
res = float(e.subs({a: 1.5, b: 2.0}).evalf())
|
||||
jit_res = f(1.5, 2.0)
|
||||
|
||||
assert isclose(jit_res, res)
|
||||
|
||||
|
||||
def test_two_pow():
|
||||
e = a**1.5 + b**7
|
||||
f = g.llvm_callable([a, b], e)
|
||||
res = float(e.subs({a: 1.5, b: 2.0}).evalf())
|
||||
jit_res = f(1.5, 2.0)
|
||||
|
||||
assert isclose(jit_res, res)
|
||||
|
||||
|
||||
def test_callback():
|
||||
e = a + 1.2
|
||||
f = g.llvm_callable([a], e, callback_type='scipy.integrate.test')
|
||||
m = ctypes.c_int(1)
|
||||
array_type = ctypes.c_double * 1
|
||||
inp = {a: 2.2}
|
||||
array = array_type(inp[a])
|
||||
jit_res = f(m, array)
|
||||
|
||||
res = float(e.subs(inp).evalf())
|
||||
|
||||
assert isclose(jit_res, res)
|
||||
|
||||
|
||||
def test_callback_cubature():
|
||||
e = a + 1.2
|
||||
f = g.llvm_callable([a], e, callback_type='cubature')
|
||||
m = ctypes.c_int(1)
|
||||
array_type = ctypes.c_double * 1
|
||||
inp = {a: 2.2}
|
||||
array = array_type(inp[a])
|
||||
out_array = array_type(0.0)
|
||||
jit_ret = f(m, array, None, m, out_array)
|
||||
|
||||
assert jit_ret == 0
|
||||
|
||||
res = float(e.subs(inp).evalf())
|
||||
|
||||
assert isclose(out_array[0], res)
|
||||
|
||||
|
||||
def test_callback_two():
|
||||
e = 3*a*b
|
||||
f = g.llvm_callable([a, b], e, callback_type='scipy.integrate.test')
|
||||
m = ctypes.c_int(2)
|
||||
array_type = ctypes.c_double * 2
|
||||
inp = {a: 0.2, b: 1.7}
|
||||
array = array_type(inp[a], inp[b])
|
||||
jit_res = f(m, array)
|
||||
|
||||
res = float(e.subs(inp).evalf())
|
||||
|
||||
assert isclose(jit_res, res)
|
||||
|
||||
|
||||
def test_callback_alt_two():
|
||||
d = sympy.IndexedBase('d')
|
||||
e = 3*d[0]*d[1]
|
||||
f = g.llvm_callable([n, d], e, callback_type='scipy.integrate.test')
|
||||
m = ctypes.c_int(2)
|
||||
array_type = ctypes.c_double * 2
|
||||
inp = {d[0]: 0.2, d[1]: 1.7}
|
||||
array = array_type(inp[d[0]], inp[d[1]])
|
||||
jit_res = f(m, array)
|
||||
|
||||
res = float(e.subs(inp).evalf())
|
||||
|
||||
assert isclose(jit_res, res)
|
||||
|
||||
|
||||
def test_multiple_statements():
|
||||
# Match return from CSE
|
||||
e = [[(b, 4.0*a)], [b + 5]]
|
||||
f = g.llvm_callable([a], e)
|
||||
b_val = e[0][0][1].subs({a: 1.5})
|
||||
res = float(e[1][0].subs({b: b_val}).evalf())
|
||||
jit_res = f(1.5)
|
||||
assert isclose(jit_res, res)
|
||||
|
||||
f_callback = g.llvm_callable([a], e, callback_type='scipy.integrate.test')
|
||||
m = ctypes.c_int(1)
|
||||
array_type = ctypes.c_double * 1
|
||||
array = array_type(1.5)
|
||||
jit_callback_res = f_callback(m, array)
|
||||
assert isclose(jit_callback_res, res)
|
||||
|
||||
|
||||
def test_cse():
|
||||
e = a*a + b*b + sympy.exp(-a*a - b*b)
|
||||
e2 = sympy.cse(e)
|
||||
f = g.llvm_callable([a, b], e2)
|
||||
res = float(e.subs({a: 2.3, b: 0.1}).evalf())
|
||||
jit_res = f(2.3, 0.1)
|
||||
|
||||
assert isclose(jit_res, res)
|
||||
|
||||
|
||||
def eval_cse(e, sub_dict):
|
||||
tmp_dict = {}
|
||||
for tmp_name, tmp_expr in e[0]:
|
||||
e2 = tmp_expr.subs(sub_dict)
|
||||
e3 = e2.subs(tmp_dict)
|
||||
tmp_dict[tmp_name] = e3
|
||||
return [e.subs(sub_dict).subs(tmp_dict) for e in e[1]]
|
||||
|
||||
|
||||
def test_cse_multiple():
|
||||
e1 = a*a
|
||||
e2 = a*a + b*b
|
||||
e3 = sympy.cse([e1, e2])
|
||||
|
||||
raises(NotImplementedError,
|
||||
lambda: g.llvm_callable([a, b], e3, callback_type='scipy.integrate'))
|
||||
|
||||
f = g.llvm_callable([a, b], e3)
|
||||
jit_res = f(0.1, 1.5)
|
||||
assert len(jit_res) == 2
|
||||
res = eval_cse(e3, {a: 0.1, b: 1.5})
|
||||
assert isclose(res[0], jit_res[0])
|
||||
assert isclose(res[1], jit_res[1])
|
||||
|
||||
|
||||
def test_callback_cubature_multiple():
|
||||
e1 = a*a
|
||||
e2 = a*a + b*b
|
||||
e3 = sympy.cse([e1, e2, 4*e2])
|
||||
f = g.llvm_callable([a, b], e3, callback_type='cubature')
|
||||
|
||||
# Number of input variables
|
||||
ndim = 2
|
||||
# Number of output expression values
|
||||
outdim = 3
|
||||
|
||||
m = ctypes.c_int(ndim)
|
||||
fdim = ctypes.c_int(outdim)
|
||||
array_type = ctypes.c_double * ndim
|
||||
out_array_type = ctypes.c_double * outdim
|
||||
inp = {a: 0.2, b: 1.5}
|
||||
array = array_type(inp[a], inp[b])
|
||||
out_array = out_array_type()
|
||||
jit_ret = f(m, array, None, fdim, out_array)
|
||||
|
||||
assert jit_ret == 0
|
||||
|
||||
res = eval_cse(e3, inp)
|
||||
|
||||
assert isclose(out_array[0], res[0])
|
||||
assert isclose(out_array[1], res[1])
|
||||
assert isclose(out_array[2], res[2])
|
||||
|
||||
|
||||
def test_symbol_not_found():
|
||||
e = a*a + b
|
||||
raises(LookupError, lambda: g.llvm_callable([a], e))
|
||||
|
||||
|
||||
def test_bad_callback():
|
||||
e = a
|
||||
raises(ValueError, lambda: g.llvm_callable([a], e, callback_type='bad_callback'))
|
||||
@@ -0,0 +1,381 @@
|
||||
from sympy.core import (S, pi, oo, symbols, Function, Rational, Integer,
|
||||
Tuple, Symbol, Eq, Ne, Le, Lt, Gt, Ge)
|
||||
from sympy.core import EulerGamma, GoldenRatio, Catalan, Lambda, Mul, Pow
|
||||
from sympy.functions import Piecewise, sqrt, ceiling, exp, sin, cos, sinc, lucas
|
||||
from sympy.testing.pytest import raises
|
||||
from sympy.utilities.lambdify import implemented_function
|
||||
from sympy.matrices import (eye, Matrix, MatrixSymbol, Identity,
|
||||
HadamardProduct, SparseMatrix)
|
||||
from sympy.functions.special.bessel import besseli
|
||||
|
||||
from sympy.printing.maple import maple_code
|
||||
|
||||
x, y, z = symbols('x,y,z')
|
||||
|
||||
|
||||
def test_Integer():
|
||||
assert maple_code(Integer(67)) == "67"
|
||||
assert maple_code(Integer(-1)) == "-1"
|
||||
|
||||
|
||||
def test_Rational():
|
||||
assert maple_code(Rational(3, 7)) == "3/7"
|
||||
assert maple_code(Rational(18, 9)) == "2"
|
||||
assert maple_code(Rational(3, -7)) == "-3/7"
|
||||
assert maple_code(Rational(-3, -7)) == "3/7"
|
||||
assert maple_code(x + Rational(3, 7)) == "x + 3/7"
|
||||
assert maple_code(Rational(3, 7) * x) == '(3/7)*x'
|
||||
|
||||
|
||||
def test_Relational():
|
||||
assert maple_code(Eq(x, y)) == "x = y"
|
||||
assert maple_code(Ne(x, y)) == "x <> y"
|
||||
assert maple_code(Le(x, y)) == "x <= y"
|
||||
assert maple_code(Lt(x, y)) == "x < y"
|
||||
assert maple_code(Gt(x, y)) == "x > y"
|
||||
assert maple_code(Ge(x, y)) == "x >= y"
|
||||
|
||||
|
||||
def test_Function():
|
||||
assert maple_code(sin(x) ** cos(x)) == "sin(x)^cos(x)"
|
||||
assert maple_code(abs(x)) == "abs(x)"
|
||||
assert maple_code(ceiling(x)) == "ceil(x)"
|
||||
|
||||
|
||||
def test_Pow():
|
||||
assert maple_code(x ** 3) == "x^3"
|
||||
assert maple_code(x ** (y ** 3)) == "x^(y^3)"
|
||||
|
||||
assert maple_code((x ** 3) ** y) == "(x^3)^y"
|
||||
assert maple_code(x ** Rational(2, 3)) == 'x^(2/3)'
|
||||
|
||||
g = implemented_function('g', Lambda(x, 2 * x))
|
||||
assert maple_code(1 / (g(x) * 3.5) ** (x - y ** x) / (x ** 2 + y)) == \
|
||||
"(3.5*2*x)^(-x + y^x)/(x^2 + y)"
|
||||
# For issue 14160
|
||||
assert maple_code(Mul(-2, x, Pow(Mul(y, y, evaluate=False), -1, evaluate=False),
|
||||
evaluate=False)) == '-2*x/(y*y)'
|
||||
|
||||
|
||||
def test_basic_ops():
|
||||
assert maple_code(x * y) == "x*y"
|
||||
assert maple_code(x + y) == "x + y"
|
||||
assert maple_code(x - y) == "x - y"
|
||||
assert maple_code(-x) == "-x"
|
||||
|
||||
|
||||
def test_1_over_x_and_sqrt():
|
||||
# 1.0 and 0.5 would do something different in regular StrPrinter,
|
||||
# but these are exact in IEEE floating point so no different here.
|
||||
assert maple_code(1 / x) == '1/x'
|
||||
assert maple_code(x ** -1) == maple_code(x ** -1.0) == '1/x'
|
||||
assert maple_code(1 / sqrt(x)) == '1/sqrt(x)'
|
||||
assert maple_code(x ** -S.Half) == maple_code(x ** -0.5) == '1/sqrt(x)'
|
||||
assert maple_code(sqrt(x)) == 'sqrt(x)'
|
||||
assert maple_code(x ** S.Half) == maple_code(x ** 0.5) == 'sqrt(x)'
|
||||
assert maple_code(1 / pi) == '1/Pi'
|
||||
assert maple_code(pi ** -1) == maple_code(pi ** -1.0) == '1/Pi'
|
||||
assert maple_code(pi ** -0.5) == '1/sqrt(Pi)'
|
||||
|
||||
|
||||
def test_mix_number_mult_symbols():
|
||||
assert maple_code(3 * x) == "3*x"
|
||||
assert maple_code(pi * x) == "Pi*x"
|
||||
assert maple_code(3 / x) == "3/x"
|
||||
assert maple_code(pi / x) == "Pi/x"
|
||||
assert maple_code(x / 3) == '(1/3)*x'
|
||||
assert maple_code(x / pi) == "x/Pi"
|
||||
assert maple_code(x * y) == "x*y"
|
||||
assert maple_code(3 * x * y) == "3*x*y"
|
||||
assert maple_code(3 * pi * x * y) == "3*Pi*x*y"
|
||||
assert maple_code(x / y) == "x/y"
|
||||
assert maple_code(3 * x / y) == "3*x/y"
|
||||
assert maple_code(x * y / z) == "x*y/z"
|
||||
assert maple_code(x / y * z) == "x*z/y"
|
||||
assert maple_code(1 / x / y) == "1/(x*y)"
|
||||
assert maple_code(2 * pi * x / y / z) == "2*Pi*x/(y*z)"
|
||||
assert maple_code(3 * pi / x) == "3*Pi/x"
|
||||
assert maple_code(S(3) / 5) == "3/5"
|
||||
assert maple_code(S(3) / 5 * x) == '(3/5)*x'
|
||||
assert maple_code(x / y / z) == "x/(y*z)"
|
||||
assert maple_code((x + y) / z) == "(x + y)/z"
|
||||
assert maple_code((x + y) / (z + x)) == "(x + y)/(x + z)"
|
||||
assert maple_code((x + y) / EulerGamma) == '(x + y)/gamma'
|
||||
assert maple_code(x / 3 / pi) == '(1/3)*x/Pi'
|
||||
assert maple_code(S(3) / 5 * x * y / pi) == '(3/5)*x*y/Pi'
|
||||
|
||||
|
||||
def test_mix_number_pow_symbols():
|
||||
assert maple_code(pi ** 3) == 'Pi^3'
|
||||
assert maple_code(x ** 2) == 'x^2'
|
||||
|
||||
assert maple_code(x ** (pi ** 3)) == 'x^(Pi^3)'
|
||||
assert maple_code(x ** y) == 'x^y'
|
||||
|
||||
assert maple_code(x ** (y ** z)) == 'x^(y^z)'
|
||||
assert maple_code((x ** y) ** z) == '(x^y)^z'
|
||||
|
||||
|
||||
def test_imag():
|
||||
I = S('I')
|
||||
assert maple_code(I) == "I"
|
||||
assert maple_code(5 * I) == "5*I"
|
||||
|
||||
assert maple_code((S(3) / 2) * I) == "(3/2)*I"
|
||||
assert maple_code(3 + 4 * I) == "3 + 4*I"
|
||||
|
||||
|
||||
def test_constants():
|
||||
assert maple_code(pi) == "Pi"
|
||||
assert maple_code(oo) == "infinity"
|
||||
assert maple_code(-oo) == "-infinity"
|
||||
assert maple_code(S.NegativeInfinity) == "-infinity"
|
||||
assert maple_code(S.NaN) == "undefined"
|
||||
assert maple_code(S.Exp1) == "exp(1)"
|
||||
assert maple_code(exp(1)) == "exp(1)"
|
||||
|
||||
|
||||
def test_constants_other():
|
||||
assert maple_code(2 * GoldenRatio) == '2*(1/2 + (1/2)*sqrt(5))'
|
||||
assert maple_code(2 * Catalan) == '2*Catalan'
|
||||
assert maple_code(2 * EulerGamma) == "2*gamma"
|
||||
|
||||
|
||||
def test_boolean():
|
||||
assert maple_code(x & y) == "x and y"
|
||||
assert maple_code(x | y) == "x or y"
|
||||
assert maple_code(~x) == "not x"
|
||||
assert maple_code(x & y & z) == "x and y and z"
|
||||
assert maple_code(x | y | z) == "x or y or z"
|
||||
assert maple_code((x & y) | z) == "z or x and y"
|
||||
assert maple_code((x | y) & z) == "z and (x or y)"
|
||||
|
||||
|
||||
def test_Matrices():
|
||||
assert maple_code(Matrix(1, 1, [10])) == \
|
||||
'Matrix([[10]], storage = rectangular)'
|
||||
|
||||
A = Matrix([[1, sin(x / 2), abs(x)],
|
||||
[0, 1, pi],
|
||||
[0, exp(1), ceiling(x)]])
|
||||
expected = \
|
||||
'Matrix(' \
|
||||
'[[1, sin((1/2)*x), abs(x)],' \
|
||||
' [0, 1, Pi],' \
|
||||
' [0, exp(1), ceil(x)]], ' \
|
||||
'storage = rectangular)'
|
||||
assert maple_code(A) == expected
|
||||
|
||||
# row and columns
|
||||
assert maple_code(A[:, 0]) == \
|
||||
'Matrix([[1], [0], [0]], storage = rectangular)'
|
||||
assert maple_code(A[0, :]) == \
|
||||
'Matrix([[1, sin((1/2)*x), abs(x)]], storage = rectangular)'
|
||||
assert maple_code(Matrix([[x, x - y, -y]])) == \
|
||||
'Matrix([[x, x - y, -y]], storage = rectangular)'
|
||||
|
||||
# empty matrices
|
||||
assert maple_code(Matrix(0, 0, [])) == \
|
||||
'Matrix([], storage = rectangular)'
|
||||
assert maple_code(Matrix(0, 3, [])) == \
|
||||
'Matrix([], storage = rectangular)'
|
||||
|
||||
def test_SparseMatrices():
|
||||
assert maple_code(SparseMatrix(Identity(2))) == 'Matrix([[1, 0], [0, 1]], storage = sparse)'
|
||||
|
||||
|
||||
def test_vector_entries_hadamard():
|
||||
# For a row or column, user might to use the other dimension
|
||||
A = Matrix([[1, sin(2 / x), 3 * pi / x / 5]])
|
||||
assert maple_code(A) == \
|
||||
'Matrix([[1, sin(2/x), (3/5)*Pi/x]], storage = rectangular)'
|
||||
assert maple_code(A.T) == \
|
||||
'Matrix([[1], [sin(2/x)], [(3/5)*Pi/x]], storage = rectangular)'
|
||||
|
||||
|
||||
def test_Matrices_entries_not_hadamard():
|
||||
A = Matrix([[1, sin(2 / x), 3 * pi / x / 5], [1, 2, x * y]])
|
||||
expected = \
|
||||
'Matrix([[1, sin(2/x), (3/5)*Pi/x], [1, 2, x*y]], ' \
|
||||
'storage = rectangular)'
|
||||
assert maple_code(A) == expected
|
||||
|
||||
|
||||
def test_MatrixSymbol():
|
||||
n = Symbol('n', integer=True)
|
||||
A = MatrixSymbol('A', n, n)
|
||||
B = MatrixSymbol('B', n, n)
|
||||
assert maple_code(A * B) == "A.B"
|
||||
assert maple_code(B * A) == "B.A"
|
||||
assert maple_code(2 * A * B) == "2*A.B"
|
||||
assert maple_code(B * 2 * A) == "2*B.A"
|
||||
|
||||
assert maple_code(
|
||||
A * (B + 3 * Identity(n))) == "A.(3*Matrix(n, shape = identity) + B)"
|
||||
|
||||
assert maple_code(A ** (x ** 2)) == "MatrixPower(A, x^2)"
|
||||
assert maple_code(A ** 3) == "MatrixPower(A, 3)"
|
||||
assert maple_code(A ** (S.Half)) == "MatrixPower(A, 1/2)"
|
||||
|
||||
|
||||
def test_special_matrices():
|
||||
assert maple_code(6 * Identity(3)) == "6*Matrix([[1, 0, 0], [0, 1, 0], [0, 0, 1]], storage = sparse)"
|
||||
assert maple_code(Identity(x)) == 'Matrix(x, shape = identity)'
|
||||
|
||||
|
||||
def test_containers():
|
||||
assert maple_code([1, 2, 3, [4, 5, [6, 7]], 8, [9, 10], 11]) == \
|
||||
"[1, 2, 3, [4, 5, [6, 7]], 8, [9, 10], 11]"
|
||||
|
||||
assert maple_code((1, 2, (3, 4))) == "[1, 2, [3, 4]]"
|
||||
assert maple_code([1]) == "[1]"
|
||||
assert maple_code((1,)) == "[1]"
|
||||
assert maple_code(Tuple(*[1, 2, 3])) == "[1, 2, 3]"
|
||||
assert maple_code((1, x * y, (3, x ** 2))) == "[1, x*y, [3, x^2]]"
|
||||
# scalar, matrix, empty matrix and empty list
|
||||
|
||||
assert maple_code((1, eye(3), Matrix(0, 0, []), [])) == \
|
||||
"[1, Matrix([[1, 0, 0], [0, 1, 0], [0, 0, 1]], storage = rectangular), Matrix([], storage = rectangular), []]"
|
||||
|
||||
|
||||
def test_maple_noninline():
|
||||
source = maple_code((x + y)/Catalan, assign_to='me', inline=False)
|
||||
expected = "me := (x + y)/Catalan"
|
||||
|
||||
assert source == expected
|
||||
|
||||
|
||||
def test_maple_matrix_assign_to():
|
||||
A = Matrix([[1, 2, 3]])
|
||||
assert maple_code(A, assign_to='a') == "a := Matrix([[1, 2, 3]], storage = rectangular)"
|
||||
A = Matrix([[1, 2], [3, 4]])
|
||||
assert maple_code(A, assign_to='A') == "A := Matrix([[1, 2], [3, 4]], storage = rectangular)"
|
||||
|
||||
|
||||
def test_maple_matrix_assign_to_more():
|
||||
# assigning to Symbol or MatrixSymbol requires lhs/rhs match
|
||||
A = Matrix([[1, 2, 3]])
|
||||
B = MatrixSymbol('B', 1, 3)
|
||||
C = MatrixSymbol('C', 2, 3)
|
||||
assert maple_code(A, assign_to=B) == "B := Matrix([[1, 2, 3]], storage = rectangular)"
|
||||
raises(ValueError, lambda: maple_code(A, assign_to=x))
|
||||
raises(ValueError, lambda: maple_code(A, assign_to=C))
|
||||
|
||||
|
||||
def test_maple_matrix_1x1():
|
||||
A = Matrix([[3]])
|
||||
assert maple_code(A, assign_to='B') == "B := Matrix([[3]], storage = rectangular)"
|
||||
|
||||
|
||||
def test_maple_matrix_elements():
|
||||
A = Matrix([[x, 2, x * y]])
|
||||
|
||||
assert maple_code(A[0, 0] ** 2 + A[0, 1] + A[0, 2]) == "x^2 + x*y + 2"
|
||||
AA = MatrixSymbol('AA', 1, 3)
|
||||
assert maple_code(AA) == "AA"
|
||||
|
||||
assert maple_code(AA[0, 0] ** 2 + sin(AA[0, 1]) + AA[0, 2]) == \
|
||||
"sin(AA[1, 2]) + AA[1, 1]^2 + AA[1, 3]"
|
||||
assert maple_code(sum(AA)) == "AA[1, 1] + AA[1, 2] + AA[1, 3]"
|
||||
|
||||
|
||||
def test_maple_boolean():
|
||||
assert maple_code(True) == "true"
|
||||
assert maple_code(S.true) == "true"
|
||||
assert maple_code(False) == "false"
|
||||
assert maple_code(S.false) == "false"
|
||||
|
||||
|
||||
def test_sparse():
|
||||
M = SparseMatrix(5, 6, {})
|
||||
M[2, 2] = 10
|
||||
M[1, 2] = 20
|
||||
M[1, 3] = 22
|
||||
M[0, 3] = 30
|
||||
M[3, 0] = x * y
|
||||
assert maple_code(M) == \
|
||||
'Matrix([[0, 0, 0, 30, 0, 0],' \
|
||||
' [0, 0, 20, 22, 0, 0],' \
|
||||
' [0, 0, 10, 0, 0, 0],' \
|
||||
' [x*y, 0, 0, 0, 0, 0],' \
|
||||
' [0, 0, 0, 0, 0, 0]], ' \
|
||||
'storage = sparse)'
|
||||
|
||||
# Not an important point.
|
||||
def test_maple_not_supported():
|
||||
with raises(NotImplementedError):
|
||||
maple_code(S.ComplexInfinity)
|
||||
|
||||
|
||||
def test_MatrixElement_printing():
|
||||
# test cases for issue #11821
|
||||
A = MatrixSymbol("A", 1, 3)
|
||||
B = MatrixSymbol("B", 1, 3)
|
||||
|
||||
assert (maple_code(A[0, 0]) == "A[1, 1]")
|
||||
assert (maple_code(3 * A[0, 0]) == "3*A[1, 1]")
|
||||
|
||||
F = A-B
|
||||
|
||||
assert (maple_code(F[0,0]) == "A[1, 1] - B[1, 1]")
|
||||
|
||||
|
||||
def test_hadamard():
|
||||
A = MatrixSymbol('A', 3, 3)
|
||||
B = MatrixSymbol('B', 3, 3)
|
||||
v = MatrixSymbol('v', 3, 1)
|
||||
h = MatrixSymbol('h', 1, 3)
|
||||
C = HadamardProduct(A, B)
|
||||
assert maple_code(C) == "A*B"
|
||||
|
||||
assert maple_code(C * v) == "(A*B).v"
|
||||
# HadamardProduct is higher than dot product.
|
||||
|
||||
assert maple_code(h * C * v) == "h.(A*B).v"
|
||||
|
||||
assert maple_code(C * A) == "(A*B).A"
|
||||
# mixing Hadamard and scalar strange b/c we vectorize scalars
|
||||
|
||||
assert maple_code(C * x * y) == "x*y*(A*B)"
|
||||
|
||||
|
||||
def test_maple_piecewise():
|
||||
expr = Piecewise((x, x < 1), (x ** 2, True))
|
||||
|
||||
assert maple_code(expr) == "piecewise(x < 1, x, x^2)"
|
||||
assert maple_code(expr, assign_to="r") == (
|
||||
"r := piecewise(x < 1, x, x^2)")
|
||||
|
||||
expr = Piecewise((x ** 2, x < 1), (x ** 3, x < 2), (x ** 4, x < 3), (x ** 5, True))
|
||||
expected = "piecewise(x < 1, x^2, x < 2, x^3, x < 3, x^4, x^5)"
|
||||
assert maple_code(expr) == expected
|
||||
assert maple_code(expr, assign_to="r") == "r := " + expected
|
||||
|
||||
# Check that Piecewise without a True (default) condition error
|
||||
expr = Piecewise((x, x < 1), (x ** 2, x > 1), (sin(x), x > 0))
|
||||
raises(ValueError, lambda: maple_code(expr))
|
||||
|
||||
|
||||
def test_maple_piecewise_times_const():
|
||||
pw = Piecewise((x, x < 1), (x ** 2, True))
|
||||
|
||||
assert maple_code(2 * pw) == "2*piecewise(x < 1, x, x^2)"
|
||||
assert maple_code(pw / x) == "piecewise(x < 1, x, x^2)/x"
|
||||
assert maple_code(pw / (x * y)) == "piecewise(x < 1, x, x^2)/(x*y)"
|
||||
assert maple_code(pw / 3) == "(1/3)*piecewise(x < 1, x, x^2)"
|
||||
|
||||
|
||||
def test_maple_derivatives():
|
||||
f = Function('f')
|
||||
assert maple_code(f(x).diff(x)) == 'diff(f(x), x)'
|
||||
assert maple_code(f(x).diff(x, 2)) == 'diff(f(x), x$2)'
|
||||
|
||||
|
||||
def test_automatic_rewrites():
|
||||
assert maple_code(lucas(x)) == '(2^(-x)*((1 - sqrt(5))^x + (1 + sqrt(5))^x))'
|
||||
assert maple_code(sinc(x)) == '(piecewise(x <> 0, sin(x)/x, 1))'
|
||||
|
||||
|
||||
def test_specfun():
|
||||
assert maple_code('asin(x)') == 'arcsin(x)'
|
||||
assert maple_code(besseli(x, y)) == 'BesselI(x, y)'
|
||||
@@ -0,0 +1,287 @@
|
||||
from sympy.core import (S, pi, oo, symbols, Function, Rational, Integer, Tuple,
|
||||
Derivative, Eq, Ne, Le, Lt, Gt, Ge)
|
||||
from sympy.integrals import Integral
|
||||
from sympy.concrete import Sum
|
||||
from sympy.functions import (exp, sin, cos, fresnelc, fresnels, conjugate, Max,
|
||||
Min, gamma, polygamma, loggamma, erf, erfi, erfc,
|
||||
erf2, expint, erfinv, erfcinv, Ei, Si, Ci, li,
|
||||
Shi, Chi, uppergamma, beta, subfactorial, erf2inv,
|
||||
factorial, factorial2, catalan, RisingFactorial,
|
||||
FallingFactorial, harmonic, atan2, sec, acsc,
|
||||
hermite, laguerre, assoc_laguerre, jacobi,
|
||||
gegenbauer, chebyshevt, chebyshevu, legendre,
|
||||
assoc_legendre, Li, LambertW)
|
||||
|
||||
from sympy.printing.mathematica import mathematica_code as mcode
|
||||
|
||||
x, y, z, w = symbols('x,y,z,w')
|
||||
f = Function('f')
|
||||
|
||||
|
||||
def test_Integer():
|
||||
assert mcode(Integer(67)) == "67"
|
||||
assert mcode(Integer(-1)) == "-1"
|
||||
|
||||
|
||||
def test_Rational():
|
||||
assert mcode(Rational(3, 7)) == "3/7"
|
||||
assert mcode(Rational(18, 9)) == "2"
|
||||
assert mcode(Rational(3, -7)) == "-3/7"
|
||||
assert mcode(Rational(-3, -7)) == "3/7"
|
||||
assert mcode(x + Rational(3, 7)) == "x + 3/7"
|
||||
assert mcode(Rational(3, 7)*x) == "(3/7)*x"
|
||||
|
||||
|
||||
def test_Relational():
|
||||
assert mcode(Eq(x, y)) == "x == y"
|
||||
assert mcode(Ne(x, y)) == "x != y"
|
||||
assert mcode(Le(x, y)) == "x <= y"
|
||||
assert mcode(Lt(x, y)) == "x < y"
|
||||
assert mcode(Gt(x, y)) == "x > y"
|
||||
assert mcode(Ge(x, y)) == "x >= y"
|
||||
|
||||
|
||||
def test_Function():
|
||||
assert mcode(f(x, y, z)) == "f[x, y, z]"
|
||||
assert mcode(sin(x) ** cos(x)) == "Sin[x]^Cos[x]"
|
||||
assert mcode(sec(x) * acsc(x)) == "ArcCsc[x]*Sec[x]"
|
||||
assert mcode(atan2(y, x)) == "ArcTan[x, y]"
|
||||
assert mcode(conjugate(x)) == "Conjugate[x]"
|
||||
assert mcode(Max(x, y, z)*Min(y, z)) == "Max[x, y, z]*Min[y, z]"
|
||||
assert mcode(fresnelc(x)) == "FresnelC[x]"
|
||||
assert mcode(fresnels(x)) == "FresnelS[x]"
|
||||
assert mcode(gamma(x)) == "Gamma[x]"
|
||||
assert mcode(uppergamma(x, y)) == "Gamma[x, y]"
|
||||
assert mcode(polygamma(x, y)) == "PolyGamma[x, y]"
|
||||
assert mcode(loggamma(x)) == "LogGamma[x]"
|
||||
assert mcode(erf(x)) == "Erf[x]"
|
||||
assert mcode(erfc(x)) == "Erfc[x]"
|
||||
assert mcode(erfi(x)) == "Erfi[x]"
|
||||
assert mcode(erf2(x, y)) == "Erf[x, y]"
|
||||
assert mcode(expint(x, y)) == "ExpIntegralE[x, y]"
|
||||
assert mcode(erfcinv(x)) == "InverseErfc[x]"
|
||||
assert mcode(erfinv(x)) == "InverseErf[x]"
|
||||
assert mcode(erf2inv(x, y)) == "InverseErf[x, y]"
|
||||
assert mcode(Ei(x)) == "ExpIntegralEi[x]"
|
||||
assert mcode(Ci(x)) == "CosIntegral[x]"
|
||||
assert mcode(li(x)) == "LogIntegral[x]"
|
||||
assert mcode(Si(x)) == "SinIntegral[x]"
|
||||
assert mcode(Shi(x)) == "SinhIntegral[x]"
|
||||
assert mcode(Chi(x)) == "CoshIntegral[x]"
|
||||
assert mcode(beta(x, y)) == "Beta[x, y]"
|
||||
assert mcode(factorial(x)) == "Factorial[x]"
|
||||
assert mcode(factorial2(x)) == "Factorial2[x]"
|
||||
assert mcode(subfactorial(x)) == "Subfactorial[x]"
|
||||
assert mcode(FallingFactorial(x, y)) == "FactorialPower[x, y]"
|
||||
assert mcode(RisingFactorial(x, y)) == "Pochhammer[x, y]"
|
||||
assert mcode(catalan(x)) == "CatalanNumber[x]"
|
||||
assert mcode(harmonic(x)) == "HarmonicNumber[x]"
|
||||
assert mcode(harmonic(x, y)) == "HarmonicNumber[x, y]"
|
||||
assert mcode(Li(x)) == "LogIntegral[x] - LogIntegral[2]"
|
||||
assert mcode(LambertW(x)) == "ProductLog[x]"
|
||||
assert mcode(LambertW(x, -1)) == "ProductLog[-1, x]"
|
||||
assert mcode(LambertW(x, y)) == "ProductLog[y, x]"
|
||||
|
||||
|
||||
def test_special_polynomials():
|
||||
assert mcode(hermite(x, y)) == "HermiteH[x, y]"
|
||||
assert mcode(laguerre(x, y)) == "LaguerreL[x, y]"
|
||||
assert mcode(assoc_laguerre(x, y, z)) == "LaguerreL[x, y, z]"
|
||||
assert mcode(jacobi(x, y, z, w)) == "JacobiP[x, y, z, w]"
|
||||
assert mcode(gegenbauer(x, y, z)) == "GegenbauerC[x, y, z]"
|
||||
assert mcode(chebyshevt(x, y)) == "ChebyshevT[x, y]"
|
||||
assert mcode(chebyshevu(x, y)) == "ChebyshevU[x, y]"
|
||||
assert mcode(legendre(x, y)) == "LegendreP[x, y]"
|
||||
assert mcode(assoc_legendre(x, y, z)) == "LegendreP[x, y, z]"
|
||||
|
||||
|
||||
def test_Pow():
|
||||
assert mcode(x**3) == "x^3"
|
||||
assert mcode(x**(y**3)) == "x^(y^3)"
|
||||
assert mcode(1/(f(x)*3.5)**(x - y**x)/(x**2 + y)) == \
|
||||
"(3.5*f[x])^(-x + y^x)/(x^2 + y)"
|
||||
assert mcode(x**-1.0) == 'x^(-1.0)'
|
||||
assert mcode(x**Rational(2, 3)) == 'x^(2/3)'
|
||||
|
||||
|
||||
def test_Mul():
|
||||
A, B, C, D = symbols('A B C D', commutative=False)
|
||||
assert mcode(x*y*z) == "x*y*z"
|
||||
assert mcode(x*y*A) == "x*y*A"
|
||||
assert mcode(x*y*A*B) == "x*y*A**B"
|
||||
assert mcode(x*y*A*B*C) == "x*y*A**B**C"
|
||||
assert mcode(x*A*B*(C + D)*A*y) == "x*y*A**B**(C + D)**A"
|
||||
|
||||
|
||||
def test_constants():
|
||||
assert mcode(S.Zero) == "0"
|
||||
assert mcode(S.One) == "1"
|
||||
assert mcode(S.NegativeOne) == "-1"
|
||||
assert mcode(S.Half) == "1/2"
|
||||
assert mcode(S.ImaginaryUnit) == "I"
|
||||
|
||||
assert mcode(oo) == "Infinity"
|
||||
assert mcode(S.NegativeInfinity) == "-Infinity"
|
||||
assert mcode(S.ComplexInfinity) == "ComplexInfinity"
|
||||
assert mcode(S.NaN) == "Indeterminate"
|
||||
|
||||
assert mcode(S.Exp1) == "E"
|
||||
assert mcode(pi) == "Pi"
|
||||
assert mcode(S.GoldenRatio) == "GoldenRatio"
|
||||
assert mcode(S.TribonacciConstant) == \
|
||||
"(1/3 + (1/3)*(19 - 3*33^(1/2))^(1/3) + " \
|
||||
"(1/3)*(3*33^(1/2) + 19)^(1/3))"
|
||||
assert mcode(2*S.TribonacciConstant) == \
|
||||
"2*(1/3 + (1/3)*(19 - 3*33^(1/2))^(1/3) + " \
|
||||
"(1/3)*(3*33^(1/2) + 19)^(1/3))"
|
||||
assert mcode(S.EulerGamma) == "EulerGamma"
|
||||
assert mcode(S.Catalan) == "Catalan"
|
||||
|
||||
|
||||
def test_containers():
|
||||
assert mcode([1, 2, 3, [4, 5, [6, 7]], 8, [9, 10], 11]) == \
|
||||
"{1, 2, 3, {4, 5, {6, 7}}, 8, {9, 10}, 11}"
|
||||
assert mcode((1, 2, (3, 4))) == "{1, 2, {3, 4}}"
|
||||
assert mcode([1]) == "{1}"
|
||||
assert mcode((1,)) == "{1}"
|
||||
assert mcode(Tuple(*[1, 2, 3])) == "{1, 2, 3}"
|
||||
|
||||
|
||||
def test_matrices():
|
||||
from sympy.matrices import MutableDenseMatrix, MutableSparseMatrix, \
|
||||
ImmutableDenseMatrix, ImmutableSparseMatrix
|
||||
A = MutableDenseMatrix(
|
||||
[[1, -1, 0, 0],
|
||||
[0, 1, -1, 0],
|
||||
[0, 0, 1, -1],
|
||||
[0, 0, 0, 1]]
|
||||
)
|
||||
B = MutableSparseMatrix(A)
|
||||
C = ImmutableDenseMatrix(A)
|
||||
D = ImmutableSparseMatrix(A)
|
||||
|
||||
assert mcode(C) == mcode(A) == \
|
||||
"{{1, -1, 0, 0}, " \
|
||||
"{0, 1, -1, 0}, " \
|
||||
"{0, 0, 1, -1}, " \
|
||||
"{0, 0, 0, 1}}"
|
||||
|
||||
assert mcode(D) == mcode(B) == \
|
||||
"SparseArray[{" \
|
||||
"{1, 1} -> 1, {1, 2} -> -1, {2, 2} -> 1, {2, 3} -> -1, " \
|
||||
"{3, 3} -> 1, {3, 4} -> -1, {4, 4} -> 1" \
|
||||
"}, {4, 4}]"
|
||||
|
||||
# Trivial cases of matrices
|
||||
assert mcode(MutableDenseMatrix(0, 0, [])) == '{}'
|
||||
assert mcode(MutableSparseMatrix(0, 0, [])) == 'SparseArray[{}, {0, 0}]'
|
||||
assert mcode(MutableDenseMatrix(0, 3, [])) == '{}'
|
||||
assert mcode(MutableSparseMatrix(0, 3, [])) == 'SparseArray[{}, {0, 3}]'
|
||||
assert mcode(MutableDenseMatrix(3, 0, [])) == '{{}, {}, {}}'
|
||||
assert mcode(MutableSparseMatrix(3, 0, [])) == 'SparseArray[{}, {3, 0}]'
|
||||
|
||||
def test_NDArray():
|
||||
from sympy.tensor.array import (
|
||||
MutableDenseNDimArray, ImmutableDenseNDimArray,
|
||||
MutableSparseNDimArray, ImmutableSparseNDimArray)
|
||||
|
||||
example = MutableDenseNDimArray(
|
||||
[[[1, 2, 3, 4],
|
||||
[5, 6, 7, 8],
|
||||
[9, 10, 11, 12]],
|
||||
[[13, 14, 15, 16],
|
||||
[17, 18, 19, 20],
|
||||
[21, 22, 23, 24]]]
|
||||
)
|
||||
|
||||
assert mcode(example) == \
|
||||
"{{{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}}, " \
|
||||
"{{13, 14, 15, 16}, {17, 18, 19, 20}, {21, 22, 23, 24}}}"
|
||||
|
||||
example = ImmutableDenseNDimArray(example)
|
||||
|
||||
assert mcode(example) == \
|
||||
"{{{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}}, " \
|
||||
"{{13, 14, 15, 16}, {17, 18, 19, 20}, {21, 22, 23, 24}}}"
|
||||
|
||||
example = MutableSparseNDimArray(example)
|
||||
|
||||
assert mcode(example) == \
|
||||
"SparseArray[{" \
|
||||
"{1, 1, 1} -> 1, {1, 1, 2} -> 2, {1, 1, 3} -> 3, " \
|
||||
"{1, 1, 4} -> 4, {1, 2, 1} -> 5, {1, 2, 2} -> 6, " \
|
||||
"{1, 2, 3} -> 7, {1, 2, 4} -> 8, {1, 3, 1} -> 9, " \
|
||||
"{1, 3, 2} -> 10, {1, 3, 3} -> 11, {1, 3, 4} -> 12, " \
|
||||
"{2, 1, 1} -> 13, {2, 1, 2} -> 14, {2, 1, 3} -> 15, " \
|
||||
"{2, 1, 4} -> 16, {2, 2, 1} -> 17, {2, 2, 2} -> 18, " \
|
||||
"{2, 2, 3} -> 19, {2, 2, 4} -> 20, {2, 3, 1} -> 21, " \
|
||||
"{2, 3, 2} -> 22, {2, 3, 3} -> 23, {2, 3, 4} -> 24" \
|
||||
"}, {2, 3, 4}]"
|
||||
|
||||
example = ImmutableSparseNDimArray(example)
|
||||
|
||||
assert mcode(example) == \
|
||||
"SparseArray[{" \
|
||||
"{1, 1, 1} -> 1, {1, 1, 2} -> 2, {1, 1, 3} -> 3, " \
|
||||
"{1, 1, 4} -> 4, {1, 2, 1} -> 5, {1, 2, 2} -> 6, " \
|
||||
"{1, 2, 3} -> 7, {1, 2, 4} -> 8, {1, 3, 1} -> 9, " \
|
||||
"{1, 3, 2} -> 10, {1, 3, 3} -> 11, {1, 3, 4} -> 12, " \
|
||||
"{2, 1, 1} -> 13, {2, 1, 2} -> 14, {2, 1, 3} -> 15, " \
|
||||
"{2, 1, 4} -> 16, {2, 2, 1} -> 17, {2, 2, 2} -> 18, " \
|
||||
"{2, 2, 3} -> 19, {2, 2, 4} -> 20, {2, 3, 1} -> 21, " \
|
||||
"{2, 3, 2} -> 22, {2, 3, 3} -> 23, {2, 3, 4} -> 24" \
|
||||
"}, {2, 3, 4}]"
|
||||
|
||||
|
||||
def test_Integral():
|
||||
assert mcode(Integral(sin(sin(x)), x)) == "Hold[Integrate[Sin[Sin[x]], x]]"
|
||||
assert mcode(Integral(exp(-x**2 - y**2),
|
||||
(x, -oo, oo),
|
||||
(y, -oo, oo))) == \
|
||||
"Hold[Integrate[Exp[-x^2 - y^2], {x, -Infinity, Infinity}, " \
|
||||
"{y, -Infinity, Infinity}]]"
|
||||
|
||||
|
||||
def test_Derivative():
|
||||
assert mcode(Derivative(sin(x), x)) == "Hold[D[Sin[x], x]]"
|
||||
assert mcode(Derivative(x, x)) == "Hold[D[x, x]]"
|
||||
assert mcode(Derivative(sin(x)*y**4, x, 2)) == "Hold[D[y^4*Sin[x], {x, 2}]]"
|
||||
assert mcode(Derivative(sin(x)*y**4, x, y, x)) == "Hold[D[y^4*Sin[x], x, y, x]]"
|
||||
assert mcode(Derivative(sin(x)*y**4, x, y, 3, x)) == "Hold[D[y^4*Sin[x], x, {y, 3}, x]]"
|
||||
|
||||
|
||||
def test_Sum():
|
||||
assert mcode(Sum(sin(x), (x, 0, 10))) == "Hold[Sum[Sin[x], {x, 0, 10}]]"
|
||||
assert mcode(Sum(exp(-x**2 - y**2),
|
||||
(x, -oo, oo),
|
||||
(y, -oo, oo))) == \
|
||||
"Hold[Sum[Exp[-x^2 - y^2], {x, -Infinity, Infinity}, " \
|
||||
"{y, -Infinity, Infinity}]]"
|
||||
|
||||
|
||||
def test_comment():
|
||||
from sympy.printing.mathematica import MCodePrinter
|
||||
assert MCodePrinter()._get_comment("Hello World") == \
|
||||
"(* Hello World *)"
|
||||
|
||||
|
||||
def test_userfuncs():
|
||||
# Dictionary mutation test
|
||||
some_function = symbols("some_function", cls=Function)
|
||||
my_user_functions = {"some_function": "SomeFunction"}
|
||||
assert mcode(
|
||||
some_function(z),
|
||||
user_functions=my_user_functions) == \
|
||||
'SomeFunction[z]'
|
||||
assert mcode(
|
||||
some_function(z),
|
||||
user_functions=my_user_functions) == \
|
||||
'SomeFunction[z]'
|
||||
|
||||
# List argument test
|
||||
my_user_functions = \
|
||||
{"some_function": [(lambda x: True, "SomeOtherFunction")]}
|
||||
assert mcode(
|
||||
some_function(z),
|
||||
user_functions=my_user_functions) == \
|
||||
'SomeOtherFunction[z]'
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,381 @@
|
||||
from sympy.concrete.summations import Sum
|
||||
from sympy.core.mod import Mod
|
||||
from sympy.core.relational import (Equality, Unequality)
|
||||
from sympy.core.symbol import Symbol
|
||||
from sympy.functions.elementary.miscellaneous import sqrt
|
||||
from sympy.functions.elementary.piecewise import Piecewise
|
||||
from sympy.functions.special.gamma_functions import polygamma
|
||||
from sympy.functions.special.error_functions import (Si, Ci)
|
||||
from sympy.matrices import Matrix
|
||||
from sympy.matrices.expressions.blockmatrix import BlockMatrix
|
||||
from sympy.matrices.expressions.matexpr import MatrixSymbol
|
||||
from sympy.matrices.expressions.special import Identity
|
||||
from sympy.utilities.lambdify import lambdify
|
||||
from sympy import symbols, Min, Max
|
||||
|
||||
from sympy.abc import x, i, j, a, b, c, d
|
||||
from sympy.core import Pow
|
||||
from sympy.codegen.matrix_nodes import MatrixSolve
|
||||
from sympy.codegen.numpy_nodes import logaddexp, logaddexp2
|
||||
from sympy.codegen.cfunctions import log1p, expm1, hypot, log10, exp2, log2, Sqrt
|
||||
from sympy.tensor.array import Array
|
||||
from sympy.tensor.array.expressions.array_expressions import ArrayTensorProduct, ArrayAdd, \
|
||||
PermuteDims, ArrayDiagonal
|
||||
from sympy.printing.numpy import NumPyPrinter, SciPyPrinter, _numpy_known_constants, \
|
||||
_numpy_known_functions, _scipy_known_constants, _scipy_known_functions
|
||||
from sympy.tensor.array.expressions.from_matrix_to_array import convert_matrix_to_array
|
||||
|
||||
from sympy.testing.pytest import skip, raises
|
||||
from sympy.external import import_module
|
||||
|
||||
np = import_module('numpy')
|
||||
jax = import_module('jax')
|
||||
|
||||
if np:
|
||||
deafult_float_info = np.finfo(np.array([]).dtype)
|
||||
NUMPY_DEFAULT_EPSILON = deafult_float_info.eps
|
||||
|
||||
def test_numpy_piecewise_regression():
|
||||
"""
|
||||
NumPyPrinter needs to print Piecewise()'s choicelist as a list to avoid
|
||||
breaking compatibility with numpy 1.8. This is not necessary in numpy 1.9+.
|
||||
See gh-9747 and gh-9749 for details.
|
||||
"""
|
||||
printer = NumPyPrinter()
|
||||
p = Piecewise((1, x < 0), (0, True))
|
||||
assert printer.doprint(p) == \
|
||||
'numpy.select([numpy.less(x, 0),True], [1,0], default=numpy.nan)'
|
||||
assert printer.module_imports == {'numpy': {'select', 'less', 'nan'}}
|
||||
|
||||
def test_numpy_logaddexp():
|
||||
lae = logaddexp(a, b)
|
||||
assert NumPyPrinter().doprint(lae) == 'numpy.logaddexp(a, b)'
|
||||
lae2 = logaddexp2(a, b)
|
||||
assert NumPyPrinter().doprint(lae2) == 'numpy.logaddexp2(a, b)'
|
||||
|
||||
|
||||
def test_sum():
|
||||
if not np:
|
||||
skip("NumPy not installed")
|
||||
|
||||
s = Sum(x ** i, (i, a, b))
|
||||
f = lambdify((a, b, x), s, 'numpy')
|
||||
|
||||
a_, b_ = 0, 10
|
||||
x_ = np.linspace(-1, +1, 10)
|
||||
assert np.allclose(f(a_, b_, x_), sum(x_ ** i_ for i_ in range(a_, b_ + 1)))
|
||||
|
||||
s = Sum(i * x, (i, a, b))
|
||||
f = lambdify((a, b, x), s, 'numpy')
|
||||
|
||||
a_, b_ = 0, 10
|
||||
x_ = np.linspace(-1, +1, 10)
|
||||
assert np.allclose(f(a_, b_, x_), sum(i_ * x_ for i_ in range(a_, b_ + 1)))
|
||||
|
||||
|
||||
def test_multiple_sums():
|
||||
if not np:
|
||||
skip("NumPy not installed")
|
||||
|
||||
s = Sum((x + j) * i, (i, a, b), (j, c, d))
|
||||
f = lambdify((a, b, c, d, x), s, 'numpy')
|
||||
|
||||
a_, b_ = 0, 10
|
||||
c_, d_ = 11, 21
|
||||
x_ = np.linspace(-1, +1, 10)
|
||||
assert np.allclose(f(a_, b_, c_, d_, x_),
|
||||
sum((x_ + j_) * i_ for i_ in range(a_, b_ + 1) for j_ in range(c_, d_ + 1)))
|
||||
|
||||
|
||||
def test_codegen_einsum():
|
||||
if not np:
|
||||
skip("NumPy not installed")
|
||||
|
||||
M = MatrixSymbol("M", 2, 2)
|
||||
N = MatrixSymbol("N", 2, 2)
|
||||
|
||||
cg = convert_matrix_to_array(M * N)
|
||||
f = lambdify((M, N), cg, 'numpy')
|
||||
|
||||
ma = np.array([[1, 2], [3, 4]])
|
||||
mb = np.array([[1,-2], [-1, 3]])
|
||||
assert (f(ma, mb) == np.matmul(ma, mb)).all()
|
||||
|
||||
|
||||
def test_codegen_extra():
|
||||
if not np:
|
||||
skip("NumPy not installed")
|
||||
|
||||
M = MatrixSymbol("M", 2, 2)
|
||||
N = MatrixSymbol("N", 2, 2)
|
||||
P = MatrixSymbol("P", 2, 2)
|
||||
Q = MatrixSymbol("Q", 2, 2)
|
||||
ma = np.array([[1, 2], [3, 4]])
|
||||
mb = np.array([[1,-2], [-1, 3]])
|
||||
mc = np.array([[2, 0], [1, 2]])
|
||||
md = np.array([[1,-1], [4, 7]])
|
||||
|
||||
cg = ArrayTensorProduct(M, N)
|
||||
f = lambdify((M, N), cg, 'numpy')
|
||||
assert (f(ma, mb) == np.einsum(ma, [0, 1], mb, [2, 3])).all()
|
||||
|
||||
cg = ArrayAdd(M, N)
|
||||
f = lambdify((M, N), cg, 'numpy')
|
||||
assert (f(ma, mb) == ma+mb).all()
|
||||
|
||||
cg = ArrayAdd(M, N, P)
|
||||
f = lambdify((M, N, P), cg, 'numpy')
|
||||
assert (f(ma, mb, mc) == ma+mb+mc).all()
|
||||
|
||||
cg = ArrayAdd(M, N, P, Q)
|
||||
f = lambdify((M, N, P, Q), cg, 'numpy')
|
||||
assert (f(ma, mb, mc, md) == ma+mb+mc+md).all()
|
||||
|
||||
cg = PermuteDims(M, [1, 0])
|
||||
f = lambdify((M,), cg, 'numpy')
|
||||
assert (f(ma) == ma.T).all()
|
||||
|
||||
cg = PermuteDims(ArrayTensorProduct(M, N), [1, 2, 3, 0])
|
||||
f = lambdify((M, N), cg, 'numpy')
|
||||
assert (f(ma, mb) == np.transpose(np.einsum(ma, [0, 1], mb, [2, 3]), (1, 2, 3, 0))).all()
|
||||
|
||||
cg = ArrayDiagonal(ArrayTensorProduct(M, N), (1, 2))
|
||||
f = lambdify((M, N), cg, 'numpy')
|
||||
assert (f(ma, mb) == np.diagonal(np.einsum(ma, [0, 1], mb, [2, 3]), axis1=1, axis2=2)).all()
|
||||
|
||||
|
||||
def test_relational():
|
||||
if not np:
|
||||
skip("NumPy not installed")
|
||||
|
||||
e = Equality(x, 1)
|
||||
|
||||
f = lambdify((x,), e)
|
||||
x_ = np.array([0, 1, 2])
|
||||
assert np.array_equal(f(x_), [False, True, False])
|
||||
|
||||
e = Unequality(x, 1)
|
||||
|
||||
f = lambdify((x,), e)
|
||||
x_ = np.array([0, 1, 2])
|
||||
assert np.array_equal(f(x_), [True, False, True])
|
||||
|
||||
e = (x < 1)
|
||||
|
||||
f = lambdify((x,), e)
|
||||
x_ = np.array([0, 1, 2])
|
||||
assert np.array_equal(f(x_), [True, False, False])
|
||||
|
||||
e = (x <= 1)
|
||||
|
||||
f = lambdify((x,), e)
|
||||
x_ = np.array([0, 1, 2])
|
||||
assert np.array_equal(f(x_), [True, True, False])
|
||||
|
||||
e = (x > 1)
|
||||
|
||||
f = lambdify((x,), e)
|
||||
x_ = np.array([0, 1, 2])
|
||||
assert np.array_equal(f(x_), [False, False, True])
|
||||
|
||||
e = (x >= 1)
|
||||
|
||||
f = lambdify((x,), e)
|
||||
x_ = np.array([0, 1, 2])
|
||||
assert np.array_equal(f(x_), [False, True, True])
|
||||
|
||||
|
||||
def test_mod():
|
||||
if not np:
|
||||
skip("NumPy not installed")
|
||||
|
||||
e = Mod(a, b)
|
||||
f = lambdify((a, b), e)
|
||||
|
||||
a_ = np.array([0, 1, 2, 3])
|
||||
b_ = 2
|
||||
assert np.array_equal(f(a_, b_), [0, 1, 0, 1])
|
||||
|
||||
a_ = np.array([0, 1, 2, 3])
|
||||
b_ = np.array([2, 2, 2, 2])
|
||||
assert np.array_equal(f(a_, b_), [0, 1, 0, 1])
|
||||
|
||||
a_ = np.array([2, 3, 4, 5])
|
||||
b_ = np.array([2, 3, 4, 5])
|
||||
assert np.array_equal(f(a_, b_), [0, 0, 0, 0])
|
||||
|
||||
|
||||
def test_pow():
|
||||
if not np:
|
||||
skip('NumPy not installed')
|
||||
|
||||
expr = Pow(2, -1, evaluate=False)
|
||||
f = lambdify([], expr, 'numpy')
|
||||
assert f() == 0.5
|
||||
|
||||
|
||||
def test_expm1():
|
||||
if not np:
|
||||
skip("NumPy not installed")
|
||||
|
||||
f = lambdify((a,), expm1(a), 'numpy')
|
||||
assert abs(f(1e-10) - 1e-10 - 5e-21) <= 1e-10 * NUMPY_DEFAULT_EPSILON
|
||||
|
||||
|
||||
def test_log1p():
|
||||
if not np:
|
||||
skip("NumPy not installed")
|
||||
|
||||
f = lambdify((a,), log1p(a), 'numpy')
|
||||
assert abs(f(1e-99) - 1e-99) <= 1e-99 * NUMPY_DEFAULT_EPSILON
|
||||
|
||||
def test_hypot():
|
||||
if not np:
|
||||
skip("NumPy not installed")
|
||||
assert abs(lambdify((a, b), hypot(a, b), 'numpy')(3, 4) - 5) <= NUMPY_DEFAULT_EPSILON
|
||||
|
||||
def test_log10():
|
||||
if not np:
|
||||
skip("NumPy not installed")
|
||||
assert abs(lambdify((a,), log10(a), 'numpy')(100) - 2) <= NUMPY_DEFAULT_EPSILON
|
||||
|
||||
|
||||
def test_exp2():
|
||||
if not np:
|
||||
skip("NumPy not installed")
|
||||
assert abs(lambdify((a,), exp2(a), 'numpy')(5) - 32) <= NUMPY_DEFAULT_EPSILON
|
||||
|
||||
|
||||
def test_log2():
|
||||
if not np:
|
||||
skip("NumPy not installed")
|
||||
assert abs(lambdify((a,), log2(a), 'numpy')(256) - 8) <= NUMPY_DEFAULT_EPSILON
|
||||
|
||||
|
||||
def test_Sqrt():
|
||||
if not np:
|
||||
skip("NumPy not installed")
|
||||
assert abs(lambdify((a,), Sqrt(a), 'numpy')(4) - 2) <= NUMPY_DEFAULT_EPSILON
|
||||
|
||||
|
||||
def test_sqrt():
|
||||
if not np:
|
||||
skip("NumPy not installed")
|
||||
assert abs(lambdify((a,), sqrt(a), 'numpy')(4) - 2) <= NUMPY_DEFAULT_EPSILON
|
||||
|
||||
|
||||
def test_matsolve():
|
||||
if not np:
|
||||
skip("NumPy not installed")
|
||||
|
||||
M = MatrixSymbol("M", 3, 3)
|
||||
x = MatrixSymbol("x", 3, 1)
|
||||
|
||||
expr = M**(-1) * x + x
|
||||
matsolve_expr = MatrixSolve(M, x) + x
|
||||
|
||||
f = lambdify((M, x), expr)
|
||||
f_matsolve = lambdify((M, x), matsolve_expr)
|
||||
|
||||
m0 = np.array([[1, 2, 3], [3, 2, 5], [5, 6, 7]])
|
||||
assert np.linalg.matrix_rank(m0) == 3
|
||||
|
||||
x0 = np.array([3, 4, 5])
|
||||
|
||||
assert np.allclose(f_matsolve(m0, x0), f(m0, x0))
|
||||
|
||||
|
||||
def test_16857():
|
||||
if not np:
|
||||
skip("NumPy not installed")
|
||||
|
||||
a_1 = MatrixSymbol('a_1', 10, 3)
|
||||
a_2 = MatrixSymbol('a_2', 10, 3)
|
||||
a_3 = MatrixSymbol('a_3', 10, 3)
|
||||
a_4 = MatrixSymbol('a_4', 10, 3)
|
||||
A = BlockMatrix([[a_1, a_2], [a_3, a_4]])
|
||||
assert A.shape == (20, 6)
|
||||
|
||||
printer = NumPyPrinter()
|
||||
assert printer.doprint(A) == 'numpy.block([[a_1, a_2], [a_3, a_4]])'
|
||||
|
||||
|
||||
def test_issue_17006():
|
||||
if not np:
|
||||
skip("NumPy not installed")
|
||||
|
||||
M = MatrixSymbol("M", 2, 2)
|
||||
|
||||
f = lambdify(M, M + Identity(2))
|
||||
ma = np.array([[1, 2], [3, 4]])
|
||||
mr = np.array([[2, 2], [3, 5]])
|
||||
|
||||
assert (f(ma) == mr).all()
|
||||
|
||||
from sympy.core.symbol import symbols
|
||||
n = symbols('n', integer=True)
|
||||
N = MatrixSymbol("M", n, n)
|
||||
raises(NotImplementedError, lambda: lambdify(N, N + Identity(n)))
|
||||
|
||||
def test_jax_tuple_compatibility():
|
||||
if not jax:
|
||||
skip("Jax not installed")
|
||||
|
||||
x, y, z = symbols('x y z')
|
||||
expr = Max(x, y, z) + Min(x, y, z)
|
||||
func = lambdify((x, y, z), expr, 'jax')
|
||||
input_tuple1, input_tuple2 = (1, 2, 3), (4, 5, 6)
|
||||
input_array1, input_array2 = jax.numpy.asarray(input_tuple1), jax.numpy.asarray(input_tuple2)
|
||||
assert np.allclose(func(*input_tuple1), func(*input_array1))
|
||||
assert np.allclose(func(*input_tuple2), func(*input_array2))
|
||||
|
||||
def test_numpy_array():
|
||||
p = NumPyPrinter()
|
||||
assert p.doprint(Array([[1, 2], [3, 5]])) == 'numpy.array([[1, 2], [3, 5]])'
|
||||
assert p.doprint(Array([1, 2])) == 'numpy.array([1, 2])'
|
||||
assert p.doprint(Array([[[1, 2, 3]]])) == 'numpy.array([[[1, 2, 3]]])'
|
||||
assert p.doprint(Array([], (0,))) == 'numpy.zeros((0,))'
|
||||
assert p.doprint(Array([], (0, 0))) == 'numpy.zeros((0, 0))'
|
||||
assert p.doprint(Array([], (0, 1))) == 'numpy.zeros((0, 1))'
|
||||
assert p.doprint(Array([], (1, 0))) == 'numpy.zeros((1, 0))'
|
||||
assert p.doprint(Array([1], ())) == 'numpy.array(1)'
|
||||
|
||||
def test_numpy_matrix():
|
||||
p = NumPyPrinter()
|
||||
assert p.doprint(Matrix([[1, 2], [3, 5]])) == 'numpy.array([[1, 2], [3, 5]])'
|
||||
assert p.doprint(Matrix([1, 2])) == 'numpy.array([[1], [2]])'
|
||||
assert p.doprint(Matrix(0, 0, [])) == 'numpy.zeros((0, 0))'
|
||||
assert p.doprint(Matrix(0, 1, [])) == 'numpy.zeros((0, 1))'
|
||||
assert p.doprint(Matrix(1, 0, [])) == 'numpy.zeros((1, 0))'
|
||||
|
||||
def test_numpy_known_funcs_consts():
|
||||
assert _numpy_known_constants['NaN'] == 'numpy.nan'
|
||||
assert _numpy_known_constants['EulerGamma'] == 'numpy.euler_gamma'
|
||||
|
||||
assert _numpy_known_functions['acos'] == 'numpy.arccos'
|
||||
assert _numpy_known_functions['log'] == 'numpy.log'
|
||||
|
||||
def test_scipy_known_funcs_consts():
|
||||
assert _scipy_known_constants['GoldenRatio'] == 'scipy.constants.golden_ratio'
|
||||
assert _scipy_known_constants['Pi'] == 'scipy.constants.pi'
|
||||
|
||||
assert _scipy_known_functions['erf'] == 'scipy.special.erf'
|
||||
assert _scipy_known_functions['factorial'] == 'scipy.special.factorial'
|
||||
|
||||
def test_numpy_print_methods():
|
||||
prntr = NumPyPrinter()
|
||||
assert hasattr(prntr, '_print_acos')
|
||||
assert hasattr(prntr, '_print_log')
|
||||
|
||||
def test_scipy_print_methods():
|
||||
prntr = SciPyPrinter()
|
||||
assert hasattr(prntr, '_print_acos')
|
||||
assert hasattr(prntr, '_print_log')
|
||||
assert hasattr(prntr, '_print_erf')
|
||||
assert hasattr(prntr, '_print_factorial')
|
||||
assert hasattr(prntr, '_print_chebyshevt')
|
||||
k = Symbol('k', integer=True, nonnegative=True)
|
||||
x = Symbol('x', real=True)
|
||||
assert prntr.doprint(polygamma(k, x)) == "scipy.special.polygamma(k, x)"
|
||||
assert prntr.doprint(Si(x)) == "scipy.special.sici(x)[0]"
|
||||
assert prntr.doprint(Ci(x)) == "scipy.special.sici(x)[1]"
|
||||
@@ -0,0 +1,515 @@
|
||||
from sympy.core import (S, pi, oo, symbols, Function, Rational, Integer,
|
||||
Tuple, Symbol, EulerGamma, GoldenRatio, Catalan,
|
||||
Lambda, Mul, Pow, Mod, Eq, Ne, Le, Lt, Gt, Ge)
|
||||
from sympy.codegen.matrix_nodes import MatrixSolve
|
||||
from sympy.functions import (arg, atan2, bernoulli, beta, ceiling, chebyshevu,
|
||||
chebyshevt, conjugate, DiracDelta, exp, expint,
|
||||
factorial, floor, harmonic, Heaviside, im,
|
||||
laguerre, LambertW, log, Max, Min, Piecewise,
|
||||
polylog, re, RisingFactorial, sign, sinc, sqrt,
|
||||
zeta, binomial, legendre, dirichlet_eta,
|
||||
riemann_xi)
|
||||
from sympy.functions import (sin, cos, tan, cot, sec, csc, asin, acos, acot,
|
||||
atan, asec, acsc, sinh, cosh, tanh, coth, csch,
|
||||
sech, asinh, acosh, atanh, acoth, asech, acsch)
|
||||
from sympy.testing.pytest import raises, XFAIL
|
||||
from sympy.utilities.lambdify import implemented_function
|
||||
from sympy.matrices import (eye, Matrix, MatrixSymbol, Identity,
|
||||
HadamardProduct, SparseMatrix, HadamardPower)
|
||||
from sympy.functions.special.bessel import (jn, yn, besselj, bessely, besseli,
|
||||
besselk, hankel1, hankel2, airyai,
|
||||
airybi, airyaiprime, airybiprime)
|
||||
from sympy.functions.special.gamma_functions import (gamma, lowergamma,
|
||||
uppergamma, loggamma,
|
||||
polygamma)
|
||||
from sympy.functions.special.error_functions import (Chi, Ci, erf, erfc, erfi,
|
||||
erfcinv, erfinv, fresnelc,
|
||||
fresnels, li, Shi, Si, Li,
|
||||
erf2, Ei)
|
||||
from sympy.printing.octave import octave_code, octave_code as mcode
|
||||
|
||||
x, y, z = symbols('x,y,z')
|
||||
|
||||
|
||||
def test_Integer():
|
||||
assert mcode(Integer(67)) == "67"
|
||||
assert mcode(Integer(-1)) == "-1"
|
||||
|
||||
|
||||
def test_Rational():
|
||||
assert mcode(Rational(3, 7)) == "3/7"
|
||||
assert mcode(Rational(18, 9)) == "2"
|
||||
assert mcode(Rational(3, -7)) == "-3/7"
|
||||
assert mcode(Rational(-3, -7)) == "3/7"
|
||||
assert mcode(x + Rational(3, 7)) == "x + 3/7"
|
||||
assert mcode(Rational(3, 7)*x) == "3*x/7"
|
||||
|
||||
|
||||
def test_Relational():
|
||||
assert mcode(Eq(x, y)) == "x == y"
|
||||
assert mcode(Ne(x, y)) == "x != y"
|
||||
assert mcode(Le(x, y)) == "x <= y"
|
||||
assert mcode(Lt(x, y)) == "x < y"
|
||||
assert mcode(Gt(x, y)) == "x > y"
|
||||
assert mcode(Ge(x, y)) == "x >= y"
|
||||
|
||||
|
||||
def test_Function():
|
||||
assert mcode(sin(x) ** cos(x)) == "sin(x).^cos(x)"
|
||||
assert mcode(sign(x)) == "sign(x)"
|
||||
assert mcode(exp(x)) == "exp(x)"
|
||||
assert mcode(log(x)) == "log(x)"
|
||||
assert mcode(factorial(x)) == "factorial(x)"
|
||||
assert mcode(floor(x)) == "floor(x)"
|
||||
assert mcode(atan2(y, x)) == "atan2(y, x)"
|
||||
assert mcode(beta(x, y)) == 'beta(x, y)'
|
||||
assert mcode(polylog(x, y)) == 'polylog(x, y)'
|
||||
assert mcode(harmonic(x)) == 'harmonic(x)'
|
||||
assert mcode(bernoulli(x)) == "bernoulli(x)"
|
||||
assert mcode(bernoulli(x, y)) == "bernoulli(x, y)"
|
||||
assert mcode(legendre(x, y)) == "legendre(x, y)"
|
||||
|
||||
|
||||
def test_Function_change_name():
|
||||
assert mcode(abs(x)) == "abs(x)"
|
||||
assert mcode(ceiling(x)) == "ceil(x)"
|
||||
assert mcode(arg(x)) == "angle(x)"
|
||||
assert mcode(im(x)) == "imag(x)"
|
||||
assert mcode(re(x)) == "real(x)"
|
||||
assert mcode(conjugate(x)) == "conj(x)"
|
||||
assert mcode(chebyshevt(y, x)) == "chebyshevT(y, x)"
|
||||
assert mcode(chebyshevu(y, x)) == "chebyshevU(y, x)"
|
||||
assert mcode(laguerre(x, y)) == "laguerreL(x, y)"
|
||||
assert mcode(Chi(x)) == "coshint(x)"
|
||||
assert mcode(Shi(x)) == "sinhint(x)"
|
||||
assert mcode(Ci(x)) == "cosint(x)"
|
||||
assert mcode(Si(x)) == "sinint(x)"
|
||||
assert mcode(li(x)) == "logint(x)"
|
||||
assert mcode(loggamma(x)) == "gammaln(x)"
|
||||
assert mcode(polygamma(x, y)) == "psi(x, y)"
|
||||
assert mcode(RisingFactorial(x, y)) == "pochhammer(x, y)"
|
||||
assert mcode(DiracDelta(x)) == "dirac(x)"
|
||||
assert mcode(DiracDelta(x, 3)) == "dirac(3, x)"
|
||||
assert mcode(Heaviside(x)) == "heaviside(x, 1/2)"
|
||||
assert mcode(Heaviside(x, y)) == "heaviside(x, y)"
|
||||
assert mcode(binomial(x, y)) == "bincoeff(x, y)"
|
||||
assert mcode(Mod(x, y)) == "mod(x, y)"
|
||||
|
||||
|
||||
def test_minmax():
|
||||
assert mcode(Max(x, y) + Min(x, y)) == "max(x, y) + min(x, y)"
|
||||
assert mcode(Max(x, y, z)) == "max(x, max(y, z))"
|
||||
assert mcode(Min(x, y, z)) == "min(x, min(y, z))"
|
||||
|
||||
|
||||
def test_Pow():
|
||||
assert mcode(x**3) == "x.^3"
|
||||
assert mcode(x**(y**3)) == "x.^(y.^3)"
|
||||
assert mcode(x**Rational(2, 3)) == 'x.^(2/3)'
|
||||
g = implemented_function('g', Lambda(x, 2*x))
|
||||
assert mcode(1/(g(x)*3.5)**(x - y**x)/(x**2 + y)) == \
|
||||
"(3.5*2*x).^(-x + y.^x)./(x.^2 + y)"
|
||||
# For issue 14160
|
||||
assert mcode(Mul(-2, x, Pow(Mul(y,y,evaluate=False), -1, evaluate=False),
|
||||
evaluate=False)) == '-2*x./(y.*y)'
|
||||
|
||||
|
||||
def test_basic_ops():
|
||||
assert mcode(x*y) == "x.*y"
|
||||
assert mcode(x + y) == "x + y"
|
||||
assert mcode(x - y) == "x - y"
|
||||
assert mcode(-x) == "-x"
|
||||
|
||||
|
||||
def test_1_over_x_and_sqrt():
|
||||
# 1.0 and 0.5 would do something different in regular StrPrinter,
|
||||
# but these are exact in IEEE floating point so no different here.
|
||||
assert mcode(1/x) == '1./x'
|
||||
assert mcode(x**-1) == mcode(x**-1.0) == '1./x'
|
||||
assert mcode(1/sqrt(x)) == '1./sqrt(x)'
|
||||
assert mcode(x**-S.Half) == mcode(x**-0.5) == '1./sqrt(x)'
|
||||
assert mcode(sqrt(x)) == 'sqrt(x)'
|
||||
assert mcode(x**S.Half) == mcode(x**0.5) == 'sqrt(x)'
|
||||
assert mcode(1/pi) == '1/pi'
|
||||
assert mcode(pi**-1) == mcode(pi**-1.0) == '1/pi'
|
||||
assert mcode(pi**-0.5) == '1/sqrt(pi)'
|
||||
|
||||
|
||||
def test_mix_number_mult_symbols():
|
||||
assert mcode(3*x) == "3*x"
|
||||
assert mcode(pi*x) == "pi*x"
|
||||
assert mcode(3/x) == "3./x"
|
||||
assert mcode(pi/x) == "pi./x"
|
||||
assert mcode(x/3) == "x/3"
|
||||
assert mcode(x/pi) == "x/pi"
|
||||
assert mcode(x*y) == "x.*y"
|
||||
assert mcode(3*x*y) == "3*x.*y"
|
||||
assert mcode(3*pi*x*y) == "3*pi*x.*y"
|
||||
assert mcode(x/y) == "x./y"
|
||||
assert mcode(3*x/y) == "3*x./y"
|
||||
assert mcode(x*y/z) == "x.*y./z"
|
||||
assert mcode(x/y*z) == "x.*z./y"
|
||||
assert mcode(1/x/y) == "1./(x.*y)"
|
||||
assert mcode(2*pi*x/y/z) == "2*pi*x./(y.*z)"
|
||||
assert mcode(3*pi/x) == "3*pi./x"
|
||||
assert mcode(S(3)/5) == "3/5"
|
||||
assert mcode(S(3)/5*x) == "3*x/5"
|
||||
assert mcode(x/y/z) == "x./(y.*z)"
|
||||
assert mcode((x+y)/z) == "(x + y)./z"
|
||||
assert mcode((x+y)/(z+x)) == "(x + y)./(x + z)"
|
||||
assert mcode((x+y)/EulerGamma) == "(x + y)/%s" % EulerGamma.evalf(17)
|
||||
assert mcode(x/3/pi) == "x/(3*pi)"
|
||||
assert mcode(S(3)/5*x*y/pi) == "3*x.*y/(5*pi)"
|
||||
|
||||
|
||||
def test_mix_number_pow_symbols():
|
||||
assert mcode(pi**3) == 'pi^3'
|
||||
assert mcode(x**2) == 'x.^2'
|
||||
assert mcode(x**(pi**3)) == 'x.^(pi^3)'
|
||||
assert mcode(x**y) == 'x.^y'
|
||||
assert mcode(x**(y**z)) == 'x.^(y.^z)'
|
||||
assert mcode((x**y)**z) == '(x.^y).^z'
|
||||
|
||||
|
||||
def test_imag():
|
||||
I = S('I')
|
||||
assert mcode(I) == "1i"
|
||||
assert mcode(5*I) == "5i"
|
||||
assert mcode((S(3)/2)*I) == "3*1i/2"
|
||||
assert mcode(3+4*I) == "3 + 4i"
|
||||
assert mcode(sqrt(3)*I) == "sqrt(3)*1i"
|
||||
|
||||
|
||||
def test_constants():
|
||||
assert mcode(pi) == "pi"
|
||||
assert mcode(oo) == "inf"
|
||||
assert mcode(-oo) == "-inf"
|
||||
assert mcode(S.NegativeInfinity) == "-inf"
|
||||
assert mcode(S.NaN) == "NaN"
|
||||
assert mcode(S.Exp1) == "exp(1)"
|
||||
assert mcode(exp(1)) == "exp(1)"
|
||||
|
||||
|
||||
def test_constants_other():
|
||||
assert mcode(2*GoldenRatio) == "2*(1+sqrt(5))/2"
|
||||
assert mcode(2*Catalan) == "2*%s" % Catalan.evalf(17)
|
||||
assert mcode(2*EulerGamma) == "2*%s" % EulerGamma.evalf(17)
|
||||
|
||||
|
||||
def test_boolean():
|
||||
assert mcode(x & y) == "x & y"
|
||||
assert mcode(x | y) == "x | y"
|
||||
assert mcode(~x) == "~x"
|
||||
assert mcode(x & y & z) == "x & y & z"
|
||||
assert mcode(x | y | z) == "x | y | z"
|
||||
assert mcode((x & y) | z) == "z | x & y"
|
||||
assert mcode((x | y) & z) == "z & (x | y)"
|
||||
|
||||
|
||||
def test_KroneckerDelta():
|
||||
from sympy.functions import KroneckerDelta
|
||||
assert mcode(KroneckerDelta(x, y)) == "double(x == y)"
|
||||
assert mcode(KroneckerDelta(x, y + 1)) == "double(x == (y + 1))"
|
||||
assert mcode(KroneckerDelta(2**x, y)) == "double((2.^x) == y)"
|
||||
|
||||
|
||||
def test_Matrices():
|
||||
assert mcode(Matrix(1, 1, [10])) == "10"
|
||||
A = Matrix([[1, sin(x/2), abs(x)],
|
||||
[0, 1, pi],
|
||||
[0, exp(1), ceiling(x)]])
|
||||
expected = "[1 sin(x/2) abs(x); 0 1 pi; 0 exp(1) ceil(x)]"
|
||||
assert mcode(A) == expected
|
||||
# row and columns
|
||||
assert mcode(A[:,0]) == "[1; 0; 0]"
|
||||
assert mcode(A[0,:]) == "[1 sin(x/2) abs(x)]"
|
||||
# empty matrices
|
||||
assert mcode(Matrix(0, 0, [])) == '[]'
|
||||
assert mcode(Matrix(0, 3, [])) == 'zeros(0, 3)'
|
||||
# annoying to read but correct
|
||||
assert mcode(Matrix([[x, x - y, -y]])) == "[x x - y -y]"
|
||||
|
||||
|
||||
def test_vector_entries_hadamard():
|
||||
# For a row or column, user might to use the other dimension
|
||||
A = Matrix([[1, sin(2/x), 3*pi/x/5]])
|
||||
assert mcode(A) == "[1 sin(2./x) 3*pi./(5*x)]"
|
||||
assert mcode(A.T) == "[1; sin(2./x); 3*pi./(5*x)]"
|
||||
|
||||
|
||||
@XFAIL
|
||||
def test_Matrices_entries_not_hadamard():
|
||||
# For Matrix with col >= 2, row >= 2, they need to be scalars
|
||||
# FIXME: is it worth worrying about this? Its not wrong, just
|
||||
# leave it user's responsibility to put scalar data for x.
|
||||
A = Matrix([[1, sin(2/x), 3*pi/x/5], [1, 2, x*y]])
|
||||
expected = ("[1 sin(2/x) 3*pi/(5*x);\n"
|
||||
"1 2 x*y]") # <- we give x.*y
|
||||
assert mcode(A) == expected
|
||||
|
||||
|
||||
def test_MatrixSymbol():
|
||||
n = Symbol('n', integer=True)
|
||||
A = MatrixSymbol('A', n, n)
|
||||
B = MatrixSymbol('B', n, n)
|
||||
assert mcode(A*B) == "A*B"
|
||||
assert mcode(B*A) == "B*A"
|
||||
assert mcode(2*A*B) == "2*A*B"
|
||||
assert mcode(B*2*A) == "2*B*A"
|
||||
assert mcode(A*(B + 3*Identity(n))) == "A*(3*eye(n) + B)"
|
||||
assert mcode(A**(x**2)) == "A^(x.^2)"
|
||||
assert mcode(A**3) == "A^3"
|
||||
assert mcode(A**S.Half) == "A^(1/2)"
|
||||
|
||||
|
||||
def test_MatrixSolve():
|
||||
n = Symbol('n', integer=True)
|
||||
A = MatrixSymbol('A', n, n)
|
||||
x = MatrixSymbol('x', n, 1)
|
||||
assert mcode(MatrixSolve(A, x)) == "A \\ x"
|
||||
|
||||
def test_special_matrices():
|
||||
assert mcode(6*Identity(3)) == "6*eye(3)"
|
||||
|
||||
|
||||
def test_containers():
|
||||
assert mcode([1, 2, 3, [4, 5, [6, 7]], 8, [9, 10], 11]) == \
|
||||
"{1, 2, 3, {4, 5, {6, 7}}, 8, {9, 10}, 11}"
|
||||
assert mcode((1, 2, (3, 4))) == "{1, 2, {3, 4}}"
|
||||
assert mcode([1]) == "{1}"
|
||||
assert mcode((1,)) == "{1}"
|
||||
assert mcode(Tuple(*[1, 2, 3])) == "{1, 2, 3}"
|
||||
assert mcode((1, x*y, (3, x**2))) == "{1, x.*y, {3, x.^2}}"
|
||||
# scalar, matrix, empty matrix and empty list
|
||||
assert mcode((1, eye(3), Matrix(0, 0, []), [])) == "{1, [1 0 0; 0 1 0; 0 0 1], [], {}}"
|
||||
|
||||
|
||||
def test_octave_noninline():
|
||||
source = mcode((x+y)/Catalan, assign_to='me', inline=False)
|
||||
expected = (
|
||||
"Catalan = %s;\n"
|
||||
"me = (x + y)/Catalan;"
|
||||
) % Catalan.evalf(17)
|
||||
assert source == expected
|
||||
|
||||
|
||||
def test_octave_piecewise():
|
||||
expr = Piecewise((x, x < 1), (x**2, True))
|
||||
assert mcode(expr) == "((x < 1).*(x) + (~(x < 1)).*(x.^2))"
|
||||
assert mcode(expr, assign_to="r") == (
|
||||
"r = ((x < 1).*(x) + (~(x < 1)).*(x.^2));")
|
||||
assert mcode(expr, assign_to="r", inline=False) == (
|
||||
"if (x < 1)\n"
|
||||
" r = x;\n"
|
||||
"else\n"
|
||||
" r = x.^2;\n"
|
||||
"end")
|
||||
expr = Piecewise((x**2, x < 1), (x**3, x < 2), (x**4, x < 3), (x**5, True))
|
||||
expected = ("((x < 1).*(x.^2) + (~(x < 1)).*( ...\n"
|
||||
"(x < 2).*(x.^3) + (~(x < 2)).*( ...\n"
|
||||
"(x < 3).*(x.^4) + (~(x < 3)).*(x.^5))))")
|
||||
assert mcode(expr) == expected
|
||||
assert mcode(expr, assign_to="r") == "r = " + expected + ";"
|
||||
assert mcode(expr, assign_to="r", inline=False) == (
|
||||
"if (x < 1)\n"
|
||||
" r = x.^2;\n"
|
||||
"elseif (x < 2)\n"
|
||||
" r = x.^3;\n"
|
||||
"elseif (x < 3)\n"
|
||||
" r = x.^4;\n"
|
||||
"else\n"
|
||||
" r = x.^5;\n"
|
||||
"end")
|
||||
# Check that Piecewise without a True (default) condition error
|
||||
expr = Piecewise((x, x < 1), (x**2, x > 1), (sin(x), x > 0))
|
||||
raises(ValueError, lambda: mcode(expr))
|
||||
|
||||
|
||||
def test_octave_piecewise_times_const():
|
||||
pw = Piecewise((x, x < 1), (x**2, True))
|
||||
assert mcode(2*pw) == "2*((x < 1).*(x) + (~(x < 1)).*(x.^2))"
|
||||
assert mcode(pw/x) == "((x < 1).*(x) + (~(x < 1)).*(x.^2))./x"
|
||||
assert mcode(pw/(x*y)) == "((x < 1).*(x) + (~(x < 1)).*(x.^2))./(x.*y)"
|
||||
assert mcode(pw/3) == "((x < 1).*(x) + (~(x < 1)).*(x.^2))/3"
|
||||
|
||||
|
||||
def test_octave_matrix_assign_to():
|
||||
A = Matrix([[1, 2, 3]])
|
||||
assert mcode(A, assign_to='a') == "a = [1 2 3];"
|
||||
A = Matrix([[1, 2], [3, 4]])
|
||||
assert mcode(A, assign_to='A') == "A = [1 2; 3 4];"
|
||||
|
||||
|
||||
def test_octave_matrix_assign_to_more():
|
||||
# assigning to Symbol or MatrixSymbol requires lhs/rhs match
|
||||
A = Matrix([[1, 2, 3]])
|
||||
B = MatrixSymbol('B', 1, 3)
|
||||
C = MatrixSymbol('C', 2, 3)
|
||||
assert mcode(A, assign_to=B) == "B = [1 2 3];"
|
||||
raises(ValueError, lambda: mcode(A, assign_to=x))
|
||||
raises(ValueError, lambda: mcode(A, assign_to=C))
|
||||
|
||||
|
||||
def test_octave_matrix_1x1():
|
||||
A = Matrix([[3]])
|
||||
B = MatrixSymbol('B', 1, 1)
|
||||
C = MatrixSymbol('C', 1, 2)
|
||||
assert mcode(A, assign_to=B) == "B = 3;"
|
||||
# FIXME?
|
||||
#assert mcode(A, assign_to=x) == "x = 3;"
|
||||
raises(ValueError, lambda: mcode(A, assign_to=C))
|
||||
|
||||
|
||||
def test_octave_matrix_elements():
|
||||
A = Matrix([[x, 2, x*y]])
|
||||
assert mcode(A[0, 0]**2 + A[0, 1] + A[0, 2]) == "x.^2 + x.*y + 2"
|
||||
A = MatrixSymbol('AA', 1, 3)
|
||||
assert mcode(A) == "AA"
|
||||
assert mcode(A[0, 0]**2 + sin(A[0,1]) + A[0,2]) == \
|
||||
"sin(AA(1, 2)) + AA(1, 1).^2 + AA(1, 3)"
|
||||
assert mcode(sum(A)) == "AA(1, 1) + AA(1, 2) + AA(1, 3)"
|
||||
|
||||
|
||||
def test_octave_boolean():
|
||||
assert mcode(True) == "true"
|
||||
assert mcode(S.true) == "true"
|
||||
assert mcode(False) == "false"
|
||||
assert mcode(S.false) == "false"
|
||||
|
||||
|
||||
def test_octave_not_supported():
|
||||
with raises(NotImplementedError):
|
||||
mcode(S.ComplexInfinity)
|
||||
f = Function('f')
|
||||
assert mcode(f(x).diff(x), strict=False) == (
|
||||
"% Not supported in Octave:\n"
|
||||
"% Derivative\n"
|
||||
"Derivative(f(x), x)"
|
||||
)
|
||||
|
||||
|
||||
def test_octave_not_supported_not_on_whitelist():
|
||||
from sympy.functions.special.polynomials import assoc_laguerre
|
||||
with raises(NotImplementedError):
|
||||
mcode(assoc_laguerre(x, y, z))
|
||||
|
||||
|
||||
def test_octave_expint():
|
||||
assert mcode(expint(1, x)) == "expint(x)"
|
||||
with raises(NotImplementedError):
|
||||
mcode(expint(2, x))
|
||||
assert mcode(expint(y, x), strict=False) == (
|
||||
"% Not supported in Octave:\n"
|
||||
"% expint\n"
|
||||
"expint(y, x)"
|
||||
)
|
||||
|
||||
|
||||
def test_trick_indent_with_end_else_words():
|
||||
# words starting with "end" or "else" do not confuse the indenter
|
||||
t1 = S('endless')
|
||||
t2 = S('elsewhere')
|
||||
pw = Piecewise((t1, x < 0), (t2, x <= 1), (1, True))
|
||||
assert mcode(pw, inline=False) == (
|
||||
"if (x < 0)\n"
|
||||
" endless\n"
|
||||
"elseif (x <= 1)\n"
|
||||
" elsewhere\n"
|
||||
"else\n"
|
||||
" 1\n"
|
||||
"end")
|
||||
|
||||
|
||||
def test_hadamard():
|
||||
A = MatrixSymbol('A', 3, 3)
|
||||
B = MatrixSymbol('B', 3, 3)
|
||||
v = MatrixSymbol('v', 3, 1)
|
||||
h = MatrixSymbol('h', 1, 3)
|
||||
C = HadamardProduct(A, B)
|
||||
n = Symbol('n')
|
||||
assert mcode(C) == "A.*B"
|
||||
assert mcode(C*v) == "(A.*B)*v"
|
||||
assert mcode(h*C*v) == "h*(A.*B)*v"
|
||||
assert mcode(C*A) == "(A.*B)*A"
|
||||
# mixing Hadamard and scalar strange b/c we vectorize scalars
|
||||
assert mcode(C*x*y) == "(x.*y)*(A.*B)"
|
||||
|
||||
# Testing HadamardPower:
|
||||
assert mcode(HadamardPower(A, n)) == "A.**n"
|
||||
assert mcode(HadamardPower(A, 1+n)) == "A.**(n + 1)"
|
||||
assert mcode(HadamardPower(A*B.T, 1+n)) == "(A*B.T).**(n + 1)"
|
||||
|
||||
|
||||
def test_sparse():
|
||||
M = SparseMatrix(5, 6, {})
|
||||
M[2, 2] = 10
|
||||
M[1, 2] = 20
|
||||
M[1, 3] = 22
|
||||
M[0, 3] = 30
|
||||
M[3, 0] = x*y
|
||||
assert mcode(M) == (
|
||||
"sparse([4 2 3 1 2], [1 3 3 4 4], [x.*y 20 10 30 22], 5, 6)"
|
||||
)
|
||||
|
||||
|
||||
def test_sinc():
|
||||
assert mcode(sinc(x)) == 'sinc(x/pi)'
|
||||
assert mcode(sinc(x + 3)) == 'sinc((x + 3)/pi)'
|
||||
assert mcode(sinc(pi*(x + 3))) == 'sinc(x + 3)'
|
||||
|
||||
|
||||
def test_trigfun():
|
||||
for f in (sin, cos, tan, cot, sec, csc, asin, acos, acot, atan, asec, acsc,
|
||||
sinh, cosh, tanh, coth, csch, sech, asinh, acosh, atanh, acoth,
|
||||
asech, acsch):
|
||||
assert octave_code(f(x) == f.__name__ + '(x)')
|
||||
|
||||
|
||||
def test_specfun():
|
||||
n = Symbol('n')
|
||||
for f in [besselj, bessely, besseli, besselk]:
|
||||
assert octave_code(f(n, x)) == f.__name__ + '(n, x)'
|
||||
for f in (erfc, erfi, erf, erfinv, erfcinv, fresnelc, fresnels, gamma):
|
||||
assert octave_code(f(x)) == f.__name__ + '(x)'
|
||||
assert octave_code(hankel1(n, x)) == 'besselh(n, 1, x)'
|
||||
assert octave_code(hankel2(n, x)) == 'besselh(n, 2, x)'
|
||||
assert octave_code(airyai(x)) == 'airy(0, x)'
|
||||
assert octave_code(airyaiprime(x)) == 'airy(1, x)'
|
||||
assert octave_code(airybi(x)) == 'airy(2, x)'
|
||||
assert octave_code(airybiprime(x)) == 'airy(3, x)'
|
||||
assert octave_code(uppergamma(n, x)) == '(gammainc(x, n, \'upper\').*gamma(n))'
|
||||
assert octave_code(lowergamma(n, x)) == '(gammainc(x, n).*gamma(n))'
|
||||
assert octave_code(z**lowergamma(n, x)) == 'z.^(gammainc(x, n).*gamma(n))'
|
||||
assert octave_code(jn(n, x)) == 'sqrt(2)*sqrt(pi)*sqrt(1./x).*besselj(n + 1/2, x)/2'
|
||||
assert octave_code(yn(n, x)) == 'sqrt(2)*sqrt(pi)*sqrt(1./x).*bessely(n + 1/2, x)/2'
|
||||
assert octave_code(LambertW(x)) == 'lambertw(x)'
|
||||
assert octave_code(LambertW(x, n)) == 'lambertw(n, x)'
|
||||
|
||||
# Automatic rewrite
|
||||
assert octave_code(Ei(x)) == '(logint(exp(x)))'
|
||||
assert octave_code(dirichlet_eta(x)) == '(((x == 1).*(log(2)) + (~(x == 1)).*((1 - 2.^(1 - x)).*zeta(x))))'
|
||||
assert octave_code(riemann_xi(x)) == '(pi.^(-x/2).*x.*(x - 1).*gamma(x/2).*zeta(x)/2)'
|
||||
|
||||
|
||||
def test_MatrixElement_printing():
|
||||
# test cases for issue #11821
|
||||
A = MatrixSymbol("A", 1, 3)
|
||||
B = MatrixSymbol("B", 1, 3)
|
||||
C = MatrixSymbol("C", 1, 3)
|
||||
|
||||
assert mcode(A[0, 0]) == "A(1, 1)"
|
||||
assert mcode(3 * A[0, 0]) == "3*A(1, 1)"
|
||||
|
||||
F = C[0, 0].subs(C, A - B)
|
||||
assert mcode(F) == "(A - B)(1, 1)"
|
||||
|
||||
|
||||
def test_zeta_printing_issue_14820():
|
||||
assert octave_code(zeta(x)) == 'zeta(x)'
|
||||
with raises(NotImplementedError):
|
||||
octave_code(zeta(x, y))
|
||||
|
||||
|
||||
def test_automatic_rewrite():
|
||||
assert octave_code(Li(x)) == '(logint(x) - logint(2))'
|
||||
assert octave_code(erf2(x, y)) == '(-erf(x) + erf(y))'
|
||||
@@ -0,0 +1,128 @@
|
||||
from sympy.concrete.products import Product
|
||||
from sympy.concrete.summations import Sum
|
||||
from sympy.core.function import Derivative, Function
|
||||
from sympy.core.numbers import Integer, Rational, Float, oo
|
||||
from sympy.core.relational import Rel
|
||||
from sympy.core.symbol import symbols
|
||||
from sympy.functions import sin
|
||||
from sympy.integrals.integrals import Integral
|
||||
from sympy.series.order import Order
|
||||
|
||||
from sympy.printing.precedence import precedence, PRECEDENCE
|
||||
|
||||
x, y = symbols("x,y")
|
||||
|
||||
|
||||
def test_Add():
|
||||
assert precedence(x + y) == PRECEDENCE["Add"]
|
||||
assert precedence(x*y + 1) == PRECEDENCE["Add"]
|
||||
|
||||
|
||||
def test_Function():
|
||||
assert precedence(sin(x)) == PRECEDENCE["Func"]
|
||||
|
||||
def test_Derivative():
|
||||
assert precedence(Derivative(x, y)) == PRECEDENCE["Atom"]
|
||||
|
||||
def test_Integral():
|
||||
assert precedence(Integral(x, y)) == PRECEDENCE["Atom"]
|
||||
|
||||
|
||||
def test_Mul():
|
||||
assert precedence(x*y) == PRECEDENCE["Mul"]
|
||||
assert precedence(-x*y) == PRECEDENCE["Add"]
|
||||
|
||||
|
||||
def test_Number():
|
||||
assert precedence(Integer(0)) == PRECEDENCE["Atom"]
|
||||
assert precedence(Integer(1)) == PRECEDENCE["Atom"]
|
||||
assert precedence(Integer(-1)) == PRECEDENCE["Add"]
|
||||
assert precedence(Integer(10)) == PRECEDENCE["Atom"]
|
||||
assert precedence(Rational(5, 2)) == PRECEDENCE["Mul"]
|
||||
assert precedence(Rational(-5, 2)) == PRECEDENCE["Add"]
|
||||
assert precedence(Float(5)) == PRECEDENCE["Atom"]
|
||||
assert precedence(Float(-5)) == PRECEDENCE["Add"]
|
||||
assert precedence(oo) == PRECEDENCE["Atom"]
|
||||
assert precedence(-oo) == PRECEDENCE["Add"]
|
||||
|
||||
|
||||
def test_Order():
|
||||
assert precedence(Order(x)) == PRECEDENCE["Atom"]
|
||||
|
||||
|
||||
def test_Pow():
|
||||
assert precedence(x**y) == PRECEDENCE["Pow"]
|
||||
assert precedence(-x**y) == PRECEDENCE["Add"]
|
||||
assert precedence(x**-y) == PRECEDENCE["Pow"]
|
||||
|
||||
|
||||
def test_Product():
|
||||
assert precedence(Product(x, (x, y, y + 1))) == PRECEDENCE["Atom"]
|
||||
|
||||
|
||||
def test_Relational():
|
||||
assert precedence(Rel(x + y, y, "<")) == PRECEDENCE["Relational"]
|
||||
|
||||
|
||||
def test_Sum():
|
||||
assert precedence(Sum(x, (x, y, y + 1))) == PRECEDENCE["Atom"]
|
||||
|
||||
|
||||
def test_Symbol():
|
||||
assert precedence(x) == PRECEDENCE["Atom"]
|
||||
|
||||
|
||||
def test_And_Or():
|
||||
# precedence relations between logical operators, ...
|
||||
assert precedence(x & y) > precedence(x | y)
|
||||
assert precedence(~y) > precedence(x & y)
|
||||
# ... and with other operators (cfr. other programming languages)
|
||||
assert precedence(x + y) > precedence(x | y)
|
||||
assert precedence(x + y) > precedence(x & y)
|
||||
assert precedence(x*y) > precedence(x | y)
|
||||
assert precedence(x*y) > precedence(x & y)
|
||||
assert precedence(~y) > precedence(x*y)
|
||||
assert precedence(~y) > precedence(x - y)
|
||||
# double checks
|
||||
assert precedence(x & y) == PRECEDENCE["And"]
|
||||
assert precedence(x | y) == PRECEDENCE["Or"]
|
||||
assert precedence(~y) == PRECEDENCE["Not"]
|
||||
|
||||
|
||||
def test_custom_function_precedence_comparison():
|
||||
"""
|
||||
Test cases for custom functions with different precedence values,
|
||||
specifically handling:
|
||||
1. Functions with precedence < PRECEDENCE["Mul"] (50)
|
||||
2. Functions with precedence = Func (70)
|
||||
|
||||
Key distinction:
|
||||
1. Lower precedence functions (45) need parentheses: -2*(x F y)
|
||||
2. Higher precedence functions (70) don't: -2*x F y
|
||||
"""
|
||||
class LowPrecedenceF(Function):
|
||||
precedence = PRECEDENCE["Mul"] - 5
|
||||
def _sympystr(self, printer):
|
||||
return f"{printer._print(self.args[0])} F {printer._print(self.args[1])}"
|
||||
|
||||
class HighPrecedenceF(Function):
|
||||
precedence = PRECEDENCE["Func"]
|
||||
def _sympystr(self, printer):
|
||||
return f"{printer._print(self.args[0])} F {printer._print(self.args[1])}"
|
||||
|
||||
def test_low_precedence():
|
||||
expr1 = 2 * LowPrecedenceF(x, y)
|
||||
assert str(expr1) == "2*(x F y)"
|
||||
|
||||
expr2 = -2 * LowPrecedenceF(x, y)
|
||||
assert str(expr2) == "-2*(x F y)"
|
||||
|
||||
def test_high_precedence():
|
||||
expr1 = 2 * HighPrecedenceF(x, y)
|
||||
assert str(expr1) == "2*x F y"
|
||||
|
||||
expr2 = -2 * HighPrecedenceF(x, y)
|
||||
assert str(expr2) == "-2*x F y"
|
||||
|
||||
test_low_precedence()
|
||||
test_high_precedence()
|
||||
@@ -0,0 +1,38 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
from sympy.core.relational import Eq
|
||||
from sympy.core.symbol import Symbol
|
||||
from sympy.functions.elementary.piecewise import Piecewise
|
||||
from sympy.printing.preview import preview
|
||||
|
||||
from io import BytesIO
|
||||
|
||||
|
||||
def test_preview():
|
||||
x = Symbol('x')
|
||||
obj = BytesIO()
|
||||
try:
|
||||
preview(x, output='png', viewer='BytesIO', outputbuffer=obj)
|
||||
except RuntimeError:
|
||||
pass # latex not installed on CI server
|
||||
|
||||
|
||||
def test_preview_unicode_symbol():
|
||||
# issue 9107
|
||||
a = Symbol('α')
|
||||
obj = BytesIO()
|
||||
try:
|
||||
preview(a, output='png', viewer='BytesIO', outputbuffer=obj)
|
||||
except RuntimeError:
|
||||
pass # latex not installed on CI server
|
||||
|
||||
|
||||
def test_preview_latex_construct_in_expr():
|
||||
# see PR 9801
|
||||
x = Symbol('x')
|
||||
pw = Piecewise((1, Eq(x, 0)), (0, True))
|
||||
obj = BytesIO()
|
||||
try:
|
||||
preview(pw, output='png', viewer='BytesIO', outputbuffer=obj)
|
||||
except RuntimeError:
|
||||
pass # latex not installed on CI server
|
||||
@@ -0,0 +1,493 @@
|
||||
from sympy import Not
|
||||
from sympy.codegen import Assignment
|
||||
from sympy.codegen.ast import none
|
||||
from sympy.codegen.cfunctions import expm1, log1p
|
||||
from sympy.codegen.scipy_nodes import cosm1
|
||||
from sympy.codegen.matrix_nodes import MatrixSolve
|
||||
from sympy.core import Expr, Mod, symbols, Eq, Le, Gt, zoo, oo, Rational, Pow
|
||||
from sympy.core.function import Derivative
|
||||
from sympy.core.numbers import pi
|
||||
from sympy.core.singleton import S
|
||||
from sympy.functions import acos, KroneckerDelta, Piecewise, sign, sqrt, Min, Max, cot, acsch, asec, coth, sec, log, sin, cos, tan, asin, atan, sinh, cosh, tanh, asinh, acosh, atanh
|
||||
from sympy.functions.elementary.trigonometric import atan2
|
||||
from sympy.logic import And, Or
|
||||
from sympy.matrices import SparseMatrix, MatrixSymbol, Identity
|
||||
from sympy.printing.codeprinter import PrintMethodNotImplementedError
|
||||
from sympy.printing.pycode import (
|
||||
MpmathPrinter, CmathPrinter, PythonCodePrinter, pycode, SymPyPrinter
|
||||
)
|
||||
from sympy.printing.tensorflow import TensorflowPrinter
|
||||
from sympy.printing.numpy import NumPyPrinter, SciPyPrinter
|
||||
from sympy.testing.pytest import raises, skip
|
||||
from sympy.tensor import IndexedBase, Idx
|
||||
from sympy.tensor.array.expressions.array_expressions import ArraySymbol, ArrayDiagonal, ArrayContraction, ZeroArray, OneArray
|
||||
from sympy.external import import_module
|
||||
from sympy.functions.special.gamma_functions import loggamma
|
||||
|
||||
|
||||
|
||||
x, y, z = symbols('x y z')
|
||||
p = IndexedBase("p")
|
||||
|
||||
|
||||
def test_PythonCodePrinter():
|
||||
prntr = PythonCodePrinter()
|
||||
|
||||
assert not prntr.module_imports
|
||||
|
||||
assert prntr.doprint(x**y) == 'x**y'
|
||||
assert prntr.doprint(Mod(x, 2)) == 'x % 2'
|
||||
assert prntr.doprint(-Mod(x, y)) == '-(x % y)'
|
||||
assert prntr.doprint(Mod(-x, y)) == '(-x) % y'
|
||||
assert prntr.doprint(And(x, y)) == 'x and y'
|
||||
assert prntr.doprint(Or(x, y)) == 'x or y'
|
||||
assert prntr.doprint(1/(x+y)) == '1/(x + y)'
|
||||
assert prntr.doprint(Not(x)) == 'not x'
|
||||
assert not prntr.module_imports
|
||||
|
||||
assert prntr.doprint(pi) == 'math.pi'
|
||||
assert prntr.module_imports == {'math': {'pi'}}
|
||||
|
||||
assert prntr.doprint(x**Rational(1, 2)) == 'math.sqrt(x)'
|
||||
assert prntr.doprint(sqrt(x)) == 'math.sqrt(x)'
|
||||
assert prntr.module_imports == {'math': {'pi', 'sqrt'}}
|
||||
|
||||
assert prntr.doprint(acos(x)) == 'math.acos(x)'
|
||||
assert prntr.doprint(cot(x)) == '(1/math.tan(x))'
|
||||
assert prntr.doprint(coth(x)) == '((math.exp(x) + math.exp(-x))/(math.exp(x) - math.exp(-x)))'
|
||||
assert prntr.doprint(asec(x)) == '(math.acos(1/x))'
|
||||
assert prntr.doprint(acsch(x)) == '(math.log(math.sqrt(1 + x**(-2)) + 1/x))'
|
||||
|
||||
assert prntr.doprint(Assignment(x, 2)) == 'x = 2'
|
||||
assert prntr.doprint(Piecewise((1, Eq(x, 0)),
|
||||
(2, x>6))) == '((1) if (x == 0) else (2) if (x > 6) else None)'
|
||||
assert prntr.doprint(Piecewise((2, Le(x, 0)),
|
||||
(3, Gt(x, 0)), evaluate=False)) == '((2) if (x <= 0) else'\
|
||||
' (3) if (x > 0) else None)'
|
||||
assert prntr.doprint(sign(x)) == '(0.0 if x == 0 else math.copysign(1, x))'
|
||||
assert prntr.doprint(p[0, 1]) == 'p[0, 1]'
|
||||
assert prntr.doprint(KroneckerDelta(x,y)) == '(1 if x == y else 0)'
|
||||
|
||||
assert prntr.doprint((2,3)) == "(2, 3)"
|
||||
assert prntr.doprint([2,3]) == "[2, 3]"
|
||||
|
||||
assert prntr.doprint(Min(x, y)) == "min(x, y)"
|
||||
assert prntr.doprint(Max(x, y)) == "max(x, y)"
|
||||
|
||||
|
||||
def test_PythonCodePrinter_standard():
|
||||
prntr = PythonCodePrinter()
|
||||
|
||||
assert prntr.standard == 'python3'
|
||||
|
||||
raises(ValueError, lambda: PythonCodePrinter({'standard':'python4'}))
|
||||
|
||||
|
||||
def test_CmathPrinter():
|
||||
p = CmathPrinter()
|
||||
|
||||
assert p.doprint(sqrt(x)) == 'cmath.sqrt(x)'
|
||||
assert p.doprint(log(x)) == 'cmath.log(x)'
|
||||
|
||||
assert p.doprint(sin(x)) == 'cmath.sin(x)'
|
||||
assert p.doprint(cos(x)) == 'cmath.cos(x)'
|
||||
assert p.doprint(tan(x)) == 'cmath.tan(x)'
|
||||
|
||||
assert p.doprint(asin(x)) == 'cmath.asin(x)'
|
||||
assert p.doprint(acos(x)) == 'cmath.acos(x)'
|
||||
assert p.doprint(atan(x)) == 'cmath.atan(x)'
|
||||
|
||||
assert p.doprint(sinh(x)) == 'cmath.sinh(x)'
|
||||
assert p.doprint(cosh(x)) == 'cmath.cosh(x)'
|
||||
assert p.doprint(tanh(x)) == 'cmath.tanh(x)'
|
||||
|
||||
assert p.doprint(asinh(x)) == 'cmath.asinh(x)'
|
||||
assert p.doprint(acosh(x)) == 'cmath.acosh(x)'
|
||||
assert p.doprint(atanh(x)) == 'cmath.atanh(x)'
|
||||
|
||||
|
||||
def test_MpmathPrinter():
|
||||
p = MpmathPrinter()
|
||||
assert p.doprint(sign(x)) == 'mpmath.sign(x)'
|
||||
assert p.doprint(Rational(1, 2)) == 'mpmath.mpf(1)/mpmath.mpf(2)'
|
||||
|
||||
assert p.doprint(S.Exp1) == 'mpmath.e'
|
||||
assert p.doprint(S.Pi) == 'mpmath.pi'
|
||||
assert p.doprint(S.GoldenRatio) == 'mpmath.phi'
|
||||
assert p.doprint(S.EulerGamma) == 'mpmath.euler'
|
||||
assert p.doprint(S.NaN) == 'mpmath.nan'
|
||||
assert p.doprint(S.Infinity) == 'mpmath.inf'
|
||||
assert p.doprint(S.NegativeInfinity) == 'mpmath.ninf'
|
||||
assert p.doprint(loggamma(x)) == 'mpmath.loggamma(x)'
|
||||
|
||||
|
||||
def test_NumPyPrinter():
|
||||
from sympy.core.function import Lambda
|
||||
from sympy.matrices.expressions.adjoint import Adjoint
|
||||
from sympy.matrices.expressions.diagonal import (DiagMatrix, DiagonalMatrix, DiagonalOf)
|
||||
from sympy.matrices.expressions.funcmatrix import FunctionMatrix
|
||||
from sympy.matrices.expressions.hadamard import HadamardProduct
|
||||
from sympy.matrices.expressions.kronecker import KroneckerProduct
|
||||
from sympy.matrices.expressions.special import (OneMatrix, ZeroMatrix)
|
||||
from sympy.abc import a, b
|
||||
p = NumPyPrinter()
|
||||
assert p.doprint(sign(x)) == 'numpy.sign(x)'
|
||||
A = MatrixSymbol("A", 2, 2)
|
||||
B = MatrixSymbol("B", 2, 2)
|
||||
C = MatrixSymbol("C", 1, 5)
|
||||
D = MatrixSymbol("D", 3, 4)
|
||||
assert p.doprint(A**(-1)) == "numpy.linalg.inv(A)"
|
||||
assert p.doprint(A**5) == "numpy.linalg.matrix_power(A, 5)"
|
||||
assert p.doprint(Identity(3)) == "numpy.eye(3)"
|
||||
|
||||
u = MatrixSymbol('x', 2, 1)
|
||||
v = MatrixSymbol('y', 2, 1)
|
||||
assert p.doprint(MatrixSolve(A, u)) == 'numpy.linalg.solve(A, x)'
|
||||
assert p.doprint(MatrixSolve(A, u) + v) == 'numpy.linalg.solve(A, x) + y'
|
||||
|
||||
assert p.doprint(ZeroMatrix(2, 3)) == "numpy.zeros((2, 3))"
|
||||
assert p.doprint(OneMatrix(2, 3)) == "numpy.ones((2, 3))"
|
||||
assert p.doprint(FunctionMatrix(4, 5, Lambda((a, b), a + b))) == \
|
||||
"numpy.fromfunction(lambda a, b: a + b, (4, 5))"
|
||||
assert p.doprint(HadamardProduct(A, B)) == "numpy.multiply(A, B)"
|
||||
assert p.doprint(KroneckerProduct(A, B)) == "numpy.kron(A, B)"
|
||||
assert p.doprint(Adjoint(A)) == "numpy.conjugate(numpy.transpose(A))"
|
||||
assert p.doprint(DiagonalOf(A)) == "numpy.reshape(numpy.diag(A), (-1, 1))"
|
||||
assert p.doprint(DiagMatrix(C)) == "numpy.diagflat(C)"
|
||||
assert p.doprint(DiagonalMatrix(D)) == "numpy.multiply(D, numpy.eye(3, 4))"
|
||||
|
||||
# Workaround for numpy negative integer power errors
|
||||
assert p.doprint(x**-1) == 'x**(-1.0)'
|
||||
assert p.doprint(x**-2) == 'x**(-2.0)'
|
||||
|
||||
expr = Pow(2, -1, evaluate=False)
|
||||
assert p.doprint(expr) == "2**(-1.0)"
|
||||
|
||||
assert p.doprint(S.Exp1) == 'numpy.e'
|
||||
assert p.doprint(S.Pi) == 'numpy.pi'
|
||||
assert p.doprint(S.EulerGamma) == 'numpy.euler_gamma'
|
||||
assert p.doprint(S.NaN) == 'numpy.nan'
|
||||
assert p.doprint(S.Infinity) == 'numpy.inf'
|
||||
assert p.doprint(S.NegativeInfinity) == '-numpy.inf'
|
||||
|
||||
# Function rewriting operator precedence fix
|
||||
assert p.doprint(sec(x)**2) == '(numpy.cos(x)**(-1.0))**2'
|
||||
|
||||
|
||||
def test_issue_18770():
|
||||
numpy = import_module('numpy')
|
||||
if not numpy:
|
||||
skip("numpy not installed.")
|
||||
|
||||
from sympy.functions.elementary.miscellaneous import (Max, Min)
|
||||
from sympy.utilities.lambdify import lambdify
|
||||
|
||||
expr1 = Min(0.1*x + 3, x + 1, 0.5*x + 1)
|
||||
func = lambdify(x, expr1, "numpy")
|
||||
assert (func(numpy.linspace(0, 3, 3)) == [1.0, 1.75, 2.5 ]).all()
|
||||
assert func(4) == 3
|
||||
|
||||
expr1 = Max(x**2, x**3)
|
||||
func = lambdify(x,expr1, "numpy")
|
||||
assert (func(numpy.linspace(-1, 2, 4)) == [1, 0, 1, 8] ).all()
|
||||
assert func(4) == 64
|
||||
|
||||
|
||||
def test_SciPyPrinter():
|
||||
p = SciPyPrinter()
|
||||
expr = acos(x)
|
||||
assert 'numpy' not in p.module_imports
|
||||
assert p.doprint(expr) == 'numpy.arccos(x)'
|
||||
assert 'numpy' in p.module_imports
|
||||
assert not any(m.startswith('scipy') for m in p.module_imports)
|
||||
smat = SparseMatrix(2, 5, {(0, 1): 3})
|
||||
assert p.doprint(smat) == \
|
||||
'scipy.sparse.coo_matrix(([3], ([0], [1])), shape=(2, 5))'
|
||||
assert 'scipy.sparse' in p.module_imports
|
||||
|
||||
assert p.doprint(S.GoldenRatio) == 'scipy.constants.golden_ratio'
|
||||
assert p.doprint(S.Pi) == 'scipy.constants.pi'
|
||||
assert p.doprint(S.Exp1) == 'numpy.e'
|
||||
|
||||
|
||||
def test_pycode_reserved_words():
|
||||
s1, s2 = symbols('if else')
|
||||
raises(ValueError, lambda: pycode(s1 + s2, error_on_reserved=True))
|
||||
py_str = pycode(s1 + s2)
|
||||
assert py_str in ('else_ + if_', 'if_ + else_')
|
||||
|
||||
|
||||
def test_issue_20762():
|
||||
# Make sure pycode removes curly braces from subscripted variables
|
||||
a_b, b, a_11 = symbols('a_{b} b a_{11}')
|
||||
expr = a_b*b
|
||||
assert pycode(expr) == 'a_b*b'
|
||||
expr = a_11*b
|
||||
assert pycode(expr) == 'a_11*b'
|
||||
|
||||
|
||||
def test_sqrt():
|
||||
prntr = PythonCodePrinter()
|
||||
assert prntr._print_Pow(sqrt(x), rational=False) == 'math.sqrt(x)'
|
||||
assert prntr._print_Pow(1/sqrt(x), rational=False) == '1/math.sqrt(x)'
|
||||
|
||||
prntr = PythonCodePrinter({'standard' : 'python3'})
|
||||
assert prntr._print_Pow(sqrt(x), rational=True) == 'x**(1/2)'
|
||||
assert prntr._print_Pow(1/sqrt(x), rational=True) == 'x**(-1/2)'
|
||||
|
||||
prntr = MpmathPrinter()
|
||||
assert prntr._print_Pow(sqrt(x), rational=False) == 'mpmath.sqrt(x)'
|
||||
assert prntr._print_Pow(sqrt(x), rational=True) == \
|
||||
"x**(mpmath.mpf(1)/mpmath.mpf(2))"
|
||||
|
||||
prntr = NumPyPrinter()
|
||||
assert prntr._print_Pow(sqrt(x), rational=False) == 'numpy.sqrt(x)'
|
||||
assert prntr._print_Pow(sqrt(x), rational=True) == 'x**(1/2)'
|
||||
|
||||
prntr = SciPyPrinter()
|
||||
assert prntr._print_Pow(sqrt(x), rational=False) == 'numpy.sqrt(x)'
|
||||
assert prntr._print_Pow(sqrt(x), rational=True) == 'x**(1/2)'
|
||||
|
||||
prntr = SymPyPrinter()
|
||||
assert prntr._print_Pow(sqrt(x), rational=False) == 'sympy.sqrt(x)'
|
||||
assert prntr._print_Pow(sqrt(x), rational=True) == 'x**(1/2)'
|
||||
|
||||
|
||||
def test_frac():
|
||||
from sympy.functions.elementary.integers import frac
|
||||
|
||||
expr = frac(x)
|
||||
prntr = NumPyPrinter()
|
||||
assert prntr.doprint(expr) == 'numpy.mod(x, 1)'
|
||||
|
||||
prntr = SciPyPrinter()
|
||||
assert prntr.doprint(expr) == 'numpy.mod(x, 1)'
|
||||
|
||||
prntr = PythonCodePrinter()
|
||||
assert prntr.doprint(expr) == 'x % 1'
|
||||
|
||||
prntr = MpmathPrinter()
|
||||
assert prntr.doprint(expr) == 'mpmath.frac(x)'
|
||||
|
||||
prntr = SymPyPrinter()
|
||||
assert prntr.doprint(expr) == 'sympy.functions.elementary.integers.frac(x)'
|
||||
|
||||
|
||||
class CustomPrintedObject(Expr):
|
||||
def _numpycode(self, printer):
|
||||
return 'numpy'
|
||||
|
||||
def _mpmathcode(self, printer):
|
||||
return 'mpmath'
|
||||
|
||||
|
||||
def test_printmethod():
|
||||
obj = CustomPrintedObject()
|
||||
assert NumPyPrinter().doprint(obj) == 'numpy'
|
||||
assert MpmathPrinter().doprint(obj) == 'mpmath'
|
||||
|
||||
|
||||
def test_codegen_ast_nodes():
|
||||
assert pycode(none) == 'None'
|
||||
|
||||
|
||||
def test_issue_14283():
|
||||
prntr = PythonCodePrinter()
|
||||
|
||||
assert prntr.doprint(zoo) == "math.nan"
|
||||
assert prntr.doprint(-oo) == "float('-inf')"
|
||||
|
||||
|
||||
def test_NumPyPrinter_print_seq():
|
||||
n = NumPyPrinter()
|
||||
|
||||
assert n._print_seq(range(2)) == '(0, 1,)'
|
||||
|
||||
|
||||
def test_issue_16535_16536():
|
||||
from sympy.functions.special.gamma_functions import (lowergamma, uppergamma)
|
||||
|
||||
a = symbols('a')
|
||||
expr1 = lowergamma(a, x)
|
||||
expr2 = uppergamma(a, x)
|
||||
|
||||
prntr = SciPyPrinter()
|
||||
assert prntr.doprint(expr1) == 'scipy.special.gamma(a)*scipy.special.gammainc(a, x)'
|
||||
assert prntr.doprint(expr2) == 'scipy.special.gamma(a)*scipy.special.gammaincc(a, x)'
|
||||
|
||||
p_numpy = NumPyPrinter()
|
||||
p_pycode = PythonCodePrinter({'strict': False})
|
||||
|
||||
for expr in [expr1, expr2]:
|
||||
with raises(NotImplementedError):
|
||||
p_numpy.doprint(expr1)
|
||||
assert "Not supported" in p_pycode.doprint(expr)
|
||||
|
||||
|
||||
def test_Integral():
|
||||
from sympy.functions.elementary.exponential import exp
|
||||
from sympy.integrals.integrals import Integral
|
||||
|
||||
single = Integral(exp(-x), (x, 0, oo))
|
||||
double = Integral(x**2*exp(x*y), (x, -z, z), (y, 0, z))
|
||||
indefinite = Integral(x**2, x)
|
||||
evaluateat = Integral(x**2, (x, 1))
|
||||
|
||||
prntr = SciPyPrinter()
|
||||
assert prntr.doprint(single) == 'scipy.integrate.quad(lambda x: numpy.exp(-x), 0, numpy.inf)[0]'
|
||||
assert prntr.doprint(double) == 'scipy.integrate.nquad(lambda x, y: x**2*numpy.exp(x*y), ((-z, z), (0, z)))[0]'
|
||||
raises(NotImplementedError, lambda: prntr.doprint(indefinite))
|
||||
raises(NotImplementedError, lambda: prntr.doprint(evaluateat))
|
||||
|
||||
prntr = MpmathPrinter()
|
||||
assert prntr.doprint(single) == 'mpmath.quad(lambda x: mpmath.exp(-x), (0, mpmath.inf))'
|
||||
assert prntr.doprint(double) == 'mpmath.quad(lambda x, y: x**2*mpmath.exp(x*y), (-z, z), (0, z))'
|
||||
raises(NotImplementedError, lambda: prntr.doprint(indefinite))
|
||||
raises(NotImplementedError, lambda: prntr.doprint(evaluateat))
|
||||
|
||||
|
||||
def test_fresnel_integrals():
|
||||
from sympy.functions.special.error_functions import (fresnelc, fresnels)
|
||||
|
||||
expr1 = fresnelc(x)
|
||||
expr2 = fresnels(x)
|
||||
|
||||
prntr = SciPyPrinter()
|
||||
assert prntr.doprint(expr1) == 'scipy.special.fresnel(x)[1]'
|
||||
assert prntr.doprint(expr2) == 'scipy.special.fresnel(x)[0]'
|
||||
|
||||
p_numpy = NumPyPrinter()
|
||||
p_pycode = PythonCodePrinter()
|
||||
p_mpmath = MpmathPrinter()
|
||||
for expr in [expr1, expr2]:
|
||||
with raises(NotImplementedError):
|
||||
p_numpy.doprint(expr)
|
||||
with raises(NotImplementedError):
|
||||
p_pycode.doprint(expr)
|
||||
|
||||
assert p_mpmath.doprint(expr1) == 'mpmath.fresnelc(x)'
|
||||
assert p_mpmath.doprint(expr2) == 'mpmath.fresnels(x)'
|
||||
|
||||
|
||||
def test_beta():
|
||||
from sympy.functions.special.beta_functions import beta
|
||||
|
||||
expr = beta(x, y)
|
||||
|
||||
prntr = SciPyPrinter()
|
||||
assert prntr.doprint(expr) == 'scipy.special.beta(x, y)'
|
||||
|
||||
prntr = NumPyPrinter()
|
||||
assert prntr.doprint(expr) == '(math.gamma(x)*math.gamma(y)/math.gamma(x + y))'
|
||||
|
||||
prntr = PythonCodePrinter()
|
||||
assert prntr.doprint(expr) == '(math.gamma(x)*math.gamma(y)/math.gamma(x + y))'
|
||||
|
||||
prntr = PythonCodePrinter({'allow_unknown_functions': True})
|
||||
assert prntr.doprint(expr) == '(math.gamma(x)*math.gamma(y)/math.gamma(x + y))'
|
||||
|
||||
prntr = MpmathPrinter()
|
||||
assert prntr.doprint(expr) == 'mpmath.beta(x, y)'
|
||||
|
||||
def test_airy():
|
||||
from sympy.functions.special.bessel import (airyai, airybi)
|
||||
|
||||
expr1 = airyai(x)
|
||||
expr2 = airybi(x)
|
||||
|
||||
prntr = SciPyPrinter()
|
||||
assert prntr.doprint(expr1) == 'scipy.special.airy(x)[0]'
|
||||
assert prntr.doprint(expr2) == 'scipy.special.airy(x)[2]'
|
||||
|
||||
prntr = NumPyPrinter({'strict': False})
|
||||
assert "Not supported" in prntr.doprint(expr1)
|
||||
assert "Not supported" in prntr.doprint(expr2)
|
||||
|
||||
prntr = PythonCodePrinter({'strict': False})
|
||||
assert "Not supported" in prntr.doprint(expr1)
|
||||
assert "Not supported" in prntr.doprint(expr2)
|
||||
|
||||
def test_airy_prime():
|
||||
from sympy.functions.special.bessel import (airyaiprime, airybiprime)
|
||||
|
||||
expr1 = airyaiprime(x)
|
||||
expr2 = airybiprime(x)
|
||||
|
||||
prntr = SciPyPrinter()
|
||||
assert prntr.doprint(expr1) == 'scipy.special.airy(x)[1]'
|
||||
assert prntr.doprint(expr2) == 'scipy.special.airy(x)[3]'
|
||||
|
||||
prntr = NumPyPrinter({'strict': False})
|
||||
assert "Not supported" in prntr.doprint(expr1)
|
||||
assert "Not supported" in prntr.doprint(expr2)
|
||||
|
||||
prntr = PythonCodePrinter({'strict': False})
|
||||
assert "Not supported" in prntr.doprint(expr1)
|
||||
assert "Not supported" in prntr.doprint(expr2)
|
||||
|
||||
|
||||
def test_numerical_accuracy_functions():
|
||||
prntr = SciPyPrinter()
|
||||
assert prntr.doprint(expm1(x)) == 'numpy.expm1(x)'
|
||||
assert prntr.doprint(log1p(x)) == 'numpy.log1p(x)'
|
||||
assert prntr.doprint(cosm1(x)) == 'scipy.special.cosm1(x)'
|
||||
|
||||
def test_array_printer():
|
||||
A = ArraySymbol('A', (4,4,6,6,6))
|
||||
I = IndexedBase('I')
|
||||
i,j,k = Idx('i', (0,1)), Idx('j', (2,3)), Idx('k', (4,5))
|
||||
|
||||
prntr = NumPyPrinter()
|
||||
assert prntr.doprint(ZeroArray(5)) == 'numpy.zeros((5,))'
|
||||
assert prntr.doprint(OneArray(5)) == 'numpy.ones((5,))'
|
||||
assert prntr.doprint(ArrayContraction(A, [2,3])) == 'numpy.einsum("abccd->abd", A)'
|
||||
assert prntr.doprint(I) == 'I'
|
||||
assert prntr.doprint(ArrayDiagonal(A, [2,3,4])) == 'numpy.einsum("abccc->abc", A)'
|
||||
assert prntr.doprint(ArrayDiagonal(A, [0,1], [2,3])) == 'numpy.einsum("aabbc->cab", A)'
|
||||
assert prntr.doprint(ArrayContraction(A, [2], [3])) == 'numpy.einsum("abcde->abe", A)'
|
||||
assert prntr.doprint(Assignment(I[i,j,k], I[i,j,k])) == 'I = I'
|
||||
|
||||
prntr = TensorflowPrinter()
|
||||
assert prntr.doprint(ZeroArray(5)) == 'tensorflow.zeros((5,))'
|
||||
assert prntr.doprint(OneArray(5)) == 'tensorflow.ones((5,))'
|
||||
assert prntr.doprint(ArrayContraction(A, [2,3])) == 'tensorflow.linalg.einsum("abccd->abd", A)'
|
||||
assert prntr.doprint(I) == 'I'
|
||||
assert prntr.doprint(ArrayDiagonal(A, [2,3,4])) == 'tensorflow.linalg.einsum("abccc->abc", A)'
|
||||
assert prntr.doprint(ArrayDiagonal(A, [0,1], [2,3])) == 'tensorflow.linalg.einsum("aabbc->cab", A)'
|
||||
assert prntr.doprint(ArrayContraction(A, [2], [3])) == 'tensorflow.linalg.einsum("abcde->abe", A)'
|
||||
assert prntr.doprint(Assignment(I[i,j,k], I[i,j,k])) == 'I = I'
|
||||
|
||||
|
||||
def test_custom_Derivative_methods():
|
||||
class MyPrinter(SciPyPrinter):
|
||||
def _print_Derivative_cosm1(self, args, seq_orders):
|
||||
arg, = args
|
||||
order, = seq_orders
|
||||
return 'my_custom_cosm1(%s, deriv_order=%d)' % (self._print(arg), order)
|
||||
|
||||
def _print_Derivative_atan2(self, args, seq_orders):
|
||||
arg1, arg2 = args
|
||||
ord1, ord2 = seq_orders
|
||||
return 'my_custom_atan2(%s, %s, deriv1=%d, deriv2=%d)' % (
|
||||
self._print(arg1), self._print(arg2), ord1, ord2
|
||||
)
|
||||
|
||||
p = MyPrinter()
|
||||
cosm1_1 = cosm1(x).diff(x, evaluate=False)
|
||||
assert p.doprint(cosm1_1) == 'my_custom_cosm1(x, deriv_order=1)'
|
||||
atan2_2_3 = atan2(x, y).diff(x, 2, y, 3, evaluate=False)
|
||||
assert p.doprint(atan2_2_3) == 'my_custom_atan2(x, y, deriv1=2, deriv2=3)'
|
||||
|
||||
try:
|
||||
p.doprint(expm1(x).diff(x, evaluate=False))
|
||||
except PrintMethodNotImplementedError as e:
|
||||
assert '_print_Derivative_expm1' in repr(e)
|
||||
else:
|
||||
assert False # should have thrown
|
||||
|
||||
try:
|
||||
p.doprint(Derivative(cosm1(x**2),x))
|
||||
except ValueError as e:
|
||||
assert '_print_Derivative(' in repr(e)
|
||||
else:
|
||||
assert False # should have thrown
|
||||
@@ -0,0 +1,203 @@
|
||||
from sympy.core.function import (Derivative, Function)
|
||||
from sympy.core.numbers import (I, Rational, oo, pi)
|
||||
from sympy.core.relational import (Eq, Ge, Gt, Le, Lt, Ne)
|
||||
from sympy.core.symbol import (Symbol, symbols)
|
||||
from sympy.functions.elementary.complexes import (Abs, conjugate)
|
||||
from sympy.functions.elementary.exponential import (exp, log)
|
||||
from sympy.functions.elementary.miscellaneous import sqrt
|
||||
from sympy.functions.elementary.trigonometric import sin
|
||||
from sympy.integrals.integrals import Integral
|
||||
from sympy.matrices.dense import Matrix
|
||||
from sympy.series.limits import limit
|
||||
|
||||
from sympy.printing.python import python
|
||||
|
||||
from sympy.testing.pytest import raises, XFAIL
|
||||
|
||||
x, y = symbols('x,y')
|
||||
th = Symbol('theta')
|
||||
ph = Symbol('phi')
|
||||
|
||||
|
||||
def test_python_basic():
|
||||
# Simple numbers/symbols
|
||||
assert python(-Rational(1)/2) == "e = Rational(-1, 2)"
|
||||
assert python(-Rational(13)/22) == "e = Rational(-13, 22)"
|
||||
assert python(oo) == "e = oo"
|
||||
|
||||
# Powers
|
||||
assert python(x**2) == "x = Symbol(\'x\')\ne = x**2"
|
||||
assert python(1/x) == "x = Symbol('x')\ne = 1/x"
|
||||
assert python(y*x**-2) == "y = Symbol('y')\nx = Symbol('x')\ne = y/x**2"
|
||||
assert python(
|
||||
x**Rational(-5, 2)) == "x = Symbol('x')\ne = x**Rational(-5, 2)"
|
||||
|
||||
# Sums of terms
|
||||
assert python(x**2 + x + 1) in [
|
||||
"x = Symbol('x')\ne = 1 + x + x**2",
|
||||
"x = Symbol('x')\ne = x + x**2 + 1",
|
||||
"x = Symbol('x')\ne = x**2 + x + 1", ]
|
||||
assert python(1 - x) in [
|
||||
"x = Symbol('x')\ne = 1 - x",
|
||||
"x = Symbol('x')\ne = -x + 1"]
|
||||
assert python(1 - 2*x) in [
|
||||
"x = Symbol('x')\ne = 1 - 2*x",
|
||||
"x = Symbol('x')\ne = -2*x + 1"]
|
||||
assert python(1 - Rational(3, 2)*y/x) in [
|
||||
"y = Symbol('y')\nx = Symbol('x')\ne = 1 - 3/2*y/x",
|
||||
"y = Symbol('y')\nx = Symbol('x')\ne = -3/2*y/x + 1",
|
||||
"y = Symbol('y')\nx = Symbol('x')\ne = 1 - 3*y/(2*x)"]
|
||||
|
||||
# Multiplication
|
||||
assert python(x/y) == "x = Symbol('x')\ny = Symbol('y')\ne = x/y"
|
||||
assert python(-x/y) == "x = Symbol('x')\ny = Symbol('y')\ne = -x/y"
|
||||
assert python((x + 2)/y) in [
|
||||
"y = Symbol('y')\nx = Symbol('x')\ne = 1/y*(2 + x)",
|
||||
"y = Symbol('y')\nx = Symbol('x')\ne = 1/y*(x + 2)",
|
||||
"x = Symbol('x')\ny = Symbol('y')\ne = 1/y*(2 + x)",
|
||||
"x = Symbol('x')\ny = Symbol('y')\ne = (2 + x)/y",
|
||||
"x = Symbol('x')\ny = Symbol('y')\ne = (x + 2)/y"]
|
||||
assert python((1 + x)*y) in [
|
||||
"y = Symbol('y')\nx = Symbol('x')\ne = y*(1 + x)",
|
||||
"y = Symbol('y')\nx = Symbol('x')\ne = y*(x + 1)", ]
|
||||
|
||||
# Check for proper placement of negative sign
|
||||
assert python(-5*x/(x + 10)) == "x = Symbol('x')\ne = -5*x/(x + 10)"
|
||||
assert python(1 - Rational(3, 2)*(x + 1)) in [
|
||||
"x = Symbol('x')\ne = Rational(-3, 2)*x + Rational(-1, 2)",
|
||||
"x = Symbol('x')\ne = -3*x/2 + Rational(-1, 2)",
|
||||
"x = Symbol('x')\ne = -3*x/2 + Rational(-1, 2)"
|
||||
]
|
||||
|
||||
|
||||
def test_python_keyword_symbol_name_escaping():
|
||||
# Check for escaping of keywords
|
||||
assert python(
|
||||
5*Symbol("lambda")) == "lambda_ = Symbol('lambda')\ne = 5*lambda_"
|
||||
assert (python(5*Symbol("lambda") + 7*Symbol("lambda_")) ==
|
||||
"lambda__ = Symbol('lambda')\nlambda_ = Symbol('lambda_')\ne = 7*lambda_ + 5*lambda__")
|
||||
assert (python(5*Symbol("for") + Function("for_")(8)) ==
|
||||
"for__ = Symbol('for')\nfor_ = Function('for_')\ne = 5*for__ + for_(8)")
|
||||
|
||||
|
||||
def test_python_keyword_function_name_escaping():
|
||||
assert python(
|
||||
5*Function("for")(8)) == "for_ = Function('for')\ne = 5*for_(8)"
|
||||
|
||||
|
||||
def test_python_relational():
|
||||
assert python(Eq(x, y)) == "x = Symbol('x')\ny = Symbol('y')\ne = Eq(x, y)"
|
||||
assert python(Ge(x, y)) == "x = Symbol('x')\ny = Symbol('y')\ne = x >= y"
|
||||
assert python(Le(x, y)) == "x = Symbol('x')\ny = Symbol('y')\ne = x <= y"
|
||||
assert python(Gt(x, y)) == "x = Symbol('x')\ny = Symbol('y')\ne = x > y"
|
||||
assert python(Lt(x, y)) == "x = Symbol('x')\ny = Symbol('y')\ne = x < y"
|
||||
assert python(Ne(x/(y + 1), y**2)) in [
|
||||
"x = Symbol('x')\ny = Symbol('y')\ne = Ne(x/(1 + y), y**2)",
|
||||
"x = Symbol('x')\ny = Symbol('y')\ne = Ne(x/(y + 1), y**2)"]
|
||||
|
||||
|
||||
def test_python_functions():
|
||||
# Simple
|
||||
assert python(2*x + exp(x)) in "x = Symbol('x')\ne = 2*x + exp(x)"
|
||||
assert python(sqrt(2)) == 'e = sqrt(2)'
|
||||
assert python(2**Rational(1, 3)) == 'e = 2**Rational(1, 3)'
|
||||
assert python(sqrt(2 + pi)) == 'e = sqrt(2 + pi)'
|
||||
assert python((2 + pi)**Rational(1, 3)) == 'e = (2 + pi)**Rational(1, 3)'
|
||||
assert python(2**Rational(1, 4)) == 'e = 2**Rational(1, 4)'
|
||||
assert python(Abs(x)) == "x = Symbol('x')\ne = Abs(x)"
|
||||
assert python(
|
||||
Abs(x/(x**2 + 1))) in ["x = Symbol('x')\ne = Abs(x/(1 + x**2))",
|
||||
"x = Symbol('x')\ne = Abs(x/(x**2 + 1))"]
|
||||
|
||||
# Univariate/Multivariate functions
|
||||
f = Function('f')
|
||||
assert python(f(x)) == "x = Symbol('x')\nf = Function('f')\ne = f(x)"
|
||||
assert python(f(x, y)) == "x = Symbol('x')\ny = Symbol('y')\nf = Function('f')\ne = f(x, y)"
|
||||
assert python(f(x/(y + 1), y)) in [
|
||||
"x = Symbol('x')\ny = Symbol('y')\nf = Function('f')\ne = f(x/(1 + y), y)",
|
||||
"x = Symbol('x')\ny = Symbol('y')\nf = Function('f')\ne = f(x/(y + 1), y)"]
|
||||
|
||||
# Nesting of square roots
|
||||
assert python(sqrt((sqrt(x + 1)) + 1)) in [
|
||||
"x = Symbol('x')\ne = sqrt(1 + sqrt(1 + x))",
|
||||
"x = Symbol('x')\ne = sqrt(sqrt(x + 1) + 1)"]
|
||||
|
||||
# Nesting of powers
|
||||
assert python((((x + 1)**Rational(1, 3)) + 1)**Rational(1, 3)) in [
|
||||
"x = Symbol('x')\ne = (1 + (1 + x)**Rational(1, 3))**Rational(1, 3)",
|
||||
"x = Symbol('x')\ne = ((x + 1)**Rational(1, 3) + 1)**Rational(1, 3)"]
|
||||
|
||||
# Function powers
|
||||
assert python(sin(x)**2) == "x = Symbol('x')\ne = sin(x)**2"
|
||||
|
||||
|
||||
@XFAIL
|
||||
def test_python_functions_conjugates():
|
||||
a, b = map(Symbol, 'ab')
|
||||
assert python( conjugate(a + b*I) ) == '_ _\na - I*b'
|
||||
assert python( conjugate(exp(a + b*I)) ) == ' _ _\n a - I*b\ne '
|
||||
|
||||
|
||||
def test_python_derivatives():
|
||||
# Simple
|
||||
f_1 = Derivative(log(x), x, evaluate=False)
|
||||
assert python(f_1) == "x = Symbol('x')\ne = Derivative(log(x), x)"
|
||||
|
||||
f_2 = Derivative(log(x), x, evaluate=False) + x
|
||||
assert python(f_2) == "x = Symbol('x')\ne = x + Derivative(log(x), x)"
|
||||
|
||||
# Multiple symbols
|
||||
f_3 = Derivative(log(x) + x**2, x, y, evaluate=False)
|
||||
assert python(f_3) == \
|
||||
"x = Symbol('x')\ny = Symbol('y')\ne = Derivative(x**2 + log(x), x, y)"
|
||||
|
||||
f_4 = Derivative(2*x*y, y, x, evaluate=False) + x**2
|
||||
assert python(f_4) in [
|
||||
"x = Symbol('x')\ny = Symbol('y')\ne = x**2 + Derivative(2*x*y, y, x)",
|
||||
"x = Symbol('x')\ny = Symbol('y')\ne = Derivative(2*x*y, y, x) + x**2"]
|
||||
|
||||
|
||||
def test_python_integrals():
|
||||
# Simple
|
||||
f_1 = Integral(log(x), x)
|
||||
assert python(f_1) == "x = Symbol('x')\ne = Integral(log(x), x)"
|
||||
|
||||
f_2 = Integral(x**2, x)
|
||||
assert python(f_2) == "x = Symbol('x')\ne = Integral(x**2, x)"
|
||||
|
||||
# Double nesting of pow
|
||||
f_3 = Integral(x**(2**x), x)
|
||||
assert python(f_3) == "x = Symbol('x')\ne = Integral(x**(2**x), x)"
|
||||
|
||||
# Definite integrals
|
||||
f_4 = Integral(x**2, (x, 1, 2))
|
||||
assert python(f_4) == "x = Symbol('x')\ne = Integral(x**2, (x, 1, 2))"
|
||||
|
||||
f_5 = Integral(x**2, (x, Rational(1, 2), 10))
|
||||
assert python(
|
||||
f_5) == "x = Symbol('x')\ne = Integral(x**2, (x, Rational(1, 2), 10))"
|
||||
|
||||
# Nested integrals
|
||||
f_6 = Integral(x**2*y**2, x, y)
|
||||
assert python(f_6) == "x = Symbol('x')\ny = Symbol('y')\ne = Integral(x**2*y**2, x, y)"
|
||||
|
||||
|
||||
def test_python_matrix():
|
||||
p = python(Matrix([[x**2+1, 1], [y, x+y]]))
|
||||
s = "x = Symbol('x')\ny = Symbol('y')\ne = MutableDenseMatrix([[x**2 + 1, 1], [y, x + y]])"
|
||||
assert p == s
|
||||
|
||||
def test_python_limits():
|
||||
assert python(limit(x, x, oo)) == 'e = oo'
|
||||
assert python(limit(x**2, x, 0)) == 'e = 0'
|
||||
|
||||
def test_issue_20762():
|
||||
# Make sure Python removes curly braces from subscripted variables
|
||||
a_b = Symbol('a_{b}')
|
||||
b = Symbol('b')
|
||||
expr = a_b*b
|
||||
assert python(expr) == "a_b = Symbol('a_{b}')\nb = Symbol('b')\ne = a_b*b"
|
||||
|
||||
|
||||
def test_settings():
|
||||
raises(TypeError, lambda: python(x, method="garbage"))
|
||||
@@ -0,0 +1,476 @@
|
||||
from sympy.core import (S, pi, oo, Symbol, symbols, Rational, Integer,
|
||||
GoldenRatio, EulerGamma, Catalan, Lambda, Dummy)
|
||||
from sympy.functions import (Piecewise, sin, cos, Abs, exp, ceiling, sqrt,
|
||||
gamma, sign, Max, Min, factorial, beta)
|
||||
from sympy.core.relational import (Eq, Ge, Gt, Le, Lt, Ne)
|
||||
from sympy.sets import Range
|
||||
from sympy.logic import ITE
|
||||
from sympy.codegen import For, aug_assign, Assignment
|
||||
from sympy.testing.pytest import raises
|
||||
from sympy.printing.rcode import RCodePrinter
|
||||
from sympy.utilities.lambdify import implemented_function
|
||||
from sympy.tensor import IndexedBase, Idx
|
||||
from sympy.matrices import Matrix, MatrixSymbol
|
||||
|
||||
from sympy.printing.rcode import rcode
|
||||
|
||||
x, y, z = symbols('x,y,z')
|
||||
|
||||
|
||||
def test_printmethod():
|
||||
class fabs(Abs):
|
||||
def _rcode(self, printer):
|
||||
return "abs(%s)" % printer._print(self.args[0])
|
||||
|
||||
assert rcode(fabs(x)) == "abs(x)"
|
||||
|
||||
|
||||
def test_rcode_sqrt():
|
||||
assert rcode(sqrt(x)) == "sqrt(x)"
|
||||
assert rcode(x**0.5) == "sqrt(x)"
|
||||
assert rcode(sqrt(x)) == "sqrt(x)"
|
||||
|
||||
|
||||
def test_rcode_Pow():
|
||||
assert rcode(x**3) == "x^3"
|
||||
assert rcode(x**(y**3)) == "x^(y^3)"
|
||||
g = implemented_function('g', Lambda(x, 2*x))
|
||||
assert rcode(1/(g(x)*3.5)**(x - y**x)/(x**2 + y)) == \
|
||||
"(3.5*2*x)^(-x + y^x)/(x^2 + y)"
|
||||
assert rcode(x**-1.0) == '1.0/x'
|
||||
assert rcode(x**Rational(2, 3)) == 'x^(2.0/3.0)'
|
||||
_cond_cfunc = [(lambda base, exp: exp.is_integer, "dpowi"),
|
||||
(lambda base, exp: not exp.is_integer, "pow")]
|
||||
assert rcode(x**3, user_functions={'Pow': _cond_cfunc}) == 'dpowi(x, 3)'
|
||||
assert rcode(x**3.2, user_functions={'Pow': _cond_cfunc}) == 'pow(x, 3.2)'
|
||||
|
||||
|
||||
def test_rcode_Max():
|
||||
# Test for gh-11926
|
||||
assert rcode(Max(x,x*x),user_functions={"Max":"my_max", "Pow":"my_pow"}) == 'my_max(x, my_pow(x, 2))'
|
||||
|
||||
|
||||
def test_rcode_constants_mathh():
|
||||
assert rcode(exp(1)) == "exp(1)"
|
||||
assert rcode(pi) == "pi"
|
||||
assert rcode(oo) == "Inf"
|
||||
assert rcode(-oo) == "-Inf"
|
||||
|
||||
|
||||
def test_rcode_constants_other():
|
||||
assert rcode(2*GoldenRatio) == "GoldenRatio = 1.61803398874989;\n2*GoldenRatio"
|
||||
assert rcode(
|
||||
2*Catalan) == "Catalan = 0.915965594177219;\n2*Catalan"
|
||||
assert rcode(2*EulerGamma) == "EulerGamma = 0.577215664901533;\n2*EulerGamma"
|
||||
|
||||
|
||||
def test_rcode_Rational():
|
||||
assert rcode(Rational(3, 7)) == "3.0/7.0"
|
||||
assert rcode(Rational(18, 9)) == "2"
|
||||
assert rcode(Rational(3, -7)) == "-3.0/7.0"
|
||||
assert rcode(Rational(-3, -7)) == "3.0/7.0"
|
||||
assert rcode(x + Rational(3, 7)) == "x + 3.0/7.0"
|
||||
assert rcode(Rational(3, 7)*x) == "(3.0/7.0)*x"
|
||||
|
||||
|
||||
def test_rcode_Integer():
|
||||
assert rcode(Integer(67)) == "67"
|
||||
assert rcode(Integer(-1)) == "-1"
|
||||
|
||||
|
||||
def test_rcode_functions():
|
||||
assert rcode(sin(x) ** cos(x)) == "sin(x)^cos(x)"
|
||||
assert rcode(factorial(x) + gamma(y)) == "factorial(x) + gamma(y)"
|
||||
assert rcode(beta(Min(x, y), Max(x, y))) == "beta(min(x, y), max(x, y))"
|
||||
|
||||
|
||||
def test_rcode_inline_function():
|
||||
x = symbols('x')
|
||||
g = implemented_function('g', Lambda(x, 2*x))
|
||||
assert rcode(g(x)) == "2*x"
|
||||
g = implemented_function('g', Lambda(x, 2*x/Catalan))
|
||||
assert rcode(
|
||||
g(x)) == "Catalan = %s;\n2*x/Catalan" % Catalan.n()
|
||||
A = IndexedBase('A')
|
||||
i = Idx('i', symbols('n', integer=True))
|
||||
g = implemented_function('g', Lambda(x, x*(1 + x)*(2 + x)))
|
||||
res=rcode(g(A[i]), assign_to=A[i])
|
||||
ref=(
|
||||
"for (i in 1:n){\n"
|
||||
" A[i] = (A[i] + 1)*(A[i] + 2)*A[i];\n"
|
||||
"}"
|
||||
)
|
||||
assert res == ref
|
||||
|
||||
|
||||
def test_rcode_exceptions():
|
||||
assert rcode(ceiling(x)) == "ceiling(x)"
|
||||
assert rcode(Abs(x)) == "abs(x)"
|
||||
assert rcode(gamma(x)) == "gamma(x)"
|
||||
|
||||
|
||||
def test_rcode_user_functions():
|
||||
x = symbols('x', integer=False)
|
||||
n = symbols('n', integer=True)
|
||||
custom_functions = {
|
||||
"ceiling": "myceil",
|
||||
"Abs": [(lambda x: not x.is_integer, "fabs"), (lambda x: x.is_integer, "abs")],
|
||||
}
|
||||
assert rcode(ceiling(x), user_functions=custom_functions) == "myceil(x)"
|
||||
assert rcode(Abs(x), user_functions=custom_functions) == "fabs(x)"
|
||||
assert rcode(Abs(n), user_functions=custom_functions) == "abs(n)"
|
||||
|
||||
|
||||
def test_rcode_boolean():
|
||||
assert rcode(True) == "True"
|
||||
assert rcode(S.true) == "True"
|
||||
assert rcode(False) == "False"
|
||||
assert rcode(S.false) == "False"
|
||||
assert rcode(x & y) == "x & y"
|
||||
assert rcode(x | y) == "x | y"
|
||||
assert rcode(~x) == "!x"
|
||||
assert rcode(x & y & z) == "x & y & z"
|
||||
assert rcode(x | y | z) == "x | y | z"
|
||||
assert rcode((x & y) | z) == "z | x & y"
|
||||
assert rcode((x | y) & z) == "z & (x | y)"
|
||||
|
||||
def test_rcode_Relational():
|
||||
assert rcode(Eq(x, y)) == "x == y"
|
||||
assert rcode(Ne(x, y)) == "x != y"
|
||||
assert rcode(Le(x, y)) == "x <= y"
|
||||
assert rcode(Lt(x, y)) == "x < y"
|
||||
assert rcode(Gt(x, y)) == "x > y"
|
||||
assert rcode(Ge(x, y)) == "x >= y"
|
||||
|
||||
|
||||
def test_rcode_Piecewise():
|
||||
expr = Piecewise((x, x < 1), (x**2, True))
|
||||
res=rcode(expr)
|
||||
ref="ifelse(x < 1,x,x^2)"
|
||||
assert res == ref
|
||||
tau=Symbol("tau")
|
||||
res=rcode(expr,tau)
|
||||
ref="tau = ifelse(x < 1,x,x^2);"
|
||||
assert res == ref
|
||||
|
||||
expr = 2*Piecewise((x, x < 1), (x**2, x<2), (x**3,True))
|
||||
assert rcode(expr) == "2*ifelse(x < 1,x,ifelse(x < 2,x^2,x^3))"
|
||||
res = rcode(expr, assign_to='c')
|
||||
assert res == "c = 2*ifelse(x < 1,x,ifelse(x < 2,x^2,x^3));"
|
||||
|
||||
# Check that Piecewise without a True (default) condition error
|
||||
#expr = Piecewise((x, x < 1), (x**2, x > 1), (sin(x), x > 0))
|
||||
#raises(ValueError, lambda: rcode(expr))
|
||||
expr = 2*Piecewise((x, x < 1), (x**2, x<2))
|
||||
assert(rcode(expr))== "2*ifelse(x < 1,x,ifelse(x < 2,x^2,NA))"
|
||||
|
||||
|
||||
def test_rcode_sinc():
|
||||
from sympy.functions.elementary.trigonometric import sinc
|
||||
expr = sinc(x)
|
||||
res = rcode(expr)
|
||||
ref = "(ifelse(x != 0,sin(x)/x,1))"
|
||||
assert res == ref
|
||||
|
||||
|
||||
def test_rcode_Piecewise_deep():
|
||||
p = rcode(2*Piecewise((x, x < 1), (x + 1, x < 2), (x**2, True)))
|
||||
assert p == "2*ifelse(x < 1,x,ifelse(x < 2,x + 1,x^2))"
|
||||
expr = x*y*z + x**2 + y**2 + Piecewise((0, x < 0.5), (1, True)) + cos(z) - 1
|
||||
p = rcode(expr)
|
||||
ref="x^2 + x*y*z + y^2 + ifelse(x < 0.5,0,1) + cos(z) - 1"
|
||||
assert p == ref
|
||||
|
||||
ref="c = x^2 + x*y*z + y^2 + ifelse(x < 0.5,0,1) + cos(z) - 1;"
|
||||
p = rcode(expr, assign_to='c')
|
||||
assert p == ref
|
||||
|
||||
|
||||
def test_rcode_ITE():
|
||||
expr = ITE(x < 1, y, z)
|
||||
p = rcode(expr)
|
||||
ref="ifelse(x < 1,y,z)"
|
||||
assert p == ref
|
||||
|
||||
|
||||
def test_rcode_settings():
|
||||
raises(TypeError, lambda: rcode(sin(x), method="garbage"))
|
||||
|
||||
|
||||
def test_rcode_Indexed():
|
||||
n, m, o = symbols('n m o', integer=True)
|
||||
i, j, k = Idx('i', n), Idx('j', m), Idx('k', o)
|
||||
p = RCodePrinter()
|
||||
p._not_r = set()
|
||||
|
||||
x = IndexedBase('x')[j]
|
||||
assert p._print_Indexed(x) == 'x[j]'
|
||||
A = IndexedBase('A')[i, j]
|
||||
assert p._print_Indexed(A) == 'A[i, j]'
|
||||
B = IndexedBase('B')[i, j, k]
|
||||
assert p._print_Indexed(B) == 'B[i, j, k]'
|
||||
|
||||
assert p._not_r == set()
|
||||
|
||||
def test_rcode_Indexed_without_looking_for_contraction():
|
||||
len_y = 5
|
||||
y = IndexedBase('y', shape=(len_y,))
|
||||
x = IndexedBase('x', shape=(len_y,))
|
||||
Dy = IndexedBase('Dy', shape=(len_y-1,))
|
||||
i = Idx('i', len_y-1)
|
||||
e=Eq(Dy[i], (y[i+1]-y[i])/(x[i+1]-x[i]))
|
||||
code0 = rcode(e.rhs, assign_to=e.lhs, contract=False)
|
||||
assert code0 == 'Dy[i] = (y[%s] - y[i])/(x[%s] - x[i]);' % (i + 1, i + 1)
|
||||
|
||||
|
||||
def test_rcode_loops_matrix_vector():
|
||||
n, m = symbols('n m', integer=True)
|
||||
A = IndexedBase('A')
|
||||
x = IndexedBase('x')
|
||||
y = IndexedBase('y')
|
||||
i = Idx('i', m)
|
||||
j = Idx('j', n)
|
||||
|
||||
s = (
|
||||
'for (i in 1:m){\n'
|
||||
' y[i] = 0;\n'
|
||||
'}\n'
|
||||
'for (i in 1:m){\n'
|
||||
' for (j in 1:n){\n'
|
||||
' y[i] = A[i, j]*x[j] + y[i];\n'
|
||||
' }\n'
|
||||
'}'
|
||||
)
|
||||
c = rcode(A[i, j]*x[j], assign_to=y[i])
|
||||
assert c == s
|
||||
|
||||
|
||||
def test_dummy_loops():
|
||||
# the following line could also be
|
||||
# [Dummy(s, integer=True) for s in 'im']
|
||||
# or [Dummy(integer=True) for s in 'im']
|
||||
i, m = symbols('i m', integer=True, cls=Dummy)
|
||||
x = IndexedBase('x')
|
||||
y = IndexedBase('y')
|
||||
i = Idx(i, m)
|
||||
|
||||
expected = (
|
||||
'for (i_%(icount)i in 1:m_%(mcount)i){\n'
|
||||
' y[i_%(icount)i] = x[i_%(icount)i];\n'
|
||||
'}'
|
||||
) % {'icount': i.label.dummy_index, 'mcount': m.dummy_index}
|
||||
code = rcode(x[i], assign_to=y[i])
|
||||
assert code == expected
|
||||
|
||||
|
||||
def test_rcode_loops_add():
|
||||
n, m = symbols('n m', integer=True)
|
||||
A = IndexedBase('A')
|
||||
x = IndexedBase('x')
|
||||
y = IndexedBase('y')
|
||||
z = IndexedBase('z')
|
||||
i = Idx('i', m)
|
||||
j = Idx('j', n)
|
||||
|
||||
s = (
|
||||
'for (i in 1:m){\n'
|
||||
' y[i] = x[i] + z[i];\n'
|
||||
'}\n'
|
||||
'for (i in 1:m){\n'
|
||||
' for (j in 1:n){\n'
|
||||
' y[i] = A[i, j]*x[j] + y[i];\n'
|
||||
' }\n'
|
||||
'}'
|
||||
)
|
||||
c = rcode(A[i, j]*x[j] + x[i] + z[i], assign_to=y[i])
|
||||
assert c == s
|
||||
|
||||
|
||||
def test_rcode_loops_multiple_contractions():
|
||||
n, m, o, p = symbols('n m o p', integer=True)
|
||||
a = IndexedBase('a')
|
||||
b = IndexedBase('b')
|
||||
y = IndexedBase('y')
|
||||
i = Idx('i', m)
|
||||
j = Idx('j', n)
|
||||
k = Idx('k', o)
|
||||
l = Idx('l', p)
|
||||
|
||||
s = (
|
||||
'for (i in 1:m){\n'
|
||||
' y[i] = 0;\n'
|
||||
'}\n'
|
||||
'for (i in 1:m){\n'
|
||||
' for (j in 1:n){\n'
|
||||
' for (k in 1:o){\n'
|
||||
' for (l in 1:p){\n'
|
||||
' y[i] = a[i, j, k, l]*b[j, k, l] + y[i];\n'
|
||||
' }\n'
|
||||
' }\n'
|
||||
' }\n'
|
||||
'}'
|
||||
)
|
||||
c = rcode(b[j, k, l]*a[i, j, k, l], assign_to=y[i])
|
||||
assert c == s
|
||||
|
||||
|
||||
def test_rcode_loops_addfactor():
|
||||
n, m, o, p = symbols('n m o p', integer=True)
|
||||
a = IndexedBase('a')
|
||||
b = IndexedBase('b')
|
||||
c = IndexedBase('c')
|
||||
y = IndexedBase('y')
|
||||
i = Idx('i', m)
|
||||
j = Idx('j', n)
|
||||
k = Idx('k', o)
|
||||
l = Idx('l', p)
|
||||
|
||||
s = (
|
||||
'for (i in 1:m){\n'
|
||||
' y[i] = 0;\n'
|
||||
'}\n'
|
||||
'for (i in 1:m){\n'
|
||||
' for (j in 1:n){\n'
|
||||
' for (k in 1:o){\n'
|
||||
' for (l in 1:p){\n'
|
||||
' y[i] = (a[i, j, k, l] + b[i, j, k, l])*c[j, k, l] + y[i];\n'
|
||||
' }\n'
|
||||
' }\n'
|
||||
' }\n'
|
||||
'}'
|
||||
)
|
||||
c = rcode((a[i, j, k, l] + b[i, j, k, l])*c[j, k, l], assign_to=y[i])
|
||||
assert c == s
|
||||
|
||||
|
||||
def test_rcode_loops_multiple_terms():
|
||||
n, m, o, p = symbols('n m o p', integer=True)
|
||||
a = IndexedBase('a')
|
||||
b = IndexedBase('b')
|
||||
c = IndexedBase('c')
|
||||
y = IndexedBase('y')
|
||||
i = Idx('i', m)
|
||||
j = Idx('j', n)
|
||||
k = Idx('k', o)
|
||||
|
||||
s0 = (
|
||||
'for (i in 1:m){\n'
|
||||
' y[i] = 0;\n'
|
||||
'}\n'
|
||||
)
|
||||
s1 = (
|
||||
'for (i in 1:m){\n'
|
||||
' for (j in 1:n){\n'
|
||||
' for (k in 1:o){\n'
|
||||
' y[i] = b[j]*b[k]*c[i, j, k] + y[i];\n'
|
||||
' }\n'
|
||||
' }\n'
|
||||
'}\n'
|
||||
)
|
||||
s2 = (
|
||||
'for (i in 1:m){\n'
|
||||
' for (k in 1:o){\n'
|
||||
' y[i] = a[i, k]*b[k] + y[i];\n'
|
||||
' }\n'
|
||||
'}\n'
|
||||
)
|
||||
s3 = (
|
||||
'for (i in 1:m){\n'
|
||||
' for (j in 1:n){\n'
|
||||
' y[i] = a[i, j]*b[j] + y[i];\n'
|
||||
' }\n'
|
||||
'}\n'
|
||||
)
|
||||
c = rcode(
|
||||
b[j]*a[i, j] + b[k]*a[i, k] + b[j]*b[k]*c[i, j, k], assign_to=y[i])
|
||||
|
||||
ref={}
|
||||
ref[0] = s0 + s1 + s2 + s3[:-1]
|
||||
ref[1] = s0 + s1 + s3 + s2[:-1]
|
||||
ref[2] = s0 + s2 + s1 + s3[:-1]
|
||||
ref[3] = s0 + s2 + s3 + s1[:-1]
|
||||
ref[4] = s0 + s3 + s1 + s2[:-1]
|
||||
ref[5] = s0 + s3 + s2 + s1[:-1]
|
||||
|
||||
assert (c == ref[0] or
|
||||
c == ref[1] or
|
||||
c == ref[2] or
|
||||
c == ref[3] or
|
||||
c == ref[4] or
|
||||
c == ref[5])
|
||||
|
||||
|
||||
def test_dereference_printing():
|
||||
expr = x + y + sin(z) + z
|
||||
assert rcode(expr, dereference=[z]) == "x + y + (*z) + sin((*z))"
|
||||
|
||||
|
||||
def test_Matrix_printing():
|
||||
# Test returning a Matrix
|
||||
mat = Matrix([x*y, Piecewise((2 + x, y>0), (y, True)), sin(z)])
|
||||
A = MatrixSymbol('A', 3, 1)
|
||||
p = rcode(mat, A)
|
||||
assert p == (
|
||||
"A[0] = x*y;\n"
|
||||
"A[1] = ifelse(y > 0,x + 2,y);\n"
|
||||
"A[2] = sin(z);")
|
||||
# Test using MatrixElements in expressions
|
||||
expr = Piecewise((2*A[2, 0], x > 0), (A[2, 0], True)) + sin(A[1, 0]) + A[0, 0]
|
||||
p = rcode(expr)
|
||||
assert p == ("ifelse(x > 0,2*A[2],A[2]) + sin(A[1]) + A[0]")
|
||||
# Test using MatrixElements in a Matrix
|
||||
q = MatrixSymbol('q', 5, 1)
|
||||
M = MatrixSymbol('M', 3, 3)
|
||||
m = Matrix([[sin(q[1,0]), 0, cos(q[2,0])],
|
||||
[q[1,0] + q[2,0], q[3, 0], 5],
|
||||
[2*q[4, 0]/q[1,0], sqrt(q[0,0]) + 4, 0]])
|
||||
assert rcode(m, M) == (
|
||||
"M[0] = sin(q[1]);\n"
|
||||
"M[1] = 0;\n"
|
||||
"M[2] = cos(q[2]);\n"
|
||||
"M[3] = q[1] + q[2];\n"
|
||||
"M[4] = q[3];\n"
|
||||
"M[5] = 5;\n"
|
||||
"M[6] = 2*q[4]/q[1];\n"
|
||||
"M[7] = sqrt(q[0]) + 4;\n"
|
||||
"M[8] = 0;")
|
||||
|
||||
|
||||
def test_rcode_sgn():
|
||||
|
||||
expr = sign(x) * y
|
||||
assert rcode(expr) == 'y*sign(x)'
|
||||
p = rcode(expr, 'z')
|
||||
assert p == 'z = y*sign(x);'
|
||||
|
||||
p = rcode(sign(2 * x + x**2) * x + x**2)
|
||||
assert p == "x^2 + x*sign(x^2 + 2*x)"
|
||||
|
||||
expr = sign(cos(x))
|
||||
p = rcode(expr)
|
||||
assert p == 'sign(cos(x))'
|
||||
|
||||
def test_rcode_Assignment():
|
||||
assert rcode(Assignment(x, y + z)) == 'x = y + z;'
|
||||
assert rcode(aug_assign(x, '+', y + z)) == 'x += y + z;'
|
||||
|
||||
|
||||
def test_rcode_For():
|
||||
f = For(x, Range(0, 10, 2), [aug_assign(y, '*', x)])
|
||||
sol = rcode(f)
|
||||
assert sol == ("for(x in seq(from=0, to=9, by=2){\n"
|
||||
" y *= x;\n"
|
||||
"}")
|
||||
|
||||
|
||||
def test_MatrixElement_printing():
|
||||
# test cases for issue #11821
|
||||
A = MatrixSymbol("A", 1, 3)
|
||||
B = MatrixSymbol("B", 1, 3)
|
||||
C = MatrixSymbol("C", 1, 3)
|
||||
|
||||
assert(rcode(A[0, 0]) == "A[0]")
|
||||
assert(rcode(3 * A[0, 0]) == "3*A[0]")
|
||||
|
||||
F = C[0, 0].subs(C, A - B)
|
||||
assert(rcode(F) == "(A - B)[0]")
|
||||
@@ -0,0 +1,382 @@
|
||||
from __future__ import annotations
|
||||
from typing import Any
|
||||
|
||||
from sympy.external.gmpy import GROUND_TYPES
|
||||
from sympy.testing.pytest import raises, warns_deprecated_sympy
|
||||
from sympy.assumptions.ask import Q
|
||||
from sympy.core.function import (Function, WildFunction)
|
||||
from sympy.core.numbers import (AlgebraicNumber, Float, Integer, Rational)
|
||||
from sympy.core.singleton import S
|
||||
from sympy.core.symbol import (Dummy, Symbol, Wild, symbols)
|
||||
from sympy.core.sympify import sympify
|
||||
from sympy.functions.elementary.complexes import Abs
|
||||
from sympy.functions.elementary.miscellaneous import (root, sqrt)
|
||||
from sympy.functions.elementary.trigonometric import sin
|
||||
from sympy.functions.special.delta_functions import Heaviside
|
||||
from sympy.logic.boolalg import (false, true)
|
||||
from sympy.matrices.dense import (Matrix, ones)
|
||||
from sympy.matrices.expressions.matexpr import MatrixSymbol
|
||||
from sympy.matrices.immutable import ImmutableDenseMatrix
|
||||
from sympy.combinatorics import Cycle, Permutation
|
||||
from sympy.core.symbol import Str
|
||||
from sympy.geometry import Point, Ellipse
|
||||
from sympy.printing import srepr
|
||||
from sympy.polys import ring, field, ZZ, QQ, lex, grlex, Poly
|
||||
from sympy.polys.polyclasses import DMP
|
||||
from sympy.polys.agca.extensions import FiniteExtension
|
||||
|
||||
x, y = symbols('x,y')
|
||||
|
||||
# eval(srepr(expr)) == expr has to succeed in the right environment. The right
|
||||
# environment is the scope of "from sympy import *" for most cases.
|
||||
ENV: dict[str, Any] = {"Str": Str}
|
||||
exec("from sympy import *", ENV)
|
||||
|
||||
|
||||
def sT(expr, string, import_stmt=None, **kwargs):
|
||||
"""
|
||||
sT := sreprTest
|
||||
|
||||
Tests that srepr delivers the expected string and that
|
||||
the condition eval(srepr(expr))==expr holds.
|
||||
"""
|
||||
if import_stmt is None:
|
||||
ENV2 = ENV
|
||||
else:
|
||||
ENV2 = ENV.copy()
|
||||
exec(import_stmt, ENV2)
|
||||
|
||||
assert srepr(expr, **kwargs) == string
|
||||
assert eval(string, ENV2) == expr
|
||||
|
||||
|
||||
def test_printmethod():
|
||||
class R(Abs):
|
||||
def _sympyrepr(self, printer):
|
||||
return "foo(%s)" % printer._print(self.args[0])
|
||||
assert srepr(R(x)) == "foo(Symbol('x'))"
|
||||
|
||||
|
||||
def test_Add():
|
||||
sT(x + y, "Add(Symbol('x'), Symbol('y'))")
|
||||
assert srepr(x**2 + 1, order='lex') == "Add(Pow(Symbol('x'), Integer(2)), Integer(1))"
|
||||
assert srepr(x**2 + 1, order='old') == "Add(Integer(1), Pow(Symbol('x'), Integer(2)))"
|
||||
assert srepr(sympify('x + 3 - 2', evaluate=False), order='none') == "Add(Symbol('x'), Integer(3), Mul(Integer(-1), Integer(2)))"
|
||||
|
||||
|
||||
def test_more_than_255_args_issue_10259():
|
||||
from sympy.core.add import Add
|
||||
from sympy.core.mul import Mul
|
||||
for op in (Add, Mul):
|
||||
expr = op(*symbols('x:256'))
|
||||
assert eval(srepr(expr)) == expr
|
||||
|
||||
|
||||
def test_Function():
|
||||
sT(Function("f")(x), "Function('f')(Symbol('x'))")
|
||||
# test unapplied Function
|
||||
sT(Function('f'), "Function('f')")
|
||||
|
||||
sT(sin(x), "sin(Symbol('x'))")
|
||||
sT(sin, "sin")
|
||||
|
||||
|
||||
def test_Heaviside():
|
||||
sT(Heaviside(x), "Heaviside(Symbol('x'))")
|
||||
sT(Heaviside(x, 1), "Heaviside(Symbol('x'), Integer(1))")
|
||||
|
||||
|
||||
def test_Geometry():
|
||||
sT(Point(0, 0), "Point2D(Integer(0), Integer(0))")
|
||||
sT(Ellipse(Point(0, 0), 5, 1),
|
||||
"Ellipse(Point2D(Integer(0), Integer(0)), Integer(5), Integer(1))")
|
||||
# TODO more tests
|
||||
|
||||
|
||||
def test_Singletons():
|
||||
sT(S.Catalan, 'Catalan')
|
||||
sT(S.ComplexInfinity, 'zoo')
|
||||
sT(S.EulerGamma, 'EulerGamma')
|
||||
sT(S.Exp1, 'E')
|
||||
sT(S.GoldenRatio, 'GoldenRatio')
|
||||
sT(S.TribonacciConstant, 'TribonacciConstant')
|
||||
sT(S.Half, 'Rational(1, 2)')
|
||||
sT(S.ImaginaryUnit, 'I')
|
||||
sT(S.Infinity, 'oo')
|
||||
sT(S.NaN, 'nan')
|
||||
sT(S.NegativeInfinity, '-oo')
|
||||
sT(S.NegativeOne, 'Integer(-1)')
|
||||
sT(S.One, 'Integer(1)')
|
||||
sT(S.Pi, 'pi')
|
||||
sT(S.Zero, 'Integer(0)')
|
||||
sT(S.Complexes, 'Complexes')
|
||||
sT(S.EmptySequence, 'EmptySequence')
|
||||
sT(S.EmptySet, 'EmptySet')
|
||||
# sT(S.IdentityFunction, 'Lambda(_x, _x)')
|
||||
sT(S.Naturals, 'Naturals')
|
||||
sT(S.Naturals0, 'Naturals0')
|
||||
sT(S.Rationals, 'Rationals')
|
||||
sT(S.Reals, 'Reals')
|
||||
sT(S.UniversalSet, 'UniversalSet')
|
||||
|
||||
|
||||
def test_Integer():
|
||||
sT(Integer(4), "Integer(4)")
|
||||
|
||||
|
||||
def test_list():
|
||||
sT([x, Integer(4)], "[Symbol('x'), Integer(4)]")
|
||||
|
||||
|
||||
def test_Matrix():
|
||||
for cls, name in [(Matrix, "MutableDenseMatrix"), (ImmutableDenseMatrix, "ImmutableDenseMatrix")]:
|
||||
sT(cls([[x**+1, 1], [y, x + y]]),
|
||||
"%s([[Symbol('x'), Integer(1)], [Symbol('y'), Add(Symbol('x'), Symbol('y'))]])" % name)
|
||||
|
||||
sT(cls(), "%s([])" % name)
|
||||
|
||||
sT(cls([[x**+1, 1], [y, x + y]]), "%s([[Symbol('x'), Integer(1)], [Symbol('y'), Add(Symbol('x'), Symbol('y'))]])" % name)
|
||||
|
||||
|
||||
def test_empty_Matrix():
|
||||
sT(ones(0, 3), "MutableDenseMatrix(0, 3, [])")
|
||||
sT(ones(4, 0), "MutableDenseMatrix(4, 0, [])")
|
||||
sT(ones(0, 0), "MutableDenseMatrix([])")
|
||||
|
||||
|
||||
def test_Rational():
|
||||
sT(Rational(1, 3), "Rational(1, 3)")
|
||||
sT(Rational(-1, 3), "Rational(-1, 3)")
|
||||
|
||||
|
||||
def test_Float():
|
||||
sT(Float('1.23', dps=3), "Float('1.22998', precision=13)")
|
||||
sT(Float('1.23456789', dps=9), "Float('1.23456788994', precision=33)")
|
||||
sT(Float('1.234567890123456789', dps=19),
|
||||
"Float('1.234567890123456789013', precision=66)")
|
||||
sT(Float('0.60038617995049726', dps=15),
|
||||
"Float('0.60038617995049726', precision=53)")
|
||||
|
||||
sT(Float('1.23', precision=13), "Float('1.22998', precision=13)")
|
||||
sT(Float('1.23456789', precision=33),
|
||||
"Float('1.23456788994', precision=33)")
|
||||
sT(Float('1.234567890123456789', precision=66),
|
||||
"Float('1.234567890123456789013', precision=66)")
|
||||
sT(Float('0.60038617995049726', precision=53),
|
||||
"Float('0.60038617995049726', precision=53)")
|
||||
|
||||
sT(Float('0.60038617995049726', 15),
|
||||
"Float('0.60038617995049726', precision=53)")
|
||||
|
||||
|
||||
def test_Symbol():
|
||||
sT(x, "Symbol('x')")
|
||||
sT(y, "Symbol('y')")
|
||||
sT(Symbol('x', negative=True), "Symbol('x', negative=True)")
|
||||
|
||||
|
||||
def test_Symbol_two_assumptions():
|
||||
x = Symbol('x', negative=0, integer=1)
|
||||
# order could vary
|
||||
s1 = "Symbol('x', integer=True, negative=False)"
|
||||
s2 = "Symbol('x', negative=False, integer=True)"
|
||||
assert srepr(x) in (s1, s2)
|
||||
assert eval(srepr(x), ENV) == x
|
||||
|
||||
|
||||
def test_Symbol_no_special_commutative_treatment():
|
||||
sT(Symbol('x'), "Symbol('x')")
|
||||
sT(Symbol('x', commutative=False), "Symbol('x', commutative=False)")
|
||||
sT(Symbol('x', commutative=0), "Symbol('x', commutative=False)")
|
||||
sT(Symbol('x', commutative=True), "Symbol('x', commutative=True)")
|
||||
sT(Symbol('x', commutative=1), "Symbol('x', commutative=True)")
|
||||
|
||||
|
||||
def test_Wild():
|
||||
sT(Wild('x', even=True), "Wild('x', even=True)")
|
||||
|
||||
|
||||
def test_Dummy():
|
||||
d = Dummy('d')
|
||||
sT(d, "Dummy('d', dummy_index=%s)" % str(d.dummy_index))
|
||||
|
||||
|
||||
def test_Dummy_assumption():
|
||||
d = Dummy('d', nonzero=True)
|
||||
assert d == eval(srepr(d))
|
||||
s1 = "Dummy('d', dummy_index=%s, nonzero=True)" % str(d.dummy_index)
|
||||
s2 = "Dummy('d', nonzero=True, dummy_index=%s)" % str(d.dummy_index)
|
||||
assert srepr(d) in (s1, s2)
|
||||
|
||||
|
||||
def test_Dummy_from_Symbol():
|
||||
# should not get the full dictionary of assumptions
|
||||
n = Symbol('n', integer=True)
|
||||
d = n.as_dummy()
|
||||
assert srepr(d
|
||||
) == "Dummy('n', dummy_index=%s)" % str(d.dummy_index)
|
||||
|
||||
|
||||
def test_tuple():
|
||||
sT((x,), "(Symbol('x'),)")
|
||||
sT((x, y), "(Symbol('x'), Symbol('y'))")
|
||||
|
||||
|
||||
def test_WildFunction():
|
||||
sT(WildFunction('w'), "WildFunction('w')")
|
||||
|
||||
|
||||
def test_settins():
|
||||
raises(TypeError, lambda: srepr(x, method="garbage"))
|
||||
|
||||
|
||||
def test_Mul():
|
||||
sT(3*x**3*y, "Mul(Integer(3), Pow(Symbol('x'), Integer(3)), Symbol('y'))")
|
||||
assert srepr(3*x**3*y, order='old') == "Mul(Integer(3), Symbol('y'), Pow(Symbol('x'), Integer(3)))"
|
||||
assert srepr(sympify('(x+4)*2*x*7', evaluate=False), order='none') == "Mul(Add(Symbol('x'), Integer(4)), Integer(2), Symbol('x'), Integer(7))"
|
||||
|
||||
|
||||
def test_AlgebraicNumber():
|
||||
a = AlgebraicNumber(sqrt(2))
|
||||
sT(a, "AlgebraicNumber(Pow(Integer(2), Rational(1, 2)), [Integer(1), Integer(0)])")
|
||||
a = AlgebraicNumber(root(-2, 3))
|
||||
sT(a, "AlgebraicNumber(Pow(Integer(-2), Rational(1, 3)), [Integer(1), Integer(0)])")
|
||||
|
||||
|
||||
def test_PolyRing():
|
||||
assert srepr(ring("x", ZZ, lex)[0]) == "PolyRing((Symbol('x'),), ZZ, lex)"
|
||||
assert srepr(ring("x,y", QQ, grlex)[0]) == "PolyRing((Symbol('x'), Symbol('y')), QQ, grlex)"
|
||||
assert srepr(ring("x,y,z", ZZ["t"], lex)[0]) == "PolyRing((Symbol('x'), Symbol('y'), Symbol('z')), ZZ[t], lex)"
|
||||
|
||||
|
||||
def test_FracField():
|
||||
assert srepr(field("x", ZZ, lex)[0]) == "FracField((Symbol('x'),), ZZ, lex)"
|
||||
assert srepr(field("x,y", QQ, grlex)[0]) == "FracField((Symbol('x'), Symbol('y')), QQ, grlex)"
|
||||
assert srepr(field("x,y,z", ZZ["t"], lex)[0]) == "FracField((Symbol('x'), Symbol('y'), Symbol('z')), ZZ[t], lex)"
|
||||
|
||||
|
||||
def test_PolyElement():
|
||||
R, x, y = ring("x,y", ZZ)
|
||||
assert srepr(3*x**2*y + 1) == "PolyElement(PolyRing((Symbol('x'), Symbol('y')), ZZ, lex), [((2, 1), 3), ((0, 0), 1)])"
|
||||
|
||||
|
||||
def test_FracElement():
|
||||
F, x, y = field("x,y", ZZ)
|
||||
assert srepr((3*x**2*y + 1)/(x - y**2)) == "FracElement(FracField((Symbol('x'), Symbol('y')), ZZ, lex), [((2, 1), 3), ((0, 0), 1)], [((1, 0), 1), ((0, 2), -1)])"
|
||||
|
||||
|
||||
def test_FractionField():
|
||||
assert srepr(QQ.frac_field(x)) == \
|
||||
"FractionField(FracField((Symbol('x'),), QQ, lex))"
|
||||
assert srepr(QQ.frac_field(x, y, order=grlex)) == \
|
||||
"FractionField(FracField((Symbol('x'), Symbol('y')), QQ, grlex))"
|
||||
|
||||
|
||||
def test_PolynomialRingBase():
|
||||
assert srepr(ZZ.old_poly_ring(x)) == \
|
||||
"GlobalPolynomialRing(ZZ, Symbol('x'))"
|
||||
assert srepr(ZZ[x].old_poly_ring(y)) == \
|
||||
"GlobalPolynomialRing(ZZ[x], Symbol('y'))"
|
||||
assert srepr(QQ.frac_field(x).old_poly_ring(y)) == \
|
||||
"GlobalPolynomialRing(FractionField(FracField((Symbol('x'),), QQ, lex)), Symbol('y'))"
|
||||
|
||||
|
||||
def test_DMP():
|
||||
p1 = DMP([1, 2], ZZ)
|
||||
p2 = ZZ.old_poly_ring(x)([1, 2])
|
||||
if GROUND_TYPES != 'flint':
|
||||
assert srepr(p1) == "DMP_Python([1, 2], ZZ)"
|
||||
assert srepr(p2) == "DMP_Python([1, 2], ZZ)"
|
||||
else:
|
||||
assert srepr(p1) == "DUP_Flint([1, 2], ZZ)"
|
||||
assert srepr(p2) == "DUP_Flint([1, 2], ZZ)"
|
||||
|
||||
|
||||
def test_FiniteExtension():
|
||||
assert srepr(FiniteExtension(Poly(x**2 + 1, x))) == \
|
||||
"FiniteExtension(Poly(x**2 + 1, x, domain='ZZ'))"
|
||||
|
||||
|
||||
def test_ExtensionElement():
|
||||
A = FiniteExtension(Poly(x**2 + 1, x))
|
||||
if GROUND_TYPES != 'flint':
|
||||
ans = "ExtElem(DMP_Python([1, 0], ZZ), FiniteExtension(Poly(x**2 + 1, x, domain='ZZ')))"
|
||||
else:
|
||||
ans = "ExtElem(DUP_Flint([1, 0], ZZ), FiniteExtension(Poly(x**2 + 1, x, domain='ZZ')))"
|
||||
assert srepr(A.generator) == ans
|
||||
|
||||
def test_BooleanAtom():
|
||||
assert srepr(true) == "true"
|
||||
assert srepr(false) == "false"
|
||||
|
||||
|
||||
def test_Integers():
|
||||
sT(S.Integers, "Integers")
|
||||
|
||||
|
||||
def test_Naturals():
|
||||
sT(S.Naturals, "Naturals")
|
||||
|
||||
|
||||
def test_Naturals0():
|
||||
sT(S.Naturals0, "Naturals0")
|
||||
|
||||
|
||||
def test_Reals():
|
||||
sT(S.Reals, "Reals")
|
||||
|
||||
|
||||
def test_matrix_expressions():
|
||||
n = symbols('n', integer=True)
|
||||
A = MatrixSymbol("A", n, n)
|
||||
B = MatrixSymbol("B", n, n)
|
||||
sT(A, "MatrixSymbol(Str('A'), Symbol('n', integer=True), Symbol('n', integer=True))")
|
||||
sT(A*B, "MatMul(MatrixSymbol(Str('A'), Symbol('n', integer=True), Symbol('n', integer=True)), MatrixSymbol(Str('B'), Symbol('n', integer=True), Symbol('n', integer=True)))")
|
||||
sT(A + B, "MatAdd(MatrixSymbol(Str('A'), Symbol('n', integer=True), Symbol('n', integer=True)), MatrixSymbol(Str('B'), Symbol('n', integer=True), Symbol('n', integer=True)))")
|
||||
|
||||
|
||||
def test_Cycle():
|
||||
# FIXME: sT fails because Cycle is not immutable and calling srepr(Cycle(1, 2))
|
||||
# adds keys to the Cycle dict (GH-17661)
|
||||
#import_stmt = "from sympy.combinatorics import Cycle"
|
||||
#sT(Cycle(1, 2), "Cycle(1, 2)", import_stmt)
|
||||
assert srepr(Cycle(1, 2)) == "Cycle(1, 2)"
|
||||
|
||||
|
||||
def test_Permutation():
|
||||
import_stmt = "from sympy.combinatorics import Permutation"
|
||||
sT(Permutation(1, 2)(3, 4), "Permutation([0, 2, 1, 4, 3])", import_stmt, perm_cyclic=False)
|
||||
sT(Permutation(1, 2)(3, 4), "Permutation(1, 2)(3, 4)", import_stmt, perm_cyclic=True)
|
||||
|
||||
with warns_deprecated_sympy():
|
||||
old_print_cyclic = Permutation.print_cyclic
|
||||
Permutation.print_cyclic = False
|
||||
sT(Permutation(1, 2)(3, 4), "Permutation([0, 2, 1, 4, 3])", import_stmt)
|
||||
Permutation.print_cyclic = old_print_cyclic
|
||||
|
||||
def test_dict():
|
||||
from sympy.abc import x, y, z
|
||||
d = {}
|
||||
assert srepr(d) == "{}"
|
||||
d = {x: y}
|
||||
assert srepr(d) == "{Symbol('x'): Symbol('y')}"
|
||||
d = {x: y, y: z}
|
||||
assert srepr(d) in (
|
||||
"{Symbol('x'): Symbol('y'), Symbol('y'): Symbol('z')}",
|
||||
"{Symbol('y'): Symbol('z'), Symbol('x'): Symbol('y')}",
|
||||
)
|
||||
d = {x: {y: z}}
|
||||
assert srepr(d) == "{Symbol('x'): {Symbol('y'): Symbol('z')}}"
|
||||
|
||||
def test_set():
|
||||
from sympy.abc import x, y
|
||||
s = set()
|
||||
assert srepr(s) == "set()"
|
||||
s = {x, y}
|
||||
assert srepr(s) in ("{Symbol('x'), Symbol('y')}", "{Symbol('y'), Symbol('x')}")
|
||||
|
||||
def test_Predicate():
|
||||
sT(Q.even, "Q.even")
|
||||
|
||||
def test_AppliedPredicate():
|
||||
sT(Q.even(Symbol('z')), "AppliedPredicate(Q.even, Symbol('z'))")
|
||||
@@ -0,0 +1,363 @@
|
||||
from sympy.core import (S, pi, oo, symbols, Rational, Integer,
|
||||
GoldenRatio, EulerGamma, Catalan, Lambda, Dummy,
|
||||
Eq, Ne, Le, Lt, Gt, Ge, Mod)
|
||||
from sympy.functions import (Piecewise, sin, cos, Abs, exp, ceiling, sqrt,
|
||||
sign, floor)
|
||||
from sympy.logic import ITE
|
||||
from sympy.testing.pytest import raises
|
||||
from sympy.utilities.lambdify import implemented_function
|
||||
from sympy.tensor import IndexedBase, Idx
|
||||
from sympy.matrices import MatrixSymbol, SparseMatrix, Matrix
|
||||
|
||||
from sympy.printing.codeprinter import rust_code
|
||||
|
||||
x, y, z = symbols('x,y,z', integer=False, real=True)
|
||||
k, m, n = symbols('k,m,n', integer=True)
|
||||
|
||||
|
||||
def test_Integer():
|
||||
assert rust_code(Integer(42)) == "42"
|
||||
assert rust_code(Integer(-56)) == "-56"
|
||||
|
||||
|
||||
def test_Relational():
|
||||
assert rust_code(Eq(x, y)) == "x == y"
|
||||
assert rust_code(Ne(x, y)) == "x != y"
|
||||
assert rust_code(Le(x, y)) == "x <= y"
|
||||
assert rust_code(Lt(x, y)) == "x < y"
|
||||
assert rust_code(Gt(x, y)) == "x > y"
|
||||
assert rust_code(Ge(x, y)) == "x >= y"
|
||||
|
||||
|
||||
def test_Rational():
|
||||
assert rust_code(Rational(3, 7)) == "3_f64/7.0"
|
||||
assert rust_code(Rational(18, 9)) == "2"
|
||||
assert rust_code(Rational(3, -7)) == "-3_f64/7.0"
|
||||
assert rust_code(Rational(-3, -7)) == "3_f64/7.0"
|
||||
assert rust_code(x + Rational(3, 7)) == "x + 3_f64/7.0"
|
||||
assert rust_code(Rational(3, 7)*x) == "(3_f64/7.0)*x"
|
||||
|
||||
|
||||
def test_basic_ops():
|
||||
assert rust_code(x + y) == "x + y"
|
||||
assert rust_code(x - y) == "x - y"
|
||||
assert rust_code(x * y) == "x*y"
|
||||
assert rust_code(x / y) == "x*y.recip()"
|
||||
assert rust_code(-x) == "-x"
|
||||
assert rust_code(2 * x) == "2.0*x"
|
||||
assert rust_code(y + 2) == "y + 2.0"
|
||||
assert rust_code(x + n) == "n as f64 + x"
|
||||
|
||||
def test_printmethod():
|
||||
class fabs(Abs):
|
||||
def _rust_code(self, printer):
|
||||
return "%s.fabs()" % printer._print(self.args[0])
|
||||
assert rust_code(fabs(x)) == "x.fabs()"
|
||||
a = MatrixSymbol("a", 1, 3)
|
||||
assert rust_code(a[0,0]) == 'a[0]'
|
||||
|
||||
|
||||
def test_Functions():
|
||||
assert rust_code(sin(x) ** cos(x)) == "x.sin().powf(x.cos())"
|
||||
assert rust_code(abs(x)) == "x.abs()"
|
||||
assert rust_code(ceiling(x)) == "x.ceil()"
|
||||
assert rust_code(floor(x)) == "x.floor()"
|
||||
|
||||
# Automatic rewrite
|
||||
assert rust_code(Mod(x, 3)) == 'x - 3.0*((1_f64/3.0)*x).floor()'
|
||||
|
||||
|
||||
def test_Pow():
|
||||
assert rust_code(1/x) == "x.recip()"
|
||||
assert rust_code(x**-1) == rust_code(x**-1.0) == "x.recip()"
|
||||
assert rust_code(sqrt(x)) == "x.sqrt()"
|
||||
assert rust_code(x**S.Half) == rust_code(x**0.5) == "x.sqrt()"
|
||||
|
||||
assert rust_code(1/sqrt(x)) == "x.sqrt().recip()"
|
||||
assert rust_code(x**-S.Half) == rust_code(x**-0.5) == "x.sqrt().recip()"
|
||||
|
||||
assert rust_code(1/pi) == "PI.recip()"
|
||||
assert rust_code(pi**-1) == rust_code(pi**-1.0) == "PI.recip()"
|
||||
assert rust_code(pi**-0.5) == "PI.sqrt().recip()"
|
||||
|
||||
assert rust_code(x**Rational(1, 3)) == "x.cbrt()"
|
||||
assert rust_code(2**x) == "x.exp2()"
|
||||
assert rust_code(exp(x)) == "x.exp()"
|
||||
assert rust_code(x**3) == "x.powi(3)"
|
||||
assert rust_code(x**(y**3)) == "x.powf(y.powi(3))"
|
||||
assert rust_code(x**Rational(2, 3)) == "x.powf(2_f64/3.0)"
|
||||
|
||||
g = implemented_function('g', Lambda(x, 2*x))
|
||||
assert rust_code(1/(g(x)*3.5)**(x - y**x)/(x**2 + y)) == \
|
||||
"(3.5*2.0*x).powf(-x + y.powf(x))/(x.powi(2) + y)"
|
||||
_cond_cfunc = [(lambda base, exp: exp.is_integer, "dpowi", 1),
|
||||
(lambda base, exp: not exp.is_integer, "pow", 1)]
|
||||
assert rust_code(x**3, user_functions={'Pow': _cond_cfunc}) == 'x.dpowi(3)'
|
||||
assert rust_code(x**3.2, user_functions={'Pow': _cond_cfunc}) == 'x.pow(3.2)'
|
||||
|
||||
|
||||
def test_constants():
|
||||
assert rust_code(pi) == "PI"
|
||||
assert rust_code(oo) == "INFINITY"
|
||||
assert rust_code(S.Infinity) == "INFINITY"
|
||||
assert rust_code(-oo) == "NEG_INFINITY"
|
||||
assert rust_code(S.NegativeInfinity) == "NEG_INFINITY"
|
||||
assert rust_code(S.NaN) == "NAN"
|
||||
assert rust_code(exp(1)) == "E"
|
||||
assert rust_code(S.Exp1) == "E"
|
||||
|
||||
|
||||
def test_constants_other():
|
||||
assert rust_code(2*GoldenRatio) == "const GoldenRatio: f64 = %s;\n2.0*GoldenRatio" % GoldenRatio.evalf(17)
|
||||
assert rust_code(
|
||||
2*Catalan) == "const Catalan: f64 = %s;\n2.0*Catalan" % Catalan.evalf(17)
|
||||
assert rust_code(2*EulerGamma) == "const EulerGamma: f64 = %s;\n2.0*EulerGamma" % EulerGamma.evalf(17)
|
||||
|
||||
|
||||
def test_boolean():
|
||||
assert rust_code(True) == "true"
|
||||
assert rust_code(S.true) == "true"
|
||||
assert rust_code(False) == "false"
|
||||
assert rust_code(S.false) == "false"
|
||||
assert rust_code(k & m) == "k && m"
|
||||
assert rust_code(k | m) == "k || m"
|
||||
assert rust_code(~k) == "!k"
|
||||
assert rust_code(k & m & n) == "k && m && n"
|
||||
assert rust_code(k | m | n) == "k || m || n"
|
||||
assert rust_code((k & m) | n) == "n || k && m"
|
||||
assert rust_code((k | m) & n) == "n && (k || m)"
|
||||
|
||||
|
||||
def test_Piecewise():
|
||||
expr = Piecewise((x, x < 1), (x + 2, True))
|
||||
assert rust_code(expr) == (
|
||||
"if (x < 1.0) {\n"
|
||||
" x\n"
|
||||
"} else {\n"
|
||||
" x + 2.0\n"
|
||||
"}")
|
||||
assert rust_code(expr, assign_to="r") == (
|
||||
"r = if (x < 1.0) {\n"
|
||||
" x\n"
|
||||
"} else {\n"
|
||||
" x + 2.0\n"
|
||||
"};")
|
||||
assert rust_code(expr, assign_to="r", inline=True) == (
|
||||
"r = if (x < 1.0) { x } else { x + 2.0 };")
|
||||
expr = Piecewise((x, x < 1), (x + 1, x < 5), (x + 2, True))
|
||||
assert rust_code(expr, inline=True) == (
|
||||
"if (x < 1.0) { x } else if (x < 5.0) { x + 1.0 } else { x + 2.0 }")
|
||||
assert rust_code(expr, assign_to="r", inline=True) == (
|
||||
"r = if (x < 1.0) { x } else if (x < 5.0) { x + 1.0 } else { x + 2.0 };")
|
||||
assert rust_code(expr, assign_to="r") == (
|
||||
"r = if (x < 1.0) {\n"
|
||||
" x\n"
|
||||
"} else if (x < 5.0) {\n"
|
||||
" x + 1.0\n"
|
||||
"} else {\n"
|
||||
" x + 2.0\n"
|
||||
"};")
|
||||
expr = 2*Piecewise((x, x < 1), (x + 1, x < 5), (x + 2, True))
|
||||
assert rust_code(expr, inline=True) == (
|
||||
"2.0*if (x < 1.0) { x } else if (x < 5.0) { x + 1.0 } else { x + 2.0 }")
|
||||
expr = 2*Piecewise((x, x < 1), (x + 1, x < 5), (x + 2, True)) - 42
|
||||
assert rust_code(expr, inline=True) == (
|
||||
"2.0*if (x < 1.0) { x } else if (x < 5.0) { x + 1.0 } else { x + 2.0 } - 42.0")
|
||||
# Check that Piecewise without a True (default) condition error
|
||||
expr = Piecewise((x, x < 1), (x**2, x > 1), (sin(x), x > 0))
|
||||
raises(ValueError, lambda: rust_code(expr))
|
||||
|
||||
|
||||
def test_dereference_printing():
|
||||
expr = x + y + sin(z) + z
|
||||
assert rust_code(expr, dereference=[z]) == "x + y + (*z) + (*z).sin()"
|
||||
|
||||
|
||||
def test_sign():
|
||||
expr = sign(x) * y
|
||||
assert rust_code(expr) == "y*(if (x == 0.0) { 0.0 } else { (x).signum() }) as f64"
|
||||
assert rust_code(expr, assign_to='r') == "r = y*(if (x == 0.0) { 0.0 } else { (x).signum() }) as f64;"
|
||||
|
||||
expr = sign(x + y) + 42
|
||||
assert rust_code(expr) == "(if (x + y == 0.0) { 0.0 } else { (x + y).signum() }) + 42"
|
||||
assert rust_code(expr, assign_to='r') == "r = (if (x + y == 0.0) { 0.0 } else { (x + y).signum() }) + 42;"
|
||||
|
||||
expr = sign(cos(x))
|
||||
assert rust_code(expr) == "(if (x.cos() == 0.0) { 0.0 } else { (x.cos()).signum() })"
|
||||
|
||||
|
||||
def test_reserved_words():
|
||||
|
||||
x, y = symbols("x if")
|
||||
|
||||
expr = sin(y)
|
||||
assert rust_code(expr) == "if_.sin()"
|
||||
assert rust_code(expr, dereference=[y]) == "(*if_).sin()"
|
||||
assert rust_code(expr, reserved_word_suffix='_unreserved') == "if_unreserved.sin()"
|
||||
|
||||
with raises(ValueError):
|
||||
rust_code(expr, error_on_reserved=True)
|
||||
|
||||
|
||||
def test_ITE():
|
||||
ekpr = ITE(k < 1, m, n)
|
||||
assert rust_code(ekpr) == (
|
||||
"if (k < 1) {\n"
|
||||
" m\n"
|
||||
"} else {\n"
|
||||
" n\n"
|
||||
"}")
|
||||
|
||||
|
||||
def test_Indexed():
|
||||
n, m, o = symbols('n m o', integer=True)
|
||||
i, j, k = Idx('i', n), Idx('j', m), Idx('k', o)
|
||||
|
||||
x = IndexedBase('x')[j]
|
||||
assert rust_code(x) == "x[j]"
|
||||
|
||||
A = IndexedBase('A')[i, j]
|
||||
assert rust_code(A) == "A[m*i + j]"
|
||||
|
||||
B = IndexedBase('B')[i, j, k]
|
||||
assert rust_code(B) == "B[m*o*i + o*j + k]"
|
||||
|
||||
|
||||
def test_dummy_loops():
|
||||
i, m = symbols('i m', integer=True, cls=Dummy)
|
||||
x = IndexedBase('x')
|
||||
y = IndexedBase('y')
|
||||
i = Idx(i, m)
|
||||
|
||||
assert rust_code(x[i], assign_to=y[i]) == (
|
||||
"for i in 0..m {\n"
|
||||
" y[i] = x[i];\n"
|
||||
"}")
|
||||
|
||||
|
||||
def test_loops():
|
||||
m, n = symbols('m n', integer=True)
|
||||
A = IndexedBase('A')
|
||||
x = IndexedBase('x')
|
||||
y = IndexedBase('y')
|
||||
z = IndexedBase('z')
|
||||
i = Idx('i', m)
|
||||
j = Idx('j', n)
|
||||
|
||||
assert rust_code(A[i, j]*x[j], assign_to=y[i]) == (
|
||||
"for i in 0..m {\n"
|
||||
" y[i] = 0;\n"
|
||||
"}\n"
|
||||
"for i in 0..m {\n"
|
||||
" for j in 0..n {\n"
|
||||
" y[i] = A[n*i + j]*x[j] + y[i];\n"
|
||||
" }\n"
|
||||
"}")
|
||||
|
||||
assert rust_code(A[i, j]*x[j] + x[i] + z[i], assign_to=y[i]) == (
|
||||
"for i in 0..m {\n"
|
||||
" y[i] = x[i] + z[i];\n"
|
||||
"}\n"
|
||||
"for i in 0..m {\n"
|
||||
" for j in 0..n {\n"
|
||||
" y[i] = A[n*i + j]*x[j] + y[i];\n"
|
||||
" }\n"
|
||||
"}")
|
||||
|
||||
|
||||
def test_loops_multiple_contractions():
|
||||
n, m, o, p = symbols('n m o p', integer=True)
|
||||
a = IndexedBase('a')
|
||||
b = IndexedBase('b')
|
||||
y = IndexedBase('y')
|
||||
i = Idx('i', m)
|
||||
j = Idx('j', n)
|
||||
k = Idx('k', o)
|
||||
l = Idx('l', p)
|
||||
|
||||
assert rust_code(b[j, k, l]*a[i, j, k, l], assign_to=y[i]) == (
|
||||
"for i in 0..m {\n"
|
||||
" y[i] = 0;\n"
|
||||
"}\n"
|
||||
"for i in 0..m {\n"
|
||||
" for j in 0..n {\n"
|
||||
" for k in 0..o {\n"
|
||||
" for l in 0..p {\n"
|
||||
" y[i] = a[%s]*b[%s] + y[i];\n" % (i*n*o*p + j*o*p + k*p + l, j*o*p + k*p + l) +\
|
||||
" }\n"
|
||||
" }\n"
|
||||
" }\n"
|
||||
"}")
|
||||
|
||||
|
||||
def test_loops_addfactor():
|
||||
m, n, o, p = symbols('m n o p', integer=True)
|
||||
a = IndexedBase('a')
|
||||
b = IndexedBase('b')
|
||||
c = IndexedBase('c')
|
||||
y = IndexedBase('y')
|
||||
i = Idx('i', m)
|
||||
j = Idx('j', n)
|
||||
k = Idx('k', o)
|
||||
l = Idx('l', p)
|
||||
|
||||
code = rust_code((a[i, j, k, l] + b[i, j, k, l])*c[j, k, l], assign_to=y[i])
|
||||
assert code == (
|
||||
"for i in 0..m {\n"
|
||||
" y[i] = 0;\n"
|
||||
"}\n"
|
||||
"for i in 0..m {\n"
|
||||
" for j in 0..n {\n"
|
||||
" for k in 0..o {\n"
|
||||
" for l in 0..p {\n"
|
||||
" y[i] = (a[%s] + b[%s])*c[%s] + y[i];\n" % (i*n*o*p + j*o*p + k*p + l, i*n*o*p + j*o*p + k*p + l, j*o*p + k*p + l) +\
|
||||
" }\n"
|
||||
" }\n"
|
||||
" }\n"
|
||||
"}")
|
||||
|
||||
|
||||
def test_settings():
|
||||
raises(TypeError, lambda: rust_code(sin(x), method="garbage"))
|
||||
|
||||
|
||||
def test_inline_function():
|
||||
x = symbols('x')
|
||||
g = implemented_function('g', Lambda(x, 2*x))
|
||||
assert rust_code(g(x)) == "2*x"
|
||||
|
||||
g = implemented_function('g', Lambda(x, 2*x/Catalan))
|
||||
assert rust_code(g(x)) == (
|
||||
"const Catalan: f64 = %s;\n2.0*x/Catalan" % Catalan.evalf(17))
|
||||
|
||||
A = IndexedBase('A')
|
||||
i = Idx('i', symbols('n', integer=True))
|
||||
g = implemented_function('g', Lambda(x, x*(1 + x)*(2 + x)))
|
||||
assert rust_code(g(A[i]), assign_to=A[i]) == (
|
||||
"for i in 0..n {\n"
|
||||
" A[i] = (A[i] + 1)*(A[i] + 2)*A[i];\n"
|
||||
"}")
|
||||
|
||||
|
||||
def test_user_functions():
|
||||
x = symbols('x', integer=False)
|
||||
n = symbols('n', integer=True)
|
||||
custom_functions = {
|
||||
"ceiling": "ceil",
|
||||
"Abs": [(lambda x: not x.is_integer, "fabs", 4), (lambda x: x.is_integer, "abs", 4)],
|
||||
}
|
||||
assert rust_code(ceiling(x), user_functions=custom_functions) == "x.ceil()"
|
||||
assert rust_code(Abs(x), user_functions=custom_functions) == "fabs(x)"
|
||||
assert rust_code(Abs(n), user_functions=custom_functions) == "abs(n)"
|
||||
|
||||
|
||||
def test_matrix():
|
||||
assert rust_code(Matrix([1, 2, 3])) == '[1, 2, 3]'
|
||||
with raises(ValueError):
|
||||
rust_code(Matrix([[1, 2, 3]]))
|
||||
|
||||
|
||||
def test_sparse_matrix():
|
||||
# gh-15791
|
||||
with raises(NotImplementedError):
|
||||
rust_code(SparseMatrix([[1, 2, 3]]))
|
||||
@@ -0,0 +1,553 @@
|
||||
import contextlib
|
||||
import itertools
|
||||
import re
|
||||
import typing
|
||||
from enum import Enum
|
||||
from typing import Callable
|
||||
|
||||
import sympy
|
||||
from sympy import Add, Implies, sqrt
|
||||
from sympy.core import Mul, Pow
|
||||
from sympy.core import (S, pi, symbols, Function, Rational, Integer,
|
||||
Symbol, Eq, Ne, Le, Lt, Gt, Ge)
|
||||
from sympy.functions import Piecewise, exp, sin, cos
|
||||
from sympy.assumptions.ask import Q
|
||||
from sympy.printing.smtlib import smtlib_code
|
||||
from sympy.testing.pytest import raises, Failed
|
||||
|
||||
x, y, z = symbols('x,y,z')
|
||||
|
||||
|
||||
class _W(Enum):
|
||||
DEFAULTING_TO_FLOAT = re.compile("Could not infer type of `.+`. Defaulting to float.", re.IGNORECASE)
|
||||
WILL_NOT_DECLARE = re.compile("Non-Symbol/Function `.+` will not be declared.", re.IGNORECASE)
|
||||
WILL_NOT_ASSERT = re.compile("Non-Boolean expression `.+` will not be asserted. Converting to SMTLib verbatim.", re.IGNORECASE)
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def _check_warns(expected: typing.Iterable[_W]):
|
||||
warns: typing.List[str] = []
|
||||
log_warn = warns.append
|
||||
yield log_warn
|
||||
|
||||
errors = []
|
||||
for i, (w, e) in enumerate(itertools.zip_longest(warns, expected)):
|
||||
if not e:
|
||||
errors += [f"[{i}] Received unexpected warning `{w}`."]
|
||||
elif not w:
|
||||
errors += [f"[{i}] Did not receive expected warning `{e.name}`."]
|
||||
elif not e.value.match(w):
|
||||
errors += [f"[{i}] Warning `{w}` does not match expected {e.name}."]
|
||||
|
||||
if errors: raise Failed('\n'.join(errors))
|
||||
|
||||
|
||||
def test_Integer():
|
||||
with _check_warns([_W.WILL_NOT_ASSERT] * 2) as w:
|
||||
assert smtlib_code(Integer(67), log_warn=w) == "67"
|
||||
assert smtlib_code(Integer(-1), log_warn=w) == "-1"
|
||||
with _check_warns([]) as w:
|
||||
assert smtlib_code(Integer(67)) == "67"
|
||||
assert smtlib_code(Integer(-1)) == "-1"
|
||||
|
||||
|
||||
def test_Rational():
|
||||
with _check_warns([_W.WILL_NOT_ASSERT] * 4) as w:
|
||||
assert smtlib_code(Rational(3, 7), log_warn=w) == "(/ 3 7)"
|
||||
assert smtlib_code(Rational(18, 9), log_warn=w) == "2"
|
||||
assert smtlib_code(Rational(3, -7), log_warn=w) == "(/ -3 7)"
|
||||
assert smtlib_code(Rational(-3, -7), log_warn=w) == "(/ 3 7)"
|
||||
|
||||
with _check_warns([_W.DEFAULTING_TO_FLOAT, _W.WILL_NOT_ASSERT] * 2) as w:
|
||||
assert smtlib_code(x + Rational(3, 7), auto_declare=False, log_warn=w) == "(+ (/ 3 7) x)"
|
||||
assert smtlib_code(Rational(3, 7) * x, log_warn=w) == "(declare-const x Real)\n" \
|
||||
"(* (/ 3 7) x)"
|
||||
|
||||
|
||||
def test_Relational():
|
||||
with _check_warns([_W.DEFAULTING_TO_FLOAT] * 12) as w:
|
||||
assert smtlib_code(Eq(x, y), auto_declare=False, log_warn=w) == "(assert (= x y))"
|
||||
assert smtlib_code(Ne(x, y), auto_declare=False, log_warn=w) == "(assert (not (= x y)))"
|
||||
assert smtlib_code(Le(x, y), auto_declare=False, log_warn=w) == "(assert (<= x y))"
|
||||
assert smtlib_code(Lt(x, y), auto_declare=False, log_warn=w) == "(assert (< x y))"
|
||||
assert smtlib_code(Gt(x, y), auto_declare=False, log_warn=w) == "(assert (> x y))"
|
||||
assert smtlib_code(Ge(x, y), auto_declare=False, log_warn=w) == "(assert (>= x y))"
|
||||
|
||||
|
||||
def test_AppliedBinaryRelation():
|
||||
with _check_warns([_W.DEFAULTING_TO_FLOAT] * 12) as w:
|
||||
assert smtlib_code(Q.eq(x, y), auto_declare=False, log_warn=w) == "(assert (= x y))"
|
||||
assert smtlib_code(Q.ne(x, y), auto_declare=False, log_warn=w) == "(assert (not (= x y)))"
|
||||
assert smtlib_code(Q.lt(x, y), auto_declare=False, log_warn=w) == "(assert (< x y))"
|
||||
assert smtlib_code(Q.le(x, y), auto_declare=False, log_warn=w) == "(assert (<= x y))"
|
||||
assert smtlib_code(Q.gt(x, y), auto_declare=False, log_warn=w) == "(assert (> x y))"
|
||||
assert smtlib_code(Q.ge(x, y), auto_declare=False, log_warn=w) == "(assert (>= x y))"
|
||||
|
||||
raises(ValueError, lambda: smtlib_code(Q.complex(x), log_warn=w))
|
||||
|
||||
|
||||
def test_AppliedPredicate():
|
||||
with _check_warns([_W.DEFAULTING_TO_FLOAT] * 6) as w:
|
||||
assert smtlib_code(Q.positive(x), auto_declare=False, log_warn=w) == "(assert (> x 0))"
|
||||
assert smtlib_code(Q.negative(x), auto_declare=False, log_warn=w) == "(assert (< x 0))"
|
||||
assert smtlib_code(Q.zero(x), auto_declare=False, log_warn=w) == "(assert (= x 0))"
|
||||
assert smtlib_code(Q.nonpositive(x), auto_declare=False, log_warn=w) == "(assert (<= x 0))"
|
||||
assert smtlib_code(Q.nonnegative(x), auto_declare=False, log_warn=w) == "(assert (>= x 0))"
|
||||
assert smtlib_code(Q.nonzero(x), auto_declare=False, log_warn=w) == "(assert (not (= x 0)))"
|
||||
|
||||
def test_Function():
|
||||
with _check_warns([_W.DEFAULTING_TO_FLOAT, _W.WILL_NOT_ASSERT]) as w:
|
||||
assert smtlib_code(sin(x) ** cos(x), auto_declare=False, log_warn=w) == "(pow (sin x) (cos x))"
|
||||
|
||||
with _check_warns([_W.WILL_NOT_ASSERT]) as w:
|
||||
assert smtlib_code(
|
||||
abs(x),
|
||||
symbol_table={x: int, y: bool},
|
||||
known_types={int: "INTEGER_TYPE"},
|
||||
known_functions={sympy.Abs: "ABSOLUTE_VALUE_OF"},
|
||||
log_warn=w
|
||||
) == "(declare-const x INTEGER_TYPE)\n" \
|
||||
"(ABSOLUTE_VALUE_OF x)"
|
||||
|
||||
my_fun1 = Function('f1')
|
||||
with _check_warns([_W.WILL_NOT_ASSERT]) as w:
|
||||
assert smtlib_code(
|
||||
my_fun1(x),
|
||||
symbol_table={my_fun1: Callable[[bool], float]},
|
||||
log_warn=w
|
||||
) == "(declare-const x Bool)\n" \
|
||||
"(declare-fun f1 (Bool) Real)\n" \
|
||||
"(f1 x)"
|
||||
|
||||
with _check_warns([]) as w:
|
||||
assert smtlib_code(
|
||||
my_fun1(x),
|
||||
symbol_table={my_fun1: Callable[[bool], bool]},
|
||||
log_warn=w
|
||||
) == "(declare-const x Bool)\n" \
|
||||
"(declare-fun f1 (Bool) Bool)\n" \
|
||||
"(assert (f1 x))"
|
||||
|
||||
assert smtlib_code(
|
||||
Eq(my_fun1(x, z), y),
|
||||
symbol_table={my_fun1: Callable[[int, bool], bool]},
|
||||
log_warn=w
|
||||
) == "(declare-const x Int)\n" \
|
||||
"(declare-const y Bool)\n" \
|
||||
"(declare-const z Bool)\n" \
|
||||
"(declare-fun f1 (Int Bool) Bool)\n" \
|
||||
"(assert (= (f1 x z) y))"
|
||||
|
||||
assert smtlib_code(
|
||||
Eq(my_fun1(x, z), y),
|
||||
symbol_table={my_fun1: Callable[[int, bool], bool]},
|
||||
known_functions={my_fun1: "MY_KNOWN_FUN", Eq: '=='},
|
||||
log_warn=w
|
||||
) == "(declare-const x Int)\n" \
|
||||
"(declare-const y Bool)\n" \
|
||||
"(declare-const z Bool)\n" \
|
||||
"(assert (== (MY_KNOWN_FUN x z) y))"
|
||||
|
||||
with _check_warns([_W.DEFAULTING_TO_FLOAT] * 3) as w:
|
||||
assert smtlib_code(
|
||||
Eq(my_fun1(x, z), y),
|
||||
known_functions={my_fun1: "MY_KNOWN_FUN", Eq: '=='},
|
||||
log_warn=w
|
||||
) == "(declare-const x Real)\n" \
|
||||
"(declare-const y Real)\n" \
|
||||
"(declare-const z Real)\n" \
|
||||
"(assert (== (MY_KNOWN_FUN x z) y))"
|
||||
|
||||
|
||||
def test_Pow():
|
||||
with _check_warns([_W.DEFAULTING_TO_FLOAT, _W.WILL_NOT_ASSERT]) as w:
|
||||
assert smtlib_code(x ** 3, auto_declare=False, log_warn=w) == "(pow x 3)"
|
||||
with _check_warns([_W.DEFAULTING_TO_FLOAT, _W.DEFAULTING_TO_FLOAT, _W.WILL_NOT_ASSERT]) as w:
|
||||
assert smtlib_code(x ** (y ** 3), auto_declare=False, log_warn=w) == "(pow x (pow y 3))"
|
||||
with _check_warns([_W.DEFAULTING_TO_FLOAT, _W.WILL_NOT_ASSERT]) as w:
|
||||
assert smtlib_code(x ** Rational(2, 3), auto_declare=False, log_warn=w) == '(pow x (/ 2 3))'
|
||||
|
||||
a = Symbol('a', integer=True)
|
||||
b = Symbol('b', real=True)
|
||||
c = Symbol('c')
|
||||
|
||||
def g(x): return 2 * x
|
||||
|
||||
# if x=1, y=2, then expr=2.333...
|
||||
expr = 1 / (g(a) * 3.5) ** (a - b ** a) / (a ** 2 + b)
|
||||
|
||||
with _check_warns([]) as w:
|
||||
assert smtlib_code(
|
||||
[
|
||||
Eq(a < 2, c),
|
||||
Eq(b > a, c),
|
||||
c & True,
|
||||
Eq(expr, 2 + Rational(1, 3))
|
||||
],
|
||||
log_warn=w
|
||||
) == '(declare-const a Int)\n' \
|
||||
'(declare-const b Real)\n' \
|
||||
'(declare-const c Bool)\n' \
|
||||
'(assert (= (< a 2) c))\n' \
|
||||
'(assert (= (> b a) c))\n' \
|
||||
'(assert c)\n' \
|
||||
'(assert (= ' \
|
||||
'(* (pow (* 7.0 a) (+ (pow b a) (* -1 a))) (pow (+ b (pow a 2)) -1)) ' \
|
||||
'(/ 7 3)' \
|
||||
'))'
|
||||
|
||||
with _check_warns([_W.DEFAULTING_TO_FLOAT, _W.WILL_NOT_ASSERT]) as w:
|
||||
assert smtlib_code(
|
||||
Mul(-2, c, Pow(Mul(b, b, evaluate=False), -1, evaluate=False), evaluate=False),
|
||||
log_warn=w
|
||||
) == '(declare-const b Real)\n' \
|
||||
'(declare-const c Real)\n' \
|
||||
'(* -2 c (pow (* b b) -1))'
|
||||
|
||||
|
||||
def test_basic_ops():
|
||||
with _check_warns([_W.DEFAULTING_TO_FLOAT, _W.DEFAULTING_TO_FLOAT, _W.WILL_NOT_ASSERT]) as w:
|
||||
assert smtlib_code(x * y, auto_declare=False, log_warn=w) == "(* x y)"
|
||||
|
||||
with _check_warns([_W.DEFAULTING_TO_FLOAT, _W.DEFAULTING_TO_FLOAT, _W.WILL_NOT_ASSERT]) as w:
|
||||
assert smtlib_code(x + y, auto_declare=False, log_warn=w) == "(+ x y)"
|
||||
|
||||
# with _check_warns([_SmtlibWarnings.DEFAULTING_TO_FLOAT, _SmtlibWarnings.DEFAULTING_TO_FLOAT, _SmtlibWarnings.WILL_NOT_ASSERT]) as w:
|
||||
# todo: implement re-write, currently does '(+ x (* -1 y))' instead
|
||||
# assert smtlib_code(x - y, auto_declare=False, log_warn=w) == "(- x y)"
|
||||
|
||||
with _check_warns([_W.DEFAULTING_TO_FLOAT, _W.WILL_NOT_ASSERT]) as w:
|
||||
assert smtlib_code(-x, auto_declare=False, log_warn=w) == "(* -1 x)"
|
||||
|
||||
|
||||
def test_quantifier_extensions():
|
||||
from sympy.logic.boolalg import Boolean
|
||||
from sympy import Interval, Tuple, sympify
|
||||
|
||||
# start For-all quantifier class example
|
||||
class ForAll(Boolean):
|
||||
def _smtlib(self, printer):
|
||||
bound_symbol_declarations = [
|
||||
printer._s_expr(sym.name, [
|
||||
printer._known_types[printer.symbol_table[sym]],
|
||||
Interval(start, end)
|
||||
]) for sym, start, end in self.limits
|
||||
]
|
||||
return printer._s_expr('forall', [
|
||||
printer._s_expr('', bound_symbol_declarations),
|
||||
self.function
|
||||
])
|
||||
|
||||
@property
|
||||
def bound_symbols(self):
|
||||
return {s for s, _, _ in self.limits}
|
||||
|
||||
@property
|
||||
def free_symbols(self):
|
||||
bound_symbol_names = {s.name for s in self.bound_symbols}
|
||||
return {
|
||||
s for s in self.function.free_symbols
|
||||
if s.name not in bound_symbol_names
|
||||
}
|
||||
|
||||
def __new__(cls, *args):
|
||||
limits = [sympify(a) for a in args if isinstance(a, (tuple, Tuple))]
|
||||
function = [sympify(a) for a in args if isinstance(a, Boolean)]
|
||||
assert len(limits) + len(function) == len(args)
|
||||
assert len(function) == 1
|
||||
function = function[0]
|
||||
|
||||
if isinstance(function, ForAll): return ForAll.__new__(
|
||||
ForAll, *(limits + function.limits), function.function
|
||||
)
|
||||
inst = Boolean.__new__(cls)
|
||||
inst._args = tuple(limits + [function])
|
||||
inst.limits = limits
|
||||
inst.function = function
|
||||
return inst
|
||||
|
||||
# end For-All Quantifier class example
|
||||
|
||||
f = Function('f')
|
||||
with _check_warns([_W.DEFAULTING_TO_FLOAT]) as w:
|
||||
assert smtlib_code(
|
||||
ForAll((x, -42, +21), Eq(f(x), f(x))),
|
||||
symbol_table={f: Callable[[float], float]},
|
||||
log_warn=w
|
||||
) == '(assert (forall ( (x Real [-42, 21])) true))'
|
||||
|
||||
with _check_warns([_W.DEFAULTING_TO_FLOAT] * 2) as w:
|
||||
assert smtlib_code(
|
||||
ForAll(
|
||||
(x, -42, +21), (y, -100, 3),
|
||||
Implies(Eq(x, y), Eq(f(x), f(y)))
|
||||
),
|
||||
symbol_table={f: Callable[[float], float]},
|
||||
log_warn=w
|
||||
) == '(declare-fun f (Real) Real)\n' \
|
||||
'(assert (' \
|
||||
'forall ( (x Real [-42, 21]) (y Real [-100, 3])) ' \
|
||||
'(=> (= x y) (= (f x) (f y)))' \
|
||||
'))'
|
||||
|
||||
a = Symbol('a', integer=True)
|
||||
b = Symbol('b', real=True)
|
||||
c = Symbol('c')
|
||||
|
||||
with _check_warns([]) as w:
|
||||
assert smtlib_code(
|
||||
ForAll(
|
||||
(a, 2, 100), ForAll(
|
||||
(b, 2, 100),
|
||||
Implies(a < b, sqrt(a) < b) | c
|
||||
)),
|
||||
log_warn=w
|
||||
) == '(declare-const c Bool)\n' \
|
||||
'(assert (forall ( (a Int [2, 100]) (b Real [2, 100])) ' \
|
||||
'(or c (=> (< a b) (< (pow a (/ 1 2)) b)))' \
|
||||
'))'
|
||||
|
||||
|
||||
def test_mix_number_mult_symbols():
|
||||
with _check_warns([_W.WILL_NOT_ASSERT]) as w:
|
||||
assert smtlib_code(
|
||||
1 / pi,
|
||||
known_constants={pi: "MY_PI"},
|
||||
log_warn=w
|
||||
) == '(pow MY_PI -1)'
|
||||
|
||||
with _check_warns([_W.WILL_NOT_ASSERT]) as w:
|
||||
assert smtlib_code(
|
||||
[
|
||||
Eq(pi, 3.14, evaluate=False),
|
||||
1 / pi,
|
||||
],
|
||||
known_constants={pi: "MY_PI"},
|
||||
log_warn=w
|
||||
) == '(assert (= MY_PI 3.14))\n' \
|
||||
'(pow MY_PI -1)'
|
||||
|
||||
with _check_warns([_W.WILL_NOT_ASSERT]) as w:
|
||||
assert smtlib_code(
|
||||
Add(S.Zero, S.One, S.NegativeOne, S.Half,
|
||||
S.Exp1, S.Pi, S.GoldenRatio, evaluate=False),
|
||||
known_constants={
|
||||
S.Pi: 'p', S.GoldenRatio: 'g',
|
||||
S.Exp1: 'e'
|
||||
},
|
||||
known_functions={
|
||||
Add: 'plus',
|
||||
exp: 'exp'
|
||||
},
|
||||
precision=3,
|
||||
log_warn=w
|
||||
) == '(plus 0 1 -1 (/ 1 2) (exp 1) p g)'
|
||||
|
||||
with _check_warns([_W.WILL_NOT_ASSERT]) as w:
|
||||
assert smtlib_code(
|
||||
Add(S.Zero, S.One, S.NegativeOne, S.Half,
|
||||
S.Exp1, S.Pi, S.GoldenRatio, evaluate=False),
|
||||
known_constants={
|
||||
S.Pi: 'p'
|
||||
},
|
||||
known_functions={
|
||||
Add: 'plus',
|
||||
exp: 'exp'
|
||||
},
|
||||
precision=3,
|
||||
log_warn=w
|
||||
) == '(plus 0 1 -1 (/ 1 2) (exp 1) p 1.62)'
|
||||
|
||||
with _check_warns([_W.WILL_NOT_ASSERT]) as w:
|
||||
assert smtlib_code(
|
||||
Add(S.Zero, S.One, S.NegativeOne, S.Half,
|
||||
S.Exp1, S.Pi, S.GoldenRatio, evaluate=False),
|
||||
known_functions={Add: 'plus'},
|
||||
precision=3,
|
||||
log_warn=w
|
||||
) == '(plus 0 1 -1 (/ 1 2) 2.72 3.14 1.62)'
|
||||
|
||||
with _check_warns([_W.WILL_NOT_ASSERT]) as w:
|
||||
assert smtlib_code(
|
||||
Add(S.Zero, S.One, S.NegativeOne, S.Half,
|
||||
S.Exp1, S.Pi, S.GoldenRatio, evaluate=False),
|
||||
known_constants={S.Exp1: 'e'},
|
||||
known_functions={Add: 'plus'},
|
||||
precision=3,
|
||||
log_warn=w
|
||||
) == '(plus 0 1 -1 (/ 1 2) e 3.14 1.62)'
|
||||
|
||||
|
||||
def test_boolean():
|
||||
with _check_warns([]) as w:
|
||||
assert smtlib_code(x & y, log_warn=w) == '(declare-const x Bool)\n' \
|
||||
'(declare-const y Bool)\n' \
|
||||
'(assert (and x y))'
|
||||
assert smtlib_code(x | y, log_warn=w) == '(declare-const x Bool)\n' \
|
||||
'(declare-const y Bool)\n' \
|
||||
'(assert (or x y))'
|
||||
assert smtlib_code(~x, log_warn=w) == '(declare-const x Bool)\n' \
|
||||
'(assert (not x))'
|
||||
assert smtlib_code(x & y & z, log_warn=w) == '(declare-const x Bool)\n' \
|
||||
'(declare-const y Bool)\n' \
|
||||
'(declare-const z Bool)\n' \
|
||||
'(assert (and x y z))'
|
||||
|
||||
with _check_warns([_W.DEFAULTING_TO_FLOAT]) as w:
|
||||
assert smtlib_code((x & ~y) | (z > 3), log_warn=w) == '(declare-const x Bool)\n' \
|
||||
'(declare-const y Bool)\n' \
|
||||
'(declare-const z Real)\n' \
|
||||
'(assert (or (> z 3) (and x (not y))))'
|
||||
|
||||
f = Function('f')
|
||||
g = Function('g')
|
||||
h = Function('h')
|
||||
with _check_warns([_W.DEFAULTING_TO_FLOAT]) as w:
|
||||
assert smtlib_code(
|
||||
[Gt(f(x), y),
|
||||
Lt(y, g(z))],
|
||||
symbol_table={
|
||||
f: Callable[[bool], int], g: Callable[[bool], int],
|
||||
}, log_warn=w
|
||||
) == '(declare-const x Bool)\n' \
|
||||
'(declare-const y Real)\n' \
|
||||
'(declare-const z Bool)\n' \
|
||||
'(declare-fun f (Bool) Int)\n' \
|
||||
'(declare-fun g (Bool) Int)\n' \
|
||||
'(assert (> (f x) y))\n' \
|
||||
'(assert (< y (g z)))'
|
||||
|
||||
with _check_warns([]) as w:
|
||||
assert smtlib_code(
|
||||
[Eq(f(x), y),
|
||||
Lt(y, g(z))],
|
||||
symbol_table={
|
||||
f: Callable[[bool], int], g: Callable[[bool], int],
|
||||
}, log_warn=w
|
||||
) == '(declare-const x Bool)\n' \
|
||||
'(declare-const y Int)\n' \
|
||||
'(declare-const z Bool)\n' \
|
||||
'(declare-fun f (Bool) Int)\n' \
|
||||
'(declare-fun g (Bool) Int)\n' \
|
||||
'(assert (= (f x) y))\n' \
|
||||
'(assert (< y (g z)))'
|
||||
|
||||
with _check_warns([]) as w:
|
||||
assert smtlib_code(
|
||||
[Eq(f(x), y),
|
||||
Eq(g(f(x)), z),
|
||||
Eq(h(g(f(x))), x)],
|
||||
symbol_table={
|
||||
f: Callable[[float], int],
|
||||
g: Callable[[int], bool],
|
||||
h: Callable[[bool], float]
|
||||
},
|
||||
log_warn=w
|
||||
) == '(declare-const x Real)\n' \
|
||||
'(declare-const y Int)\n' \
|
||||
'(declare-const z Bool)\n' \
|
||||
'(declare-fun f (Real) Int)\n' \
|
||||
'(declare-fun g (Int) Bool)\n' \
|
||||
'(declare-fun h (Bool) Real)\n' \
|
||||
'(assert (= (f x) y))\n' \
|
||||
'(assert (= (g (f x)) z))\n' \
|
||||
'(assert (= (h (g (f x))) x))'
|
||||
|
||||
|
||||
# todo: make smtlib_code support arrays
|
||||
# def test_containers():
|
||||
# assert julia_code([1, 2, 3, [4, 5, [6, 7]], 8, [9, 10], 11]) == \
|
||||
# "Any[1, 2, 3, Any[4, 5, Any[6, 7]], 8, Any[9, 10], 11]"
|
||||
# assert julia_code((1, 2, (3, 4))) == "(1, 2, (3, 4))"
|
||||
# assert julia_code([1]) == "Any[1]"
|
||||
# assert julia_code((1,)) == "(1,)"
|
||||
# assert julia_code(Tuple(*[1, 2, 3])) == "(1, 2, 3)"
|
||||
# assert julia_code((1, x * y, (3, x ** 2))) == "(1, x .* y, (3, x .^ 2))"
|
||||
# # scalar, matrix, empty matrix and empty list
|
||||
# assert julia_code((1, eye(3), Matrix(0, 0, []), [])) == "(1, [1 0 0;\n0 1 0;\n0 0 1], zeros(0, 0), Any[])"
|
||||
|
||||
def test_smtlib_piecewise():
|
||||
with _check_warns([_W.DEFAULTING_TO_FLOAT, _W.WILL_NOT_ASSERT]) as w:
|
||||
assert smtlib_code(
|
||||
Piecewise((x, x < 1),
|
||||
(x ** 2, True)),
|
||||
auto_declare=False,
|
||||
log_warn=w
|
||||
) == '(ite (< x 1) x (pow x 2))'
|
||||
|
||||
with _check_warns([_W.DEFAULTING_TO_FLOAT, _W.WILL_NOT_ASSERT]) as w:
|
||||
assert smtlib_code(
|
||||
Piecewise((x ** 2, x < 1),
|
||||
(x ** 3, x < 2),
|
||||
(x ** 4, x < 3),
|
||||
(x ** 5, True)),
|
||||
auto_declare=False,
|
||||
log_warn=w
|
||||
) == '(ite (< x 1) (pow x 2) ' \
|
||||
'(ite (< x 2) (pow x 3) ' \
|
||||
'(ite (< x 3) (pow x 4) ' \
|
||||
'(pow x 5))))'
|
||||
|
||||
# Check that Piecewise without a True (default) condition error
|
||||
expr = Piecewise((x, x < 1), (x ** 2, x > 1), (sin(x), x > 0))
|
||||
with _check_warns([_W.DEFAULTING_TO_FLOAT, _W.WILL_NOT_ASSERT]) as w:
|
||||
raises(AssertionError, lambda: smtlib_code(expr, log_warn=w))
|
||||
|
||||
|
||||
def test_smtlib_piecewise_times_const():
|
||||
pw = Piecewise((x, x < 1), (x ** 2, True))
|
||||
with _check_warns([_W.DEFAULTING_TO_FLOAT, _W.WILL_NOT_ASSERT]) as w:
|
||||
assert smtlib_code(2 * pw, log_warn=w) == '(declare-const x Real)\n(* 2 (ite (< x 1) x (pow x 2)))'
|
||||
with _check_warns([_W.DEFAULTING_TO_FLOAT, _W.WILL_NOT_ASSERT]) as w:
|
||||
assert smtlib_code(pw / x, log_warn=w) == '(declare-const x Real)\n(* (pow x -1) (ite (< x 1) x (pow x 2)))'
|
||||
with _check_warns([_W.DEFAULTING_TO_FLOAT, _W.DEFAULTING_TO_FLOAT, _W.WILL_NOT_ASSERT]) as w:
|
||||
assert smtlib_code(pw / (x * y), log_warn=w) == '(declare-const x Real)\n(declare-const y Real)\n(* (pow x -1) (pow y -1) (ite (< x 1) x (pow x 2)))'
|
||||
with _check_warns([_W.DEFAULTING_TO_FLOAT, _W.WILL_NOT_ASSERT]) as w:
|
||||
assert smtlib_code(pw / 3, log_warn=w) == '(declare-const x Real)\n(* (/ 1 3) (ite (< x 1) x (pow x 2)))'
|
||||
|
||||
|
||||
# todo: make smtlib_code support arrays / matrices ?
|
||||
# def test_smtlib_matrix_assign_to():
|
||||
# A = Matrix([[1, 2, 3]])
|
||||
# assert smtlib_code(A, assign_to='a') == "a = [1 2 3]"
|
||||
# A = Matrix([[1, 2], [3, 4]])
|
||||
# assert smtlib_code(A, assign_to='A') == "A = [1 2;\n3 4]"
|
||||
|
||||
# def test_julia_matrix_1x1():
|
||||
# A = Matrix([[3]])
|
||||
# B = MatrixSymbol('B', 1, 1)
|
||||
# C = MatrixSymbol('C', 1, 2)
|
||||
# assert julia_code(A, assign_to=B) == "B = [3]"
|
||||
# raises(ValueError, lambda: julia_code(A, assign_to=C))
|
||||
|
||||
# def test_julia_matrix_elements():
|
||||
# A = Matrix([[x, 2, x * y]])
|
||||
# assert julia_code(A[0, 0] ** 2 + A[0, 1] + A[0, 2]) == "x .^ 2 + x .* y + 2"
|
||||
# A = MatrixSymbol('AA', 1, 3)
|
||||
# assert julia_code(A) == "AA"
|
||||
# assert julia_code(A[0, 0] ** 2 + sin(A[0, 1]) + A[0, 2]) == \
|
||||
# "sin(AA[1,2]) + AA[1,1] .^ 2 + AA[1,3]"
|
||||
# assert julia_code(sum(A)) == "AA[1,1] + AA[1,2] + AA[1,3]"
|
||||
|
||||
def test_smtlib_boolean():
|
||||
with _check_warns([]) as w:
|
||||
assert smtlib_code(True, auto_assert=False, log_warn=w) == 'true'
|
||||
assert smtlib_code(True, log_warn=w) == '(assert true)'
|
||||
assert smtlib_code(S.true, log_warn=w) == '(assert true)'
|
||||
assert smtlib_code(S.false, log_warn=w) == '(assert false)'
|
||||
assert smtlib_code(False, log_warn=w) == '(assert false)'
|
||||
assert smtlib_code(False, auto_assert=False, log_warn=w) == 'false'
|
||||
|
||||
|
||||
def test_not_supported():
|
||||
f = Function('f')
|
||||
with _check_warns([_W.DEFAULTING_TO_FLOAT, _W.WILL_NOT_ASSERT]) as w:
|
||||
raises(KeyError, lambda: smtlib_code(f(x).diff(x), symbol_table={f: Callable[[float], float]}, log_warn=w))
|
||||
with _check_warns([_W.WILL_NOT_ASSERT]) as w:
|
||||
raises(KeyError, lambda: smtlib_code(S.ComplexInfinity, log_warn=w))
|
||||
|
||||
|
||||
def test_Float():
|
||||
assert smtlib_code(0.0) == "0.0"
|
||||
assert smtlib_code(0.000000000000000003) == '(* 3.0 (pow 10 -18))'
|
||||
assert smtlib_code(5.3) == "5.3"
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,182 @@
|
||||
from sympy.core.singleton import S
|
||||
from sympy.printing.tableform import TableForm
|
||||
from sympy.printing.latex import latex
|
||||
from sympy.abc import x
|
||||
from sympy.functions.elementary.miscellaneous import sqrt
|
||||
from sympy.functions.elementary.trigonometric import sin
|
||||
from sympy.testing.pytest import raises
|
||||
|
||||
from textwrap import dedent
|
||||
|
||||
|
||||
def test_TableForm():
|
||||
s = str(TableForm([["a", "b"], ["c", "d"], ["e", 0]],
|
||||
headings="automatic"))
|
||||
assert s == (
|
||||
' | 1 2\n'
|
||||
'-------\n'
|
||||
'1 | a b\n'
|
||||
'2 | c d\n'
|
||||
'3 | e '
|
||||
)
|
||||
s = str(TableForm([["a", "b"], ["c", "d"], ["e", 0]],
|
||||
headings="automatic", wipe_zeros=False))
|
||||
assert s == dedent('''\
|
||||
| 1 2
|
||||
-------
|
||||
1 | a b
|
||||
2 | c d
|
||||
3 | e 0''')
|
||||
s = str(TableForm([[x**2, "b"], ["c", x**2], ["e", "f"]],
|
||||
headings=("automatic", None)))
|
||||
assert s == (
|
||||
'1 | x**2 b \n'
|
||||
'2 | c x**2\n'
|
||||
'3 | e f '
|
||||
)
|
||||
s = str(TableForm([["a", "b"], ["c", "d"], ["e", "f"]],
|
||||
headings=(None, "automatic")))
|
||||
assert s == dedent('''\
|
||||
1 2
|
||||
---
|
||||
a b
|
||||
c d
|
||||
e f''')
|
||||
s = str(TableForm([[5, 7], [4, 2], [10, 3]],
|
||||
headings=[["Group A", "Group B", "Group C"], ["y1", "y2"]]))
|
||||
assert s == (
|
||||
' | y1 y2\n'
|
||||
'---------------\n'
|
||||
'Group A | 5 7 \n'
|
||||
'Group B | 4 2 \n'
|
||||
'Group C | 10 3 '
|
||||
)
|
||||
raises(
|
||||
ValueError,
|
||||
lambda:
|
||||
TableForm(
|
||||
[[5, 7], [4, 2], [10, 3]],
|
||||
headings=[["Group A", "Group B", "Group C"], ["y1", "y2"]],
|
||||
alignments="middle")
|
||||
)
|
||||
s = str(TableForm([[5, 7], [4, 2], [10, 3]],
|
||||
headings=[["Group A", "Group B", "Group C"], ["y1", "y2"]],
|
||||
alignments="right"))
|
||||
assert s == dedent('''\
|
||||
| y1 y2
|
||||
---------------
|
||||
Group A | 5 7
|
||||
Group B | 4 2
|
||||
Group C | 10 3''')
|
||||
|
||||
# other alignment permutations
|
||||
d = [[1, 100], [100, 1]]
|
||||
s = TableForm(d, headings=(('xxx', 'x'), None), alignments='l')
|
||||
assert str(s) == (
|
||||
'xxx | 1 100\n'
|
||||
' x | 100 1 '
|
||||
)
|
||||
s = TableForm(d, headings=(('xxx', 'x'), None), alignments='lr')
|
||||
assert str(s) == dedent('''\
|
||||
xxx | 1 100
|
||||
x | 100 1''')
|
||||
s = TableForm(d, headings=(('xxx', 'x'), None), alignments='clr')
|
||||
assert str(s) == dedent('''\
|
||||
xxx | 1 100
|
||||
x | 100 1''')
|
||||
|
||||
s = TableForm(d, headings=(('xxx', 'x'), None))
|
||||
assert str(s) == (
|
||||
'xxx | 1 100\n'
|
||||
' x | 100 1 '
|
||||
)
|
||||
|
||||
raises(ValueError, lambda: TableForm(d, alignments='clr'))
|
||||
|
||||
#pad
|
||||
s = str(TableForm([[None, "-", 2], [1]], pad='?'))
|
||||
assert s == dedent('''\
|
||||
? - 2
|
||||
1 ? ?''')
|
||||
|
||||
|
||||
def test_TableForm_latex():
|
||||
s = latex(TableForm([[0, x**3], ["c", S.One/4], [sqrt(x), sin(x**2)]],
|
||||
wipe_zeros=True, headings=("automatic", "automatic")))
|
||||
assert s == (
|
||||
'\\begin{tabular}{r l l}\n'
|
||||
' & 1 & 2 \\\\\n'
|
||||
'\\hline\n'
|
||||
'1 & & $x^{3}$ \\\\\n'
|
||||
'2 & $c$ & $\\frac{1}{4}$ \\\\\n'
|
||||
'3 & $\\sqrt{x}$ & $\\sin{\\left(x^{2} \\right)}$ \\\\\n'
|
||||
'\\end{tabular}'
|
||||
)
|
||||
s = latex(TableForm([[0, x**3], ["c", S.One/4], [sqrt(x), sin(x**2)]],
|
||||
wipe_zeros=True, headings=("automatic", "automatic"), alignments='l'))
|
||||
assert s == (
|
||||
'\\begin{tabular}{r l l}\n'
|
||||
' & 1 & 2 \\\\\n'
|
||||
'\\hline\n'
|
||||
'1 & & $x^{3}$ \\\\\n'
|
||||
'2 & $c$ & $\\frac{1}{4}$ \\\\\n'
|
||||
'3 & $\\sqrt{x}$ & $\\sin{\\left(x^{2} \\right)}$ \\\\\n'
|
||||
'\\end{tabular}'
|
||||
)
|
||||
s = latex(TableForm([[0, x**3], ["c", S.One/4], [sqrt(x), sin(x**2)]],
|
||||
wipe_zeros=True, headings=("automatic", "automatic"), alignments='l'*3))
|
||||
assert s == (
|
||||
'\\begin{tabular}{l l l}\n'
|
||||
' & 1 & 2 \\\\\n'
|
||||
'\\hline\n'
|
||||
'1 & & $x^{3}$ \\\\\n'
|
||||
'2 & $c$ & $\\frac{1}{4}$ \\\\\n'
|
||||
'3 & $\\sqrt{x}$ & $\\sin{\\left(x^{2} \\right)}$ \\\\\n'
|
||||
'\\end{tabular}'
|
||||
)
|
||||
s = latex(TableForm([["a", x**3], ["c", S.One/4], [sqrt(x), sin(x**2)]],
|
||||
headings=("automatic", "automatic")))
|
||||
assert s == (
|
||||
'\\begin{tabular}{r l l}\n'
|
||||
' & 1 & 2 \\\\\n'
|
||||
'\\hline\n'
|
||||
'1 & $a$ & $x^{3}$ \\\\\n'
|
||||
'2 & $c$ & $\\frac{1}{4}$ \\\\\n'
|
||||
'3 & $\\sqrt{x}$ & $\\sin{\\left(x^{2} \\right)}$ \\\\\n'
|
||||
'\\end{tabular}'
|
||||
)
|
||||
s = latex(TableForm([["a", x**3], ["c", S.One/4], [sqrt(x), sin(x**2)]],
|
||||
formats=['(%s)', None], headings=("automatic", "automatic")))
|
||||
assert s == (
|
||||
'\\begin{tabular}{r l l}\n'
|
||||
' & 1 & 2 \\\\\n'
|
||||
'\\hline\n'
|
||||
'1 & (a) & $x^{3}$ \\\\\n'
|
||||
'2 & (c) & $\\frac{1}{4}$ \\\\\n'
|
||||
'3 & (sqrt(x)) & $\\sin{\\left(x^{2} \\right)}$ \\\\\n'
|
||||
'\\end{tabular}'
|
||||
)
|
||||
|
||||
def neg_in_paren(x, i, j):
|
||||
if i % 2:
|
||||
return ('(%s)' if x < 0 else '%s') % x
|
||||
else:
|
||||
pass # use default print
|
||||
s = latex(TableForm([[-1, 2], [-3, 4]],
|
||||
formats=[neg_in_paren]*2, headings=("automatic", "automatic")))
|
||||
assert s == (
|
||||
'\\begin{tabular}{r l l}\n'
|
||||
' & 1 & 2 \\\\\n'
|
||||
'\\hline\n'
|
||||
'1 & -1 & 2 \\\\\n'
|
||||
'2 & (-3) & 4 \\\\\n'
|
||||
'\\end{tabular}'
|
||||
)
|
||||
s = latex(TableForm([["a", x**3], ["c", S.One/4], [sqrt(x), sin(x**2)]]))
|
||||
assert s == (
|
||||
'\\begin{tabular}{l l}\n'
|
||||
'$a$ & $x^{3}$ \\\\\n'
|
||||
'$c$ & $\\frac{1}{4}$ \\\\\n'
|
||||
'$\\sqrt{x}$ & $\\sin{\\left(x^{2} \\right)}$ \\\\\n'
|
||||
'\\end{tabular}'
|
||||
)
|
||||
@@ -0,0 +1,493 @@
|
||||
import random
|
||||
from sympy.core.function import Derivative
|
||||
from sympy.core.symbol import symbols
|
||||
from sympy import Piecewise
|
||||
from sympy.tensor.array.expressions.array_expressions import ArrayTensorProduct, ArrayAdd, \
|
||||
PermuteDims, ArrayDiagonal
|
||||
from sympy.core.relational import Eq, Ne, Ge, Gt, Le, Lt
|
||||
from sympy.external import import_module
|
||||
from sympy.functions import \
|
||||
Abs, ceiling, exp, floor, sign, sin, asin, sqrt, cos, \
|
||||
acos, tan, atan, atan2, cosh, acosh, sinh, asinh, tanh, atanh, \
|
||||
re, im, arg, erf, loggamma, log
|
||||
from sympy.codegen.cfunctions import isnan, isinf
|
||||
from sympy.matrices import Matrix, MatrixBase, eye, randMatrix
|
||||
from sympy.matrices.expressions import \
|
||||
Determinant, HadamardProduct, Inverse, MatrixSymbol, Trace
|
||||
from sympy.printing.tensorflow import tensorflow_code
|
||||
from sympy.tensor.array.expressions.from_matrix_to_array import convert_matrix_to_array
|
||||
from sympy.utilities.lambdify import lambdify
|
||||
from sympy.testing.pytest import skip
|
||||
from sympy.testing.pytest import XFAIL
|
||||
|
||||
|
||||
tf = tensorflow = import_module("tensorflow")
|
||||
|
||||
if tensorflow:
|
||||
# Hide Tensorflow warnings
|
||||
import os
|
||||
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
|
||||
|
||||
|
||||
M = MatrixSymbol("M", 3, 3)
|
||||
N = MatrixSymbol("N", 3, 3)
|
||||
P = MatrixSymbol("P", 3, 3)
|
||||
Q = MatrixSymbol("Q", 3, 3)
|
||||
|
||||
x, y, z, t = symbols("x y z t")
|
||||
|
||||
if tf is not None:
|
||||
llo = [list(range(i, i+3)) for i in range(0, 9, 3)]
|
||||
m3x3 = tf.constant(llo)
|
||||
m3x3sympy = Matrix(llo)
|
||||
|
||||
|
||||
def _compare_tensorflow_matrix(variables, expr, use_float=False):
|
||||
f = lambdify(variables, expr, 'tensorflow')
|
||||
if not use_float:
|
||||
random_matrices = [randMatrix(v.rows, v.cols) for v in variables]
|
||||
else:
|
||||
random_matrices = [randMatrix(v.rows, v.cols)/100. for v in variables]
|
||||
|
||||
graph = tf.Graph()
|
||||
r = None
|
||||
with graph.as_default():
|
||||
random_variables = [eval(tensorflow_code(i)) for i in random_matrices]
|
||||
session = tf.compat.v1.Session(graph=graph)
|
||||
r = session.run(f(*random_variables))
|
||||
|
||||
e = expr.subs(dict(zip(variables, random_matrices)))
|
||||
e = e.doit()
|
||||
if e.is_Matrix:
|
||||
if not isinstance(e, MatrixBase):
|
||||
e = e.as_explicit()
|
||||
e = e.tolist()
|
||||
|
||||
if not use_float:
|
||||
assert (r == e).all()
|
||||
else:
|
||||
r = [i for row in r for i in row]
|
||||
e = [i for row in e for i in row]
|
||||
assert all(
|
||||
abs(a-b) < 10**-(4-int(log(abs(a), 10))) for a, b in zip(r, e))
|
||||
|
||||
|
||||
# Creating a custom inverse test.
|
||||
# See https://github.com/sympy/sympy/issues/18469
|
||||
def _compare_tensorflow_matrix_inverse(variables, expr, use_float=False):
|
||||
f = lambdify(variables, expr, 'tensorflow')
|
||||
if not use_float:
|
||||
random_matrices = [eye(v.rows, v.cols)*4 for v in variables]
|
||||
else:
|
||||
random_matrices = [eye(v.rows, v.cols)*3.14 for v in variables]
|
||||
|
||||
graph = tf.Graph()
|
||||
r = None
|
||||
with graph.as_default():
|
||||
random_variables = [eval(tensorflow_code(i)) for i in random_matrices]
|
||||
session = tf.compat.v1.Session(graph=graph)
|
||||
r = session.run(f(*random_variables))
|
||||
|
||||
e = expr.subs(dict(zip(variables, random_matrices)))
|
||||
e = e.doit()
|
||||
if e.is_Matrix:
|
||||
if not isinstance(e, MatrixBase):
|
||||
e = e.as_explicit()
|
||||
e = e.tolist()
|
||||
|
||||
if not use_float:
|
||||
assert (r == e).all()
|
||||
else:
|
||||
r = [i for row in r for i in row]
|
||||
e = [i for row in e for i in row]
|
||||
assert all(
|
||||
abs(a-b) < 10**-(4-int(log(abs(a), 10))) for a, b in zip(r, e))
|
||||
|
||||
|
||||
def _compare_tensorflow_matrix_scalar(variables, expr):
|
||||
f = lambdify(variables, expr, 'tensorflow')
|
||||
random_matrices = [
|
||||
randMatrix(v.rows, v.cols).evalf() / 100 for v in variables]
|
||||
|
||||
graph = tf.Graph()
|
||||
r = None
|
||||
with graph.as_default():
|
||||
random_variables = [eval(tensorflow_code(i)) for i in random_matrices]
|
||||
session = tf.compat.v1.Session(graph=graph)
|
||||
r = session.run(f(*random_variables))
|
||||
|
||||
e = expr.subs(dict(zip(variables, random_matrices)))
|
||||
e = e.doit()
|
||||
assert abs(r-e) < 10**-6
|
||||
|
||||
|
||||
def _compare_tensorflow_scalar(
|
||||
variables, expr, rng=lambda: random.randint(0, 10)):
|
||||
f = lambdify(variables, expr, 'tensorflow')
|
||||
rvs = [rng() for v in variables]
|
||||
|
||||
graph = tf.Graph()
|
||||
r = None
|
||||
with graph.as_default():
|
||||
tf_rvs = [eval(tensorflow_code(i)) for i in rvs]
|
||||
session = tf.compat.v1.Session(graph=graph)
|
||||
r = session.run(f(*tf_rvs))
|
||||
|
||||
e = expr.subs(dict(zip(variables, rvs))).evalf().doit()
|
||||
assert abs(r-e) < 10**-6
|
||||
|
||||
|
||||
def _compare_tensorflow_relational(
|
||||
variables, expr, rng=lambda: random.randint(0, 10)):
|
||||
f = lambdify(variables, expr, 'tensorflow')
|
||||
rvs = [rng() for v in variables]
|
||||
|
||||
graph = tf.Graph()
|
||||
r = None
|
||||
with graph.as_default():
|
||||
tf_rvs = [eval(tensorflow_code(i)) for i in rvs]
|
||||
session = tf.compat.v1.Session(graph=graph)
|
||||
r = session.run(f(*tf_rvs))
|
||||
|
||||
e = expr.subs(dict(zip(variables, rvs))).doit()
|
||||
assert r == e
|
||||
|
||||
|
||||
def test_tensorflow_printing():
|
||||
assert tensorflow_code(eye(3)) == \
|
||||
"tensorflow.constant([[1, 0, 0], [0, 1, 0], [0, 0, 1]])"
|
||||
|
||||
expr = Matrix([[x, sin(y)], [exp(z), -t]])
|
||||
assert tensorflow_code(expr) == \
|
||||
"tensorflow.Variable(" \
|
||||
"[[x, tensorflow.math.sin(y)]," \
|
||||
" [tensorflow.math.exp(z), -t]])"
|
||||
|
||||
|
||||
# This (random) test is XFAIL because it fails occasionally
|
||||
# See https://github.com/sympy/sympy/issues/18469
|
||||
@XFAIL
|
||||
def test_tensorflow_math():
|
||||
if not tf:
|
||||
skip("TensorFlow not installed")
|
||||
|
||||
expr = Abs(x)
|
||||
assert tensorflow_code(expr) == "tensorflow.math.abs(x)"
|
||||
_compare_tensorflow_scalar((x,), expr)
|
||||
|
||||
expr = sign(x)
|
||||
assert tensorflow_code(expr) == "tensorflow.math.sign(x)"
|
||||
_compare_tensorflow_scalar((x,), expr)
|
||||
|
||||
expr = ceiling(x)
|
||||
assert tensorflow_code(expr) == "tensorflow.math.ceil(x)"
|
||||
_compare_tensorflow_scalar((x,), expr, rng=lambda: random.random())
|
||||
|
||||
expr = floor(x)
|
||||
assert tensorflow_code(expr) == "tensorflow.math.floor(x)"
|
||||
_compare_tensorflow_scalar((x,), expr, rng=lambda: random.random())
|
||||
|
||||
expr = exp(x)
|
||||
assert tensorflow_code(expr) == "tensorflow.math.exp(x)"
|
||||
_compare_tensorflow_scalar((x,), expr, rng=lambda: random.random())
|
||||
|
||||
expr = sqrt(x)
|
||||
assert tensorflow_code(expr) == "tensorflow.math.sqrt(x)"
|
||||
_compare_tensorflow_scalar((x,), expr, rng=lambda: random.random())
|
||||
|
||||
expr = x ** 4
|
||||
assert tensorflow_code(expr) == "tensorflow.math.pow(x, 4)"
|
||||
_compare_tensorflow_scalar((x,), expr, rng=lambda: random.random())
|
||||
|
||||
expr = cos(x)
|
||||
assert tensorflow_code(expr) == "tensorflow.math.cos(x)"
|
||||
_compare_tensorflow_scalar((x,), expr, rng=lambda: random.random())
|
||||
|
||||
expr = acos(x)
|
||||
assert tensorflow_code(expr) == "tensorflow.math.acos(x)"
|
||||
_compare_tensorflow_scalar((x,), expr, rng=lambda: random.uniform(0, 0.95))
|
||||
|
||||
expr = sin(x)
|
||||
assert tensorflow_code(expr) == "tensorflow.math.sin(x)"
|
||||
_compare_tensorflow_scalar((x,), expr, rng=lambda: random.random())
|
||||
|
||||
expr = asin(x)
|
||||
assert tensorflow_code(expr) == "tensorflow.math.asin(x)"
|
||||
_compare_tensorflow_scalar((x,), expr, rng=lambda: random.random())
|
||||
|
||||
expr = tan(x)
|
||||
assert tensorflow_code(expr) == "tensorflow.math.tan(x)"
|
||||
_compare_tensorflow_scalar((x,), expr, rng=lambda: random.random())
|
||||
|
||||
expr = atan(x)
|
||||
assert tensorflow_code(expr) == "tensorflow.math.atan(x)"
|
||||
_compare_tensorflow_scalar((x,), expr, rng=lambda: random.random())
|
||||
|
||||
expr = atan2(y, x)
|
||||
assert tensorflow_code(expr) == "tensorflow.math.atan2(y, x)"
|
||||
_compare_tensorflow_scalar((y, x), expr, rng=lambda: random.random())
|
||||
|
||||
expr = cosh(x)
|
||||
assert tensorflow_code(expr) == "tensorflow.math.cosh(x)"
|
||||
_compare_tensorflow_scalar((x,), expr, rng=lambda: random.random())
|
||||
|
||||
expr = acosh(x)
|
||||
assert tensorflow_code(expr) == "tensorflow.math.acosh(x)"
|
||||
_compare_tensorflow_scalar((x,), expr, rng=lambda: random.uniform(1, 2))
|
||||
|
||||
expr = sinh(x)
|
||||
assert tensorflow_code(expr) == "tensorflow.math.sinh(x)"
|
||||
_compare_tensorflow_scalar((x,), expr, rng=lambda: random.uniform(1, 2))
|
||||
|
||||
expr = asinh(x)
|
||||
assert tensorflow_code(expr) == "tensorflow.math.asinh(x)"
|
||||
_compare_tensorflow_scalar((x,), expr, rng=lambda: random.uniform(1, 2))
|
||||
|
||||
expr = tanh(x)
|
||||
assert tensorflow_code(expr) == "tensorflow.math.tanh(x)"
|
||||
_compare_tensorflow_scalar((x,), expr, rng=lambda: random.uniform(1, 2))
|
||||
|
||||
expr = atanh(x)
|
||||
assert tensorflow_code(expr) == "tensorflow.math.atanh(x)"
|
||||
_compare_tensorflow_scalar(
|
||||
(x,), expr, rng=lambda: random.uniform(-.5, .5))
|
||||
|
||||
expr = erf(x)
|
||||
assert tensorflow_code(expr) == "tensorflow.math.erf(x)"
|
||||
_compare_tensorflow_scalar(
|
||||
(x,), expr, rng=lambda: random.random())
|
||||
|
||||
expr = loggamma(x)
|
||||
assert tensorflow_code(expr) == "tensorflow.math.lgamma(x)"
|
||||
_compare_tensorflow_scalar(
|
||||
(x,), expr, rng=lambda: random.random())
|
||||
|
||||
|
||||
def test_tensorflow_complexes():
|
||||
assert tensorflow_code(re(x)) == "tensorflow.math.real(x)"
|
||||
assert tensorflow_code(im(x)) == "tensorflow.math.imag(x)"
|
||||
assert tensorflow_code(arg(x)) == "tensorflow.math.angle(x)"
|
||||
|
||||
|
||||
def test_tensorflow_relational():
|
||||
if not tf:
|
||||
skip("TensorFlow not installed")
|
||||
|
||||
expr = Eq(x, y)
|
||||
assert tensorflow_code(expr) == "tensorflow.math.equal(x, y)"
|
||||
_compare_tensorflow_relational((x, y), expr)
|
||||
|
||||
expr = Ne(x, y)
|
||||
assert tensorflow_code(expr) == "tensorflow.math.not_equal(x, y)"
|
||||
_compare_tensorflow_relational((x, y), expr)
|
||||
|
||||
expr = Ge(x, y)
|
||||
assert tensorflow_code(expr) == "tensorflow.math.greater_equal(x, y)"
|
||||
_compare_tensorflow_relational((x, y), expr)
|
||||
|
||||
expr = Gt(x, y)
|
||||
assert tensorflow_code(expr) == "tensorflow.math.greater(x, y)"
|
||||
_compare_tensorflow_relational((x, y), expr)
|
||||
|
||||
expr = Le(x, y)
|
||||
assert tensorflow_code(expr) == "tensorflow.math.less_equal(x, y)"
|
||||
_compare_tensorflow_relational((x, y), expr)
|
||||
|
||||
expr = Lt(x, y)
|
||||
assert tensorflow_code(expr) == "tensorflow.math.less(x, y)"
|
||||
_compare_tensorflow_relational((x, y), expr)
|
||||
|
||||
|
||||
# This (random) test is XFAIL because it fails occasionally
|
||||
# See https://github.com/sympy/sympy/issues/18469
|
||||
@XFAIL
|
||||
def test_tensorflow_matrices():
|
||||
if not tf:
|
||||
skip("TensorFlow not installed")
|
||||
|
||||
expr = M
|
||||
assert tensorflow_code(expr) == "M"
|
||||
_compare_tensorflow_matrix((M,), expr)
|
||||
|
||||
expr = M + N
|
||||
assert tensorflow_code(expr) == "tensorflow.math.add(M, N)"
|
||||
_compare_tensorflow_matrix((M, N), expr)
|
||||
|
||||
expr = M * N
|
||||
assert tensorflow_code(expr) == "tensorflow.linalg.matmul(M, N)"
|
||||
_compare_tensorflow_matrix((M, N), expr)
|
||||
|
||||
expr = HadamardProduct(M, N)
|
||||
assert tensorflow_code(expr) == "tensorflow.math.multiply(M, N)"
|
||||
_compare_tensorflow_matrix((M, N), expr)
|
||||
|
||||
expr = M*N*P*Q
|
||||
assert tensorflow_code(expr) == \
|
||||
"tensorflow.linalg.matmul(" \
|
||||
"tensorflow.linalg.matmul(" \
|
||||
"tensorflow.linalg.matmul(M, N), P), Q)"
|
||||
_compare_tensorflow_matrix((M, N, P, Q), expr)
|
||||
|
||||
expr = M**3
|
||||
assert tensorflow_code(expr) == \
|
||||
"tensorflow.linalg.matmul(tensorflow.linalg.matmul(M, M), M)"
|
||||
_compare_tensorflow_matrix((M,), expr)
|
||||
|
||||
expr = Trace(M)
|
||||
assert tensorflow_code(expr) == "tensorflow.linalg.trace(M)"
|
||||
_compare_tensorflow_matrix((M,), expr)
|
||||
|
||||
expr = Determinant(M)
|
||||
assert tensorflow_code(expr) == "tensorflow.linalg.det(M)"
|
||||
_compare_tensorflow_matrix_scalar((M,), expr)
|
||||
|
||||
expr = Inverse(M)
|
||||
assert tensorflow_code(expr) == "tensorflow.linalg.inv(M)"
|
||||
_compare_tensorflow_matrix_inverse((M,), expr, use_float=True)
|
||||
|
||||
expr = M.T
|
||||
assert tensorflow_code(expr, tensorflow_version='1.14') == \
|
||||
"tensorflow.linalg.matrix_transpose(M)"
|
||||
assert tensorflow_code(expr, tensorflow_version='1.13') == \
|
||||
"tensorflow.matrix_transpose(M)"
|
||||
|
||||
_compare_tensorflow_matrix((M,), expr)
|
||||
|
||||
|
||||
def test_codegen_einsum():
|
||||
if not tf:
|
||||
skip("TensorFlow not installed")
|
||||
|
||||
graph = tf.Graph()
|
||||
with graph.as_default():
|
||||
session = tf.compat.v1.Session(graph=graph)
|
||||
|
||||
M = MatrixSymbol("M", 2, 2)
|
||||
N = MatrixSymbol("N", 2, 2)
|
||||
|
||||
cg = convert_matrix_to_array(M * N)
|
||||
f = lambdify((M, N), cg, 'tensorflow')
|
||||
|
||||
ma = tf.constant([[1, 2], [3, 4]])
|
||||
mb = tf.constant([[1,-2], [-1, 3]])
|
||||
y = session.run(f(ma, mb))
|
||||
c = session.run(tf.matmul(ma, mb))
|
||||
assert (y == c).all()
|
||||
|
||||
|
||||
def test_codegen_extra():
|
||||
if not tf:
|
||||
skip("TensorFlow not installed")
|
||||
|
||||
graph = tf.Graph()
|
||||
with graph.as_default():
|
||||
session = tf.compat.v1.Session()
|
||||
|
||||
M = MatrixSymbol("M", 2, 2)
|
||||
N = MatrixSymbol("N", 2, 2)
|
||||
P = MatrixSymbol("P", 2, 2)
|
||||
Q = MatrixSymbol("Q", 2, 2)
|
||||
ma = tf.constant([[1, 2], [3, 4]])
|
||||
mb = tf.constant([[1,-2], [-1, 3]])
|
||||
mc = tf.constant([[2, 0], [1, 2]])
|
||||
md = tf.constant([[1,-1], [4, 7]])
|
||||
|
||||
cg = ArrayTensorProduct(M, N)
|
||||
assert tensorflow_code(cg) == \
|
||||
'tensorflow.linalg.einsum("ab,cd", M, N)'
|
||||
f = lambdify((M, N), cg, 'tensorflow')
|
||||
y = session.run(f(ma, mb))
|
||||
c = session.run(tf.einsum("ij,kl", ma, mb))
|
||||
assert (y == c).all()
|
||||
|
||||
cg = ArrayAdd(M, N)
|
||||
assert tensorflow_code(cg) == 'tensorflow.math.add(M, N)'
|
||||
f = lambdify((M, N), cg, 'tensorflow')
|
||||
y = session.run(f(ma, mb))
|
||||
c = session.run(ma + mb)
|
||||
assert (y == c).all()
|
||||
|
||||
cg = ArrayAdd(M, N, P)
|
||||
assert tensorflow_code(cg) == \
|
||||
'tensorflow.math.add(tensorflow.math.add(M, N), P)'
|
||||
f = lambdify((M, N, P), cg, 'tensorflow')
|
||||
y = session.run(f(ma, mb, mc))
|
||||
c = session.run(ma + mb + mc)
|
||||
assert (y == c).all()
|
||||
|
||||
cg = ArrayAdd(M, N, P, Q)
|
||||
assert tensorflow_code(cg) == \
|
||||
'tensorflow.math.add(' \
|
||||
'tensorflow.math.add(tensorflow.math.add(M, N), P), Q)'
|
||||
f = lambdify((M, N, P, Q), cg, 'tensorflow')
|
||||
y = session.run(f(ma, mb, mc, md))
|
||||
c = session.run(ma + mb + mc + md)
|
||||
assert (y == c).all()
|
||||
|
||||
cg = PermuteDims(M, [1, 0])
|
||||
assert tensorflow_code(cg) == 'tensorflow.transpose(M, [1, 0])'
|
||||
f = lambdify((M,), cg, 'tensorflow')
|
||||
y = session.run(f(ma))
|
||||
c = session.run(tf.transpose(ma))
|
||||
assert (y == c).all()
|
||||
|
||||
cg = PermuteDims(ArrayTensorProduct(M, N), [1, 2, 3, 0])
|
||||
assert tensorflow_code(cg) == \
|
||||
'tensorflow.transpose(' \
|
||||
'tensorflow.linalg.einsum("ab,cd", M, N), [1, 2, 3, 0])'
|
||||
f = lambdify((M, N), cg, 'tensorflow')
|
||||
y = session.run(f(ma, mb))
|
||||
c = session.run(tf.transpose(tf.einsum("ab,cd", ma, mb), [1, 2, 3, 0]))
|
||||
assert (y == c).all()
|
||||
|
||||
cg = ArrayDiagonal(ArrayTensorProduct(M, N), (1, 2))
|
||||
assert tensorflow_code(cg) == \
|
||||
'tensorflow.linalg.einsum("ab,bc->acb", M, N)'
|
||||
f = lambdify((M, N), cg, 'tensorflow')
|
||||
y = session.run(f(ma, mb))
|
||||
c = session.run(tf.einsum("ab,bc->acb", ma, mb))
|
||||
assert (y == c).all()
|
||||
|
||||
|
||||
def test_MatrixElement_printing():
|
||||
A = MatrixSymbol("A", 1, 3)
|
||||
B = MatrixSymbol("B", 1, 3)
|
||||
C = MatrixSymbol("C", 1, 3)
|
||||
|
||||
assert tensorflow_code(A[0, 0]) == "A[0, 0]"
|
||||
assert tensorflow_code(3 * A[0, 0]) == "3*A[0, 0]"
|
||||
|
||||
F = C[0, 0].subs(C, A - B)
|
||||
assert tensorflow_code(F) == "(tensorflow.math.add((-1)*B, A))[0, 0]"
|
||||
|
||||
|
||||
def test_tensorflow_Derivative():
|
||||
expr = Derivative(sin(x), x)
|
||||
assert tensorflow_code(expr) == \
|
||||
"tensorflow.gradients(tensorflow.math.sin(x), x)[0]"
|
||||
|
||||
def test_tensorflow_isnan_isinf():
|
||||
if not tf:
|
||||
skip("TensorFlow not installed")
|
||||
|
||||
# Test for isnan
|
||||
x = symbols("x")
|
||||
# Return 0 if x is of nan value, and 1 otherwise
|
||||
expression = Piecewise((0.0, isnan(x)), (1.0, True))
|
||||
printed_code = tensorflow_code(expression)
|
||||
expected_printed_code = "tensorflow.where(tensorflow.math.is_nan(x), 0.0, 1.0)"
|
||||
assert tensorflow_code(expression) == expected_printed_code, f"Incorrect printed result {printed_code}, expected {expected_printed_code}"
|
||||
for _input, _expected in [(float('nan'), 0.0), (float('inf'), 1.0), (float('-inf'), 1.0), (1.0, 1.0)]:
|
||||
_output = lambdify((x), expression, modules="tensorflow")(x=tf.constant([_input]))
|
||||
assert (_output == _expected).numpy().all()
|
||||
|
||||
# Test for isinf
|
||||
x = symbols("x")
|
||||
# Return 0 if x is of nan value, and 1 otherwise
|
||||
expression = Piecewise((0.0, isinf(x)), (1.0, True))
|
||||
printed_code = tensorflow_code(expression)
|
||||
expected_printed_code = "tensorflow.where(tensorflow.math.is_inf(x), 0.0, 1.0)"
|
||||
assert tensorflow_code(expression) == expected_printed_code, f"Incorrect printed result {printed_code}, expected {expected_printed_code}"
|
||||
for _input, _expected in [(float('inf'), 0.0), (float('-inf'), 0.0), (float('nan'), 1.0), (1.0, 1.0)]:
|
||||
_output = lambdify((x), expression, modules="tensorflow")(x=tf.constant([_input]))
|
||||
assert (_output == _expected).numpy().all()
|
||||
@@ -0,0 +1,639 @@
|
||||
"""
|
||||
Important note on tests in this module - the Theano printing functions use a
|
||||
global cache by default, which means that tests using it will modify global
|
||||
state and thus not be independent from each other. Instead of using the "cache"
|
||||
keyword argument each time, this module uses the theano_code_ and
|
||||
theano_function_ functions defined below which default to using a new, empty
|
||||
cache instead.
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
||||
from sympy.external import import_module
|
||||
from sympy.testing.pytest import raises, SKIP, warns_deprecated_sympy
|
||||
|
||||
theanologger = logging.getLogger('theano.configdefaults')
|
||||
theanologger.setLevel(logging.CRITICAL)
|
||||
theano = import_module('theano')
|
||||
theanologger.setLevel(logging.WARNING)
|
||||
|
||||
|
||||
if theano:
|
||||
import numpy as np
|
||||
ts = theano.scalar
|
||||
tt = theano.tensor
|
||||
xt, yt, zt = [tt.scalar(name, 'floatX') for name in 'xyz']
|
||||
Xt, Yt, Zt = [tt.tensor('floatX', (False, False), name=n) for n in 'XYZ']
|
||||
else:
|
||||
#bin/test will not execute any tests now
|
||||
disabled = True
|
||||
|
||||
import sympy as sy
|
||||
from sympy.core.singleton import S
|
||||
from sympy.abc import x, y, z, t
|
||||
from sympy.printing.theanocode import (theano_code, dim_handling,
|
||||
theano_function)
|
||||
|
||||
|
||||
# Default set of matrix symbols for testing - make square so we can both
|
||||
# multiply and perform elementwise operations between them.
|
||||
X, Y, Z = [sy.MatrixSymbol(n, 4, 4) for n in 'XYZ']
|
||||
|
||||
# For testing AppliedUndef
|
||||
f_t = sy.Function('f')(t)
|
||||
|
||||
|
||||
def theano_code_(expr, **kwargs):
|
||||
""" Wrapper for theano_code that uses a new, empty cache by default. """
|
||||
kwargs.setdefault('cache', {})
|
||||
with warns_deprecated_sympy():
|
||||
return theano_code(expr, **kwargs)
|
||||
|
||||
def theano_function_(inputs, outputs, **kwargs):
|
||||
""" Wrapper for theano_function that uses a new, empty cache by default. """
|
||||
kwargs.setdefault('cache', {})
|
||||
with warns_deprecated_sympy():
|
||||
return theano_function(inputs, outputs, **kwargs)
|
||||
|
||||
|
||||
def fgraph_of(*exprs):
|
||||
""" Transform SymPy expressions into Theano Computation.
|
||||
|
||||
Parameters
|
||||
==========
|
||||
exprs
|
||||
SymPy expressions
|
||||
|
||||
Returns
|
||||
=======
|
||||
theano.gof.FunctionGraph
|
||||
"""
|
||||
outs = list(map(theano_code_, exprs))
|
||||
ins = theano.gof.graph.inputs(outs)
|
||||
ins, outs = theano.gof.graph.clone(ins, outs)
|
||||
return theano.gof.FunctionGraph(ins, outs)
|
||||
|
||||
|
||||
def theano_simplify(fgraph):
|
||||
""" Simplify a Theano Computation.
|
||||
|
||||
Parameters
|
||||
==========
|
||||
fgraph : theano.gof.FunctionGraph
|
||||
|
||||
Returns
|
||||
=======
|
||||
theano.gof.FunctionGraph
|
||||
"""
|
||||
mode = theano.compile.get_default_mode().excluding("fusion")
|
||||
fgraph = fgraph.clone()
|
||||
mode.optimizer.optimize(fgraph)
|
||||
return fgraph
|
||||
|
||||
|
||||
def theq(a, b):
|
||||
""" Test two Theano objects for equality.
|
||||
|
||||
Also accepts numeric types and lists/tuples of supported types.
|
||||
|
||||
Note - debugprint() has a bug where it will accept numeric types but does
|
||||
not respect the "file" argument and in this case and instead prints the number
|
||||
to stdout and returns an empty string. This can lead to tests passing where
|
||||
they should fail because any two numbers will always compare as equal. To
|
||||
prevent this we treat numbers as a separate case.
|
||||
"""
|
||||
numeric_types = (int, float, np.number)
|
||||
a_is_num = isinstance(a, numeric_types)
|
||||
b_is_num = isinstance(b, numeric_types)
|
||||
|
||||
# Compare numeric types using regular equality
|
||||
if a_is_num or b_is_num:
|
||||
if not (a_is_num and b_is_num):
|
||||
return False
|
||||
|
||||
return a == b
|
||||
|
||||
# Compare sequences element-wise
|
||||
a_is_seq = isinstance(a, (tuple, list))
|
||||
b_is_seq = isinstance(b, (tuple, list))
|
||||
|
||||
if a_is_seq or b_is_seq:
|
||||
if not (a_is_seq and b_is_seq) or type(a) != type(b):
|
||||
return False
|
||||
|
||||
return list(map(theq, a)) == list(map(theq, b))
|
||||
|
||||
# Otherwise, assume debugprint() can handle it
|
||||
astr = theano.printing.debugprint(a, file='str')
|
||||
bstr = theano.printing.debugprint(b, file='str')
|
||||
|
||||
# Check for bug mentioned above
|
||||
for argname, argval, argstr in [('a', a, astr), ('b', b, bstr)]:
|
||||
if argstr == '':
|
||||
raise TypeError(
|
||||
'theano.printing.debugprint(%s) returned empty string '
|
||||
'(%s is instance of %r)'
|
||||
% (argname, argname, type(argval))
|
||||
)
|
||||
|
||||
return astr == bstr
|
||||
|
||||
|
||||
def test_example_symbols():
|
||||
"""
|
||||
Check that the example symbols in this module print to their Theano
|
||||
equivalents, as many of the other tests depend on this.
|
||||
"""
|
||||
assert theq(xt, theano_code_(x))
|
||||
assert theq(yt, theano_code_(y))
|
||||
assert theq(zt, theano_code_(z))
|
||||
assert theq(Xt, theano_code_(X))
|
||||
assert theq(Yt, theano_code_(Y))
|
||||
assert theq(Zt, theano_code_(Z))
|
||||
|
||||
|
||||
def test_Symbol():
|
||||
""" Test printing a Symbol to a theano variable. """
|
||||
xx = theano_code_(x)
|
||||
assert isinstance(xx, (tt.TensorVariable, ts.ScalarVariable))
|
||||
assert xx.broadcastable == ()
|
||||
assert xx.name == x.name
|
||||
|
||||
xx2 = theano_code_(x, broadcastables={x: (False,)})
|
||||
assert xx2.broadcastable == (False,)
|
||||
assert xx2.name == x.name
|
||||
|
||||
def test_MatrixSymbol():
|
||||
""" Test printing a MatrixSymbol to a theano variable. """
|
||||
XX = theano_code_(X)
|
||||
assert isinstance(XX, tt.TensorVariable)
|
||||
assert XX.broadcastable == (False, False)
|
||||
|
||||
@SKIP # TODO - this is currently not checked but should be implemented
|
||||
def test_MatrixSymbol_wrong_dims():
|
||||
""" Test MatrixSymbol with invalid broadcastable. """
|
||||
bcs = [(), (False,), (True,), (True, False), (False, True,), (True, True)]
|
||||
for bc in bcs:
|
||||
with raises(ValueError):
|
||||
theano_code_(X, broadcastables={X: bc})
|
||||
|
||||
def test_AppliedUndef():
|
||||
""" Test printing AppliedUndef instance, which works similarly to Symbol. """
|
||||
ftt = theano_code_(f_t)
|
||||
assert isinstance(ftt, tt.TensorVariable)
|
||||
assert ftt.broadcastable == ()
|
||||
assert ftt.name == 'f_t'
|
||||
|
||||
|
||||
def test_add():
|
||||
expr = x + y
|
||||
comp = theano_code_(expr)
|
||||
assert comp.owner.op == theano.tensor.add
|
||||
|
||||
def test_trig():
|
||||
assert theq(theano_code_(sy.sin(x)), tt.sin(xt))
|
||||
assert theq(theano_code_(sy.tan(x)), tt.tan(xt))
|
||||
|
||||
def test_many():
|
||||
""" Test printing a complex expression with multiple symbols. """
|
||||
expr = sy.exp(x**2 + sy.cos(y)) * sy.log(2*z)
|
||||
comp = theano_code_(expr)
|
||||
expected = tt.exp(xt**2 + tt.cos(yt)) * tt.log(2*zt)
|
||||
assert theq(comp, expected)
|
||||
|
||||
|
||||
def test_dtype():
|
||||
""" Test specifying specific data types through the dtype argument. """
|
||||
for dtype in ['float32', 'float64', 'int8', 'int16', 'int32', 'int64']:
|
||||
assert theano_code_(x, dtypes={x: dtype}).type.dtype == dtype
|
||||
|
||||
# "floatX" type
|
||||
assert theano_code_(x, dtypes={x: 'floatX'}).type.dtype in ('float32', 'float64')
|
||||
|
||||
# Type promotion
|
||||
assert theano_code_(x + 1, dtypes={x: 'float32'}).type.dtype == 'float32'
|
||||
assert theano_code_(x + y, dtypes={x: 'float64', y: 'float32'}).type.dtype == 'float64'
|
||||
|
||||
|
||||
def test_broadcastables():
|
||||
""" Test the "broadcastables" argument when printing symbol-like objects. """
|
||||
|
||||
# No restrictions on shape
|
||||
for s in [x, f_t]:
|
||||
for bc in [(), (False,), (True,), (False, False), (True, False)]:
|
||||
assert theano_code_(s, broadcastables={s: bc}).broadcastable == bc
|
||||
|
||||
# TODO - matrix broadcasting?
|
||||
|
||||
def test_broadcasting():
|
||||
""" Test "broadcastable" attribute after applying element-wise binary op. """
|
||||
|
||||
expr = x + y
|
||||
|
||||
cases = [
|
||||
[(), (), ()],
|
||||
[(False,), (False,), (False,)],
|
||||
[(True,), (False,), (False,)],
|
||||
[(False, True), (False, False), (False, False)],
|
||||
[(True, False), (False, False), (False, False)],
|
||||
]
|
||||
|
||||
for bc1, bc2, bc3 in cases:
|
||||
comp = theano_code_(expr, broadcastables={x: bc1, y: bc2})
|
||||
assert comp.broadcastable == bc3
|
||||
|
||||
|
||||
def test_MatMul():
|
||||
expr = X*Y*Z
|
||||
expr_t = theano_code_(expr)
|
||||
assert isinstance(expr_t.owner.op, tt.Dot)
|
||||
assert theq(expr_t, Xt.dot(Yt).dot(Zt))
|
||||
|
||||
def test_Transpose():
|
||||
assert isinstance(theano_code_(X.T).owner.op, tt.DimShuffle)
|
||||
|
||||
def test_MatAdd():
|
||||
expr = X+Y+Z
|
||||
assert isinstance(theano_code_(expr).owner.op, tt.Elemwise)
|
||||
|
||||
|
||||
def test_Rationals():
|
||||
assert theq(theano_code_(sy.Integer(2) / 3), tt.true_div(2, 3))
|
||||
assert theq(theano_code_(S.Half), tt.true_div(1, 2))
|
||||
|
||||
def test_Integers():
|
||||
assert theano_code_(sy.Integer(3)) == 3
|
||||
|
||||
def test_factorial():
|
||||
n = sy.Symbol('n')
|
||||
assert theano_code_(sy.factorial(n))
|
||||
|
||||
def test_Derivative():
|
||||
simp = lambda expr: theano_simplify(fgraph_of(expr))
|
||||
assert theq(simp(theano_code_(sy.Derivative(sy.sin(x), x, evaluate=False))),
|
||||
simp(theano.grad(tt.sin(xt), xt)))
|
||||
|
||||
|
||||
def test_theano_function_simple():
|
||||
""" Test theano_function() with single output. """
|
||||
f = theano_function_([x, y], [x+y])
|
||||
assert f(2, 3) == 5
|
||||
|
||||
def test_theano_function_multi():
|
||||
""" Test theano_function() with multiple outputs. """
|
||||
f = theano_function_([x, y], [x+y, x-y])
|
||||
o1, o2 = f(2, 3)
|
||||
assert o1 == 5
|
||||
assert o2 == -1
|
||||
|
||||
def test_theano_function_numpy():
|
||||
""" Test theano_function() vs Numpy implementation. """
|
||||
f = theano_function_([x, y], [x+y], dim=1,
|
||||
dtypes={x: 'float64', y: 'float64'})
|
||||
assert np.linalg.norm(f([1, 2], [3, 4]) - np.asarray([4, 6])) < 1e-9
|
||||
|
||||
f = theano_function_([x, y], [x+y], dtypes={x: 'float64', y: 'float64'},
|
||||
dim=1)
|
||||
xx = np.arange(3).astype('float64')
|
||||
yy = 2*np.arange(3).astype('float64')
|
||||
assert np.linalg.norm(f(xx, yy) - 3*np.arange(3)) < 1e-9
|
||||
|
||||
|
||||
def test_theano_function_matrix():
|
||||
m = sy.Matrix([[x, y], [z, x + y + z]])
|
||||
expected = np.array([[1.0, 2.0], [3.0, 1.0 + 2.0 + 3.0]])
|
||||
f = theano_function_([x, y, z], [m])
|
||||
np.testing.assert_allclose(f(1.0, 2.0, 3.0), expected)
|
||||
f = theano_function_([x, y, z], [m], scalar=True)
|
||||
np.testing.assert_allclose(f(1.0, 2.0, 3.0), expected)
|
||||
f = theano_function_([x, y, z], [m, m])
|
||||
assert isinstance(f(1.0, 2.0, 3.0), type([]))
|
||||
np.testing.assert_allclose(f(1.0, 2.0, 3.0)[0], expected)
|
||||
np.testing.assert_allclose(f(1.0, 2.0, 3.0)[1], expected)
|
||||
|
||||
def test_dim_handling():
|
||||
assert dim_handling([x], dim=2) == {x: (False, False)}
|
||||
assert dim_handling([x, y], dims={x: 1, y: 2}) == {x: (False, True),
|
||||
y: (False, False)}
|
||||
assert dim_handling([x], broadcastables={x: (False,)}) == {x: (False,)}
|
||||
|
||||
def test_theano_function_kwargs():
|
||||
"""
|
||||
Test passing additional kwargs from theano_function() to theano.function().
|
||||
"""
|
||||
import numpy as np
|
||||
f = theano_function_([x, y, z], [x+y], dim=1, on_unused_input='ignore',
|
||||
dtypes={x: 'float64', y: 'float64', z: 'float64'})
|
||||
assert np.linalg.norm(f([1, 2], [3, 4], [0, 0]) - np.asarray([4, 6])) < 1e-9
|
||||
|
||||
f = theano_function_([x, y, z], [x+y],
|
||||
dtypes={x: 'float64', y: 'float64', z: 'float64'},
|
||||
dim=1, on_unused_input='ignore')
|
||||
xx = np.arange(3).astype('float64')
|
||||
yy = 2*np.arange(3).astype('float64')
|
||||
zz = 2*np.arange(3).astype('float64')
|
||||
assert np.linalg.norm(f(xx, yy, zz) - 3*np.arange(3)) < 1e-9
|
||||
|
||||
def test_theano_function_scalar():
|
||||
""" Test the "scalar" argument to theano_function(). """
|
||||
|
||||
args = [
|
||||
([x, y], [x + y], None, [0]), # Single 0d output
|
||||
([X, Y], [X + Y], None, [2]), # Single 2d output
|
||||
([x, y], [x + y], {x: 0, y: 1}, [1]), # Single 1d output
|
||||
([x, y], [x + y, x - y], None, [0, 0]), # Two 0d outputs
|
||||
([x, y, X, Y], [x + y, X + Y], None, [0, 2]), # One 0d output, one 2d
|
||||
]
|
||||
|
||||
# Create and test functions with and without the scalar setting
|
||||
for inputs, outputs, in_dims, out_dims in args:
|
||||
for scalar in [False, True]:
|
||||
|
||||
f = theano_function_(inputs, outputs, dims=in_dims, scalar=scalar)
|
||||
|
||||
# Check the theano_function attribute is set whether wrapped or not
|
||||
assert isinstance(f.theano_function, theano.compile.function_module.Function)
|
||||
|
||||
# Feed in inputs of the appropriate size and get outputs
|
||||
in_values = [
|
||||
np.ones([1 if bc else 5 for bc in i.type.broadcastable])
|
||||
for i in f.theano_function.input_storage
|
||||
]
|
||||
out_values = f(*in_values)
|
||||
if not isinstance(out_values, list):
|
||||
out_values = [out_values]
|
||||
|
||||
# Check output types and shapes
|
||||
assert len(out_dims) == len(out_values)
|
||||
for d, value in zip(out_dims, out_values):
|
||||
|
||||
if scalar and d == 0:
|
||||
# Should have been converted to a scalar value
|
||||
assert isinstance(value, np.number)
|
||||
|
||||
else:
|
||||
# Otherwise should be an array
|
||||
assert isinstance(value, np.ndarray)
|
||||
assert value.ndim == d
|
||||
|
||||
def test_theano_function_bad_kwarg():
|
||||
"""
|
||||
Passing an unknown keyword argument to theano_function() should raise an
|
||||
exception.
|
||||
"""
|
||||
raises(Exception, lambda : theano_function_([x], [x+1], foobar=3))
|
||||
|
||||
|
||||
def test_slice():
|
||||
assert theano_code_(slice(1, 2, 3)) == slice(1, 2, 3)
|
||||
|
||||
def theq_slice(s1, s2):
|
||||
for attr in ['start', 'stop', 'step']:
|
||||
a1 = getattr(s1, attr)
|
||||
a2 = getattr(s2, attr)
|
||||
if a1 is None or a2 is None:
|
||||
if not (a1 is None or a2 is None):
|
||||
return False
|
||||
elif not theq(a1, a2):
|
||||
return False
|
||||
return True
|
||||
|
||||
dtypes = {x: 'int32', y: 'int32'}
|
||||
assert theq_slice(theano_code_(slice(x, y), dtypes=dtypes), slice(xt, yt))
|
||||
assert theq_slice(theano_code_(slice(1, x, 3), dtypes=dtypes), slice(1, xt, 3))
|
||||
|
||||
def test_MatrixSlice():
|
||||
from theano import Constant
|
||||
|
||||
cache = {}
|
||||
|
||||
n = sy.Symbol('n', integer=True)
|
||||
X = sy.MatrixSymbol('X', n, n)
|
||||
|
||||
Y = X[1:2:3, 4:5:6]
|
||||
Yt = theano_code_(Y, cache=cache)
|
||||
|
||||
s = ts.Scalar('int64')
|
||||
assert tuple(Yt.owner.op.idx_list) == (slice(s, s, s), slice(s, s, s))
|
||||
assert Yt.owner.inputs[0] == theano_code_(X, cache=cache)
|
||||
# == doesn't work in theano like it does in SymPy. You have to use
|
||||
# equals.
|
||||
assert all(Yt.owner.inputs[i].equals(Constant(s, i)) for i in range(1, 7))
|
||||
|
||||
k = sy.Symbol('k')
|
||||
theano_code_(k, dtypes={k: 'int32'})
|
||||
start, stop, step = 4, k, 2
|
||||
Y = X[start:stop:step]
|
||||
Yt = theano_code_(Y, dtypes={n: 'int32', k: 'int32'})
|
||||
# assert Yt.owner.op.idx_list[0].stop == kt
|
||||
|
||||
def test_BlockMatrix():
|
||||
n = sy.Symbol('n', integer=True)
|
||||
A, B, C, D = [sy.MatrixSymbol(name, n, n) for name in 'ABCD']
|
||||
At, Bt, Ct, Dt = map(theano_code_, (A, B, C, D))
|
||||
Block = sy.BlockMatrix([[A, B], [C, D]])
|
||||
Blockt = theano_code_(Block)
|
||||
solutions = [tt.join(0, tt.join(1, At, Bt), tt.join(1, Ct, Dt)),
|
||||
tt.join(1, tt.join(0, At, Ct), tt.join(0, Bt, Dt))]
|
||||
assert any(theq(Blockt, solution) for solution in solutions)
|
||||
|
||||
@SKIP
|
||||
def test_BlockMatrix_Inverse_execution():
|
||||
k, n = 2, 4
|
||||
dtype = 'float32'
|
||||
A = sy.MatrixSymbol('A', n, k)
|
||||
B = sy.MatrixSymbol('B', n, n)
|
||||
inputs = A, B
|
||||
output = B.I*A
|
||||
|
||||
cutsizes = {A: [(n//2, n//2), (k//2, k//2)],
|
||||
B: [(n//2, n//2), (n//2, n//2)]}
|
||||
cutinputs = [sy.blockcut(i, *cutsizes[i]) for i in inputs]
|
||||
cutoutput = output.subs(dict(zip(inputs, cutinputs)))
|
||||
|
||||
dtypes = dict(zip(inputs, [dtype]*len(inputs)))
|
||||
f = theano_function_(inputs, [output], dtypes=dtypes, cache={})
|
||||
fblocked = theano_function_(inputs, [sy.block_collapse(cutoutput)],
|
||||
dtypes=dtypes, cache={})
|
||||
|
||||
ninputs = [np.random.rand(*x.shape).astype(dtype) for x in inputs]
|
||||
ninputs = [np.arange(n*k).reshape(A.shape).astype(dtype),
|
||||
np.eye(n).astype(dtype)]
|
||||
ninputs[1] += np.ones(B.shape)*1e-5
|
||||
|
||||
assert np.allclose(f(*ninputs), fblocked(*ninputs), rtol=1e-5)
|
||||
|
||||
def test_DenseMatrix():
|
||||
t = sy.Symbol('theta')
|
||||
for MatrixType in [sy.Matrix, sy.ImmutableMatrix]:
|
||||
X = MatrixType([[sy.cos(t), -sy.sin(t)], [sy.sin(t), sy.cos(t)]])
|
||||
tX = theano_code_(X)
|
||||
assert isinstance(tX, tt.TensorVariable)
|
||||
assert tX.owner.op == tt.join_
|
||||
|
||||
|
||||
def test_cache_basic():
|
||||
""" Test single symbol-like objects are cached when printed by themselves. """
|
||||
|
||||
# Pairs of objects which should be considered equivalent with respect to caching
|
||||
pairs = [
|
||||
(x, sy.Symbol('x')),
|
||||
(X, sy.MatrixSymbol('X', *X.shape)),
|
||||
(f_t, sy.Function('f')(sy.Symbol('t'))),
|
||||
]
|
||||
|
||||
for s1, s2 in pairs:
|
||||
cache = {}
|
||||
st = theano_code_(s1, cache=cache)
|
||||
|
||||
# Test hit with same instance
|
||||
assert theano_code_(s1, cache=cache) is st
|
||||
|
||||
# Test miss with same instance but new cache
|
||||
assert theano_code_(s1, cache={}) is not st
|
||||
|
||||
# Test hit with different but equivalent instance
|
||||
assert theano_code_(s2, cache=cache) is st
|
||||
|
||||
def test_global_cache():
|
||||
""" Test use of the global cache. """
|
||||
from sympy.printing.theanocode import global_cache
|
||||
|
||||
backup = dict(global_cache)
|
||||
try:
|
||||
# Temporarily empty global cache
|
||||
global_cache.clear()
|
||||
|
||||
for s in [x, X, f_t]:
|
||||
with warns_deprecated_sympy():
|
||||
st = theano_code(s)
|
||||
assert theano_code(s) is st
|
||||
|
||||
finally:
|
||||
# Restore global cache
|
||||
global_cache.update(backup)
|
||||
|
||||
def test_cache_types_distinct():
|
||||
"""
|
||||
Test that symbol-like objects of different types (Symbol, MatrixSymbol,
|
||||
AppliedUndef) are distinguished by the cache even if they have the same
|
||||
name.
|
||||
"""
|
||||
symbols = [sy.Symbol('f_t'), sy.MatrixSymbol('f_t', 4, 4), f_t]
|
||||
|
||||
cache = {} # Single shared cache
|
||||
printed = {}
|
||||
|
||||
for s in symbols:
|
||||
st = theano_code_(s, cache=cache)
|
||||
assert st not in printed.values()
|
||||
printed[s] = st
|
||||
|
||||
# Check all printed objects are distinct
|
||||
assert len(set(map(id, printed.values()))) == len(symbols)
|
||||
|
||||
# Check retrieving
|
||||
for s, st in printed.items():
|
||||
with warns_deprecated_sympy():
|
||||
assert theano_code(s, cache=cache) is st
|
||||
|
||||
def test_symbols_are_created_once():
|
||||
"""
|
||||
Test that a symbol is cached and reused when it appears in an expression
|
||||
more than once.
|
||||
"""
|
||||
expr = sy.Add(x, x, evaluate=False)
|
||||
comp = theano_code_(expr)
|
||||
|
||||
assert theq(comp, xt + xt)
|
||||
assert not theq(comp, xt + theano_code_(x))
|
||||
|
||||
def test_cache_complex():
|
||||
"""
|
||||
Test caching on a complicated expression with multiple symbols appearing
|
||||
multiple times.
|
||||
"""
|
||||
expr = x ** 2 + (y - sy.exp(x)) * sy.sin(z - x * y)
|
||||
symbol_names = {s.name for s in expr.free_symbols}
|
||||
expr_t = theano_code_(expr)
|
||||
|
||||
# Iterate through variables in the Theano computational graph that the
|
||||
# printed expression depends on
|
||||
seen = set()
|
||||
for v in theano.gof.graph.ancestors([expr_t]):
|
||||
# Owner-less, non-constant variables should be our symbols
|
||||
if v.owner is None and not isinstance(v, theano.gof.graph.Constant):
|
||||
# Check it corresponds to a symbol and appears only once
|
||||
assert v.name in symbol_names
|
||||
assert v.name not in seen
|
||||
seen.add(v.name)
|
||||
|
||||
# Check all were present
|
||||
assert seen == symbol_names
|
||||
|
||||
|
||||
def test_Piecewise():
|
||||
# A piecewise linear
|
||||
expr = sy.Piecewise((0, x<0), (x, x<2), (1, True)) # ___/III
|
||||
result = theano_code_(expr)
|
||||
assert result.owner.op == tt.switch
|
||||
|
||||
expected = tt.switch(xt<0, 0, tt.switch(xt<2, xt, 1))
|
||||
assert theq(result, expected)
|
||||
|
||||
expr = sy.Piecewise((x, x < 0))
|
||||
result = theano_code_(expr)
|
||||
expected = tt.switch(xt < 0, xt, np.nan)
|
||||
assert theq(result, expected)
|
||||
|
||||
expr = sy.Piecewise((0, sy.And(x>0, x<2)), \
|
||||
(x, sy.Or(x>2, x<0)))
|
||||
result = theano_code_(expr)
|
||||
expected = tt.switch(tt.and_(xt>0,xt<2), 0, \
|
||||
tt.switch(tt.or_(xt>2, xt<0), xt, np.nan))
|
||||
assert theq(result, expected)
|
||||
|
||||
|
||||
def test_Relationals():
|
||||
assert theq(theano_code_(sy.Eq(x, y)), tt.eq(xt, yt))
|
||||
# assert theq(theano_code_(sy.Ne(x, y)), tt.neq(xt, yt)) # TODO - implement
|
||||
assert theq(theano_code_(x > y), xt > yt)
|
||||
assert theq(theano_code_(x < y), xt < yt)
|
||||
assert theq(theano_code_(x >= y), xt >= yt)
|
||||
assert theq(theano_code_(x <= y), xt <= yt)
|
||||
|
||||
|
||||
def test_complexfunctions():
|
||||
with warns_deprecated_sympy():
|
||||
xt, yt = theano_code_(x, dtypes={x:'complex128'}), theano_code_(y, dtypes={y: 'complex128'})
|
||||
from sympy.functions.elementary.complexes import conjugate
|
||||
from theano.tensor import as_tensor_variable as atv
|
||||
from theano.tensor import complex as cplx
|
||||
with warns_deprecated_sympy():
|
||||
assert theq(theano_code_(y*conjugate(x)), yt*(xt.conj()))
|
||||
assert theq(theano_code_((1+2j)*x), xt*(atv(1.0)+atv(2.0)*cplx(0,1)))
|
||||
|
||||
|
||||
def test_constantfunctions():
|
||||
with warns_deprecated_sympy():
|
||||
tf = theano_function_([],[1+1j])
|
||||
assert(tf()==1+1j)
|
||||
|
||||
|
||||
def test_Exp1():
|
||||
"""
|
||||
Test that exp(1) prints without error and evaluates close to SymPy's E
|
||||
"""
|
||||
# sy.exp(1) should yield same instance of E as sy.E (singleton), but extra
|
||||
# check added for sanity
|
||||
e_a = sy.exp(1)
|
||||
e_b = sy.E
|
||||
|
||||
np.testing.assert_allclose(float(e_a), np.e)
|
||||
np.testing.assert_allclose(float(e_b), np.e)
|
||||
|
||||
e = theano_code_(e_a)
|
||||
np.testing.assert_allclose(float(e_a), e.eval())
|
||||
|
||||
e = theano_code_(e_b)
|
||||
np.testing.assert_allclose(float(e_b), e.eval())
|
||||
@@ -0,0 +1,531 @@
|
||||
import random
|
||||
import math
|
||||
|
||||
from sympy import symbols, Derivative
|
||||
from sympy.printing.pytorch import torch_code
|
||||
from sympy import (eye, MatrixSymbol, Matrix)
|
||||
from sympy.tensor.array import NDimArray
|
||||
from sympy.tensor.array.expressions.array_expressions import (
|
||||
ArrayTensorProduct, ArrayAdd,
|
||||
PermuteDims, ArrayDiagonal, _CodegenArrayAbstract)
|
||||
from sympy.utilities.lambdify import lambdify
|
||||
from sympy.core.relational import Eq, Ne, Ge, Gt, Le, Lt
|
||||
from sympy.functions import \
|
||||
Abs, ceiling, exp, floor, sign, sin, asin, cos, \
|
||||
acos, tan, atan, atan2, cosh, acosh, sinh, asinh, tanh, atanh, \
|
||||
re, im, arg, erf, loggamma, sqrt
|
||||
from sympy.testing.pytest import skip
|
||||
from sympy.external import import_module
|
||||
from sympy.matrices.expressions import \
|
||||
Determinant, HadamardProduct, Inverse, Trace
|
||||
from sympy.matrices import randMatrix
|
||||
from sympy.matrices import Identity, ZeroMatrix, OneMatrix
|
||||
from sympy import conjugate, I
|
||||
from sympy import Heaviside, gamma, polygamma
|
||||
|
||||
|
||||
|
||||
torch = import_module("torch")
|
||||
|
||||
M = MatrixSymbol("M", 3, 3)
|
||||
N = MatrixSymbol("N", 3, 3)
|
||||
P = MatrixSymbol("P", 3, 3)
|
||||
Q = MatrixSymbol("Q", 3, 3)
|
||||
|
||||
x, y, z, t = symbols("x y z t")
|
||||
|
||||
if torch is not None:
|
||||
llo = [list(range(i, i + 3)) for i in range(0, 9, 3)]
|
||||
m3x3 = torch.tensor(llo, dtype=torch.float64)
|
||||
m3x3sympy = Matrix(llo)
|
||||
|
||||
|
||||
def _compare_torch_matrix(variables, expr):
|
||||
f = lambdify(variables, expr, 'torch')
|
||||
|
||||
random_matrices = [randMatrix(i.shape[0], i.shape[1]) for i in variables]
|
||||
random_variables = [torch.tensor(i.tolist(), dtype=torch.float64) for i in random_matrices]
|
||||
r = f(*random_variables)
|
||||
e = expr.subs(dict(zip(variables, random_matrices))).doit()
|
||||
|
||||
if isinstance(e, _CodegenArrayAbstract):
|
||||
e = e.doit()
|
||||
|
||||
if hasattr(e, 'is_number') and e.is_number:
|
||||
if isinstance(r, torch.Tensor) and r.dim() == 0:
|
||||
r = r.item()
|
||||
e = float(e)
|
||||
assert abs(r - e) < 1e-6
|
||||
return
|
||||
|
||||
if e.is_Matrix or isinstance(e, NDimArray):
|
||||
e = torch.tensor(e.tolist(), dtype=torch.float64)
|
||||
assert torch.allclose(r, e, atol=1e-6)
|
||||
else:
|
||||
raise TypeError(f"Cannot compare {type(r)} with {type(e)}")
|
||||
|
||||
|
||||
def _compare_torch_scalar(variables, expr, rng=lambda: random.uniform(-5, 5)):
|
||||
f = lambdify(variables, expr, 'torch')
|
||||
rvs = [rng() for v in variables]
|
||||
t_rvs = [torch.tensor(i, dtype=torch.float64) for i in rvs]
|
||||
r = f(*t_rvs)
|
||||
if isinstance(r, torch.Tensor):
|
||||
r = r.item()
|
||||
e = expr.subs(dict(zip(variables, rvs))).doit()
|
||||
assert abs(r - e) < 1e-6
|
||||
|
||||
|
||||
def _compare_torch_relational(variables, expr, rng=lambda: random.randint(0, 10)):
|
||||
f = lambdify(variables, expr, 'torch')
|
||||
rvs = [rng() for v in variables]
|
||||
t_rvs = [torch.tensor(i, dtype=torch.float64) for i in rvs]
|
||||
r = f(*t_rvs)
|
||||
e = bool(expr.subs(dict(zip(variables, rvs))).doit())
|
||||
assert r.item() == e
|
||||
|
||||
|
||||
def test_torch_math():
|
||||
if not torch:
|
||||
skip("PyTorch not installed")
|
||||
|
||||
expr = Abs(x)
|
||||
assert torch_code(expr) == "torch.abs(x)"
|
||||
f = lambdify(x, expr, 'torch')
|
||||
ma = torch.tensor([[-1, 2, -3, -4]], dtype=torch.float64)
|
||||
y_abs = f(ma)
|
||||
c = torch.abs(ma)
|
||||
assert torch.all(y_abs == c)
|
||||
|
||||
expr = sign(x)
|
||||
assert torch_code(expr) == "torch.sign(x)"
|
||||
_compare_torch_scalar((x,), expr, rng=lambda: random.uniform(-10, 10))
|
||||
|
||||
expr = ceiling(x)
|
||||
assert torch_code(expr) == "torch.ceil(x)"
|
||||
_compare_torch_scalar((x,), expr, rng=lambda: random.random())
|
||||
|
||||
expr = floor(x)
|
||||
assert torch_code(expr) == "torch.floor(x)"
|
||||
_compare_torch_scalar((x,), expr, rng=lambda: random.random())
|
||||
|
||||
expr = exp(x)
|
||||
assert torch_code(expr) == "torch.exp(x)"
|
||||
_compare_torch_scalar((x,), expr, rng=lambda: random.uniform(-2, 2))
|
||||
|
||||
expr = sqrt(x)
|
||||
assert torch_code(expr) == "torch.sqrt(x)"
|
||||
_compare_torch_scalar((x,), expr, rng=lambda: random.random())
|
||||
|
||||
expr = x ** 4
|
||||
assert torch_code(expr) == "torch.pow(x, 4)"
|
||||
_compare_torch_scalar((x,), expr, rng=lambda: random.random())
|
||||
|
||||
expr = cos(x)
|
||||
assert torch_code(expr) == "torch.cos(x)"
|
||||
_compare_torch_scalar((x,), expr, rng=lambda: random.random())
|
||||
|
||||
expr = acos(x)
|
||||
assert torch_code(expr) == "torch.acos(x)"
|
||||
_compare_torch_scalar((x,), expr, rng=lambda: random.uniform(-0.99, 0.99))
|
||||
|
||||
expr = sin(x)
|
||||
assert torch_code(expr) == "torch.sin(x)"
|
||||
_compare_torch_scalar((x,), expr, rng=lambda: random.random())
|
||||
|
||||
expr = asin(x)
|
||||
assert torch_code(expr) == "torch.asin(x)"
|
||||
_compare_torch_scalar((x,), expr, rng=lambda: random.uniform(-0.99, 0.99))
|
||||
|
||||
expr = tan(x)
|
||||
assert torch_code(expr) == "torch.tan(x)"
|
||||
_compare_torch_scalar((x,), expr, rng=lambda: random.uniform(-1.5, 1.5))
|
||||
|
||||
expr = atan(x)
|
||||
assert torch_code(expr) == "torch.atan(x)"
|
||||
_compare_torch_scalar((x,), expr, rng=lambda: random.uniform(-5, 5))
|
||||
|
||||
expr = atan2(y, x)
|
||||
assert torch_code(expr) == "torch.atan2(y, x)"
|
||||
_compare_torch_scalar((y, x), expr, rng=lambda: random.uniform(-5, 5))
|
||||
|
||||
expr = cosh(x)
|
||||
assert torch_code(expr) == "torch.cosh(x)"
|
||||
_compare_torch_scalar((x,), expr, rng=lambda: random.uniform(-2, 2))
|
||||
|
||||
expr = acosh(x)
|
||||
assert torch_code(expr) == "torch.acosh(x)"
|
||||
_compare_torch_scalar((x,), expr, rng=lambda: random.uniform(1.1, 5))
|
||||
|
||||
expr = sinh(x)
|
||||
assert torch_code(expr) == "torch.sinh(x)"
|
||||
_compare_torch_scalar((x,), expr, rng=lambda: random.uniform(-2, 2))
|
||||
|
||||
expr = asinh(x)
|
||||
assert torch_code(expr) == "torch.asinh(x)"
|
||||
_compare_torch_scalar((x,), expr, rng=lambda: random.uniform(-5, 5))
|
||||
|
||||
expr = tanh(x)
|
||||
assert torch_code(expr) == "torch.tanh(x)"
|
||||
_compare_torch_scalar((x,), expr, rng=lambda: random.uniform(-2, 2))
|
||||
|
||||
expr = atanh(x)
|
||||
assert torch_code(expr) == "torch.atanh(x)"
|
||||
_compare_torch_scalar((x,), expr, rng=lambda: random.uniform(-0.9, 0.9))
|
||||
|
||||
expr = erf(x)
|
||||
assert torch_code(expr) == "torch.erf(x)"
|
||||
_compare_torch_scalar((x,), expr, rng=lambda: random.uniform(-2, 2))
|
||||
|
||||
expr = loggamma(x)
|
||||
assert torch_code(expr) == "torch.lgamma(x)"
|
||||
_compare_torch_scalar((x,), expr, rng=lambda: random.uniform(0.5, 5))
|
||||
|
||||
|
||||
def test_torch_complexes():
|
||||
assert torch_code(re(x)) == "torch.real(x)"
|
||||
assert torch_code(im(x)) == "torch.imag(x)"
|
||||
assert torch_code(arg(x)) == "torch.angle(x)"
|
||||
|
||||
|
||||
def test_torch_relational():
|
||||
if not torch:
|
||||
skip("PyTorch not installed")
|
||||
|
||||
expr = Eq(x, y)
|
||||
assert torch_code(expr) == "torch.eq(x, y)"
|
||||
_compare_torch_relational((x, y), expr)
|
||||
|
||||
expr = Ne(x, y)
|
||||
assert torch_code(expr) == "torch.ne(x, y)"
|
||||
_compare_torch_relational((x, y), expr)
|
||||
|
||||
expr = Ge(x, y)
|
||||
assert torch_code(expr) == "torch.ge(x, y)"
|
||||
_compare_torch_relational((x, y), expr)
|
||||
|
||||
expr = Gt(x, y)
|
||||
assert torch_code(expr) == "torch.gt(x, y)"
|
||||
_compare_torch_relational((x, y), expr)
|
||||
|
||||
expr = Le(x, y)
|
||||
assert torch_code(expr) == "torch.le(x, y)"
|
||||
_compare_torch_relational((x, y), expr)
|
||||
|
||||
expr = Lt(x, y)
|
||||
assert torch_code(expr) == "torch.lt(x, y)"
|
||||
_compare_torch_relational((x, y), expr)
|
||||
|
||||
|
||||
def test_torch_matrix():
|
||||
if torch is None:
|
||||
skip("PyTorch not installed")
|
||||
|
||||
expr = M
|
||||
assert torch_code(expr) == "M"
|
||||
f = lambdify((M,), expr, "torch")
|
||||
eye_mat = eye(3)
|
||||
eye_tensor = torch.tensor(eye_mat.tolist(), dtype=torch.float64)
|
||||
assert torch.allclose(f(eye_tensor), eye_tensor)
|
||||
|
||||
expr = M * N
|
||||
assert torch_code(expr) == "torch.matmul(M, N)"
|
||||
_compare_torch_matrix((M, N), expr)
|
||||
|
||||
expr = M ** 3
|
||||
assert torch_code(expr) == "torch.mm(torch.mm(M, M), M)"
|
||||
_compare_torch_matrix((M,), expr)
|
||||
|
||||
expr = M * N * P * Q
|
||||
assert torch_code(expr) == "torch.matmul(torch.matmul(torch.matmul(M, N), P), Q)"
|
||||
_compare_torch_matrix((M, N, P, Q), expr)
|
||||
|
||||
expr = Trace(M)
|
||||
assert torch_code(expr) == "torch.trace(M)"
|
||||
_compare_torch_matrix((M,), expr)
|
||||
|
||||
expr = Determinant(M)
|
||||
assert torch_code(expr) == "torch.det(M)"
|
||||
_compare_torch_matrix((M,), expr)
|
||||
|
||||
expr = HadamardProduct(M, N)
|
||||
assert torch_code(expr) == "torch.mul(M, N)"
|
||||
_compare_torch_matrix((M, N), expr)
|
||||
|
||||
expr = Inverse(M)
|
||||
assert torch_code(expr) == "torch.linalg.inv(M)"
|
||||
|
||||
# For inverse, use a matrix that's guaranteed to be invertible
|
||||
eye_mat = eye(3)
|
||||
eye_tensor = torch.tensor(eye_mat.tolist(), dtype=torch.float64)
|
||||
f = lambdify((M,), expr, "torch")
|
||||
result = f(eye_tensor)
|
||||
expected = torch.linalg.inv(eye_tensor)
|
||||
assert torch.allclose(result, expected)
|
||||
|
||||
|
||||
def test_torch_array_operations():
|
||||
if not torch:
|
||||
skip("PyTorch not installed")
|
||||
|
||||
M = MatrixSymbol("M", 2, 2)
|
||||
N = MatrixSymbol("N", 2, 2)
|
||||
P = MatrixSymbol("P", 2, 2)
|
||||
Q = MatrixSymbol("Q", 2, 2)
|
||||
|
||||
ma = torch.tensor([[1., 2.], [3., 4.]], dtype=torch.float64)
|
||||
mb = torch.tensor([[1., -2.], [-1., 3.]], dtype=torch.float64)
|
||||
mc = torch.tensor([[2., 0.], [1., 2.]], dtype=torch.float64)
|
||||
md = torch.tensor([[1., -1.], [4., 7.]], dtype=torch.float64)
|
||||
|
||||
cg = ArrayTensorProduct(M, N)
|
||||
assert torch_code(cg) == 'torch.einsum("ab,cd", M, N)'
|
||||
f = lambdify((M, N), cg, 'torch')
|
||||
y = f(ma, mb)
|
||||
c = torch.einsum("ij,kl", ma, mb)
|
||||
assert torch.allclose(y, c)
|
||||
|
||||
cg = ArrayAdd(M, N)
|
||||
assert torch_code(cg) == 'torch.add(M, N)'
|
||||
f = lambdify((M, N), cg, 'torch')
|
||||
y = f(ma, mb)
|
||||
c = ma + mb
|
||||
assert torch.allclose(y, c)
|
||||
|
||||
cg = ArrayAdd(M, N, P)
|
||||
assert torch_code(cg) == 'torch.add(torch.add(M, N), P)'
|
||||
f = lambdify((M, N, P), cg, 'torch')
|
||||
y = f(ma, mb, mc)
|
||||
c = ma + mb + mc
|
||||
assert torch.allclose(y, c)
|
||||
|
||||
cg = ArrayAdd(M, N, P, Q)
|
||||
assert torch_code(cg) == 'torch.add(torch.add(torch.add(M, N), P), Q)'
|
||||
f = lambdify((M, N, P, Q), cg, 'torch')
|
||||
y = f(ma, mb, mc, md)
|
||||
c = ma + mb + mc + md
|
||||
assert torch.allclose(y, c)
|
||||
|
||||
cg = PermuteDims(M, [1, 0])
|
||||
assert torch_code(cg) == 'M.permute(1, 0)'
|
||||
f = lambdify((M,), cg, 'torch')
|
||||
y = f(ma)
|
||||
c = ma.T
|
||||
assert torch.allclose(y, c)
|
||||
|
||||
cg = PermuteDims(ArrayTensorProduct(M, N), [1, 2, 3, 0])
|
||||
assert torch_code(cg) == 'torch.einsum("ab,cd", M, N).permute(1, 2, 3, 0)'
|
||||
f = lambdify((M, N), cg, 'torch')
|
||||
y = f(ma, mb)
|
||||
c = torch.einsum("ab,cd", ma, mb).permute(1, 2, 3, 0)
|
||||
assert torch.allclose(y, c)
|
||||
|
||||
cg = ArrayDiagonal(ArrayTensorProduct(M, N), (1, 2))
|
||||
assert torch_code(cg) == 'torch.einsum("ab,bc->acb", M, N)'
|
||||
f = lambdify((M, N), cg, 'torch')
|
||||
y = f(ma, mb)
|
||||
c = torch.einsum("ab,bc->acb", ma, mb)
|
||||
assert torch.allclose(y, c)
|
||||
|
||||
|
||||
def test_torch_derivative():
|
||||
"""Test derivative handling."""
|
||||
expr = Derivative(sin(x), x)
|
||||
assert torch_code(expr) == 'torch.autograd.grad(torch.sin(x), x)[0]'
|
||||
|
||||
|
||||
def test_torch_printing_dtype():
|
||||
if not torch:
|
||||
skip("PyTorch not installed")
|
||||
|
||||
# matrix printing with default dtype
|
||||
expr = Matrix([[x, sin(y)], [exp(z), -t]])
|
||||
assert "dtype=torch.float64" in torch_code(expr)
|
||||
|
||||
# explicit dtype
|
||||
assert "dtype=torch.float32" in torch_code(expr, dtype="torch.float32")
|
||||
|
||||
# with requires_grad
|
||||
result = torch_code(expr, requires_grad=True)
|
||||
assert "requires_grad=True" in result
|
||||
assert "dtype=torch.float64" in result
|
||||
|
||||
# both
|
||||
result = torch_code(expr, requires_grad=True, dtype="torch.float32")
|
||||
assert "requires_grad=True" in result
|
||||
assert "dtype=torch.float32" in result
|
||||
|
||||
|
||||
def test_requires_grad():
|
||||
if not torch:
|
||||
skip("PyTorch not installed")
|
||||
|
||||
expr = sin(x) + cos(y)
|
||||
f = lambdify([x, y], expr, 'torch')
|
||||
|
||||
# make sure the gradients flow
|
||||
x_val = torch.tensor(1.0, requires_grad=True)
|
||||
y_val = torch.tensor(2.0, requires_grad=True)
|
||||
result = f(x_val, y_val)
|
||||
assert result.requires_grad
|
||||
result.backward()
|
||||
|
||||
# x_val.grad should be cos(x_val) which is close to cos(1.0)
|
||||
assert abs(x_val.grad.item() - float(cos(1.0).evalf())) < 1e-6
|
||||
|
||||
# y_val.grad should be -sin(y_val) which is close to -sin(2.0)
|
||||
assert abs(y_val.grad.item() - float(-sin(2.0).evalf())) < 1e-6
|
||||
|
||||
|
||||
def test_torch_multi_variable_derivatives():
|
||||
if not torch:
|
||||
skip("PyTorch not installed")
|
||||
|
||||
x, y, z = symbols("x y z")
|
||||
|
||||
expr = Derivative(sin(x), x)
|
||||
assert torch_code(expr) == "torch.autograd.grad(torch.sin(x), x)[0]"
|
||||
|
||||
expr = Derivative(sin(x), (x, 2))
|
||||
assert torch_code(
|
||||
expr) == "torch.autograd.grad(torch.autograd.grad(torch.sin(x), x, create_graph=True)[0], x, create_graph=True)[0]"
|
||||
|
||||
expr = Derivative(sin(x * y), x, y)
|
||||
result = torch_code(expr)
|
||||
expected = "torch.autograd.grad(torch.autograd.grad(torch.sin(x*y), x, create_graph=True)[0], y, create_graph=True)[0]"
|
||||
normalized_result = result.replace(" ", "")
|
||||
normalized_expected = expected.replace(" ", "")
|
||||
assert normalized_result == normalized_expected
|
||||
|
||||
expr = Derivative(sin(x), x, x)
|
||||
result = torch_code(expr)
|
||||
expected = "torch.autograd.grad(torch.autograd.grad(torch.sin(x), x, create_graph=True)[0], x, create_graph=True)[0]"
|
||||
assert result == expected
|
||||
|
||||
expr = Derivative(sin(x * y * z), x, (y, 2), z)
|
||||
result = torch_code(expr)
|
||||
expected = "torch.autograd.grad(torch.autograd.grad(torch.autograd.grad(torch.autograd.grad(torch.sin(x*y*z), x, create_graph=True)[0], y, create_graph=True)[0], y, create_graph=True)[0], z, create_graph=True)[0]"
|
||||
normalized_result = result.replace(" ", "")
|
||||
normalized_expected = expected.replace(" ", "")
|
||||
assert normalized_result == normalized_expected
|
||||
|
||||
|
||||
def test_torch_derivative_lambdify():
|
||||
if not torch:
|
||||
skip("PyTorch not installed")
|
||||
|
||||
x = symbols("x")
|
||||
y = symbols("y")
|
||||
|
||||
expr = Derivative(x ** 2, x)
|
||||
f = lambdify(x, expr, 'torch')
|
||||
x_val = torch.tensor(2.0, requires_grad=True)
|
||||
result = f(x_val)
|
||||
assert torch.isclose(result, torch.tensor(4.0))
|
||||
|
||||
expr = Derivative(sin(x), (x, 2))
|
||||
f = lambdify(x, expr, 'torch')
|
||||
# Second derivative of sin(x) at x=0 is 0, not -1
|
||||
x_val = torch.tensor(0.0, requires_grad=True)
|
||||
result = f(x_val)
|
||||
assert torch.isclose(result, torch.tensor(0.0), atol=1e-5)
|
||||
|
||||
x_val = torch.tensor(math.pi / 2, requires_grad=True)
|
||||
result = f(x_val)
|
||||
assert torch.isclose(result, torch.tensor(-1.0), atol=1e-5)
|
||||
|
||||
expr = Derivative(x * y ** 2, x, y)
|
||||
f = lambdify((x, y), expr, 'torch')
|
||||
x_val = torch.tensor(2.0, requires_grad=True)
|
||||
y_val = torch.tensor(3.0, requires_grad=True)
|
||||
result = f(x_val, y_val)
|
||||
assert torch.isclose(result, torch.tensor(6.0))
|
||||
|
||||
|
||||
def test_torch_special_matrices():
|
||||
if not torch:
|
||||
skip("PyTorch not installed")
|
||||
|
||||
expr = Identity(3)
|
||||
assert torch_code(expr) == "torch.eye(3)"
|
||||
|
||||
n = symbols("n")
|
||||
expr = Identity(n)
|
||||
assert torch_code(expr) == "torch.eye(n, n)"
|
||||
|
||||
expr = ZeroMatrix(2, 3)
|
||||
assert torch_code(expr) == "torch.zeros((2, 3))"
|
||||
|
||||
m, n = symbols("m n")
|
||||
expr = ZeroMatrix(m, n)
|
||||
assert torch_code(expr) == "torch.zeros((m, n))"
|
||||
|
||||
expr = OneMatrix(2, 3)
|
||||
assert torch_code(expr) == "torch.ones((2, 3))"
|
||||
|
||||
expr = OneMatrix(m, n)
|
||||
assert torch_code(expr) == "torch.ones((m, n))"
|
||||
|
||||
|
||||
def test_torch_special_matrices_lambdify():
|
||||
if not torch:
|
||||
skip("PyTorch not installed")
|
||||
|
||||
expr = Identity(3)
|
||||
f = lambdify([], expr, 'torch')
|
||||
result = f()
|
||||
expected = torch.eye(3)
|
||||
assert torch.allclose(result, expected)
|
||||
|
||||
expr = ZeroMatrix(2, 3)
|
||||
f = lambdify([], expr, 'torch')
|
||||
result = f()
|
||||
expected = torch.zeros((2, 3))
|
||||
assert torch.allclose(result, expected)
|
||||
|
||||
expr = OneMatrix(2, 3)
|
||||
f = lambdify([], expr, 'torch')
|
||||
result = f()
|
||||
expected = torch.ones((2, 3))
|
||||
assert torch.allclose(result, expected)
|
||||
|
||||
|
||||
def test_torch_complex_operations():
|
||||
if not torch:
|
||||
skip("PyTorch not installed")
|
||||
|
||||
expr = conjugate(x)
|
||||
assert torch_code(expr) == "torch.conj(x)"
|
||||
|
||||
# SymPy distributes conjugate over addition and applies specific rules for each term
|
||||
expr = conjugate(sin(x) + I * cos(y))
|
||||
assert torch_code(expr) == "torch.sin(torch.conj(x)) - 1j*torch.cos(torch.conj(y))"
|
||||
|
||||
expr = I
|
||||
assert torch_code(expr) == "1j"
|
||||
|
||||
expr = 2 * I + x
|
||||
assert torch_code(expr) == "x + 2*1j"
|
||||
|
||||
expr = exp(I * x)
|
||||
assert torch_code(expr) == "torch.exp(1j*x)"
|
||||
|
||||
|
||||
def test_torch_special_functions():
|
||||
if not torch:
|
||||
skip("PyTorch not installed")
|
||||
|
||||
expr = Heaviside(x)
|
||||
assert torch_code(expr) == "torch.heaviside(x, 1/2)"
|
||||
|
||||
expr = Heaviside(x, 0)
|
||||
assert torch_code(expr) == "torch.heaviside(x, 0)"
|
||||
|
||||
expr = gamma(x)
|
||||
assert torch_code(expr) == "torch.special.gamma(x)"
|
||||
|
||||
expr = polygamma(0, x) # Use polygamma instead of digamma because sympy will default to that anyway
|
||||
assert torch_code(expr) == "torch.special.digamma(x)"
|
||||
|
||||
expr = gamma(sin(x))
|
||||
assert torch_code(expr) == "torch.special.gamma(torch.sin(x))"
|
||||
@@ -0,0 +1,196 @@
|
||||
from sympy.printing.tree import tree
|
||||
from sympy.testing.pytest import XFAIL
|
||||
|
||||
|
||||
# Remove this flag after making _assumptions cache deterministic.
|
||||
@XFAIL
|
||||
def test_print_tree_MatAdd():
|
||||
from sympy.matrices.expressions import MatrixSymbol
|
||||
A = MatrixSymbol('A', 3, 3)
|
||||
B = MatrixSymbol('B', 3, 3)
|
||||
|
||||
test_str = [
|
||||
'MatAdd: A + B\n',
|
||||
'algebraic: False\n',
|
||||
'commutative: False\n',
|
||||
'complex: False\n',
|
||||
'composite: False\n',
|
||||
'even: False\n',
|
||||
'extended_negative: False\n',
|
||||
'extended_nonnegative: False\n',
|
||||
'extended_nonpositive: False\n',
|
||||
'extended_nonzero: False\n',
|
||||
'extended_positive: False\n',
|
||||
'extended_real: False\n',
|
||||
'imaginary: False\n',
|
||||
'integer: False\n',
|
||||
'irrational: False\n',
|
||||
'negative: False\n',
|
||||
'noninteger: False\n',
|
||||
'nonnegative: False\n',
|
||||
'nonpositive: False\n',
|
||||
'nonzero: False\n',
|
||||
'odd: False\n',
|
||||
'positive: False\n',
|
||||
'prime: False\n',
|
||||
'rational: False\n',
|
||||
'real: False\n',
|
||||
'transcendental: False\n',
|
||||
'zero: False\n',
|
||||
'+-MatrixSymbol: A\n',
|
||||
'| algebraic: False\n',
|
||||
'| commutative: False\n',
|
||||
'| complex: False\n',
|
||||
'| composite: False\n',
|
||||
'| even: False\n',
|
||||
'| extended_negative: False\n',
|
||||
'| extended_nonnegative: False\n',
|
||||
'| extended_nonpositive: False\n',
|
||||
'| extended_nonzero: False\n',
|
||||
'| extended_positive: False\n',
|
||||
'| extended_real: False\n',
|
||||
'| imaginary: False\n',
|
||||
'| integer: False\n',
|
||||
'| irrational: False\n',
|
||||
'| negative: False\n',
|
||||
'| noninteger: False\n',
|
||||
'| nonnegative: False\n',
|
||||
'| nonpositive: False\n',
|
||||
'| nonzero: False\n',
|
||||
'| odd: False\n',
|
||||
'| positive: False\n',
|
||||
'| prime: False\n',
|
||||
'| rational: False\n',
|
||||
'| real: False\n',
|
||||
'| transcendental: False\n',
|
||||
'| zero: False\n',
|
||||
'| +-Symbol: A\n',
|
||||
'| | commutative: True\n',
|
||||
'| +-Integer: 3\n',
|
||||
'| | algebraic: True\n',
|
||||
'| | commutative: True\n',
|
||||
'| | complex: True\n',
|
||||
'| | extended_negative: False\n',
|
||||
'| | extended_nonnegative: True\n',
|
||||
'| | extended_real: True\n',
|
||||
'| | finite: True\n',
|
||||
'| | hermitian: True\n',
|
||||
'| | imaginary: False\n',
|
||||
'| | infinite: False\n',
|
||||
'| | integer: True\n',
|
||||
'| | irrational: False\n',
|
||||
'| | negative: False\n',
|
||||
'| | noninteger: False\n',
|
||||
'| | nonnegative: True\n',
|
||||
'| | rational: True\n',
|
||||
'| | real: True\n',
|
||||
'| | transcendental: False\n',
|
||||
'| +-Integer: 3\n',
|
||||
'| algebraic: True\n',
|
||||
'| commutative: True\n',
|
||||
'| complex: True\n',
|
||||
'| extended_negative: False\n',
|
||||
'| extended_nonnegative: True\n',
|
||||
'| extended_real: True\n',
|
||||
'| finite: True\n',
|
||||
'| hermitian: True\n',
|
||||
'| imaginary: False\n',
|
||||
'| infinite: False\n',
|
||||
'| integer: True\n',
|
||||
'| irrational: False\n',
|
||||
'| negative: False\n',
|
||||
'| noninteger: False\n',
|
||||
'| nonnegative: True\n',
|
||||
'| rational: True\n',
|
||||
'| real: True\n',
|
||||
'| transcendental: False\n',
|
||||
'+-MatrixSymbol: B\n',
|
||||
' algebraic: False\n',
|
||||
' commutative: False\n',
|
||||
' complex: False\n',
|
||||
' composite: False\n',
|
||||
' even: False\n',
|
||||
' extended_negative: False\n',
|
||||
' extended_nonnegative: False\n',
|
||||
' extended_nonpositive: False\n',
|
||||
' extended_nonzero: False\n',
|
||||
' extended_positive: False\n',
|
||||
' extended_real: False\n',
|
||||
' imaginary: False\n',
|
||||
' integer: False\n',
|
||||
' irrational: False\n',
|
||||
' negative: False\n',
|
||||
' noninteger: False\n',
|
||||
' nonnegative: False\n',
|
||||
' nonpositive: False\n',
|
||||
' nonzero: False\n',
|
||||
' odd: False\n',
|
||||
' positive: False\n',
|
||||
' prime: False\n',
|
||||
' rational: False\n',
|
||||
' real: False\n',
|
||||
' transcendental: False\n',
|
||||
' zero: False\n',
|
||||
' +-Symbol: B\n',
|
||||
' | commutative: True\n',
|
||||
' +-Integer: 3\n',
|
||||
' | algebraic: True\n',
|
||||
' | commutative: True\n',
|
||||
' | complex: True\n',
|
||||
' | extended_negative: False\n',
|
||||
' | extended_nonnegative: True\n',
|
||||
' | extended_real: True\n',
|
||||
' | finite: True\n',
|
||||
' | hermitian: True\n',
|
||||
' | imaginary: False\n',
|
||||
' | infinite: False\n',
|
||||
' | integer: True\n',
|
||||
' | irrational: False\n',
|
||||
' | negative: False\n',
|
||||
' | noninteger: False\n',
|
||||
' | nonnegative: True\n',
|
||||
' | rational: True\n',
|
||||
' | real: True\n',
|
||||
' | transcendental: False\n',
|
||||
' +-Integer: 3\n',
|
||||
' algebraic: True\n',
|
||||
' commutative: True\n',
|
||||
' complex: True\n',
|
||||
' extended_negative: False\n',
|
||||
' extended_nonnegative: True\n',
|
||||
' extended_real: True\n',
|
||||
' finite: True\n',
|
||||
' hermitian: True\n',
|
||||
' imaginary: False\n',
|
||||
' infinite: False\n',
|
||||
' integer: True\n',
|
||||
' irrational: False\n',
|
||||
' negative: False\n',
|
||||
' noninteger: False\n',
|
||||
' nonnegative: True\n',
|
||||
' rational: True\n',
|
||||
' real: True\n',
|
||||
' transcendental: False\n'
|
||||
]
|
||||
|
||||
assert tree(A + B) == "".join(test_str)
|
||||
|
||||
|
||||
def test_print_tree_MatAdd_noassumptions():
|
||||
from sympy.matrices.expressions import MatrixSymbol
|
||||
A = MatrixSymbol('A', 3, 3)
|
||||
B = MatrixSymbol('B', 3, 3)
|
||||
|
||||
test_str = \
|
||||
"""MatAdd: A + B
|
||||
+-MatrixSymbol: A
|
||||
| +-Str: A
|
||||
| +-Integer: 3
|
||||
| +-Integer: 3
|
||||
+-MatrixSymbol: B
|
||||
+-Str: B
|
||||
+-Integer: 3
|
||||
+-Integer: 3
|
||||
"""
|
||||
|
||||
assert tree(A + B, assumptions=False) == test_str
|
||||
Reference in New Issue
Block a user