switching to high quality piper tts and added label translations
This commit is contained in:
@@ -0,0 +1,60 @@
|
||||
"""The module helps converting SymPy expressions into shorter forms of them.
|
||||
|
||||
for example:
|
||||
the expression E**(pi*I) will be converted into -1
|
||||
the expression (x+x)**2 will be converted into 4*x**2
|
||||
"""
|
||||
from .simplify import (simplify, hypersimp, hypersimilar,
|
||||
logcombine, separatevars, posify, besselsimp, kroneckersimp,
|
||||
signsimp, nsimplify)
|
||||
|
||||
from .fu import FU, fu
|
||||
|
||||
from .sqrtdenest import sqrtdenest
|
||||
|
||||
from .cse_main import cse
|
||||
|
||||
from .epathtools import epath, EPath
|
||||
|
||||
from .hyperexpand import hyperexpand
|
||||
|
||||
from .radsimp import collect, rcollect, radsimp, collect_const, fraction, numer, denom
|
||||
|
||||
from .trigsimp import trigsimp, exptrigsimp
|
||||
|
||||
from .powsimp import powsimp, powdenest
|
||||
|
||||
from .combsimp import combsimp
|
||||
|
||||
from .gammasimp import gammasimp
|
||||
|
||||
from .ratsimp import ratsimp, ratsimpmodprime
|
||||
|
||||
__all__ = [
|
||||
'simplify', 'hypersimp', 'hypersimilar', 'logcombine', 'separatevars',
|
||||
'posify', 'besselsimp', 'kroneckersimp', 'signsimp',
|
||||
'nsimplify',
|
||||
|
||||
'FU', 'fu',
|
||||
|
||||
'sqrtdenest',
|
||||
|
||||
'cse',
|
||||
|
||||
'epath', 'EPath',
|
||||
|
||||
'hyperexpand',
|
||||
|
||||
'collect', 'rcollect', 'radsimp', 'collect_const', 'fraction', 'numer',
|
||||
'denom',
|
||||
|
||||
'trigsimp', 'exptrigsimp',
|
||||
|
||||
'powsimp', 'powdenest',
|
||||
|
||||
'combsimp',
|
||||
|
||||
'gammasimp',
|
||||
|
||||
'ratsimp', 'ratsimpmodprime',
|
||||
]
|
||||
@@ -0,0 +1,291 @@
|
||||
"""Module for differentiation using CSE."""
|
||||
|
||||
from sympy import cse, Matrix, Derivative, MatrixBase
|
||||
from sympy.utilities.iterables import iterable
|
||||
|
||||
|
||||
def _remove_cse_from_derivative(replacements, reduced_expressions):
|
||||
"""
|
||||
This function is designed to postprocess the output of a common subexpression
|
||||
elimination (CSE) operation. Specifically, it removes any CSE replacement
|
||||
symbols from the arguments of ``Derivative`` terms in the expression. This
|
||||
is necessary to ensure that the forward Jacobian function correctly handles
|
||||
derivative terms.
|
||||
|
||||
Parameters
|
||||
==========
|
||||
|
||||
replacements : list of (Symbol, expression) pairs
|
||||
Replacement symbols and relative common subexpressions that have been
|
||||
replaced during a CSE operation.
|
||||
|
||||
reduced_expressions : list of SymPy expressions
|
||||
The reduced expressions with all the replacements from the
|
||||
replacements list above.
|
||||
|
||||
Returns
|
||||
=======
|
||||
|
||||
processed_replacements : list of (Symbol, expression) pairs
|
||||
Processed replacement list, in the same format of the
|
||||
``replacements`` input list.
|
||||
|
||||
processed_reduced : list of SymPy expressions
|
||||
Processed reduced list, in the same format of the
|
||||
``reduced_expressions`` input list.
|
||||
"""
|
||||
|
||||
def traverse(node, repl_dict):
|
||||
if isinstance(node, Derivative):
|
||||
return replace_all(node, repl_dict)
|
||||
if not node.args:
|
||||
return node
|
||||
new_args = [traverse(arg, repl_dict) for arg in node.args]
|
||||
return node.func(*new_args)
|
||||
|
||||
def replace_all(node, repl_dict):
|
||||
result = node
|
||||
while True:
|
||||
free_symbols = result.free_symbols
|
||||
symbols_dict = {k: repl_dict[k] for k in free_symbols if k in repl_dict}
|
||||
if not symbols_dict:
|
||||
break
|
||||
result = result.xreplace(symbols_dict)
|
||||
return result
|
||||
|
||||
repl_dict = dict(replacements)
|
||||
processed_replacements = [
|
||||
(rep_sym, traverse(sub_exp, repl_dict))
|
||||
for rep_sym, sub_exp in replacements
|
||||
]
|
||||
processed_reduced = [
|
||||
red_exp.__class__([traverse(exp, repl_dict) for exp in red_exp])
|
||||
for red_exp in reduced_expressions
|
||||
]
|
||||
|
||||
return processed_replacements, processed_reduced
|
||||
|
||||
|
||||
def _forward_jacobian_cse(replacements, reduced_expr, wrt):
|
||||
"""
|
||||
Core function to compute the Jacobian of an input Matrix of expressions
|
||||
through forward accumulation. Takes directly the output of a CSE operation
|
||||
(replacements and reduced_expr), and an iterable of variables (wrt) with
|
||||
respect to which to differentiate the reduced expression and returns the
|
||||
reduced Jacobian matrix and the ``replacements`` list.
|
||||
|
||||
The function also returns a list of precomputed free symbols for each
|
||||
subexpression, which are useful in the substitution process.
|
||||
|
||||
Parameters
|
||||
==========
|
||||
|
||||
replacements : list of (Symbol, expression) pairs
|
||||
Replacement symbols and relative common subexpressions that have been
|
||||
replaced during a CSE operation.
|
||||
|
||||
reduced_expr : list of SymPy expressions
|
||||
The reduced expressions with all the replacements from the
|
||||
replacements list above.
|
||||
|
||||
wrt : iterable
|
||||
Iterable of expressions with respect to which to compute the
|
||||
Jacobian matrix.
|
||||
|
||||
Returns
|
||||
=======
|
||||
|
||||
replacements : list of (Symbol, expression) pairs
|
||||
Replacement symbols and relative common subexpressions that have been
|
||||
replaced during a CSE operation. Compared to the input replacement list,
|
||||
the output one doesn't contain replacement symbols inside
|
||||
``Derivative``'s arguments.
|
||||
|
||||
jacobian : list of SymPy expressions
|
||||
The list only contains one element, which is the Jacobian matrix with
|
||||
elements in reduced form (replacement symbols are present).
|
||||
|
||||
precomputed_fs: list
|
||||
List of sets, which store the free symbols present in each sub-expression.
|
||||
Useful in the substitution process.
|
||||
"""
|
||||
|
||||
if not isinstance(reduced_expr[0], MatrixBase):
|
||||
raise TypeError("``expr`` must be of matrix type")
|
||||
|
||||
if not (reduced_expr[0].shape[0] == 1 or reduced_expr[0].shape[1] == 1):
|
||||
raise TypeError("``expr`` must be a row or a column matrix")
|
||||
|
||||
if not iterable(wrt):
|
||||
raise TypeError("``wrt`` must be an iterable of variables")
|
||||
|
||||
elif not isinstance(wrt, MatrixBase):
|
||||
wrt = Matrix(wrt)
|
||||
|
||||
if not (wrt.shape[0] == 1 or wrt.shape[1] == 1):
|
||||
raise TypeError("``wrt`` must be a row or a column matrix")
|
||||
|
||||
replacements, reduced_expr = _remove_cse_from_derivative(replacements, reduced_expr)
|
||||
|
||||
if replacements:
|
||||
rep_sym, sub_expr = map(Matrix, zip(*replacements))
|
||||
else:
|
||||
rep_sym, sub_expr = Matrix([]), Matrix([])
|
||||
|
||||
l_sub, l_wrt, l_red = len(sub_expr), len(wrt), len(reduced_expr[0])
|
||||
|
||||
f1 = reduced_expr[0].__class__.from_dok(l_red, l_wrt,
|
||||
{
|
||||
(i, j): diff_value
|
||||
for i, r in enumerate(reduced_expr[0])
|
||||
for j, w in enumerate(wrt)
|
||||
if (diff_value := r.diff(w)) != 0
|
||||
},
|
||||
)
|
||||
|
||||
if not replacements:
|
||||
return [], [f1], []
|
||||
|
||||
f2 = Matrix.from_dok(l_red, l_sub,
|
||||
{
|
||||
(i, j): diff_value
|
||||
for i, (r, fs) in enumerate([(r, r.free_symbols) for r in reduced_expr[0]])
|
||||
for j, s in enumerate(rep_sym)
|
||||
if s in fs and (diff_value := r.diff(s)) != 0
|
||||
},
|
||||
)
|
||||
|
||||
rep_sym_set = set(rep_sym)
|
||||
precomputed_fs = [s.free_symbols & rep_sym_set for s in sub_expr ]
|
||||
|
||||
c_matrix = Matrix.from_dok(1, l_wrt,
|
||||
{(0, j): diff_value for j, w in enumerate(wrt)
|
||||
if (diff_value := sub_expr[0].diff(w)) != 0})
|
||||
|
||||
for i in range(1, l_sub):
|
||||
|
||||
bi_matrix = Matrix.from_dok(1, i,
|
||||
{(0, j): diff_value for j in range(i + 1)
|
||||
if rep_sym[j] in precomputed_fs[i]
|
||||
and (diff_value := sub_expr[i].diff(rep_sym[j])) != 0})
|
||||
|
||||
ai_matrix = Matrix.from_dok(1, l_wrt,
|
||||
{(0, j): diff_value for j, w in enumerate(wrt)
|
||||
if (diff_value := sub_expr[i].diff(w)) != 0})
|
||||
|
||||
if bi_matrix._rep.nnz():
|
||||
ci_matrix = bi_matrix.multiply(c_matrix).add(ai_matrix)
|
||||
c_matrix = Matrix.vstack(c_matrix, ci_matrix)
|
||||
else:
|
||||
c_matrix = Matrix.vstack(c_matrix, ai_matrix)
|
||||
|
||||
jacobian = f2.multiply(c_matrix).add(f1)
|
||||
jacobian = [reduced_expr[0].__class__(jacobian)]
|
||||
|
||||
return replacements, jacobian, precomputed_fs
|
||||
|
||||
|
||||
def _forward_jacobian_norm_in_cse_out(expr, wrt):
|
||||
"""
|
||||
Function to compute the Jacobian of an input Matrix of expressions through
|
||||
forward accumulation. Takes a sympy Matrix of expressions (expr) as input
|
||||
and an iterable of variables (wrt) with respect to which to compute the
|
||||
Jacobian matrix. The matrix is returned in reduced form (containing
|
||||
replacement symbols) along with the ``replacements`` list.
|
||||
|
||||
The function also returns a list of precomputed free symbols for each
|
||||
subexpression, which are useful in the substitution process.
|
||||
|
||||
Parameters
|
||||
==========
|
||||
|
||||
expr : Matrix
|
||||
The vector to be differentiated.
|
||||
|
||||
wrt : iterable
|
||||
The vector with respect to which to perform the differentiation.
|
||||
Can be a matrix or an iterable of variables.
|
||||
|
||||
Returns
|
||||
=======
|
||||
|
||||
replacements : list of (Symbol, expression) pairs
|
||||
Replacement symbols and relative common subexpressions that have been
|
||||
replaced during a CSE operation. The output replacement list doesn't
|
||||
contain replacement symbols inside ``Derivative``'s arguments.
|
||||
|
||||
jacobian : list of SymPy expressions
|
||||
The list only contains one element, which is the Jacobian matrix with
|
||||
elements in reduced form (replacement symbols are present).
|
||||
|
||||
precomputed_fs: list
|
||||
List of sets, which store the free symbols present in each
|
||||
sub-expression. Useful in the substitution process.
|
||||
"""
|
||||
|
||||
replacements, reduced_expr = cse(expr)
|
||||
replacements, jacobian, precomputed_fs = _forward_jacobian_cse(replacements, reduced_expr, wrt)
|
||||
|
||||
return replacements, jacobian, precomputed_fs
|
||||
|
||||
|
||||
def _forward_jacobian(expr, wrt):
|
||||
"""
|
||||
Function to compute the Jacobian of an input Matrix of expressions through
|
||||
forward accumulation. Takes a sympy Matrix of expressions (expr) as input
|
||||
and an iterable of variables (wrt) with respect to which to compute the
|
||||
Jacobian matrix.
|
||||
|
||||
Explanation
|
||||
===========
|
||||
|
||||
Expressions often contain repeated subexpressions. Using a tree structure,
|
||||
these subexpressions are duplicated and differentiated multiple times,
|
||||
leading to inefficiency.
|
||||
|
||||
Instead, if a data structure called a directed acyclic graph (DAG) is used
|
||||
then each of these repeated subexpressions will only exist a single time.
|
||||
This function uses a combination of representing the expression as a DAG and
|
||||
a forward accumulation algorithm (repeated application of the chain rule
|
||||
symbolically) to more efficiently calculate the Jacobian matrix of a target
|
||||
expression ``expr`` with respect to an expression or set of expressions
|
||||
``wrt``.
|
||||
|
||||
Note that this function is intended to improve performance when
|
||||
differentiating large expressions that contain many common subexpressions.
|
||||
For small and simple expressions it is likely less performant than using
|
||||
SymPy's standard differentiation functions and methods.
|
||||
|
||||
Parameters
|
||||
==========
|
||||
|
||||
expr : Matrix
|
||||
The vector to be differentiated.
|
||||
|
||||
wrt : iterable
|
||||
The vector with respect to which to do the differentiation.
|
||||
Can be a matrix or an iterable of variables.
|
||||
|
||||
See Also
|
||||
========
|
||||
|
||||
Direct Acyclic Graph : https://en.wikipedia.org/wiki/Directed_acyclic_graph
|
||||
"""
|
||||
|
||||
replacements, reduced_expr = cse(expr)
|
||||
|
||||
if replacements:
|
||||
rep_sym, _ = map(Matrix, zip(*replacements))
|
||||
else:
|
||||
rep_sym = Matrix([])
|
||||
|
||||
replacements, jacobian, precomputed_fs = _forward_jacobian_cse(replacements, reduced_expr, wrt)
|
||||
|
||||
if not replacements: return jacobian[0]
|
||||
|
||||
sub_rep = dict(replacements)
|
||||
for i, ik in enumerate(precomputed_fs):
|
||||
sub_dict = {j: sub_rep[j] for j in ik}
|
||||
sub_rep[rep_sym[i]] = sub_rep[rep_sym[i]].xreplace(sub_dict)
|
||||
|
||||
return jacobian[0].xreplace(sub_rep)
|
||||
@@ -0,0 +1,114 @@
|
||||
from sympy.core import Mul
|
||||
from sympy.core.function import count_ops
|
||||
from sympy.core.traversal import preorder_traversal, bottom_up
|
||||
from sympy.functions.combinatorial.factorials import binomial, factorial
|
||||
from sympy.functions import gamma
|
||||
from sympy.simplify.gammasimp import gammasimp, _gammasimp
|
||||
|
||||
from sympy.utilities.timeutils import timethis
|
||||
|
||||
|
||||
@timethis('combsimp')
|
||||
def combsimp(expr):
|
||||
r"""
|
||||
Simplify combinatorial expressions.
|
||||
|
||||
Explanation
|
||||
===========
|
||||
|
||||
This function takes as input an expression containing factorials,
|
||||
binomials, Pochhammer symbol and other "combinatorial" functions,
|
||||
and tries to minimize the number of those functions and reduce
|
||||
the size of their arguments.
|
||||
|
||||
The algorithm works by rewriting all combinatorial functions as
|
||||
gamma functions and applying gammasimp() except simplification
|
||||
steps that may make an integer argument non-integer. See docstring
|
||||
of gammasimp for more information.
|
||||
|
||||
Then it rewrites expression in terms of factorials and binomials by
|
||||
rewriting gammas as factorials and converting (a+b)!/a!b! into
|
||||
binomials.
|
||||
|
||||
If expression has gamma functions or combinatorial functions
|
||||
with non-integer argument, it is automatically passed to gammasimp.
|
||||
|
||||
Examples
|
||||
========
|
||||
|
||||
>>> from sympy.simplify import combsimp
|
||||
>>> from sympy import factorial, binomial, symbols
|
||||
>>> n, k = symbols('n k', integer = True)
|
||||
|
||||
>>> combsimp(factorial(n)/factorial(n - 3))
|
||||
n*(n - 2)*(n - 1)
|
||||
>>> combsimp(binomial(n+1, k+1)/binomial(n, k))
|
||||
(n + 1)/(k + 1)
|
||||
|
||||
"""
|
||||
|
||||
expr = expr.rewrite(gamma, piecewise=False)
|
||||
if any(isinstance(node, gamma) and not node.args[0].is_integer
|
||||
for node in preorder_traversal(expr)):
|
||||
return gammasimp(expr)
|
||||
|
||||
expr = _gammasimp(expr, as_comb = True)
|
||||
expr = _gamma_as_comb(expr)
|
||||
return expr
|
||||
|
||||
|
||||
def _gamma_as_comb(expr):
|
||||
"""
|
||||
Helper function for combsimp.
|
||||
|
||||
Rewrites expression in terms of factorials and binomials
|
||||
"""
|
||||
|
||||
expr = expr.rewrite(factorial)
|
||||
|
||||
def f(rv):
|
||||
if not rv.is_Mul:
|
||||
return rv
|
||||
rvd = rv.as_powers_dict()
|
||||
nd_fact_args = [[], []] # numerator, denominator
|
||||
|
||||
for k in rvd:
|
||||
if isinstance(k, factorial) and rvd[k].is_Integer:
|
||||
if rvd[k].is_positive:
|
||||
nd_fact_args[0].extend([k.args[0]]*rvd[k])
|
||||
else:
|
||||
nd_fact_args[1].extend([k.args[0]]*-rvd[k])
|
||||
rvd[k] = 0
|
||||
if not nd_fact_args[0] or not nd_fact_args[1]:
|
||||
return rv
|
||||
|
||||
hit = False
|
||||
for m in range(2):
|
||||
i = 0
|
||||
while i < len(nd_fact_args[m]):
|
||||
ai = nd_fact_args[m][i]
|
||||
for j in range(i + 1, len(nd_fact_args[m])):
|
||||
aj = nd_fact_args[m][j]
|
||||
|
||||
sum = ai + aj
|
||||
if sum in nd_fact_args[1 - m]:
|
||||
hit = True
|
||||
|
||||
nd_fact_args[1 - m].remove(sum)
|
||||
del nd_fact_args[m][j]
|
||||
del nd_fact_args[m][i]
|
||||
|
||||
rvd[binomial(sum, ai if count_ops(ai) <
|
||||
count_ops(aj) else aj)] += (
|
||||
-1 if m == 0 else 1)
|
||||
break
|
||||
else:
|
||||
i += 1
|
||||
|
||||
if hit:
|
||||
return Mul(*([k**rvd[k] for k in rvd] + [factorial(k)
|
||||
for k in nd_fact_args[0]]))/Mul(*[factorial(k)
|
||||
for k in nd_fact_args[1]])
|
||||
return rv
|
||||
|
||||
return bottom_up(expr, f)
|
||||
@@ -0,0 +1,945 @@
|
||||
""" Tools for doing common subexpression elimination.
|
||||
"""
|
||||
from collections import defaultdict
|
||||
|
||||
from sympy.core import Basic, Mul, Add, Pow, sympify
|
||||
from sympy.core.containers import Tuple, OrderedSet
|
||||
from sympy.core.exprtools import factor_terms
|
||||
from sympy.core.singleton import S
|
||||
from sympy.core.sorting import ordered
|
||||
from sympy.core.symbol import symbols, Symbol
|
||||
from sympy.matrices import (MatrixBase, Matrix, ImmutableMatrix,
|
||||
SparseMatrix, ImmutableSparseMatrix)
|
||||
from sympy.matrices.expressions import (MatrixExpr, MatrixSymbol, MatMul,
|
||||
MatAdd, MatPow, Inverse)
|
||||
from sympy.matrices.expressions.matexpr import MatrixElement
|
||||
from sympy.polys.rootoftools import RootOf
|
||||
from sympy.utilities.iterables import numbered_symbols, sift, \
|
||||
topological_sort, iterable
|
||||
|
||||
from . import cse_opts
|
||||
|
||||
# (preprocessor, postprocessor) pairs which are commonly useful. They should
|
||||
# each take a SymPy expression and return a possibly transformed expression.
|
||||
# When used in the function ``cse()``, the target expressions will be transformed
|
||||
# by each of the preprocessor functions in order. After the common
|
||||
# subexpressions are eliminated, each resulting expression will have the
|
||||
# postprocessor functions transform them in *reverse* order in order to undo the
|
||||
# transformation if necessary. This allows the algorithm to operate on
|
||||
# a representation of the expressions that allows for more optimization
|
||||
# opportunities.
|
||||
# ``None`` can be used to specify no transformation for either the preprocessor or
|
||||
# postprocessor.
|
||||
|
||||
|
||||
basic_optimizations = [(cse_opts.sub_pre, cse_opts.sub_post),
|
||||
(factor_terms, None)]
|
||||
|
||||
# sometimes we want the output in a different format; non-trivial
|
||||
# transformations can be put here for users
|
||||
# ===============================================================
|
||||
|
||||
|
||||
def reps_toposort(r):
|
||||
"""Sort replacements ``r`` so (k1, v1) appears before (k2, v2)
|
||||
if k2 is in v1's free symbols. This orders items in the
|
||||
way that cse returns its results (hence, in order to use the
|
||||
replacements in a substitution option it would make sense
|
||||
to reverse the order).
|
||||
|
||||
Examples
|
||||
========
|
||||
|
||||
>>> from sympy.simplify.cse_main import reps_toposort
|
||||
>>> from sympy.abc import x, y
|
||||
>>> from sympy import Eq
|
||||
>>> for l, r in reps_toposort([(x, y + 1), (y, 2)]):
|
||||
... print(Eq(l, r))
|
||||
...
|
||||
Eq(y, 2)
|
||||
Eq(x, y + 1)
|
||||
|
||||
"""
|
||||
r = sympify(r)
|
||||
E = []
|
||||
for c1, (k1, v1) in enumerate(r):
|
||||
for c2, (k2, v2) in enumerate(r):
|
||||
if k1 in v2.free_symbols:
|
||||
E.append((c1, c2))
|
||||
return [r[i] for i in topological_sort((range(len(r)), E))]
|
||||
|
||||
|
||||
def cse_separate(r, e):
|
||||
"""Move expressions that are in the form (symbol, expr) out of the
|
||||
expressions and sort them into the replacements using the reps_toposort.
|
||||
|
||||
Examples
|
||||
========
|
||||
|
||||
>>> from sympy.simplify.cse_main import cse_separate
|
||||
>>> from sympy.abc import x, y, z
|
||||
>>> from sympy import cos, exp, cse, Eq, symbols
|
||||
>>> x0, x1 = symbols('x:2')
|
||||
>>> eq = (x + 1 + exp((x + 1)/(y + 1)) + cos(y + 1))
|
||||
>>> cse([eq, Eq(x, z + 1), z - 2], postprocess=cse_separate) in [
|
||||
... [[(x0, y + 1), (x, z + 1), (x1, x + 1)],
|
||||
... [x1 + exp(x1/x0) + cos(x0), z - 2]],
|
||||
... [[(x1, y + 1), (x, z + 1), (x0, x + 1)],
|
||||
... [x0 + exp(x0/x1) + cos(x1), z - 2]]]
|
||||
...
|
||||
True
|
||||
"""
|
||||
d = sift(e, lambda w: w.is_Equality and w.lhs.is_Symbol)
|
||||
r = r + [w.args for w in d[True]]
|
||||
e = d[False]
|
||||
return [reps_toposort(r), e]
|
||||
|
||||
|
||||
def cse_release_variables(r, e):
|
||||
"""
|
||||
Return tuples giving ``(a, b)`` where ``a`` is a symbol and ``b`` is
|
||||
either an expression or None. The value of None is used when a
|
||||
symbol is no longer needed for subsequent expressions.
|
||||
|
||||
Use of such output can reduce the memory footprint of lambdified
|
||||
expressions that contain large, repeated subexpressions.
|
||||
|
||||
Examples
|
||||
========
|
||||
|
||||
>>> from sympy import cse
|
||||
>>> from sympy.simplify.cse_main import cse_release_variables
|
||||
>>> from sympy.abc import x, y
|
||||
>>> eqs = [(x + y - 1)**2, x, x + y, (x + y)/(2*x + 1) + (x + y - 1)**2, (2*x + 1)**(x + y)]
|
||||
>>> defs, rvs = cse_release_variables(*cse(eqs))
|
||||
>>> for i in defs:
|
||||
... print(i)
|
||||
...
|
||||
(x0, x + y)
|
||||
(x1, (x0 - 1)**2)
|
||||
(x2, 2*x + 1)
|
||||
(_3, x0/x2 + x1)
|
||||
(_4, x2**x0)
|
||||
(x2, None)
|
||||
(_0, x1)
|
||||
(x1, None)
|
||||
(_2, x0)
|
||||
(x0, None)
|
||||
(_1, x)
|
||||
>>> print(rvs)
|
||||
(_0, _1, _2, _3, _4)
|
||||
"""
|
||||
if not r:
|
||||
return r, e
|
||||
|
||||
s, p = zip(*r)
|
||||
esyms = symbols('_:%d' % len(e))
|
||||
syms = list(esyms)
|
||||
s = list(s)
|
||||
in_use = set(s)
|
||||
p = list(p)
|
||||
# sort e so those with most sub-expressions appear first
|
||||
e = [(e[i], syms[i]) for i in range(len(e))]
|
||||
e, syms = zip(*sorted(e,
|
||||
key=lambda x: -sum(p[s.index(i)].count_ops()
|
||||
for i in x[0].free_symbols & in_use)))
|
||||
syms = list(syms)
|
||||
p += e
|
||||
rv = []
|
||||
i = len(p) - 1
|
||||
while i >= 0:
|
||||
_p = p.pop()
|
||||
c = in_use & _p.free_symbols
|
||||
if c: # sorting for canonical results
|
||||
rv.extend([(s, None) for s in sorted(c, key=str)])
|
||||
if i >= len(r):
|
||||
rv.append((syms.pop(), _p))
|
||||
else:
|
||||
rv.append((s[i], _p))
|
||||
in_use -= c
|
||||
i -= 1
|
||||
rv.reverse()
|
||||
return rv, esyms
|
||||
|
||||
|
||||
# ====end of cse postprocess idioms===========================
|
||||
|
||||
|
||||
def preprocess_for_cse(expr, optimizations):
|
||||
""" Preprocess an expression to optimize for common subexpression
|
||||
elimination.
|
||||
|
||||
Parameters
|
||||
==========
|
||||
|
||||
expr : SymPy expression
|
||||
The target expression to optimize.
|
||||
optimizations : list of (callable, callable) pairs
|
||||
The (preprocessor, postprocessor) pairs.
|
||||
|
||||
Returns
|
||||
=======
|
||||
|
||||
expr : SymPy expression
|
||||
The transformed expression.
|
||||
"""
|
||||
for pre, post in optimizations:
|
||||
if pre is not None:
|
||||
expr = pre(expr)
|
||||
return expr
|
||||
|
||||
|
||||
def postprocess_for_cse(expr, optimizations):
|
||||
"""Postprocess an expression after common subexpression elimination to
|
||||
return the expression to canonical SymPy form.
|
||||
|
||||
Parameters
|
||||
==========
|
||||
|
||||
expr : SymPy expression
|
||||
The target expression to transform.
|
||||
optimizations : list of (callable, callable) pairs, optional
|
||||
The (preprocessor, postprocessor) pairs. The postprocessors will be
|
||||
applied in reversed order to undo the effects of the preprocessors
|
||||
correctly.
|
||||
|
||||
Returns
|
||||
=======
|
||||
|
||||
expr : SymPy expression
|
||||
The transformed expression.
|
||||
"""
|
||||
for pre, post in reversed(optimizations):
|
||||
if post is not None:
|
||||
expr = post(expr)
|
||||
return expr
|
||||
|
||||
|
||||
class FuncArgTracker:
|
||||
"""
|
||||
A class which manages a mapping from functions to arguments and an inverse
|
||||
mapping from arguments to functions.
|
||||
"""
|
||||
|
||||
def __init__(self, funcs):
|
||||
# To minimize the number of symbolic comparisons, all function arguments
|
||||
# get assigned a value number.
|
||||
self.value_numbers = {}
|
||||
self.value_number_to_value = []
|
||||
|
||||
# Both of these maps use integer indices for arguments / functions.
|
||||
self.arg_to_funcset = []
|
||||
self.func_to_argset = []
|
||||
|
||||
for func_i, func in enumerate(funcs):
|
||||
func_argset = OrderedSet()
|
||||
|
||||
for func_arg in func.args:
|
||||
arg_number = self.get_or_add_value_number(func_arg)
|
||||
func_argset.add(arg_number)
|
||||
self.arg_to_funcset[arg_number].add(func_i)
|
||||
|
||||
self.func_to_argset.append(func_argset)
|
||||
|
||||
def get_args_in_value_order(self, argset):
|
||||
"""
|
||||
Return the list of arguments in sorted order according to their value
|
||||
numbers.
|
||||
"""
|
||||
return [self.value_number_to_value[argn] for argn in sorted(argset)]
|
||||
|
||||
def get_or_add_value_number(self, value):
|
||||
"""
|
||||
Return the value number for the given argument.
|
||||
"""
|
||||
nvalues = len(self.value_numbers)
|
||||
value_number = self.value_numbers.setdefault(value, nvalues)
|
||||
if value_number == nvalues:
|
||||
self.value_number_to_value.append(value)
|
||||
self.arg_to_funcset.append(OrderedSet())
|
||||
return value_number
|
||||
|
||||
def stop_arg_tracking(self, func_i):
|
||||
"""
|
||||
Remove the function func_i from the argument to function mapping.
|
||||
"""
|
||||
for arg in self.func_to_argset[func_i]:
|
||||
self.arg_to_funcset[arg].remove(func_i)
|
||||
|
||||
|
||||
def get_common_arg_candidates(self, argset, min_func_i=0):
|
||||
"""Return a dict whose keys are function numbers. The entries of the dict are
|
||||
the number of arguments said function has in common with
|
||||
``argset``. Entries have at least 2 items in common. All keys have
|
||||
value at least ``min_func_i``.
|
||||
"""
|
||||
count_map = defaultdict(lambda: 0)
|
||||
if not argset:
|
||||
return count_map
|
||||
|
||||
funcsets = [self.arg_to_funcset[arg] for arg in argset]
|
||||
# As an optimization below, we handle the largest funcset separately from
|
||||
# the others.
|
||||
largest_funcset = max(funcsets, key=len)
|
||||
|
||||
for funcset in funcsets:
|
||||
if largest_funcset is funcset:
|
||||
continue
|
||||
for func_i in funcset:
|
||||
if func_i >= min_func_i:
|
||||
count_map[func_i] += 1
|
||||
|
||||
# We pick the smaller of the two containers (count_map, largest_funcset)
|
||||
# to iterate over to reduce the number of iterations needed.
|
||||
(smaller_funcs_container,
|
||||
larger_funcs_container) = sorted(
|
||||
[largest_funcset, count_map],
|
||||
key=len)
|
||||
|
||||
for func_i in smaller_funcs_container:
|
||||
# Not already in count_map? It can't possibly be in the output, so
|
||||
# skip it.
|
||||
if count_map[func_i] < 1:
|
||||
continue
|
||||
|
||||
if func_i in larger_funcs_container:
|
||||
count_map[func_i] += 1
|
||||
|
||||
return {k: v for k, v in count_map.items() if v >= 2}
|
||||
|
||||
def get_subset_candidates(self, argset, restrict_to_funcset=None):
|
||||
"""
|
||||
Return a set of functions each of which whose argument list contains
|
||||
``argset``, optionally filtered only to contain functions in
|
||||
``restrict_to_funcset``.
|
||||
"""
|
||||
iarg = iter(argset)
|
||||
|
||||
indices = OrderedSet(
|
||||
fi for fi in self.arg_to_funcset[next(iarg)])
|
||||
|
||||
if restrict_to_funcset is not None:
|
||||
indices &= restrict_to_funcset
|
||||
|
||||
for arg in iarg:
|
||||
indices &= self.arg_to_funcset[arg]
|
||||
|
||||
return indices
|
||||
|
||||
def update_func_argset(self, func_i, new_argset):
|
||||
"""
|
||||
Update a function with a new set of arguments.
|
||||
"""
|
||||
new_args = OrderedSet(new_argset)
|
||||
old_args = self.func_to_argset[func_i]
|
||||
|
||||
for deleted_arg in old_args - new_args:
|
||||
self.arg_to_funcset[deleted_arg].remove(func_i)
|
||||
for added_arg in new_args - old_args:
|
||||
self.arg_to_funcset[added_arg].add(func_i)
|
||||
|
||||
self.func_to_argset[func_i].clear()
|
||||
self.func_to_argset[func_i].update(new_args)
|
||||
|
||||
|
||||
class Unevaluated:
|
||||
|
||||
def __init__(self, func, args):
|
||||
self.func = func
|
||||
self.args = args
|
||||
|
||||
def __str__(self):
|
||||
return "Uneval<{}>({})".format(
|
||||
self.func, ", ".join(str(a) for a in self.args))
|
||||
|
||||
def as_unevaluated_basic(self):
|
||||
return self.func(*self.args, evaluate=False)
|
||||
|
||||
@property
|
||||
def free_symbols(self):
|
||||
return set().union(*[a.free_symbols for a in self.args])
|
||||
|
||||
__repr__ = __str__
|
||||
|
||||
|
||||
def match_common_args(func_class, funcs, opt_subs):
|
||||
"""
|
||||
Recognize and extract common subexpressions of function arguments within a
|
||||
set of function calls. For instance, for the following function calls::
|
||||
|
||||
x + z + y
|
||||
sin(x + y)
|
||||
|
||||
this will extract a common subexpression of `x + y`::
|
||||
|
||||
w = x + y
|
||||
w + z
|
||||
sin(w)
|
||||
|
||||
The function we work with is assumed to be associative and commutative.
|
||||
|
||||
Parameters
|
||||
==========
|
||||
|
||||
func_class: class
|
||||
The function class (e.g. Add, Mul)
|
||||
funcs: list of functions
|
||||
A list of function calls.
|
||||
opt_subs: dict
|
||||
A dictionary of substitutions which this function may update.
|
||||
"""
|
||||
|
||||
# Sort to ensure that whole-function subexpressions come before the items
|
||||
# that use them.
|
||||
funcs = sorted(funcs, key=lambda f: len(f.args))
|
||||
arg_tracker = FuncArgTracker(funcs)
|
||||
|
||||
changed = OrderedSet()
|
||||
|
||||
for i in range(len(funcs)):
|
||||
common_arg_candidates_counts = arg_tracker.get_common_arg_candidates(
|
||||
arg_tracker.func_to_argset[i], min_func_i=i + 1)
|
||||
|
||||
# Sort the candidates in order of match size.
|
||||
# This makes us try combining smaller matches first.
|
||||
common_arg_candidates = OrderedSet(sorted(
|
||||
common_arg_candidates_counts.keys(),
|
||||
key=lambda k: (common_arg_candidates_counts[k], k)))
|
||||
|
||||
while common_arg_candidates:
|
||||
j = common_arg_candidates.pop(last=False)
|
||||
|
||||
com_args = arg_tracker.func_to_argset[i].intersection(
|
||||
arg_tracker.func_to_argset[j])
|
||||
|
||||
if len(com_args) <= 1:
|
||||
# This may happen if a set of common arguments was already
|
||||
# combined in a previous iteration.
|
||||
continue
|
||||
|
||||
# For all sets, replace the common symbols by the function
|
||||
# over them, to allow recursive matches.
|
||||
|
||||
diff_i = arg_tracker.func_to_argset[i].difference(com_args)
|
||||
if diff_i:
|
||||
# com_func needs to be unevaluated to allow for recursive matches.
|
||||
com_func = Unevaluated(
|
||||
func_class, arg_tracker.get_args_in_value_order(com_args))
|
||||
com_func_number = arg_tracker.get_or_add_value_number(com_func)
|
||||
arg_tracker.update_func_argset(i, diff_i | OrderedSet([com_func_number]))
|
||||
changed.add(i)
|
||||
else:
|
||||
# Treat the whole expression as a CSE.
|
||||
#
|
||||
# The reason this needs to be done is somewhat subtle. Within
|
||||
# tree_cse(), to_eliminate only contains expressions that are
|
||||
# seen more than once. The problem is unevaluated expressions
|
||||
# do not compare equal to the evaluated equivalent. So
|
||||
# tree_cse() won't mark funcs[i] as a CSE if we use an
|
||||
# unevaluated version.
|
||||
com_func_number = arg_tracker.get_or_add_value_number(funcs[i])
|
||||
|
||||
diff_j = arg_tracker.func_to_argset[j].difference(com_args)
|
||||
arg_tracker.update_func_argset(j, diff_j | OrderedSet([com_func_number]))
|
||||
changed.add(j)
|
||||
|
||||
for k in arg_tracker.get_subset_candidates(
|
||||
com_args, common_arg_candidates):
|
||||
diff_k = arg_tracker.func_to_argset[k].difference(com_args)
|
||||
arg_tracker.update_func_argset(k, diff_k | OrderedSet([com_func_number]))
|
||||
changed.add(k)
|
||||
|
||||
if i in changed:
|
||||
opt_subs[funcs[i]] = Unevaluated(func_class,
|
||||
arg_tracker.get_args_in_value_order(arg_tracker.func_to_argset[i]))
|
||||
|
||||
arg_tracker.stop_arg_tracking(i)
|
||||
|
||||
|
||||
def opt_cse(exprs, order='canonical'):
|
||||
"""Find optimization opportunities in Adds, Muls, Pows and negative
|
||||
coefficient Muls.
|
||||
|
||||
Parameters
|
||||
==========
|
||||
|
||||
exprs : list of SymPy expressions
|
||||
The expressions to optimize.
|
||||
order : string, 'none' or 'canonical'
|
||||
The order by which Mul and Add arguments are processed. For large
|
||||
expressions where speed is a concern, use the setting order='none'.
|
||||
|
||||
Returns
|
||||
=======
|
||||
|
||||
opt_subs : dictionary of expression substitutions
|
||||
The expression substitutions which can be useful to optimize CSE.
|
||||
|
||||
Examples
|
||||
========
|
||||
|
||||
>>> from sympy.simplify.cse_main import opt_cse
|
||||
>>> from sympy.abc import x
|
||||
>>> opt_subs = opt_cse([x**-2])
|
||||
>>> k, v = list(opt_subs.keys())[0], list(opt_subs.values())[0]
|
||||
>>> print((k, v.as_unevaluated_basic()))
|
||||
(x**(-2), 1/(x**2))
|
||||
"""
|
||||
opt_subs = {}
|
||||
|
||||
adds = OrderedSet()
|
||||
muls = OrderedSet()
|
||||
|
||||
seen_subexp = set()
|
||||
collapsible_subexp = set()
|
||||
|
||||
def _find_opts(expr):
|
||||
|
||||
if not isinstance(expr, (Basic, Unevaluated)):
|
||||
return
|
||||
|
||||
if expr.is_Atom or expr.is_Order:
|
||||
return
|
||||
|
||||
if iterable(expr):
|
||||
list(map(_find_opts, expr))
|
||||
return
|
||||
|
||||
if expr in seen_subexp:
|
||||
return expr
|
||||
seen_subexp.add(expr)
|
||||
|
||||
list(map(_find_opts, expr.args))
|
||||
|
||||
if not isinstance(expr, MatrixExpr) and expr.could_extract_minus_sign():
|
||||
# XXX -expr does not always work rigorously for some expressions
|
||||
# containing UnevaluatedExpr.
|
||||
# https://github.com/sympy/sympy/issues/24818
|
||||
if isinstance(expr, Add):
|
||||
neg_expr = Add(*(-i for i in expr.args))
|
||||
else:
|
||||
neg_expr = -expr
|
||||
|
||||
if not neg_expr.is_Atom:
|
||||
opt_subs[expr] = Unevaluated(Mul, (S.NegativeOne, neg_expr))
|
||||
seen_subexp.add(neg_expr)
|
||||
expr = neg_expr
|
||||
|
||||
if isinstance(expr, (Mul, MatMul)):
|
||||
if len(expr.args) == 1:
|
||||
collapsible_subexp.add(expr)
|
||||
else:
|
||||
muls.add(expr)
|
||||
|
||||
elif isinstance(expr, (Add, MatAdd)):
|
||||
if len(expr.args) == 1:
|
||||
collapsible_subexp.add(expr)
|
||||
else:
|
||||
adds.add(expr)
|
||||
|
||||
elif isinstance(expr, Inverse):
|
||||
# Do not want to treat `Inverse` as a `MatPow`
|
||||
pass
|
||||
|
||||
elif isinstance(expr, (Pow, MatPow)):
|
||||
base, exp = expr.base, expr.exp
|
||||
if exp.could_extract_minus_sign():
|
||||
opt_subs[expr] = Unevaluated(Pow, (Pow(base, -exp), -1))
|
||||
|
||||
for e in exprs:
|
||||
if isinstance(e, (Basic, Unevaluated)):
|
||||
_find_opts(e)
|
||||
|
||||
# Handle collapsing of multinary operations with single arguments
|
||||
edges = [(s, s.args[0]) for s in collapsible_subexp
|
||||
if s.args[0] in collapsible_subexp]
|
||||
for e in reversed(topological_sort((collapsible_subexp, edges))):
|
||||
opt_subs[e] = opt_subs.get(e.args[0], e.args[0])
|
||||
|
||||
# split muls into commutative
|
||||
commutative_muls = OrderedSet()
|
||||
for m in muls:
|
||||
c, nc = m.args_cnc(cset=False)
|
||||
if c:
|
||||
c_mul = m.func(*c)
|
||||
if nc:
|
||||
if c_mul == 1:
|
||||
new_obj = m.func(*nc)
|
||||
else:
|
||||
if isinstance(m, MatMul):
|
||||
new_obj = m.func(c_mul, *nc, evaluate=False)
|
||||
else:
|
||||
new_obj = m.func(c_mul, m.func(*nc), evaluate=False)
|
||||
opt_subs[m] = new_obj
|
||||
if len(c) > 1:
|
||||
commutative_muls.add(c_mul)
|
||||
|
||||
match_common_args(Add, adds, opt_subs)
|
||||
match_common_args(Mul, commutative_muls, opt_subs)
|
||||
|
||||
return opt_subs
|
||||
|
||||
|
||||
def tree_cse(exprs, symbols, opt_subs=None, order='canonical', ignore=()):
|
||||
"""Perform raw CSE on expression tree, taking opt_subs into account.
|
||||
|
||||
Parameters
|
||||
==========
|
||||
|
||||
exprs : list of SymPy expressions
|
||||
The expressions to reduce.
|
||||
symbols : infinite iterator yielding unique Symbols
|
||||
The symbols used to label the common subexpressions which are pulled
|
||||
out.
|
||||
opt_subs : dictionary of expression substitutions
|
||||
The expressions to be substituted before any CSE action is performed.
|
||||
order : string, 'none' or 'canonical'
|
||||
The order by which Mul and Add arguments are processed. For large
|
||||
expressions where speed is a concern, use the setting order='none'.
|
||||
ignore : iterable of Symbols
|
||||
Substitutions containing any Symbol from ``ignore`` will be ignored.
|
||||
"""
|
||||
if opt_subs is None:
|
||||
opt_subs = {}
|
||||
|
||||
## Find repeated sub-expressions
|
||||
|
||||
to_eliminate = set()
|
||||
|
||||
seen_subexp = set()
|
||||
excluded_symbols = set()
|
||||
|
||||
def _find_repeated(expr):
|
||||
if not isinstance(expr, (Basic, Unevaluated)):
|
||||
return
|
||||
|
||||
if isinstance(expr, RootOf):
|
||||
return
|
||||
|
||||
if isinstance(expr, Basic) and (
|
||||
expr.is_Atom or
|
||||
expr.is_Order or
|
||||
isinstance(expr, (MatrixSymbol, MatrixElement))):
|
||||
if expr.is_Symbol:
|
||||
excluded_symbols.add(expr.name)
|
||||
return
|
||||
|
||||
if iterable(expr):
|
||||
args = expr
|
||||
|
||||
else:
|
||||
if expr in seen_subexp:
|
||||
for ign in ignore:
|
||||
if ign in expr.free_symbols:
|
||||
break
|
||||
else:
|
||||
to_eliminate.add(expr)
|
||||
return
|
||||
|
||||
seen_subexp.add(expr)
|
||||
|
||||
if expr in opt_subs:
|
||||
expr = opt_subs[expr]
|
||||
|
||||
args = expr.args
|
||||
|
||||
list(map(_find_repeated, args))
|
||||
|
||||
for e in exprs:
|
||||
if isinstance(e, Basic):
|
||||
_find_repeated(e)
|
||||
|
||||
## Rebuild tree
|
||||
|
||||
# Remove symbols from the generator that conflict with names in the expressions.
|
||||
symbols = (_ for _ in symbols if _.name not in excluded_symbols)
|
||||
|
||||
replacements = []
|
||||
|
||||
subs = {}
|
||||
|
||||
def _rebuild(expr):
|
||||
if not isinstance(expr, (Basic, Unevaluated)):
|
||||
return expr
|
||||
|
||||
if not expr.args:
|
||||
return expr
|
||||
|
||||
if iterable(expr):
|
||||
new_args = [_rebuild(arg) for arg in expr.args]
|
||||
return expr.func(*new_args)
|
||||
|
||||
if expr in subs:
|
||||
return subs[expr]
|
||||
|
||||
orig_expr = expr
|
||||
if expr in opt_subs:
|
||||
expr = opt_subs[expr]
|
||||
|
||||
# If enabled, parse Muls and Adds arguments by order to ensure
|
||||
# replacement order independent from hashes
|
||||
if order != 'none':
|
||||
if isinstance(expr, (Mul, MatMul)):
|
||||
c, nc = expr.args_cnc()
|
||||
if c == [1]:
|
||||
args = nc
|
||||
else:
|
||||
args = list(ordered(c)) + nc
|
||||
elif isinstance(expr, (Add, MatAdd)):
|
||||
args = list(ordered(expr.args))
|
||||
else:
|
||||
args = expr.args
|
||||
else:
|
||||
args = expr.args
|
||||
|
||||
new_args = list(map(_rebuild, args))
|
||||
if isinstance(expr, Unevaluated) or new_args != args:
|
||||
new_expr = expr.func(*new_args)
|
||||
else:
|
||||
new_expr = expr
|
||||
|
||||
if orig_expr in to_eliminate:
|
||||
try:
|
||||
sym = next(symbols)
|
||||
except StopIteration:
|
||||
raise ValueError("Symbols iterator ran out of symbols.")
|
||||
|
||||
if isinstance(orig_expr, MatrixExpr):
|
||||
sym = MatrixSymbol(sym.name, orig_expr.rows,
|
||||
orig_expr.cols)
|
||||
|
||||
subs[orig_expr] = sym
|
||||
replacements.append((sym, new_expr))
|
||||
return sym
|
||||
|
||||
else:
|
||||
return new_expr
|
||||
|
||||
reduced_exprs = []
|
||||
for e in exprs:
|
||||
if isinstance(e, Basic):
|
||||
reduced_e = _rebuild(e)
|
||||
else:
|
||||
reduced_e = e
|
||||
reduced_exprs.append(reduced_e)
|
||||
return replacements, reduced_exprs
|
||||
|
||||
|
||||
def cse(exprs, symbols=None, optimizations=None, postprocess=None,
|
||||
order='canonical', ignore=(), list=True):
|
||||
""" Perform common subexpression elimination on an expression.
|
||||
|
||||
Parameters
|
||||
==========
|
||||
|
||||
exprs : list of SymPy expressions, or a single SymPy expression
|
||||
The expressions to reduce.
|
||||
symbols : infinite iterator yielding unique Symbols
|
||||
The symbols used to label the common subexpressions which are pulled
|
||||
out. The ``numbered_symbols`` generator is useful. The default is a
|
||||
stream of symbols of the form "x0", "x1", etc. This must be an
|
||||
infinite iterator.
|
||||
optimizations : list of (callable, callable) pairs
|
||||
The (preprocessor, postprocessor) pairs of external optimization
|
||||
functions. Optionally 'basic' can be passed for a set of predefined
|
||||
basic optimizations. Such 'basic' optimizations were used by default
|
||||
in old implementation, however they can be really slow on larger
|
||||
expressions. Now, no pre or post optimizations are made by default.
|
||||
postprocess : a function which accepts the two return values of cse and
|
||||
returns the desired form of output from cse, e.g. if you want the
|
||||
replacements reversed the function might be the following lambda:
|
||||
lambda r, e: return reversed(r), e
|
||||
order : string, 'none' or 'canonical'
|
||||
The order by which Mul and Add arguments are processed. If set to
|
||||
'canonical', arguments will be canonically ordered. If set to 'none',
|
||||
ordering will be faster but dependent on expressions hashes, thus
|
||||
machine dependent and variable. For large expressions where speed is a
|
||||
concern, use the setting order='none'.
|
||||
ignore : iterable of Symbols
|
||||
Substitutions containing any Symbol from ``ignore`` will be ignored.
|
||||
list : bool, (default True)
|
||||
Returns expression in list or else with same type as input (when False).
|
||||
|
||||
Returns
|
||||
=======
|
||||
|
||||
replacements : list of (Symbol, expression) pairs
|
||||
All of the common subexpressions that were replaced. Subexpressions
|
||||
earlier in this list might show up in subexpressions later in this
|
||||
list.
|
||||
reduced_exprs : list of SymPy expressions
|
||||
The reduced expressions with all of the replacements above.
|
||||
|
||||
Examples
|
||||
========
|
||||
|
||||
>>> from sympy import cse, SparseMatrix
|
||||
>>> from sympy.abc import x, y, z, w
|
||||
>>> cse(((w + x + y + z)*(w + y + z))/(w + x)**3)
|
||||
([(x0, y + z), (x1, w + x)], [(w + x0)*(x0 + x1)/x1**3])
|
||||
|
||||
|
||||
List of expressions with recursive substitutions:
|
||||
|
||||
>>> m = SparseMatrix([x + y, x + y + z])
|
||||
>>> cse([(x+y)**2, x + y + z, y + z, x + z + y, m])
|
||||
([(x0, x + y), (x1, x0 + z)], [x0**2, x1, y + z, x1, Matrix([
|
||||
[x0],
|
||||
[x1]])])
|
||||
|
||||
Note: the type and mutability of input matrices is retained.
|
||||
|
||||
>>> isinstance(_[1][-1], SparseMatrix)
|
||||
True
|
||||
|
||||
The user may disallow substitutions containing certain symbols:
|
||||
|
||||
>>> cse([y**2*(x + 1), 3*y**2*(x + 1)], ignore=(y,))
|
||||
([(x0, x + 1)], [x0*y**2, 3*x0*y**2])
|
||||
|
||||
The default return value for the reduced expression(s) is a list, even if there is only
|
||||
one expression. The `list` flag preserves the type of the input in the output:
|
||||
|
||||
>>> cse(x)
|
||||
([], [x])
|
||||
>>> cse(x, list=False)
|
||||
([], x)
|
||||
"""
|
||||
if not list:
|
||||
return _cse_homogeneous(exprs,
|
||||
symbols=symbols, optimizations=optimizations,
|
||||
postprocess=postprocess, order=order, ignore=ignore)
|
||||
|
||||
if isinstance(exprs, (int, float)):
|
||||
exprs = sympify(exprs)
|
||||
|
||||
# Handle the case if just one expression was passed.
|
||||
if isinstance(exprs, (Basic, MatrixBase)):
|
||||
exprs = [exprs]
|
||||
|
||||
copy = exprs
|
||||
temp = []
|
||||
for e in exprs:
|
||||
if isinstance(e, (Matrix, ImmutableMatrix)):
|
||||
temp.append(Tuple(*e.flat()))
|
||||
elif isinstance(e, (SparseMatrix, ImmutableSparseMatrix)):
|
||||
temp.append(Tuple(*e.todok().items()))
|
||||
else:
|
||||
temp.append(e)
|
||||
exprs = temp
|
||||
del temp
|
||||
|
||||
if optimizations is None:
|
||||
optimizations = []
|
||||
elif optimizations == 'basic':
|
||||
optimizations = basic_optimizations
|
||||
|
||||
# Preprocess the expressions to give us better optimization opportunities.
|
||||
reduced_exprs = [preprocess_for_cse(e, optimizations) for e in exprs]
|
||||
|
||||
if symbols is None:
|
||||
symbols = numbered_symbols(cls=Symbol)
|
||||
else:
|
||||
# In case we get passed an iterable with an __iter__ method instead of
|
||||
# an actual iterator.
|
||||
symbols = iter(symbols)
|
||||
|
||||
# Find other optimization opportunities.
|
||||
opt_subs = opt_cse(reduced_exprs, order)
|
||||
|
||||
# Main CSE algorithm.
|
||||
replacements, reduced_exprs = tree_cse(reduced_exprs, symbols, opt_subs,
|
||||
order, ignore)
|
||||
|
||||
# Postprocess the expressions to return the expressions to canonical form.
|
||||
exprs = copy
|
||||
replacements = [(sym, postprocess_for_cse(subtree, optimizations))
|
||||
for sym, subtree in replacements]
|
||||
reduced_exprs = [postprocess_for_cse(e, optimizations)
|
||||
for e in reduced_exprs]
|
||||
|
||||
# Get the matrices back
|
||||
for i, e in enumerate(exprs):
|
||||
if isinstance(e, (Matrix, ImmutableMatrix)):
|
||||
reduced_exprs[i] = Matrix(e.rows, e.cols, reduced_exprs[i])
|
||||
if isinstance(e, ImmutableMatrix):
|
||||
reduced_exprs[i] = reduced_exprs[i].as_immutable()
|
||||
elif isinstance(e, (SparseMatrix, ImmutableSparseMatrix)):
|
||||
m = SparseMatrix(e.rows, e.cols, {})
|
||||
for k, v in reduced_exprs[i]:
|
||||
m[k] = v
|
||||
if isinstance(e, ImmutableSparseMatrix):
|
||||
m = m.as_immutable()
|
||||
reduced_exprs[i] = m
|
||||
|
||||
if postprocess is None:
|
||||
return replacements, reduced_exprs
|
||||
|
||||
return postprocess(replacements, reduced_exprs)
|
||||
|
||||
|
||||
def _cse_homogeneous(exprs, **kwargs):
|
||||
"""
|
||||
Same as ``cse`` but the ``reduced_exprs`` are returned
|
||||
with the same type as ``exprs`` or a sympified version of the same.
|
||||
|
||||
Parameters
|
||||
==========
|
||||
|
||||
exprs : an Expr, iterable of Expr or dictionary with Expr values
|
||||
the expressions in which repeated subexpressions will be identified
|
||||
kwargs : additional arguments for the ``cse`` function
|
||||
|
||||
Returns
|
||||
=======
|
||||
|
||||
replacements : list of (Symbol, expression) pairs
|
||||
All of the common subexpressions that were replaced. Subexpressions
|
||||
earlier in this list might show up in subexpressions later in this
|
||||
list.
|
||||
reduced_exprs : list of SymPy expressions
|
||||
The reduced expressions with all of the replacements above.
|
||||
|
||||
Examples
|
||||
========
|
||||
|
||||
>>> from sympy.simplify.cse_main import cse
|
||||
>>> from sympy import cos, Tuple, Matrix
|
||||
>>> from sympy.abc import x
|
||||
>>> output = lambda x: type(cse(x, list=False)[1])
|
||||
>>> output(1)
|
||||
<class 'sympy.core.numbers.One'>
|
||||
>>> output('cos(x)')
|
||||
<class 'str'>
|
||||
>>> output(cos(x))
|
||||
cos
|
||||
>>> output(Tuple(1, x))
|
||||
<class 'sympy.core.containers.Tuple'>
|
||||
>>> output(Matrix([[1,0], [0,1]]))
|
||||
<class 'sympy.matrices.dense.MutableDenseMatrix'>
|
||||
>>> output([1, x])
|
||||
<class 'list'>
|
||||
>>> output((1, x))
|
||||
<class 'tuple'>
|
||||
>>> output({1, x})
|
||||
<class 'set'>
|
||||
"""
|
||||
if isinstance(exprs, str):
|
||||
replacements, reduced_exprs = _cse_homogeneous(
|
||||
sympify(exprs), **kwargs)
|
||||
return replacements, repr(reduced_exprs)
|
||||
if isinstance(exprs, (list, tuple, set)):
|
||||
replacements, reduced_exprs = cse(exprs, **kwargs)
|
||||
return replacements, type(exprs)(reduced_exprs)
|
||||
if isinstance(exprs, dict):
|
||||
keys = list(exprs.keys()) # In order to guarantee the order of the elements.
|
||||
replacements, values = cse([exprs[k] for k in keys], **kwargs)
|
||||
reduced_exprs = dict(zip(keys, values))
|
||||
return replacements, reduced_exprs
|
||||
|
||||
try:
|
||||
replacements, (reduced_exprs,) = cse(exprs, **kwargs)
|
||||
except TypeError: # For example 'mpf' objects
|
||||
return [], exprs
|
||||
else:
|
||||
return replacements, reduced_exprs
|
||||
@@ -0,0 +1,52 @@
|
||||
""" Optimizations of the expression tree representation for better CSE
|
||||
opportunities.
|
||||
"""
|
||||
from sympy.core import Add, Basic, Mul
|
||||
from sympy.core.singleton import S
|
||||
from sympy.core.sorting import default_sort_key
|
||||
from sympy.core.traversal import preorder_traversal
|
||||
|
||||
|
||||
def sub_pre(e):
|
||||
""" Replace y - x with -(x - y) if -1 can be extracted from y - x.
|
||||
"""
|
||||
# replacing Add, A, from which -1 can be extracted with -1*-A
|
||||
adds = [a for a in e.atoms(Add) if a.could_extract_minus_sign()]
|
||||
reps = {}
|
||||
ignore = set()
|
||||
for a in adds:
|
||||
na = -a
|
||||
if na.is_Mul: # e.g. MatExpr
|
||||
ignore.add(a)
|
||||
continue
|
||||
reps[a] = Mul._from_args([S.NegativeOne, na])
|
||||
|
||||
e = e.xreplace(reps)
|
||||
|
||||
# repeat again for persisting Adds but mark these with a leading 1, -1
|
||||
# e.g. y - x -> 1*-1*(x - y)
|
||||
if isinstance(e, Basic):
|
||||
negs = {}
|
||||
for a in sorted(e.atoms(Add), key=default_sort_key):
|
||||
if a in ignore:
|
||||
continue
|
||||
if a in reps:
|
||||
negs[a] = reps[a]
|
||||
elif a.could_extract_minus_sign():
|
||||
negs[a] = Mul._from_args([S.One, S.NegativeOne, -a])
|
||||
e = e.xreplace(negs)
|
||||
return e
|
||||
|
||||
|
||||
def sub_post(e):
|
||||
""" Replace 1*-1*x with -x.
|
||||
"""
|
||||
replacements = []
|
||||
for node in preorder_traversal(e):
|
||||
if isinstance(node, Mul) and \
|
||||
node.args[0] is S.One and node.args[1] is S.NegativeOne:
|
||||
replacements.append((node, -Mul._from_args(node.args[2:])))
|
||||
for node, replacement in replacements:
|
||||
e = e.xreplace({node: replacement})
|
||||
|
||||
return e
|
||||
@@ -0,0 +1,352 @@
|
||||
"""Tools for manipulation of expressions using paths. """
|
||||
|
||||
from sympy.core import Basic
|
||||
|
||||
|
||||
class EPath:
|
||||
r"""
|
||||
Manipulate expressions using paths.
|
||||
|
||||
EPath grammar in EBNF notation::
|
||||
|
||||
literal ::= /[A-Za-z_][A-Za-z_0-9]*/
|
||||
number ::= /-?\d+/
|
||||
type ::= literal
|
||||
attribute ::= literal "?"
|
||||
all ::= "*"
|
||||
slice ::= "[" number? (":" number? (":" number?)?)? "]"
|
||||
range ::= all | slice
|
||||
query ::= (type | attribute) ("|" (type | attribute))*
|
||||
selector ::= range | query range?
|
||||
path ::= "/" selector ("/" selector)*
|
||||
|
||||
See the docstring of the epath() function.
|
||||
|
||||
"""
|
||||
|
||||
__slots__ = ("_path", "_epath")
|
||||
|
||||
def __new__(cls, path):
|
||||
"""Construct new EPath. """
|
||||
if isinstance(path, EPath):
|
||||
return path
|
||||
|
||||
if not path:
|
||||
raise ValueError("empty EPath")
|
||||
|
||||
_path = path
|
||||
|
||||
if path[0] == '/':
|
||||
path = path[1:]
|
||||
else:
|
||||
raise NotImplementedError("non-root EPath")
|
||||
|
||||
epath = []
|
||||
|
||||
for selector in path.split('/'):
|
||||
selector = selector.strip()
|
||||
|
||||
if not selector:
|
||||
raise ValueError("empty selector")
|
||||
|
||||
index = 0
|
||||
|
||||
for c in selector:
|
||||
if c.isalnum() or c in ('_', '|', '?'):
|
||||
index += 1
|
||||
else:
|
||||
break
|
||||
|
||||
attrs = []
|
||||
types = []
|
||||
|
||||
if index:
|
||||
elements = selector[:index]
|
||||
selector = selector[index:]
|
||||
|
||||
for element in elements.split('|'):
|
||||
element = element.strip()
|
||||
|
||||
if not element:
|
||||
raise ValueError("empty element")
|
||||
|
||||
if element.endswith('?'):
|
||||
attrs.append(element[:-1])
|
||||
else:
|
||||
types.append(element)
|
||||
|
||||
span = None
|
||||
|
||||
if selector == '*':
|
||||
pass
|
||||
else:
|
||||
if selector.startswith('['):
|
||||
try:
|
||||
i = selector.index(']')
|
||||
except ValueError:
|
||||
raise ValueError("expected ']', got EOL")
|
||||
|
||||
_span, span = selector[1:i], []
|
||||
|
||||
if ':' not in _span:
|
||||
span = int(_span)
|
||||
else:
|
||||
for elt in _span.split(':', 3):
|
||||
if not elt:
|
||||
span.append(None)
|
||||
else:
|
||||
span.append(int(elt))
|
||||
|
||||
span = slice(*span)
|
||||
|
||||
selector = selector[i + 1:]
|
||||
|
||||
if selector:
|
||||
raise ValueError("trailing characters in selector")
|
||||
|
||||
epath.append((attrs, types, span))
|
||||
|
||||
obj = object.__new__(cls)
|
||||
|
||||
obj._path = _path
|
||||
obj._epath = epath
|
||||
|
||||
return obj
|
||||
|
||||
def __repr__(self):
|
||||
return "%s(%r)" % (self.__class__.__name__, self._path)
|
||||
|
||||
def _get_ordered_args(self, expr):
|
||||
"""Sort ``expr.args`` using printing order. """
|
||||
if expr.is_Add:
|
||||
return expr.as_ordered_terms()
|
||||
elif expr.is_Mul:
|
||||
return expr.as_ordered_factors()
|
||||
else:
|
||||
return expr.args
|
||||
|
||||
def _hasattrs(self, expr, attrs) -> bool:
|
||||
"""Check if ``expr`` has any of ``attrs``. """
|
||||
return all(hasattr(expr, attr) for attr in attrs)
|
||||
|
||||
def _hastypes(self, expr, types):
|
||||
"""Check if ``expr`` is any of ``types``. """
|
||||
_types = [ cls.__name__ for cls in expr.__class__.mro() ]
|
||||
return bool(set(_types).intersection(types))
|
||||
|
||||
def _has(self, expr, attrs, types):
|
||||
"""Apply ``_hasattrs`` and ``_hastypes`` to ``expr``. """
|
||||
if not (attrs or types):
|
||||
return True
|
||||
|
||||
if attrs and self._hasattrs(expr, attrs):
|
||||
return True
|
||||
|
||||
if types and self._hastypes(expr, types):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def apply(self, expr, func, args=None, kwargs=None):
|
||||
"""
|
||||
Modify parts of an expression selected by a path.
|
||||
|
||||
Examples
|
||||
========
|
||||
|
||||
>>> from sympy.simplify.epathtools import EPath
|
||||
>>> from sympy import sin, cos, E
|
||||
>>> from sympy.abc import x, y, z, t
|
||||
|
||||
>>> path = EPath("/*/[0]/Symbol")
|
||||
>>> expr = [((x, 1), 2), ((3, y), z)]
|
||||
|
||||
>>> path.apply(expr, lambda expr: expr**2)
|
||||
[((x**2, 1), 2), ((3, y**2), z)]
|
||||
|
||||
>>> path = EPath("/*/*/Symbol")
|
||||
>>> expr = t + sin(x + 1) + cos(x + y + E)
|
||||
|
||||
>>> path.apply(expr, lambda expr: 2*expr)
|
||||
t + sin(2*x + 1) + cos(2*x + 2*y + E)
|
||||
|
||||
"""
|
||||
def _apply(path, expr, func):
|
||||
if not path:
|
||||
return func(expr)
|
||||
else:
|
||||
selector, path = path[0], path[1:]
|
||||
attrs, types, span = selector
|
||||
|
||||
if isinstance(expr, Basic):
|
||||
if not expr.is_Atom:
|
||||
args, basic = self._get_ordered_args(expr), True
|
||||
else:
|
||||
return expr
|
||||
elif hasattr(expr, '__iter__'):
|
||||
args, basic = expr, False
|
||||
else:
|
||||
return expr
|
||||
|
||||
args = list(args)
|
||||
|
||||
if span is not None:
|
||||
if isinstance(span, slice):
|
||||
indices = range(*span.indices(len(args)))
|
||||
else:
|
||||
indices = [span]
|
||||
else:
|
||||
indices = range(len(args))
|
||||
|
||||
for i in indices:
|
||||
try:
|
||||
arg = args[i]
|
||||
except IndexError:
|
||||
continue
|
||||
|
||||
if self._has(arg, attrs, types):
|
||||
args[i] = _apply(path, arg, func)
|
||||
|
||||
if basic:
|
||||
return expr.func(*args)
|
||||
else:
|
||||
return expr.__class__(args)
|
||||
|
||||
_args, _kwargs = args or (), kwargs or {}
|
||||
_func = lambda expr: func(expr, *_args, **_kwargs)
|
||||
|
||||
return _apply(self._epath, expr, _func)
|
||||
|
||||
def select(self, expr):
|
||||
"""
|
||||
Retrieve parts of an expression selected by a path.
|
||||
|
||||
Examples
|
||||
========
|
||||
|
||||
>>> from sympy.simplify.epathtools import EPath
|
||||
>>> from sympy import sin, cos, E
|
||||
>>> from sympy.abc import x, y, z, t
|
||||
|
||||
>>> path = EPath("/*/[0]/Symbol")
|
||||
>>> expr = [((x, 1), 2), ((3, y), z)]
|
||||
|
||||
>>> path.select(expr)
|
||||
[x, y]
|
||||
|
||||
>>> path = EPath("/*/*/Symbol")
|
||||
>>> expr = t + sin(x + 1) + cos(x + y + E)
|
||||
|
||||
>>> path.select(expr)
|
||||
[x, x, y]
|
||||
|
||||
"""
|
||||
result = []
|
||||
|
||||
def _select(path, expr):
|
||||
if not path:
|
||||
result.append(expr)
|
||||
else:
|
||||
selector, path = path[0], path[1:]
|
||||
attrs, types, span = selector
|
||||
|
||||
if isinstance(expr, Basic):
|
||||
args = self._get_ordered_args(expr)
|
||||
elif hasattr(expr, '__iter__'):
|
||||
args = expr
|
||||
else:
|
||||
return
|
||||
|
||||
if span is not None:
|
||||
if isinstance(span, slice):
|
||||
args = args[span]
|
||||
else:
|
||||
try:
|
||||
args = [args[span]]
|
||||
except IndexError:
|
||||
return
|
||||
|
||||
for arg in args:
|
||||
if self._has(arg, attrs, types):
|
||||
_select(path, arg)
|
||||
|
||||
_select(self._epath, expr)
|
||||
return result
|
||||
|
||||
|
||||
def epath(path, expr=None, func=None, args=None, kwargs=None):
|
||||
r"""
|
||||
Manipulate parts of an expression selected by a path.
|
||||
|
||||
Explanation
|
||||
===========
|
||||
|
||||
This function allows to manipulate large nested expressions in single
|
||||
line of code, utilizing techniques to those applied in XML processing
|
||||
standards (e.g. XPath).
|
||||
|
||||
If ``func`` is ``None``, :func:`epath` retrieves elements selected by
|
||||
the ``path``. Otherwise it applies ``func`` to each matching element.
|
||||
|
||||
Note that it is more efficient to create an EPath object and use the select
|
||||
and apply methods of that object, since this will compile the path string
|
||||
only once. This function should only be used as a convenient shortcut for
|
||||
interactive use.
|
||||
|
||||
This is the supported syntax:
|
||||
|
||||
* select all: ``/*``
|
||||
Equivalent of ``for arg in args:``.
|
||||
* select slice: ``/[0]`` or ``/[1:5]`` or ``/[1:5:2]``
|
||||
Supports standard Python's slice syntax.
|
||||
* select by type: ``/list`` or ``/list|tuple``
|
||||
Emulates ``isinstance()``.
|
||||
* select by attribute: ``/__iter__?``
|
||||
Emulates ``hasattr()``.
|
||||
|
||||
Parameters
|
||||
==========
|
||||
|
||||
path : str | EPath
|
||||
A path as a string or a compiled EPath.
|
||||
expr : Basic | iterable
|
||||
An expression or a container of expressions.
|
||||
func : callable (optional)
|
||||
A callable that will be applied to matching parts.
|
||||
args : tuple (optional)
|
||||
Additional positional arguments to ``func``.
|
||||
kwargs : dict (optional)
|
||||
Additional keyword arguments to ``func``.
|
||||
|
||||
Examples
|
||||
========
|
||||
|
||||
>>> from sympy.simplify.epathtools import epath
|
||||
>>> from sympy import sin, cos, E
|
||||
>>> from sympy.abc import x, y, z, t
|
||||
|
||||
>>> path = "/*/[0]/Symbol"
|
||||
>>> expr = [((x, 1), 2), ((3, y), z)]
|
||||
|
||||
>>> epath(path, expr)
|
||||
[x, y]
|
||||
>>> epath(path, expr, lambda expr: expr**2)
|
||||
[((x**2, 1), 2), ((3, y**2), z)]
|
||||
|
||||
>>> path = "/*/*/Symbol"
|
||||
>>> expr = t + sin(x + 1) + cos(x + y + E)
|
||||
|
||||
>>> epath(path, expr)
|
||||
[x, x, y]
|
||||
>>> epath(path, expr, lambda expr: 2*expr)
|
||||
t + sin(2*x + 1) + cos(2*x + 2*y + E)
|
||||
|
||||
"""
|
||||
_epath = EPath(path)
|
||||
|
||||
if expr is None:
|
||||
return _epath
|
||||
if func is None:
|
||||
return _epath.select(expr)
|
||||
else:
|
||||
return _epath.apply(expr, func, args, kwargs)
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,493 @@
|
||||
from sympy.core import Function, S, Mul, Pow, Add
|
||||
from sympy.core.sorting import ordered, default_sort_key
|
||||
from sympy.core.function import expand_func
|
||||
from sympy.core.symbol import Dummy
|
||||
from sympy.functions import gamma, sqrt, sin
|
||||
from sympy.polys import factor, cancel
|
||||
from sympy.utilities.iterables import sift, uniq
|
||||
|
||||
|
||||
def gammasimp(expr):
|
||||
r"""
|
||||
Simplify expressions with gamma functions.
|
||||
|
||||
Explanation
|
||||
===========
|
||||
|
||||
This function takes as input an expression containing gamma
|
||||
functions or functions that can be rewritten in terms of gamma
|
||||
functions and tries to minimize the number of those functions and
|
||||
reduce the size of their arguments.
|
||||
|
||||
The algorithm works by rewriting all gamma functions as expressions
|
||||
involving rising factorials (Pochhammer symbols) and applies
|
||||
recurrence relations and other transformations applicable to rising
|
||||
factorials, to reduce their arguments, possibly letting the resulting
|
||||
rising factorial to cancel. Rising factorials with the second argument
|
||||
being an integer are expanded into polynomial forms and finally all
|
||||
other rising factorial are rewritten in terms of gamma functions.
|
||||
|
||||
Then the following two steps are performed.
|
||||
|
||||
1. Reduce the number of gammas by applying the reflection theorem
|
||||
gamma(x)*gamma(1-x) == pi/sin(pi*x).
|
||||
2. Reduce the number of gammas by applying the multiplication theorem
|
||||
gamma(x)*gamma(x+1/n)*...*gamma(x+(n-1)/n) == C*gamma(n*x).
|
||||
|
||||
It then reduces the number of prefactors by absorbing them into gammas
|
||||
where possible and expands gammas with rational argument.
|
||||
|
||||
All transformation rules can be found (or were derived from) here:
|
||||
|
||||
.. [1] https://functions.wolfram.com/GammaBetaErf/Pochhammer/17/01/02/
|
||||
.. [2] https://functions.wolfram.com/GammaBetaErf/Pochhammer/27/01/0005/
|
||||
|
||||
Examples
|
||||
========
|
||||
|
||||
>>> from sympy.simplify import gammasimp
|
||||
>>> from sympy import gamma, Symbol
|
||||
>>> from sympy.abc import x
|
||||
>>> n = Symbol('n', integer = True)
|
||||
|
||||
>>> gammasimp(gamma(x)/gamma(x - 3))
|
||||
(x - 3)*(x - 2)*(x - 1)
|
||||
>>> gammasimp(gamma(n + 3))
|
||||
gamma(n + 3)
|
||||
|
||||
"""
|
||||
|
||||
expr = expr.rewrite(gamma)
|
||||
|
||||
# compute_ST will be looking for Functions and we don't want
|
||||
# it looking for non-gamma functions: issue 22606
|
||||
# so we mask free, non-gamma functions
|
||||
f = expr.atoms(Function)
|
||||
# take out gammas
|
||||
gammas = {i for i in f if isinstance(i, gamma)}
|
||||
if not gammas:
|
||||
return expr # avoid side effects like factoring
|
||||
f -= gammas
|
||||
# keep only those without bound symbols
|
||||
f = f & expr.as_dummy().atoms(Function)
|
||||
if f:
|
||||
dum, fun, simp = zip(*[
|
||||
(Dummy(), fi, fi.func(*[
|
||||
_gammasimp(a, as_comb=False) for a in fi.args]))
|
||||
for fi in ordered(f)])
|
||||
d = expr.xreplace(dict(zip(fun, dum)))
|
||||
return _gammasimp(d, as_comb=False).xreplace(dict(zip(dum, simp)))
|
||||
|
||||
return _gammasimp(expr, as_comb=False)
|
||||
|
||||
|
||||
def _gammasimp(expr, as_comb):
|
||||
"""
|
||||
Helper function for gammasimp and combsimp.
|
||||
|
||||
Explanation
|
||||
===========
|
||||
|
||||
Simplifies expressions written in terms of gamma function. If
|
||||
as_comb is True, it tries to preserve integer arguments. See
|
||||
docstring of gammasimp for more information. This was part of
|
||||
combsimp() in combsimp.py.
|
||||
"""
|
||||
expr = expr.replace(gamma,
|
||||
lambda n: _rf(1, (n - 1).expand()))
|
||||
|
||||
if as_comb:
|
||||
expr = expr.replace(_rf,
|
||||
lambda a, b: gamma(b + 1))
|
||||
else:
|
||||
expr = expr.replace(_rf,
|
||||
lambda a, b: gamma(a + b)/gamma(a))
|
||||
|
||||
def rule_gamma(expr, level=0):
|
||||
""" Simplify products of gamma functions further. """
|
||||
|
||||
if expr.is_Atom:
|
||||
return expr
|
||||
|
||||
def gamma_rat(x):
|
||||
# helper to simplify ratios of gammas
|
||||
was = x.count(gamma)
|
||||
xx = x.replace(gamma, lambda n: _rf(1, (n - 1).expand()
|
||||
).replace(_rf, lambda a, b: gamma(a + b)/gamma(a)))
|
||||
if xx.count(gamma) < was:
|
||||
x = xx
|
||||
return x
|
||||
|
||||
def gamma_factor(x):
|
||||
# return True if there is a gamma factor in shallow args
|
||||
if isinstance(x, gamma):
|
||||
return True
|
||||
if x.is_Add or x.is_Mul:
|
||||
return any(gamma_factor(xi) for xi in x.args)
|
||||
if x.is_Pow and (x.exp.is_integer or x.base.is_positive):
|
||||
return gamma_factor(x.base)
|
||||
return False
|
||||
|
||||
# recursion step
|
||||
if level == 0:
|
||||
expr = expr.func(*[rule_gamma(x, level + 1) for x in expr.args])
|
||||
level += 1
|
||||
|
||||
if not expr.is_Mul:
|
||||
return expr
|
||||
|
||||
# non-commutative step
|
||||
if level == 1:
|
||||
args, nc = expr.args_cnc()
|
||||
if not args:
|
||||
return expr
|
||||
if nc:
|
||||
return rule_gamma(Mul._from_args(args), level + 1)*Mul._from_args(nc)
|
||||
level += 1
|
||||
|
||||
# pure gamma handling, not factor absorption
|
||||
if level == 2:
|
||||
T, F = sift(expr.args, gamma_factor, binary=True)
|
||||
gamma_ind = Mul(*F)
|
||||
d = Mul(*T)
|
||||
|
||||
nd, dd = d.as_numer_denom()
|
||||
for ipass in range(2):
|
||||
args = list(ordered(Mul.make_args(nd)))
|
||||
for i, ni in enumerate(args):
|
||||
if ni.is_Add:
|
||||
ni, dd = Add(*[
|
||||
rule_gamma(gamma_rat(a/dd), level + 1) for a in ni.args]
|
||||
).as_numer_denom()
|
||||
args[i] = ni
|
||||
if not dd.has(gamma):
|
||||
break
|
||||
nd = Mul(*args)
|
||||
if ipass == 0 and not gamma_factor(nd):
|
||||
break
|
||||
nd, dd = dd, nd # now process in reversed order
|
||||
expr = gamma_ind*nd/dd
|
||||
if not (expr.is_Mul and (gamma_factor(dd) or gamma_factor(nd))):
|
||||
return expr
|
||||
level += 1
|
||||
|
||||
# iteration until constant
|
||||
if level == 3:
|
||||
while True:
|
||||
was = expr
|
||||
expr = rule_gamma(expr, 4)
|
||||
if expr == was:
|
||||
return expr
|
||||
|
||||
numer_gammas = []
|
||||
denom_gammas = []
|
||||
numer_others = []
|
||||
denom_others = []
|
||||
def explicate(p):
|
||||
if p is S.One:
|
||||
return None, []
|
||||
b, e = p.as_base_exp()
|
||||
if e.is_Integer:
|
||||
if isinstance(b, gamma):
|
||||
return True, [b.args[0]]*e
|
||||
else:
|
||||
return False, [b]*e
|
||||
else:
|
||||
return False, [p]
|
||||
|
||||
newargs = list(ordered(expr.args))
|
||||
while newargs:
|
||||
n, d = newargs.pop().as_numer_denom()
|
||||
isg, l = explicate(n)
|
||||
if isg:
|
||||
numer_gammas.extend(l)
|
||||
elif isg is False:
|
||||
numer_others.extend(l)
|
||||
isg, l = explicate(d)
|
||||
if isg:
|
||||
denom_gammas.extend(l)
|
||||
elif isg is False:
|
||||
denom_others.extend(l)
|
||||
|
||||
# =========== level 2 work: pure gamma manipulation =========
|
||||
|
||||
if not as_comb:
|
||||
# Try to reduce the number of gamma factors by applying the
|
||||
# reflection formula gamma(x)*gamma(1-x) = pi/sin(pi*x)
|
||||
for gammas, numer, denom in [(
|
||||
numer_gammas, numer_others, denom_others),
|
||||
(denom_gammas, denom_others, numer_others)]:
|
||||
new = []
|
||||
while gammas:
|
||||
g1 = gammas.pop()
|
||||
if g1.is_integer:
|
||||
new.append(g1)
|
||||
continue
|
||||
for i, g2 in enumerate(gammas):
|
||||
n = g1 + g2 - 1
|
||||
if not n.is_Integer:
|
||||
continue
|
||||
numer.append(S.Pi)
|
||||
denom.append(sin(S.Pi*g1))
|
||||
gammas.pop(i)
|
||||
if n > 0:
|
||||
numer.extend(1 - g1 + k for k in range(n))
|
||||
elif n < 0:
|
||||
denom.extend(-g1 - k for k in range(-n))
|
||||
break
|
||||
else:
|
||||
new.append(g1)
|
||||
# /!\ updating IN PLACE
|
||||
gammas[:] = new
|
||||
|
||||
# Try to reduce the number of gammas by using the duplication
|
||||
# theorem to cancel an upper and lower: gamma(2*s)/gamma(s) =
|
||||
# 2**(2*s + 1)/(4*sqrt(pi))*gamma(s + 1/2). Although this could
|
||||
# be done with higher argument ratios like gamma(3*x)/gamma(x),
|
||||
# this would not reduce the number of gammas as in this case.
|
||||
for ng, dg, no, do in [(numer_gammas, denom_gammas, numer_others,
|
||||
denom_others),
|
||||
(denom_gammas, numer_gammas, denom_others,
|
||||
numer_others)]:
|
||||
|
||||
while True:
|
||||
for x in ng:
|
||||
for y in dg:
|
||||
n = x - 2*y
|
||||
if n.is_Integer:
|
||||
break
|
||||
else:
|
||||
continue
|
||||
break
|
||||
else:
|
||||
break
|
||||
ng.remove(x)
|
||||
dg.remove(y)
|
||||
if n > 0:
|
||||
no.extend(2*y + k for k in range(n))
|
||||
elif n < 0:
|
||||
do.extend(2*y - 1 - k for k in range(-n))
|
||||
ng.append(y + S.Half)
|
||||
no.append(2**(2*y - 1))
|
||||
do.append(sqrt(S.Pi))
|
||||
|
||||
# Try to reduce the number of gamma factors by applying the
|
||||
# multiplication theorem (used when n gammas with args differing
|
||||
# by 1/n mod 1 are encountered).
|
||||
#
|
||||
# run of 2 with args differing by 1/2
|
||||
#
|
||||
# >>> gammasimp(gamma(x)*gamma(x+S.Half))
|
||||
# 2*sqrt(2)*2**(-2*x - 1/2)*sqrt(pi)*gamma(2*x)
|
||||
#
|
||||
# run of 3 args differing by 1/3 (mod 1)
|
||||
#
|
||||
# >>> gammasimp(gamma(x)*gamma(x+S(1)/3)*gamma(x+S(2)/3))
|
||||
# 6*3**(-3*x - 1/2)*pi*gamma(3*x)
|
||||
# >>> gammasimp(gamma(x)*gamma(x+S(1)/3)*gamma(x+S(5)/3))
|
||||
# 2*3**(-3*x - 1/2)*pi*(3*x + 2)*gamma(3*x)
|
||||
#
|
||||
def _run(coeffs):
|
||||
# find runs in coeffs such that the difference in terms (mod 1)
|
||||
# of t1, t2, ..., tn is 1/n
|
||||
u = list(uniq(coeffs))
|
||||
for i in range(len(u)):
|
||||
dj = ([((u[j] - u[i]) % 1, j) for j in range(i + 1, len(u))])
|
||||
for one, j in dj:
|
||||
if one.p == 1 and one.q != 1:
|
||||
n = one.q
|
||||
got = [i]
|
||||
get = list(range(1, n))
|
||||
for d, j in dj:
|
||||
m = n*d
|
||||
if m.is_Integer and m in get:
|
||||
get.remove(m)
|
||||
got.append(j)
|
||||
if not get:
|
||||
break
|
||||
else:
|
||||
continue
|
||||
for i, j in enumerate(got):
|
||||
c = u[j]
|
||||
coeffs.remove(c)
|
||||
got[i] = c
|
||||
return one.q, got[0], got[1:]
|
||||
|
||||
def _mult_thm(gammas, numer, denom):
|
||||
# pull off and analyze the leading coefficient from each gamma arg
|
||||
# looking for runs in those Rationals
|
||||
|
||||
# expr -> coeff + resid -> rats[resid] = coeff
|
||||
rats = {}
|
||||
for g in gammas:
|
||||
c, resid = g.as_coeff_Add()
|
||||
rats.setdefault(resid, []).append(c)
|
||||
|
||||
# look for runs in Rationals for each resid
|
||||
keys = sorted(rats, key=default_sort_key)
|
||||
for resid in keys:
|
||||
coeffs = sorted(rats[resid])
|
||||
new = []
|
||||
while True:
|
||||
run = _run(coeffs)
|
||||
if run is None:
|
||||
break
|
||||
|
||||
# process the sequence that was found:
|
||||
# 1) convert all the gamma functions to have the right
|
||||
# argument (could be off by an integer)
|
||||
# 2) append the factors corresponding to the theorem
|
||||
# 3) append the new gamma function
|
||||
|
||||
n, ui, other = run
|
||||
|
||||
# (1)
|
||||
for u in other:
|
||||
con = resid + u - 1
|
||||
for k in range(int(u - ui)):
|
||||
numer.append(con - k)
|
||||
|
||||
con = n*(resid + ui) # for (2) and (3)
|
||||
|
||||
# (2)
|
||||
numer.append((2*S.Pi)**(S(n - 1)/2)*
|
||||
n**(S.Half - con))
|
||||
# (3)
|
||||
new.append(con)
|
||||
|
||||
# restore resid to coeffs
|
||||
rats[resid] = [resid + c for c in coeffs] + new
|
||||
|
||||
# rebuild the gamma arguments
|
||||
g = []
|
||||
for resid in keys:
|
||||
g += rats[resid]
|
||||
# /!\ updating IN PLACE
|
||||
gammas[:] = g
|
||||
|
||||
for l, numer, denom in [(numer_gammas, numer_others, denom_others),
|
||||
(denom_gammas, denom_others, numer_others)]:
|
||||
_mult_thm(l, numer, denom)
|
||||
|
||||
# =========== level >= 2 work: factor absorption =========
|
||||
|
||||
if level >= 2:
|
||||
# Try to absorb factors into the gammas: x*gamma(x) -> gamma(x + 1)
|
||||
# and gamma(x)/(x - 1) -> gamma(x - 1)
|
||||
# This code (in particular repeated calls to find_fuzzy) can be very
|
||||
# slow.
|
||||
def find_fuzzy(l, x):
|
||||
if not l:
|
||||
return
|
||||
S1, T1 = compute_ST(x)
|
||||
for y in l:
|
||||
S2, T2 = inv[y]
|
||||
if T1 != T2 or (not S1.intersection(S2) and
|
||||
(S1 != set() or S2 != set())):
|
||||
continue
|
||||
# XXX we want some simplification (e.g. cancel or
|
||||
# simplify) but no matter what it's slow.
|
||||
a = len(cancel(x/y).free_symbols)
|
||||
b = len(x.free_symbols)
|
||||
c = len(y.free_symbols)
|
||||
# TODO is there a better heuristic?
|
||||
if a == 0 and (b > 0 or c > 0):
|
||||
return y
|
||||
|
||||
# We thus try to avoid expensive calls by building the following
|
||||
# "invariants": For every factor or gamma function argument
|
||||
# - the set of free symbols S
|
||||
# - the set of functional components T
|
||||
# We will only try to absorb if T1==T2 and (S1 intersect S2 != emptyset
|
||||
# or S1 == S2 == emptyset)
|
||||
inv = {}
|
||||
|
||||
def compute_ST(expr):
|
||||
if expr in inv:
|
||||
return inv[expr]
|
||||
return (expr.free_symbols, expr.atoms(Function).union(
|
||||
{e.exp for e in expr.atoms(Pow)}))
|
||||
|
||||
def update_ST(expr):
|
||||
inv[expr] = compute_ST(expr)
|
||||
for expr in numer_gammas + denom_gammas + numer_others + denom_others:
|
||||
update_ST(expr)
|
||||
|
||||
for gammas, numer, denom in [(
|
||||
numer_gammas, numer_others, denom_others),
|
||||
(denom_gammas, denom_others, numer_others)]:
|
||||
new = []
|
||||
while gammas:
|
||||
g = gammas.pop()
|
||||
cont = True
|
||||
while cont:
|
||||
cont = False
|
||||
y = find_fuzzy(numer, g)
|
||||
if y is not None:
|
||||
numer.remove(y)
|
||||
if y != g:
|
||||
numer.append(y/g)
|
||||
update_ST(y/g)
|
||||
g += 1
|
||||
cont = True
|
||||
y = find_fuzzy(denom, g - 1)
|
||||
if y is not None:
|
||||
denom.remove(y)
|
||||
if y != g - 1:
|
||||
numer.append((g - 1)/y)
|
||||
update_ST((g - 1)/y)
|
||||
g -= 1
|
||||
cont = True
|
||||
new.append(g)
|
||||
# /!\ updating IN PLACE
|
||||
gammas[:] = new
|
||||
|
||||
# =========== rebuild expr ==================================
|
||||
|
||||
return Mul(*[gamma(g) for g in numer_gammas]) \
|
||||
/ Mul(*[gamma(g) for g in denom_gammas]) \
|
||||
* Mul(*numer_others) / Mul(*denom_others)
|
||||
|
||||
was = factor(expr)
|
||||
# (for some reason we cannot use Basic.replace in this case)
|
||||
expr = rule_gamma(was)
|
||||
if expr != was:
|
||||
expr = factor(expr)
|
||||
|
||||
expr = expr.replace(gamma,
|
||||
lambda n: expand_func(gamma(n)) if n.is_Rational else gamma(n))
|
||||
|
||||
return expr
|
||||
|
||||
|
||||
class _rf(Function):
|
||||
@classmethod
|
||||
def eval(cls, a, b):
|
||||
if b.is_Integer:
|
||||
if not b:
|
||||
return S.One
|
||||
|
||||
n = int(b)
|
||||
|
||||
if n > 0:
|
||||
return Mul(*[a + i for i in range(n)])
|
||||
elif n < 0:
|
||||
return 1/Mul(*[a - i for i in range(1, -n + 1)])
|
||||
else:
|
||||
if b.is_Add:
|
||||
c, _b = b.as_coeff_Add()
|
||||
|
||||
if c.is_Integer:
|
||||
if c > 0:
|
||||
return _rf(a, _b)*_rf(a + _b, c)
|
||||
elif c < 0:
|
||||
return _rf(a, _b)/_rf(a + _b + c, -c)
|
||||
|
||||
if a.is_Add:
|
||||
c, _a = a.as_coeff_Add()
|
||||
|
||||
if c.is_Integer:
|
||||
if c > 0:
|
||||
return _rf(_a, b)*_rf(_a + b, c)/_rf(_a, c)
|
||||
elif c < 0:
|
||||
return _rf(_a, b)*_rf(_a + c, -c)/_rf(_a + b + c, -c)
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,18 @@
|
||||
""" This module cooks up a docstring when imported. Its only purpose is to
|
||||
be displayed in the sphinx documentation. """
|
||||
|
||||
from sympy.core.relational import Eq
|
||||
from sympy.functions.special.hyper import hyper
|
||||
from sympy.printing.latex import latex
|
||||
from sympy.simplify.hyperexpand import FormulaCollection
|
||||
|
||||
c = FormulaCollection()
|
||||
|
||||
doc = ""
|
||||
|
||||
for f in c.formulae:
|
||||
obj = Eq(hyper(f.func.ap, f.func.bq, f.z),
|
||||
f.closed_form.rewrite('nonrepsmall'))
|
||||
doc += ".. math::\n %s\n" % latex(obj)
|
||||
|
||||
__doc__ = doc
|
||||
@@ -0,0 +1,718 @@
|
||||
from collections import defaultdict
|
||||
from functools import reduce
|
||||
from math import prod
|
||||
|
||||
from sympy.core.function import expand_log, count_ops, _coeff_isneg
|
||||
from sympy.core import sympify, Basic, Dummy, S, Add, Mul, Pow, expand_mul, factor_terms
|
||||
from sympy.core.sorting import ordered, default_sort_key
|
||||
from sympy.core.numbers import Integer, Rational, equal_valued
|
||||
from sympy.core.mul import _keep_coeff
|
||||
from sympy.core.rules import Transform
|
||||
from sympy.functions import exp_polar, exp, log, root, polarify, unpolarify
|
||||
from sympy.matrices.expressions.matexpr import MatrixSymbol
|
||||
from sympy.polys import lcm, gcd
|
||||
from sympy.ntheory.factor_ import multiplicity
|
||||
|
||||
|
||||
|
||||
def powsimp(expr, deep=False, combine='all', force=False, measure=count_ops):
|
||||
"""
|
||||
Reduce expression by combining powers with similar bases and exponents.
|
||||
|
||||
Explanation
|
||||
===========
|
||||
|
||||
If ``deep`` is ``True`` then powsimp() will also simplify arguments of
|
||||
functions. By default ``deep`` is set to ``False``.
|
||||
|
||||
If ``force`` is ``True`` then bases will be combined without checking for
|
||||
assumptions, e.g. sqrt(x)*sqrt(y) -> sqrt(x*y) which is not true
|
||||
if x and y are both negative.
|
||||
|
||||
You can make powsimp() only combine bases or only combine exponents by
|
||||
changing combine='base' or combine='exp'. By default, combine='all',
|
||||
which does both. combine='base' will only combine::
|
||||
|
||||
a a a 2x x
|
||||
x * y => (x*y) as well as things like 2 => 4
|
||||
|
||||
and combine='exp' will only combine
|
||||
::
|
||||
|
||||
a b (a + b)
|
||||
x * x => x
|
||||
|
||||
combine='exp' will strictly only combine exponents in the way that used
|
||||
to be automatic. Also use deep=True if you need the old behavior.
|
||||
|
||||
When combine='all', 'exp' is evaluated first. Consider the first
|
||||
example below for when there could be an ambiguity relating to this.
|
||||
This is done so things like the second example can be completely
|
||||
combined. If you want 'base' combined first, do something like
|
||||
powsimp(powsimp(expr, combine='base'), combine='exp').
|
||||
|
||||
Examples
|
||||
========
|
||||
|
||||
>>> from sympy import powsimp, exp, log, symbols
|
||||
>>> from sympy.abc import x, y, z, n
|
||||
>>> powsimp(x**y*x**z*y**z, combine='all')
|
||||
x**(y + z)*y**z
|
||||
>>> powsimp(x**y*x**z*y**z, combine='exp')
|
||||
x**(y + z)*y**z
|
||||
>>> powsimp(x**y*x**z*y**z, combine='base', force=True)
|
||||
x**y*(x*y)**z
|
||||
|
||||
>>> powsimp(x**z*x**y*n**z*n**y, combine='all', force=True)
|
||||
(n*x)**(y + z)
|
||||
>>> powsimp(x**z*x**y*n**z*n**y, combine='exp')
|
||||
n**(y + z)*x**(y + z)
|
||||
>>> powsimp(x**z*x**y*n**z*n**y, combine='base', force=True)
|
||||
(n*x)**y*(n*x)**z
|
||||
|
||||
>>> x, y = symbols('x y', positive=True)
|
||||
>>> powsimp(log(exp(x)*exp(y)))
|
||||
log(exp(x)*exp(y))
|
||||
>>> powsimp(log(exp(x)*exp(y)), deep=True)
|
||||
x + y
|
||||
|
||||
Radicals with Mul bases will be combined if combine='exp'
|
||||
|
||||
>>> from sympy import sqrt
|
||||
>>> x, y = symbols('x y')
|
||||
|
||||
Two radicals are automatically joined through Mul:
|
||||
|
||||
>>> a=sqrt(x*sqrt(y))
|
||||
>>> a*a**3 == a**4
|
||||
True
|
||||
|
||||
But if an integer power of that radical has been
|
||||
autoexpanded then Mul does not join the resulting factors:
|
||||
|
||||
>>> a**4 # auto expands to a Mul, no longer a Pow
|
||||
x**2*y
|
||||
>>> _*a # so Mul doesn't combine them
|
||||
x**2*y*sqrt(x*sqrt(y))
|
||||
>>> powsimp(_) # but powsimp will
|
||||
(x*sqrt(y))**(5/2)
|
||||
>>> powsimp(x*y*a) # but won't when doing so would violate assumptions
|
||||
x*y*sqrt(x*sqrt(y))
|
||||
|
||||
"""
|
||||
def recurse(arg, **kwargs):
|
||||
_deep = kwargs.get('deep', deep)
|
||||
_combine = kwargs.get('combine', combine)
|
||||
_force = kwargs.get('force', force)
|
||||
_measure = kwargs.get('measure', measure)
|
||||
return powsimp(arg, _deep, _combine, _force, _measure)
|
||||
|
||||
expr = sympify(expr)
|
||||
|
||||
if (not isinstance(expr, Basic) or isinstance(expr, MatrixSymbol) or (
|
||||
expr.is_Atom or expr in (exp_polar(0), exp_polar(1)))):
|
||||
return expr
|
||||
|
||||
if deep or expr.is_Add or expr.is_Mul and _y not in expr.args:
|
||||
expr = expr.func(*[recurse(w) for w in expr.args])
|
||||
|
||||
if expr.is_Pow:
|
||||
return recurse(expr*_y, deep=False)/_y
|
||||
|
||||
if not expr.is_Mul:
|
||||
return expr
|
||||
|
||||
# handle the Mul
|
||||
if combine in ('exp', 'all'):
|
||||
# Collect base/exp data, while maintaining order in the
|
||||
# non-commutative parts of the product
|
||||
c_powers = defaultdict(list)
|
||||
nc_part = []
|
||||
newexpr = []
|
||||
coeff = S.One
|
||||
for term in expr.args:
|
||||
if term.is_Rational:
|
||||
coeff *= term
|
||||
continue
|
||||
if term.is_Pow:
|
||||
term = _denest_pow(term)
|
||||
if term.is_commutative:
|
||||
b, e = term.as_base_exp()
|
||||
if deep:
|
||||
b, e = [recurse(i) for i in [b, e]]
|
||||
if b.is_Pow or isinstance(b, exp):
|
||||
# don't let smthg like sqrt(x**a) split into x**a, 1/2
|
||||
# or else it will be joined as x**(a/2) later
|
||||
b, e = b**e, S.One
|
||||
c_powers[b].append(e)
|
||||
else:
|
||||
# This is the logic that combines exponents for equal,
|
||||
# but non-commutative bases: A**x*A**y == A**(x+y).
|
||||
if nc_part:
|
||||
b1, e1 = nc_part[-1].as_base_exp()
|
||||
b2, e2 = term.as_base_exp()
|
||||
if (b1 == b2 and
|
||||
e1.is_commutative and e2.is_commutative):
|
||||
nc_part[-1] = Pow(b1, Add(e1, e2))
|
||||
continue
|
||||
nc_part.append(term)
|
||||
|
||||
# add up exponents of common bases
|
||||
for b, e in ordered(iter(c_powers.items())):
|
||||
# allow 2**x/4 -> 2**(x - 2); don't do this when b and e are
|
||||
# Numbers since autoevaluation will undo it, e.g.
|
||||
# 2**(1/3)/4 -> 2**(1/3 - 2) -> 2**(1/3)/4
|
||||
if (b and b.is_Rational and not all(ei.is_Number for ei in e) and \
|
||||
coeff is not S.One and
|
||||
b not in (S.One, S.NegativeOne)):
|
||||
m = multiplicity(abs(b), abs(coeff))
|
||||
if m:
|
||||
e.append(m)
|
||||
coeff /= b**m
|
||||
c_powers[b] = Add(*e)
|
||||
if coeff is not S.One:
|
||||
if coeff in c_powers:
|
||||
c_powers[coeff] += S.One
|
||||
else:
|
||||
c_powers[coeff] = S.One
|
||||
|
||||
# convert to plain dictionary
|
||||
c_powers = dict(c_powers)
|
||||
|
||||
# check for base and inverted base pairs
|
||||
be = list(c_powers.items())
|
||||
skip = set() # skip if we already saw them
|
||||
for b, e in be:
|
||||
if b in skip:
|
||||
continue
|
||||
bpos = b.is_positive or b.is_polar
|
||||
if bpos:
|
||||
binv = 1/b
|
||||
#Special case for float 1
|
||||
if b.is_Float and equal_valued(b, 1):
|
||||
c_powers[b] = S.One
|
||||
continue
|
||||
if b != binv and binv in c_powers:
|
||||
if b.as_numer_denom()[0] is S.One:
|
||||
c_powers.pop(b)
|
||||
c_powers[binv] -= e
|
||||
else:
|
||||
skip.add(binv)
|
||||
e = c_powers.pop(binv)
|
||||
c_powers[b] -= e
|
||||
|
||||
# check for base and negated base pairs
|
||||
be = list(c_powers.items())
|
||||
_n = S.NegativeOne
|
||||
for b, e in be:
|
||||
if (b.is_Symbol or b.is_Add) and -b in c_powers and b in c_powers:
|
||||
if (b.is_positive is not None or e.is_integer):
|
||||
if e.is_integer or b.is_negative:
|
||||
c_powers[-b] += c_powers.pop(b)
|
||||
else: # (-b).is_positive so use its e
|
||||
e = c_powers.pop(-b)
|
||||
c_powers[b] += e
|
||||
if _n in c_powers:
|
||||
c_powers[_n] += e
|
||||
else:
|
||||
c_powers[_n] = e
|
||||
|
||||
# filter c_powers and convert to a list
|
||||
c_powers = [(b, e) for b, e in c_powers.items() if e]
|
||||
|
||||
# ==============================================================
|
||||
# check for Mul bases of Rational powers that can be combined with
|
||||
# separated bases, e.g. x*sqrt(x*y)*sqrt(x*sqrt(x*y)) ->
|
||||
# (x*sqrt(x*y))**(3/2)
|
||||
# ---------------- helper functions
|
||||
|
||||
def ratq(x):
|
||||
'''Return Rational part of x's exponent as it appears in the bkey.
|
||||
'''
|
||||
return bkey(x)[0][1]
|
||||
|
||||
def bkey(b, e=None):
|
||||
'''Return (b**s, c.q), c.p where e -> c*s. If e is not given then
|
||||
it will be taken by using as_base_exp() on the input b.
|
||||
e.g.
|
||||
x**3/2 -> (x, 2), 3
|
||||
x**y -> (x**y, 1), 1
|
||||
x**(2*y/3) -> (x**y, 3), 2
|
||||
exp(x/2) -> (exp(a), 2), 1
|
||||
|
||||
'''
|
||||
if e is not None: # coming from c_powers or from below
|
||||
if e.is_Integer:
|
||||
return (b, S.One), e
|
||||
elif e.is_Rational:
|
||||
return (b, Integer(e.q)), Integer(e.p)
|
||||
else:
|
||||
c, m = e.as_coeff_Mul(rational=True)
|
||||
if c is not S.One:
|
||||
if m.is_integer:
|
||||
return (b, Integer(c.q)), m*Integer(c.p)
|
||||
return (b**m, Integer(c.q)), Integer(c.p)
|
||||
else:
|
||||
return (b**e, S.One), S.One
|
||||
else:
|
||||
return bkey(*b.as_base_exp())
|
||||
|
||||
def update(b):
|
||||
'''Decide what to do with base, b. If its exponent is now an
|
||||
integer multiple of the Rational denominator, then remove it
|
||||
and put the factors of its base in the common_b dictionary or
|
||||
update the existing bases if necessary. If it has been zeroed
|
||||
out, simply remove the base.
|
||||
'''
|
||||
newe, r = divmod(common_b[b], b[1])
|
||||
if not r:
|
||||
common_b.pop(b)
|
||||
if newe:
|
||||
for m in Mul.make_args(b[0]**newe):
|
||||
b, e = bkey(m)
|
||||
if b not in common_b:
|
||||
common_b[b] = 0
|
||||
common_b[b] += e
|
||||
if b[1] != 1:
|
||||
bases.append(b)
|
||||
# ---------------- end of helper functions
|
||||
|
||||
# assemble a dictionary of the factors having a Rational power
|
||||
common_b = {}
|
||||
done = []
|
||||
bases = []
|
||||
for b, e in c_powers:
|
||||
b, e = bkey(b, e)
|
||||
if b in common_b:
|
||||
common_b[b] = common_b[b] + e
|
||||
else:
|
||||
common_b[b] = e
|
||||
if b[1] != 1 and b[0].is_Mul:
|
||||
bases.append(b)
|
||||
bases.sort(key=default_sort_key) # this makes tie-breaking canonical
|
||||
bases.sort(key=measure, reverse=True) # handle longest first
|
||||
for base in bases:
|
||||
if base not in common_b: # it may have been removed already
|
||||
continue
|
||||
b, exponent = base
|
||||
last = False # True when no factor of base is a radical
|
||||
qlcm = 1 # the lcm of the radical denominators
|
||||
while True:
|
||||
bstart = b
|
||||
qstart = qlcm
|
||||
|
||||
bb = [] # list of factors
|
||||
ee = [] # (factor's expo. and it's current value in common_b)
|
||||
for bi in Mul.make_args(b):
|
||||
bib, bie = bkey(bi)
|
||||
if bib not in common_b or common_b[bib] < bie:
|
||||
ee = bb = [] # failed
|
||||
break
|
||||
ee.append([bie, common_b[bib]])
|
||||
bb.append(bib)
|
||||
if ee:
|
||||
# find the number of integral extractions possible
|
||||
# e.g. [(1, 2), (2, 2)] -> min(2/1, 2/2) -> 1
|
||||
min1 = ee[0][1]//ee[0][0]
|
||||
for i in range(1, len(ee)):
|
||||
rat = ee[i][1]//ee[i][0]
|
||||
if rat < 1:
|
||||
break
|
||||
min1 = min(min1, rat)
|
||||
else:
|
||||
# update base factor counts
|
||||
# e.g. if ee = [(2, 5), (3, 6)] then min1 = 2
|
||||
# and the new base counts will be 5-2*2 and 6-2*3
|
||||
for i in range(len(bb)):
|
||||
common_b[bb[i]] -= min1*ee[i][0]
|
||||
update(bb[i])
|
||||
# update the count of the base
|
||||
# e.g. x**2*y*sqrt(x*sqrt(y)) the count of x*sqrt(y)
|
||||
# will increase by 4 to give bkey (x*sqrt(y), 2, 5)
|
||||
common_b[base] += min1*qstart*exponent
|
||||
if (last # no more radicals in base
|
||||
or len(common_b) == 1 # nothing left to join with
|
||||
or all(k[1] == 1 for k in common_b) # no rad's in common_b
|
||||
):
|
||||
break
|
||||
# see what we can exponentiate base by to remove any radicals
|
||||
# so we know what to search for
|
||||
# e.g. if base were x**(1/2)*y**(1/3) then we should
|
||||
# exponentiate by 6 and look for powers of x and y in the ratio
|
||||
# of 2 to 3
|
||||
qlcm = lcm([ratq(bi) for bi in Mul.make_args(bstart)])
|
||||
if qlcm == 1:
|
||||
break # we are done
|
||||
b = bstart**qlcm
|
||||
qlcm *= qstart
|
||||
if all(ratq(bi) == 1 for bi in Mul.make_args(b)):
|
||||
last = True # we are going to be done after this next pass
|
||||
# this base no longer can find anything to join with and
|
||||
# since it was longer than any other we are done with it
|
||||
b, q = base
|
||||
done.append((b, common_b.pop(base)*Rational(1, q)))
|
||||
|
||||
# update c_powers and get ready to continue with powsimp
|
||||
c_powers = done
|
||||
# there may be terms still in common_b that were bases that were
|
||||
# identified as needing processing, so remove those, too
|
||||
for (b, q), e in common_b.items():
|
||||
if (b.is_Pow or isinstance(b, exp)) and \
|
||||
q is not S.One and not b.exp.is_Rational:
|
||||
b, be = b.as_base_exp()
|
||||
b = b**(be/q)
|
||||
else:
|
||||
b = root(b, q)
|
||||
c_powers.append((b, e))
|
||||
check = len(c_powers)
|
||||
c_powers = dict(c_powers)
|
||||
assert len(c_powers) == check # there should have been no duplicates
|
||||
# ==============================================================
|
||||
|
||||
# rebuild the expression
|
||||
newexpr = expr.func(*(newexpr + [Pow(b, e) for b, e in c_powers.items()]))
|
||||
if combine == 'exp':
|
||||
return expr.func(newexpr, expr.func(*nc_part))
|
||||
else:
|
||||
return recurse(expr.func(*nc_part), combine='base') * \
|
||||
recurse(newexpr, combine='base')
|
||||
|
||||
elif combine == 'base':
|
||||
|
||||
# Build c_powers and nc_part. These must both be lists not
|
||||
# dicts because exp's are not combined.
|
||||
c_powers = []
|
||||
nc_part = []
|
||||
for term in expr.args:
|
||||
if term.is_commutative:
|
||||
c_powers.append(list(term.as_base_exp()))
|
||||
else:
|
||||
nc_part.append(term)
|
||||
|
||||
# Pull out numerical coefficients from exponent if assumptions allow
|
||||
# e.g., 2**(2*x) => 4**x
|
||||
for i in range(len(c_powers)):
|
||||
b, e = c_powers[i]
|
||||
if not (all(x.is_nonnegative for x in b.as_numer_denom()) or e.is_integer or force or b.is_polar):
|
||||
continue
|
||||
exp_c, exp_t = e.as_coeff_Mul(rational=True)
|
||||
if exp_c is not S.One and exp_t is not S.One:
|
||||
c_powers[i] = [Pow(b, exp_c), exp_t]
|
||||
|
||||
# Combine bases whenever they have the same exponent and
|
||||
# assumptions allow
|
||||
# first gather the potential bases under the common exponent
|
||||
c_exp = defaultdict(list)
|
||||
for b, e in c_powers:
|
||||
if deep:
|
||||
e = recurse(e)
|
||||
if e.is_Add and (b.is_positive or e.is_integer):
|
||||
e = factor_terms(e)
|
||||
if _coeff_isneg(e):
|
||||
e = -e
|
||||
b = 1/b
|
||||
c_exp[e].append(b)
|
||||
del c_powers
|
||||
|
||||
# Merge back in the results of the above to form a new product
|
||||
c_powers = defaultdict(list)
|
||||
for e in c_exp:
|
||||
bases = c_exp[e]
|
||||
|
||||
# calculate the new base for e
|
||||
|
||||
if len(bases) == 1:
|
||||
new_base = bases[0]
|
||||
elif e.is_integer or force:
|
||||
new_base = expr.func(*bases)
|
||||
else:
|
||||
# see which ones can be joined
|
||||
unk = []
|
||||
nonneg = []
|
||||
neg = []
|
||||
for bi in bases:
|
||||
if bi.is_negative:
|
||||
neg.append(bi)
|
||||
elif bi.is_nonnegative:
|
||||
nonneg.append(bi)
|
||||
elif bi.is_polar:
|
||||
nonneg.append(
|
||||
bi) # polar can be treated like non-negative
|
||||
else:
|
||||
unk.append(bi)
|
||||
if len(unk) == 1 and not neg or len(neg) == 1 and not unk:
|
||||
# a single neg or a single unk can join the rest
|
||||
nonneg.extend(unk + neg)
|
||||
unk = neg = []
|
||||
elif neg:
|
||||
# their negative signs cancel in groups of 2*q if we know
|
||||
# that e = p/q else we have to treat them as unknown
|
||||
israt = False
|
||||
if e.is_Rational:
|
||||
israt = True
|
||||
else:
|
||||
p, d = e.as_numer_denom()
|
||||
if p.is_integer and d.is_integer:
|
||||
israt = True
|
||||
if israt:
|
||||
neg = [-w for w in neg]
|
||||
unk.extend([S.NegativeOne]*len(neg))
|
||||
else:
|
||||
unk.extend(neg)
|
||||
neg = []
|
||||
del israt
|
||||
|
||||
# these shouldn't be joined
|
||||
for b in unk:
|
||||
c_powers[b].append(e)
|
||||
# here is a new joined base
|
||||
new_base = expr.func(*(nonneg + neg))
|
||||
# if there are positive parts they will just get separated
|
||||
# again unless some change is made
|
||||
|
||||
def _terms(e):
|
||||
# return the number of terms of this expression
|
||||
# when multiplied out -- assuming no joining of terms
|
||||
if e.is_Add:
|
||||
return sum(_terms(ai) for ai in e.args)
|
||||
if e.is_Mul:
|
||||
return prod([_terms(mi) for mi in e.args])
|
||||
return 1
|
||||
xnew_base = expand_mul(new_base, deep=False)
|
||||
if len(Add.make_args(xnew_base)) < _terms(new_base):
|
||||
new_base = factor_terms(xnew_base)
|
||||
|
||||
c_powers[new_base].append(e)
|
||||
|
||||
# break out the powers from c_powers now
|
||||
c_part = [Pow(b, ei) for b, e in c_powers.items() for ei in e]
|
||||
|
||||
# we're done
|
||||
return expr.func(*(c_part + nc_part))
|
||||
|
||||
else:
|
||||
raise ValueError("combine must be one of ('all', 'exp', 'base').")
|
||||
|
||||
|
||||
def powdenest(eq, force=False, polar=False):
|
||||
r"""
|
||||
Collect exponents on powers as assumptions allow.
|
||||
|
||||
Explanation
|
||||
===========
|
||||
|
||||
Given ``(bb**be)**e``, this can be simplified as follows:
|
||||
* if ``bb`` is positive, or
|
||||
* ``e`` is an integer, or
|
||||
* ``|be| < 1`` then this simplifies to ``bb**(be*e)``
|
||||
|
||||
Given a product of powers raised to a power, ``(bb1**be1 *
|
||||
bb2**be2...)**e``, simplification can be done as follows:
|
||||
|
||||
- if e is positive, the gcd of all bei can be joined with e;
|
||||
- all non-negative bb can be separated from those that are negative
|
||||
and their gcd can be joined with e; autosimplification already
|
||||
handles this separation.
|
||||
- integer factors from powers that have integers in the denominator
|
||||
of the exponent can be removed from any term and the gcd of such
|
||||
integers can be joined with e
|
||||
|
||||
Setting ``force`` to ``True`` will make symbols that are not explicitly
|
||||
negative behave as though they are positive, resulting in more
|
||||
denesting.
|
||||
|
||||
Setting ``polar`` to ``True`` will do simplifications on the Riemann surface of
|
||||
the logarithm, also resulting in more denestings.
|
||||
|
||||
When there are sums of logs in exp() then a product of powers may be
|
||||
obtained e.g. ``exp(3*(log(a) + 2*log(b)))`` - > ``a**3*b**6``.
|
||||
|
||||
Examples
|
||||
========
|
||||
|
||||
>>> from sympy.abc import a, b, x, y, z
|
||||
>>> from sympy import Symbol, exp, log, sqrt, symbols, powdenest
|
||||
|
||||
>>> powdenest((x**(2*a/3))**(3*x))
|
||||
(x**(2*a/3))**(3*x)
|
||||
>>> powdenest(exp(3*x*log(2)))
|
||||
2**(3*x)
|
||||
|
||||
Assumptions may prevent expansion:
|
||||
|
||||
>>> powdenest(sqrt(x**2))
|
||||
sqrt(x**2)
|
||||
|
||||
>>> p = symbols('p', positive=True)
|
||||
>>> powdenest(sqrt(p**2))
|
||||
p
|
||||
|
||||
No other expansion is done.
|
||||
|
||||
>>> i, j = symbols('i,j', integer=True)
|
||||
>>> powdenest((x**x)**(i + j)) # -X-> (x**x)**i*(x**x)**j
|
||||
x**(x*(i + j))
|
||||
|
||||
But exp() will be denested by moving all non-log terms outside of
|
||||
the function; this may result in the collapsing of the exp to a power
|
||||
with a different base:
|
||||
|
||||
>>> powdenest(exp(3*y*log(x)))
|
||||
x**(3*y)
|
||||
>>> powdenest(exp(y*(log(a) + log(b))))
|
||||
(a*b)**y
|
||||
>>> powdenest(exp(3*(log(a) + log(b))))
|
||||
a**3*b**3
|
||||
|
||||
If assumptions allow, symbols can also be moved to the outermost exponent:
|
||||
|
||||
>>> i = Symbol('i', integer=True)
|
||||
>>> powdenest(((x**(2*i))**(3*y))**x)
|
||||
((x**(2*i))**(3*y))**x
|
||||
>>> powdenest(((x**(2*i))**(3*y))**x, force=True)
|
||||
x**(6*i*x*y)
|
||||
|
||||
>>> powdenest(((x**(2*a/3))**(3*y/i))**x)
|
||||
((x**(2*a/3))**(3*y/i))**x
|
||||
>>> powdenest((x**(2*i)*y**(4*i))**z, force=True)
|
||||
(x*y**2)**(2*i*z)
|
||||
|
||||
>>> n = Symbol('n', negative=True)
|
||||
|
||||
>>> powdenest((x**i)**y, force=True)
|
||||
x**(i*y)
|
||||
>>> powdenest((n**i)**x, force=True)
|
||||
(n**i)**x
|
||||
|
||||
"""
|
||||
from sympy.simplify.simplify import posify
|
||||
|
||||
if force:
|
||||
def _denest(b, e):
|
||||
if not isinstance(b, (Pow, exp)):
|
||||
return b.is_positive, Pow(b, e, evaluate=False)
|
||||
return _denest(b.base, b.exp*e)
|
||||
reps = []
|
||||
for p in eq.atoms(Pow, exp):
|
||||
if isinstance(p.base, (Pow, exp)):
|
||||
ok, dp = _denest(*p.args)
|
||||
if ok is not False:
|
||||
reps.append((p, dp))
|
||||
if reps:
|
||||
eq = eq.subs(reps)
|
||||
eq, reps = posify(eq)
|
||||
return powdenest(eq, force=False, polar=polar).xreplace(reps)
|
||||
|
||||
if polar:
|
||||
eq, rep = polarify(eq)
|
||||
return unpolarify(powdenest(unpolarify(eq, exponents_only=True)), rep)
|
||||
|
||||
new = powsimp(eq)
|
||||
return new.xreplace(Transform(
|
||||
_denest_pow, filter=lambda m: m.is_Pow or isinstance(m, exp)))
|
||||
|
||||
_y = Dummy('y')
|
||||
|
||||
|
||||
def _denest_pow(eq):
|
||||
"""
|
||||
Denest powers.
|
||||
|
||||
This is a helper function for powdenest that performs the actual
|
||||
transformation.
|
||||
"""
|
||||
from sympy.simplify.simplify import logcombine
|
||||
|
||||
b, e = eq.as_base_exp()
|
||||
if b.is_Pow or isinstance(b, exp) and e != 1:
|
||||
new = b._eval_power(e)
|
||||
if new is not None:
|
||||
eq = new
|
||||
b, e = new.as_base_exp()
|
||||
|
||||
# denest exp with log terms in exponent
|
||||
if b is S.Exp1 and e.is_Mul:
|
||||
logs = []
|
||||
other = []
|
||||
for ei in e.args:
|
||||
if any(isinstance(ai, log) for ai in Add.make_args(ei)):
|
||||
logs.append(ei)
|
||||
else:
|
||||
other.append(ei)
|
||||
logs = logcombine(Mul(*logs))
|
||||
return Pow(exp(logs), Mul(*other))
|
||||
|
||||
_, be = b.as_base_exp()
|
||||
if be is S.One and not (b.is_Mul or
|
||||
b.is_Rational and b.q != 1 or
|
||||
b.is_positive):
|
||||
return eq
|
||||
|
||||
# denest eq which is either pos**e or Pow**e or Mul**e or
|
||||
# Mul(b1**e1, b2**e2)
|
||||
|
||||
# handle polar numbers specially
|
||||
polars, nonpolars = [], []
|
||||
for bb in Mul.make_args(b):
|
||||
if bb.is_polar:
|
||||
polars.append(bb.as_base_exp())
|
||||
else:
|
||||
nonpolars.append(bb)
|
||||
if len(polars) == 1 and not polars[0][0].is_Mul:
|
||||
return Pow(polars[0][0], polars[0][1]*e)*powdenest(Mul(*nonpolars)**e)
|
||||
elif polars:
|
||||
return Mul(*[powdenest(bb**(ee*e)) for (bb, ee) in polars]) \
|
||||
*powdenest(Mul(*nonpolars)**e)
|
||||
|
||||
if b.is_Integer:
|
||||
# use log to see if there is a power here
|
||||
logb = expand_log(log(b))
|
||||
if logb.is_Mul:
|
||||
c, logb = logb.args
|
||||
e *= c
|
||||
base = logb.args[0]
|
||||
return Pow(base, e)
|
||||
|
||||
# if b is not a Mul or any factor is an atom then there is nothing to do
|
||||
if not b.is_Mul or any(s.is_Atom for s in Mul.make_args(b)):
|
||||
return eq
|
||||
|
||||
# let log handle the case of the base of the argument being a Mul, e.g.
|
||||
# sqrt(x**(2*i)*y**(6*i)) -> x**i*y**(3**i) if x and y are positive; we
|
||||
# will take the log, expand it, and then factor out the common powers that
|
||||
# now appear as coefficient. We do this manually since terms_gcd pulls out
|
||||
# fractions, terms_gcd(x+x*y/2) -> x*(y + 2)/2 and we don't want the 1/2;
|
||||
# gcd won't pull out numerators from a fraction: gcd(3*x, 9*x/2) -> x but
|
||||
# we want 3*x. Neither work with noncommutatives.
|
||||
|
||||
def nc_gcd(aa, bb):
|
||||
a, b = [i.as_coeff_Mul() for i in [aa, bb]]
|
||||
c = gcd(a[0], b[0]).as_numer_denom()[0]
|
||||
g = Mul(*(a[1].args_cnc(cset=True)[0] & b[1].args_cnc(cset=True)[0]))
|
||||
return _keep_coeff(c, g)
|
||||
|
||||
glogb = expand_log(log(b))
|
||||
if glogb.is_Add:
|
||||
args = glogb.args
|
||||
g = reduce(nc_gcd, args)
|
||||
if g != 1:
|
||||
cg, rg = g.as_coeff_Mul()
|
||||
glogb = _keep_coeff(cg, rg*Add(*[a/g for a in args]))
|
||||
|
||||
# now put the log back together again
|
||||
if isinstance(glogb, log) or not glogb.is_Mul:
|
||||
if glogb.args[0].is_Pow or isinstance(glogb.args[0], exp):
|
||||
glogb = _denest_pow(glogb.args[0])
|
||||
if (abs(glogb.exp) < 1) == True:
|
||||
return Pow(glogb.base, glogb.exp*e)
|
||||
return eq
|
||||
|
||||
# the log(b) was a Mul so join any adds with logcombine
|
||||
add = []
|
||||
other = []
|
||||
for a in glogb.args:
|
||||
if a.is_Add:
|
||||
add.append(a)
|
||||
else:
|
||||
other.append(a)
|
||||
return Pow(exp(logcombine(Mul(*add))), e*Mul(*other))
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,222 @@
|
||||
from itertools import combinations_with_replacement
|
||||
from sympy.core import symbols, Add, Dummy
|
||||
from sympy.core.numbers import Rational
|
||||
from sympy.polys import cancel, ComputationFailed, parallel_poly_from_expr, reduced, Poly
|
||||
from sympy.polys.monomials import Monomial, monomial_div
|
||||
from sympy.polys.polyerrors import DomainError, PolificationFailed
|
||||
from sympy.utilities.misc import debug, debugf
|
||||
|
||||
def ratsimp(expr):
|
||||
"""
|
||||
Put an expression over a common denominator, cancel and reduce.
|
||||
|
||||
Examples
|
||||
========
|
||||
|
||||
>>> from sympy import ratsimp
|
||||
>>> from sympy.abc import x, y
|
||||
>>> ratsimp(1/x + 1/y)
|
||||
(x + y)/(x*y)
|
||||
"""
|
||||
|
||||
f, g = cancel(expr).as_numer_denom()
|
||||
try:
|
||||
Q, r = reduced(f, [g], field=True, expand=False)
|
||||
except ComputationFailed:
|
||||
return f/g
|
||||
|
||||
return Add(*Q) + cancel(r/g)
|
||||
|
||||
|
||||
def ratsimpmodprime(expr, G, *gens, quick=True, polynomial=False, **args):
|
||||
"""
|
||||
Simplifies a rational expression ``expr`` modulo the prime ideal
|
||||
generated by ``G``. ``G`` should be a Groebner basis of the
|
||||
ideal.
|
||||
|
||||
Examples
|
||||
========
|
||||
|
||||
>>> from sympy.simplify.ratsimp import ratsimpmodprime
|
||||
>>> from sympy.abc import x, y
|
||||
>>> eq = (x + y**5 + y)/(x - y)
|
||||
>>> ratsimpmodprime(eq, [x*y**5 - x - y], x, y, order='lex')
|
||||
(-x**2 - x*y - x - y)/(-x**2 + x*y)
|
||||
|
||||
If ``polynomial`` is ``False``, the algorithm computes a rational
|
||||
simplification which minimizes the sum of the total degrees of
|
||||
the numerator and the denominator.
|
||||
|
||||
If ``polynomial`` is ``True``, this function just brings numerator and
|
||||
denominator into a canonical form. This is much faster, but has
|
||||
potentially worse results.
|
||||
|
||||
References
|
||||
==========
|
||||
|
||||
.. [1] M. Monagan, R. Pearce, Rational Simplification Modulo a Polynomial
|
||||
Ideal, https://dl.acm.org/doi/pdf/10.1145/1145768.1145809
|
||||
(specifically, the second algorithm)
|
||||
"""
|
||||
from sympy.solvers.solvers import solve
|
||||
|
||||
debug('ratsimpmodprime', expr)
|
||||
|
||||
# usual preparation of polynomials:
|
||||
|
||||
num, denom = cancel(expr).as_numer_denom()
|
||||
|
||||
try:
|
||||
polys, opt = parallel_poly_from_expr([num, denom] + G, *gens, **args)
|
||||
except PolificationFailed:
|
||||
return expr
|
||||
|
||||
domain = opt.domain
|
||||
|
||||
if domain.has_assoc_Field:
|
||||
opt.domain = domain.get_field()
|
||||
else:
|
||||
raise DomainError(
|
||||
"Cannot compute rational simplification over %s" % domain)
|
||||
|
||||
# compute only once
|
||||
leading_monomials = [g.LM(opt.order) for g in polys[2:]]
|
||||
tested = set()
|
||||
|
||||
def staircase(n):
|
||||
"""
|
||||
Compute all monomials with degree less than ``n`` that are
|
||||
not divisible by any element of ``leading_monomials``.
|
||||
"""
|
||||
if n == 0:
|
||||
return [1]
|
||||
S = []
|
||||
for mi in combinations_with_replacement(range(len(opt.gens)), n):
|
||||
m = [0]*len(opt.gens)
|
||||
for i in mi:
|
||||
m[i] += 1
|
||||
if all(monomial_div(m, lmg) is None for lmg in
|
||||
leading_monomials):
|
||||
S.append(m)
|
||||
|
||||
return [Monomial(s).as_expr(*opt.gens) for s in S] + staircase(n - 1)
|
||||
|
||||
def _ratsimpmodprime(a, b, allsol, N=0, D=0):
|
||||
r"""
|
||||
Computes a rational simplification of ``a/b`` which minimizes
|
||||
the sum of the total degrees of the numerator and the denominator.
|
||||
|
||||
Explanation
|
||||
===========
|
||||
|
||||
The algorithm proceeds by looking at ``a * d - b * c`` modulo
|
||||
the ideal generated by ``G`` for some ``c`` and ``d`` with degree
|
||||
less than ``a`` and ``b`` respectively.
|
||||
The coefficients of ``c`` and ``d`` are indeterminates and thus
|
||||
the coefficients of the normalform of ``a * d - b * c`` are
|
||||
linear polynomials in these indeterminates.
|
||||
If these linear polynomials, considered as system of
|
||||
equations, have a nontrivial solution, then `\frac{a}{b}
|
||||
\equiv \frac{c}{d}` modulo the ideal generated by ``G``. So,
|
||||
by construction, the degree of ``c`` and ``d`` is less than
|
||||
the degree of ``a`` and ``b``, so a simpler representation
|
||||
has been found.
|
||||
After a simpler representation has been found, the algorithm
|
||||
tries to reduce the degree of the numerator and denominator
|
||||
and returns the result afterwards.
|
||||
|
||||
As an extension, if quick=False, we look at all possible degrees such
|
||||
that the total degree is less than *or equal to* the best current
|
||||
solution. We retain a list of all solutions of minimal degree, and try
|
||||
to find the best one at the end.
|
||||
"""
|
||||
c, d = a, b
|
||||
steps = 0
|
||||
|
||||
maxdeg = a.total_degree() + b.total_degree()
|
||||
if quick:
|
||||
bound = maxdeg - 1
|
||||
else:
|
||||
bound = maxdeg
|
||||
while N + D <= bound:
|
||||
if (N, D) in tested:
|
||||
break
|
||||
tested.add((N, D))
|
||||
|
||||
M1 = staircase(N)
|
||||
M2 = staircase(D)
|
||||
debugf('%s / %s: %s, %s', (N, D, M1, M2))
|
||||
|
||||
Cs = symbols("c:%d" % len(M1), cls=Dummy)
|
||||
Ds = symbols("d:%d" % len(M2), cls=Dummy)
|
||||
ng = Cs + Ds
|
||||
|
||||
c_hat = Poly(
|
||||
sum(Cs[i] * M1[i] for i in range(len(M1))), opt.gens + ng)
|
||||
d_hat = Poly(
|
||||
sum(Ds[i] * M2[i] for i in range(len(M2))), opt.gens + ng)
|
||||
|
||||
r = reduced(a * d_hat - b * c_hat, G, opt.gens + ng,
|
||||
order=opt.order, polys=True)[1]
|
||||
|
||||
S = Poly(r, gens=opt.gens).coeffs()
|
||||
sol = solve(S, Cs + Ds, particular=True, quick=True)
|
||||
|
||||
if sol and not all(s == 0 for s in sol.values()):
|
||||
c = c_hat.subs(sol)
|
||||
d = d_hat.subs(sol)
|
||||
|
||||
# The "free" variables occurring before as parameters
|
||||
# might still be in the substituted c, d, so set them
|
||||
# to the value chosen before:
|
||||
c = c.subs(dict(list(zip(Cs + Ds, [1] * (len(Cs) + len(Ds))))))
|
||||
d = d.subs(dict(list(zip(Cs + Ds, [1] * (len(Cs) + len(Ds))))))
|
||||
|
||||
c = Poly(c, opt.gens)
|
||||
d = Poly(d, opt.gens)
|
||||
if d == 0:
|
||||
raise ValueError('Ideal not prime?')
|
||||
|
||||
allsol.append((c_hat, d_hat, S, Cs + Ds))
|
||||
if N + D != maxdeg:
|
||||
allsol = [allsol[-1]]
|
||||
|
||||
break
|
||||
|
||||
steps += 1
|
||||
N += 1
|
||||
D += 1
|
||||
|
||||
if steps > 0:
|
||||
c, d, allsol = _ratsimpmodprime(c, d, allsol, N, D - steps)
|
||||
c, d, allsol = _ratsimpmodprime(c, d, allsol, N - steps, D)
|
||||
|
||||
return c, d, allsol
|
||||
|
||||
# preprocessing. this improves performance a bit when deg(num)
|
||||
# and deg(denom) are large:
|
||||
num = reduced(num, G, opt.gens, order=opt.order)[1]
|
||||
denom = reduced(denom, G, opt.gens, order=opt.order)[1]
|
||||
|
||||
if polynomial:
|
||||
return (num/denom).cancel()
|
||||
|
||||
c, d, allsol = _ratsimpmodprime(
|
||||
Poly(num, opt.gens, domain=opt.domain), Poly(denom, opt.gens, domain=opt.domain), [])
|
||||
if not quick and allsol:
|
||||
debugf('Looking for best minimal solution. Got: %s', len(allsol))
|
||||
newsol = []
|
||||
for c_hat, d_hat, S, ng in allsol:
|
||||
sol = solve(S, ng, particular=True, quick=False)
|
||||
# all values of sol should be numbers; if not, solve is broken
|
||||
newsol.append((c_hat.subs(sol), d_hat.subs(sol)))
|
||||
c, d = min(newsol, key=lambda x: len(x[0].terms()) + len(x[1].terms()))
|
||||
|
||||
if not domain.is_Field:
|
||||
cn, c = c.clear_denoms(convert=True)
|
||||
dn, d = d.clear_denoms(convert=True)
|
||||
r = Rational(cn, dn)
|
||||
else:
|
||||
r = Rational(1)
|
||||
|
||||
return (c*r.q)/(d*r.p)
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,678 @@
|
||||
from sympy.core import Add, Expr, Mul, S, sympify
|
||||
from sympy.core.function import _mexpand, count_ops, expand_mul
|
||||
from sympy.core.sorting import default_sort_key
|
||||
from sympy.core.symbol import Dummy
|
||||
from sympy.functions import root, sign, sqrt
|
||||
from sympy.polys import Poly, PolynomialError
|
||||
|
||||
|
||||
def is_sqrt(expr):
|
||||
"""Return True if expr is a sqrt, otherwise False."""
|
||||
|
||||
return expr.is_Pow and expr.exp.is_Rational and abs(expr.exp) is S.Half
|
||||
|
||||
|
||||
def sqrt_depth(p) -> int:
|
||||
"""Return the maximum depth of any square root argument of p.
|
||||
|
||||
>>> from sympy.functions.elementary.miscellaneous import sqrt
|
||||
>>> from sympy.simplify.sqrtdenest import sqrt_depth
|
||||
|
||||
Neither of these square roots contains any other square roots
|
||||
so the depth is 1:
|
||||
|
||||
>>> sqrt_depth(1 + sqrt(2)*(1 + sqrt(3)))
|
||||
1
|
||||
|
||||
The sqrt(3) is contained within a square root so the depth is
|
||||
2:
|
||||
|
||||
>>> sqrt_depth(1 + sqrt(2)*sqrt(1 + sqrt(3)))
|
||||
2
|
||||
"""
|
||||
if p is S.ImaginaryUnit:
|
||||
return 1
|
||||
if p.is_Atom:
|
||||
return 0
|
||||
if p.is_Add or p.is_Mul:
|
||||
return max(sqrt_depth(x) for x in p.args)
|
||||
if is_sqrt(p):
|
||||
return sqrt_depth(p.base) + 1
|
||||
return 0
|
||||
|
||||
|
||||
def is_algebraic(p):
|
||||
"""Return True if p is comprised of only Rationals or square roots
|
||||
of Rationals and algebraic operations.
|
||||
|
||||
Examples
|
||||
========
|
||||
|
||||
>>> from sympy.functions.elementary.miscellaneous import sqrt
|
||||
>>> from sympy.simplify.sqrtdenest import is_algebraic
|
||||
>>> from sympy import cos
|
||||
>>> is_algebraic(sqrt(2)*(3/(sqrt(7) + sqrt(5)*sqrt(2))))
|
||||
True
|
||||
>>> is_algebraic(sqrt(2)*(3/(sqrt(7) + sqrt(5)*cos(2))))
|
||||
False
|
||||
"""
|
||||
|
||||
if p.is_Rational:
|
||||
return True
|
||||
elif p.is_Atom:
|
||||
return False
|
||||
elif is_sqrt(p) or p.is_Pow and p.exp.is_Integer:
|
||||
return is_algebraic(p.base)
|
||||
elif p.is_Add or p.is_Mul:
|
||||
return all(is_algebraic(x) for x in p.args)
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
def _subsets(n):
|
||||
"""
|
||||
Returns all possible subsets of the set (0, 1, ..., n-1) except the
|
||||
empty set, listed in reversed lexicographical order according to binary
|
||||
representation, so that the case of the fourth root is treated last.
|
||||
|
||||
Examples
|
||||
========
|
||||
|
||||
>>> from sympy.simplify.sqrtdenest import _subsets
|
||||
>>> _subsets(2)
|
||||
[[1, 0], [0, 1], [1, 1]]
|
||||
|
||||
"""
|
||||
if n == 1:
|
||||
a = [[1]]
|
||||
elif n == 2:
|
||||
a = [[1, 0], [0, 1], [1, 1]]
|
||||
elif n == 3:
|
||||
a = [[1, 0, 0], [0, 1, 0], [1, 1, 0],
|
||||
[0, 0, 1], [1, 0, 1], [0, 1, 1], [1, 1, 1]]
|
||||
else:
|
||||
b = _subsets(n - 1)
|
||||
a0 = [x + [0] for x in b]
|
||||
a1 = [x + [1] for x in b]
|
||||
a = a0 + [[0]*(n - 1) + [1]] + a1
|
||||
return a
|
||||
|
||||
|
||||
def sqrtdenest(expr, max_iter=3):
|
||||
"""Denests sqrts in an expression that contain other square roots
|
||||
if possible, otherwise returns the expr unchanged. This is based on the
|
||||
algorithms of [1].
|
||||
|
||||
Examples
|
||||
========
|
||||
|
||||
>>> from sympy.simplify.sqrtdenest import sqrtdenest
|
||||
>>> from sympy import sqrt
|
||||
>>> sqrtdenest(sqrt(5 + 2 * sqrt(6)))
|
||||
sqrt(2) + sqrt(3)
|
||||
|
||||
See Also
|
||||
========
|
||||
|
||||
sympy.solvers.solvers.unrad
|
||||
|
||||
References
|
||||
==========
|
||||
|
||||
.. [1] https://web.archive.org/web/20210806201615/https://researcher.watson.ibm.com/researcher/files/us-fagin/symb85.pdf
|
||||
|
||||
.. [2] D. J. Jeffrey and A. D. Rich, 'Symplifying Square Roots of Square Roots
|
||||
by Denesting' (available at https://www.cybertester.com/data/denest.pdf)
|
||||
|
||||
"""
|
||||
expr = expand_mul(expr)
|
||||
for i in range(max_iter):
|
||||
z = _sqrtdenest0(expr)
|
||||
if expr == z:
|
||||
return expr
|
||||
expr = z
|
||||
return expr
|
||||
|
||||
|
||||
def _sqrt_match(p):
|
||||
"""Return [a, b, r] for p.match(a + b*sqrt(r)) where, in addition to
|
||||
matching, sqrt(r) also has then maximal sqrt_depth among addends of p.
|
||||
|
||||
Examples
|
||||
========
|
||||
|
||||
>>> from sympy.functions.elementary.miscellaneous import sqrt
|
||||
>>> from sympy.simplify.sqrtdenest import _sqrt_match
|
||||
>>> _sqrt_match(1 + sqrt(2) + sqrt(2)*sqrt(3) + 2*sqrt(1+sqrt(5)))
|
||||
[1 + sqrt(2) + sqrt(6), 2, 1 + sqrt(5)]
|
||||
"""
|
||||
from sympy.simplify.radsimp import split_surds
|
||||
|
||||
p = _mexpand(p)
|
||||
if p.is_Number:
|
||||
res = (p, S.Zero, S.Zero)
|
||||
elif p.is_Add:
|
||||
pargs = sorted(p.args, key=default_sort_key)
|
||||
sqargs = [x**2 for x in pargs]
|
||||
if all(sq.is_Rational and sq.is_positive for sq in sqargs):
|
||||
r, b, a = split_surds(p)
|
||||
res = a, b, r
|
||||
return list(res)
|
||||
# to make the process canonical, the argument is included in the tuple
|
||||
# so when the max is selected, it will be the largest arg having a
|
||||
# given depth
|
||||
v = [(sqrt_depth(x), x, i) for i, x in enumerate(pargs)]
|
||||
nmax = max(v, key=default_sort_key)
|
||||
if nmax[0] == 0:
|
||||
res = []
|
||||
else:
|
||||
# select r
|
||||
depth, _, i = nmax
|
||||
r = pargs.pop(i)
|
||||
v.pop(i)
|
||||
b = S.One
|
||||
if r.is_Mul:
|
||||
bv = []
|
||||
rv = []
|
||||
for x in r.args:
|
||||
if sqrt_depth(x) < depth:
|
||||
bv.append(x)
|
||||
else:
|
||||
rv.append(x)
|
||||
b = Mul._from_args(bv)
|
||||
r = Mul._from_args(rv)
|
||||
# collect terms containing r
|
||||
a1 = []
|
||||
b1 = [b]
|
||||
for x in v:
|
||||
if x[0] < depth:
|
||||
a1.append(x[1])
|
||||
else:
|
||||
x1 = x[1]
|
||||
if x1 == r:
|
||||
b1.append(1)
|
||||
else:
|
||||
if x1.is_Mul:
|
||||
x1args = list(x1.args)
|
||||
if r in x1args:
|
||||
x1args.remove(r)
|
||||
b1.append(Mul(*x1args))
|
||||
else:
|
||||
a1.append(x[1])
|
||||
else:
|
||||
a1.append(x[1])
|
||||
a = Add(*a1)
|
||||
b = Add(*b1)
|
||||
res = (a, b, r**2)
|
||||
else:
|
||||
b, r = p.as_coeff_Mul()
|
||||
if is_sqrt(r):
|
||||
res = (S.Zero, b, r**2)
|
||||
else:
|
||||
res = []
|
||||
return list(res)
|
||||
|
||||
|
||||
class SqrtdenestStopIteration(StopIteration):
|
||||
pass
|
||||
|
||||
|
||||
def _sqrtdenest0(expr):
|
||||
"""Returns expr after denesting its arguments."""
|
||||
|
||||
if is_sqrt(expr):
|
||||
n, d = expr.as_numer_denom()
|
||||
if d is S.One: # n is a square root
|
||||
if n.base.is_Add:
|
||||
args = sorted(n.base.args, key=default_sort_key)
|
||||
if len(args) > 2 and all((x**2).is_Integer for x in args):
|
||||
try:
|
||||
return _sqrtdenest_rec(n)
|
||||
except SqrtdenestStopIteration:
|
||||
pass
|
||||
expr = sqrt(_mexpand(Add(*[_sqrtdenest0(x) for x in args])))
|
||||
return _sqrtdenest1(expr)
|
||||
else:
|
||||
n, d = [_sqrtdenest0(i) for i in (n, d)]
|
||||
return n/d
|
||||
|
||||
if isinstance(expr, Add):
|
||||
cs = []
|
||||
args = []
|
||||
for arg in expr.args:
|
||||
c, a = arg.as_coeff_Mul()
|
||||
cs.append(c)
|
||||
args.append(a)
|
||||
|
||||
if all(c.is_Rational for c in cs) and all(is_sqrt(arg) for arg in args):
|
||||
return _sqrt_ratcomb(cs, args)
|
||||
|
||||
if isinstance(expr, Expr):
|
||||
args = expr.args
|
||||
if args:
|
||||
return expr.func(*[_sqrtdenest0(a) for a in args])
|
||||
return expr
|
||||
|
||||
|
||||
def _sqrtdenest_rec(expr):
|
||||
"""Helper that denests the square root of three or more surds.
|
||||
|
||||
Explanation
|
||||
===========
|
||||
|
||||
It returns the denested expression; if it cannot be denested it
|
||||
throws SqrtdenestStopIteration
|
||||
|
||||
Algorithm: expr.base is in the extension Q_m = Q(sqrt(r_1),..,sqrt(r_k));
|
||||
split expr.base = a + b*sqrt(r_k), where `a` and `b` are on
|
||||
Q_(m-1) = Q(sqrt(r_1),..,sqrt(r_(k-1))); then a**2 - b**2*r_k is
|
||||
on Q_(m-1); denest sqrt(a**2 - b**2*r_k) and so on.
|
||||
See [1], section 6.
|
||||
|
||||
Examples
|
||||
========
|
||||
|
||||
>>> from sympy import sqrt
|
||||
>>> from sympy.simplify.sqrtdenest import _sqrtdenest_rec
|
||||
>>> _sqrtdenest_rec(sqrt(-72*sqrt(2) + 158*sqrt(5) + 498))
|
||||
-sqrt(10) + sqrt(2) + 9 + 9*sqrt(5)
|
||||
>>> w=-6*sqrt(55)-6*sqrt(35)-2*sqrt(22)-2*sqrt(14)+2*sqrt(77)+6*sqrt(10)+65
|
||||
>>> _sqrtdenest_rec(sqrt(w))
|
||||
-sqrt(11) - sqrt(7) + sqrt(2) + 3*sqrt(5)
|
||||
"""
|
||||
from sympy.simplify.radsimp import radsimp, rad_rationalize, split_surds
|
||||
if not expr.is_Pow:
|
||||
return sqrtdenest(expr)
|
||||
if expr.base < 0:
|
||||
return sqrt(-1)*_sqrtdenest_rec(sqrt(-expr.base))
|
||||
g, a, b = split_surds(expr.base)
|
||||
a = a*sqrt(g)
|
||||
if a < b:
|
||||
a, b = b, a
|
||||
c2 = _mexpand(a**2 - b**2)
|
||||
if len(c2.args) > 2:
|
||||
g, a1, b1 = split_surds(c2)
|
||||
a1 = a1*sqrt(g)
|
||||
if a1 < b1:
|
||||
a1, b1 = b1, a1
|
||||
c2_1 = _mexpand(a1**2 - b1**2)
|
||||
c_1 = _sqrtdenest_rec(sqrt(c2_1))
|
||||
d_1 = _sqrtdenest_rec(sqrt(a1 + c_1))
|
||||
num, den = rad_rationalize(b1, d_1)
|
||||
c = _mexpand(d_1/sqrt(2) + num/(den*sqrt(2)))
|
||||
else:
|
||||
c = _sqrtdenest1(sqrt(c2))
|
||||
|
||||
if sqrt_depth(c) > 1:
|
||||
raise SqrtdenestStopIteration
|
||||
ac = a + c
|
||||
if len(ac.args) >= len(expr.args):
|
||||
if count_ops(ac) >= count_ops(expr.base):
|
||||
raise SqrtdenestStopIteration
|
||||
d = sqrtdenest(sqrt(ac))
|
||||
if sqrt_depth(d) > 1:
|
||||
raise SqrtdenestStopIteration
|
||||
num, den = rad_rationalize(b, d)
|
||||
r = d/sqrt(2) + num/(den*sqrt(2))
|
||||
r = radsimp(r)
|
||||
return _mexpand(r)
|
||||
|
||||
|
||||
def _sqrtdenest1(expr, denester=True):
|
||||
"""Return denested expr after denesting with simpler methods or, that
|
||||
failing, using the denester."""
|
||||
|
||||
from sympy.simplify.simplify import radsimp
|
||||
|
||||
if not is_sqrt(expr):
|
||||
return expr
|
||||
|
||||
a = expr.base
|
||||
if a.is_Atom:
|
||||
return expr
|
||||
val = _sqrt_match(a)
|
||||
if not val:
|
||||
return expr
|
||||
|
||||
a, b, r = val
|
||||
# try a quick numeric denesting
|
||||
d2 = _mexpand(a**2 - b**2*r)
|
||||
if d2.is_Rational:
|
||||
if d2.is_positive:
|
||||
z = _sqrt_numeric_denest(a, b, r, d2)
|
||||
if z is not None:
|
||||
return z
|
||||
else:
|
||||
# fourth root case
|
||||
# sqrtdenest(sqrt(3 + 2*sqrt(3))) =
|
||||
# sqrt(2)*3**(1/4)/2 + sqrt(2)*3**(3/4)/2
|
||||
dr2 = _mexpand(-d2*r)
|
||||
dr = sqrt(dr2)
|
||||
if dr.is_Rational:
|
||||
z = _sqrt_numeric_denest(_mexpand(b*r), a, r, dr2)
|
||||
if z is not None:
|
||||
return z/root(r, 4)
|
||||
|
||||
else:
|
||||
z = _sqrt_symbolic_denest(a, b, r)
|
||||
if z is not None:
|
||||
return z
|
||||
|
||||
if not denester or not is_algebraic(expr):
|
||||
return expr
|
||||
|
||||
res = sqrt_biquadratic_denest(expr, a, b, r, d2)
|
||||
if res:
|
||||
return res
|
||||
|
||||
# now call to the denester
|
||||
av0 = [a, b, r, d2]
|
||||
z = _denester([radsimp(expr**2)], av0, 0, sqrt_depth(expr))[0]
|
||||
if av0[1] is None:
|
||||
return expr
|
||||
if z is not None:
|
||||
if sqrt_depth(z) == sqrt_depth(expr) and count_ops(z) > count_ops(expr):
|
||||
return expr
|
||||
return z
|
||||
return expr
|
||||
|
||||
|
||||
def _sqrt_symbolic_denest(a, b, r):
|
||||
"""Given an expression, sqrt(a + b*sqrt(b)), return the denested
|
||||
expression or None.
|
||||
|
||||
Explanation
|
||||
===========
|
||||
|
||||
If r = ra + rb*sqrt(rr), try replacing sqrt(rr) in ``a`` with
|
||||
(y**2 - ra)/rb, and if the result is a quadratic, ca*y**2 + cb*y + cc, and
|
||||
(cb + b)**2 - 4*ca*cc is 0, then sqrt(a + b*sqrt(r)) can be rewritten as
|
||||
sqrt(ca*(sqrt(r) + (cb + b)/(2*ca))**2).
|
||||
|
||||
Examples
|
||||
========
|
||||
|
||||
>>> from sympy.simplify.sqrtdenest import _sqrt_symbolic_denest, sqrtdenest
|
||||
>>> from sympy import sqrt, Symbol
|
||||
>>> from sympy.abc import x
|
||||
|
||||
>>> a, b, r = 16 - 2*sqrt(29), 2, -10*sqrt(29) + 55
|
||||
>>> _sqrt_symbolic_denest(a, b, r)
|
||||
sqrt(11 - 2*sqrt(29)) + sqrt(5)
|
||||
|
||||
If the expression is numeric, it will be simplified:
|
||||
|
||||
>>> w = sqrt(sqrt(sqrt(3) + 1) + 1) + 1 + sqrt(2)
|
||||
>>> sqrtdenest(sqrt((w**2).expand()))
|
||||
1 + sqrt(2) + sqrt(1 + sqrt(1 + sqrt(3)))
|
||||
|
||||
Otherwise, it will only be simplified if assumptions allow:
|
||||
|
||||
>>> w = w.subs(sqrt(3), sqrt(x + 3))
|
||||
>>> sqrtdenest(sqrt((w**2).expand()))
|
||||
sqrt((sqrt(sqrt(sqrt(x + 3) + 1) + 1) + 1 + sqrt(2))**2)
|
||||
|
||||
Notice that the argument of the sqrt is a square. If x is made positive
|
||||
then the sqrt of the square is resolved:
|
||||
|
||||
>>> _.subs(x, Symbol('x', positive=True))
|
||||
sqrt(sqrt(sqrt(x + 3) + 1) + 1) + 1 + sqrt(2)
|
||||
"""
|
||||
|
||||
a, b, r = map(sympify, (a, b, r))
|
||||
rval = _sqrt_match(r)
|
||||
if not rval:
|
||||
return None
|
||||
ra, rb, rr = rval
|
||||
if rb:
|
||||
y = Dummy('y', positive=True)
|
||||
try:
|
||||
newa = Poly(a.subs(sqrt(rr), (y**2 - ra)/rb), y)
|
||||
except PolynomialError:
|
||||
return None
|
||||
if newa.degree() == 2:
|
||||
ca, cb, cc = newa.all_coeffs()
|
||||
cb += b
|
||||
if _mexpand(cb**2 - 4*ca*cc).equals(0):
|
||||
z = sqrt(ca*(sqrt(r) + cb/(2*ca))**2)
|
||||
if z.is_number:
|
||||
z = _mexpand(Mul._from_args(z.as_content_primitive()))
|
||||
return z
|
||||
|
||||
|
||||
def _sqrt_numeric_denest(a, b, r, d2):
|
||||
r"""Helper that denest
|
||||
$\sqrt{a + b \sqrt{r}}, d^2 = a^2 - b^2 r > 0$
|
||||
|
||||
If it cannot be denested, it returns ``None``.
|
||||
"""
|
||||
d = sqrt(d2)
|
||||
s = a + d
|
||||
# sqrt_depth(res) <= sqrt_depth(s) + 1
|
||||
# sqrt_depth(expr) = sqrt_depth(r) + 2
|
||||
# there is denesting if sqrt_depth(s) + 1 < sqrt_depth(r) + 2
|
||||
# if s**2 is Number there is a fourth root
|
||||
if sqrt_depth(s) < sqrt_depth(r) + 1 or (s**2).is_Rational:
|
||||
s1, s2 = sign(s), sign(b)
|
||||
if s1 == s2 == -1:
|
||||
s1 = s2 = 1
|
||||
res = (s1 * sqrt(a + d) + s2 * sqrt(a - d)) * sqrt(2) / 2
|
||||
return res.expand()
|
||||
|
||||
|
||||
def sqrt_biquadratic_denest(expr, a, b, r, d2):
|
||||
"""denest expr = sqrt(a + b*sqrt(r))
|
||||
where a, b, r are linear combinations of square roots of
|
||||
positive rationals on the rationals (SQRR) and r > 0, b != 0,
|
||||
d2 = a**2 - b**2*r > 0
|
||||
|
||||
If it cannot denest it returns None.
|
||||
|
||||
Explanation
|
||||
===========
|
||||
|
||||
Search for a solution A of type SQRR of the biquadratic equation
|
||||
4*A**4 - 4*a*A**2 + b**2*r = 0 (1)
|
||||
sqd = sqrt(a**2 - b**2*r)
|
||||
Choosing the sqrt to be positive, the possible solutions are
|
||||
A = sqrt(a/2 +/- sqd/2)
|
||||
Since a, b, r are SQRR, then a**2 - b**2*r is a SQRR,
|
||||
so if sqd can be denested, it is done by
|
||||
_sqrtdenest_rec, and the result is a SQRR.
|
||||
Similarly for A.
|
||||
Examples of solutions (in both cases a and sqd are positive):
|
||||
|
||||
Example of expr with solution sqrt(a/2 + sqd/2) but not
|
||||
solution sqrt(a/2 - sqd/2):
|
||||
expr = sqrt(-sqrt(15) - sqrt(2)*sqrt(-sqrt(5) + 5) - sqrt(3) + 8)
|
||||
a = -sqrt(15) - sqrt(3) + 8; sqd = -2*sqrt(5) - 2 + 4*sqrt(3)
|
||||
|
||||
Example of expr with solution sqrt(a/2 - sqd/2) but not
|
||||
solution sqrt(a/2 + sqd/2):
|
||||
w = 2 + r2 + r3 + (1 + r3)*sqrt(2 + r2 + 5*r3)
|
||||
expr = sqrt((w**2).expand())
|
||||
a = 4*sqrt(6) + 8*sqrt(2) + 47 + 28*sqrt(3)
|
||||
sqd = 29 + 20*sqrt(3)
|
||||
|
||||
Define B = b/2*A; eq.(1) implies a = A**2 + B**2*r; then
|
||||
expr**2 = a + b*sqrt(r) = (A + B*sqrt(r))**2
|
||||
|
||||
Examples
|
||||
========
|
||||
|
||||
>>> from sympy import sqrt
|
||||
>>> from sympy.simplify.sqrtdenest import _sqrt_match, sqrt_biquadratic_denest
|
||||
>>> z = sqrt((2*sqrt(2) + 4)*sqrt(2 + sqrt(2)) + 5*sqrt(2) + 8)
|
||||
>>> a, b, r = _sqrt_match(z**2)
|
||||
>>> d2 = a**2 - b**2*r
|
||||
>>> sqrt_biquadratic_denest(z, a, b, r, d2)
|
||||
sqrt(2) + sqrt(sqrt(2) + 2) + 2
|
||||
"""
|
||||
from sympy.simplify.radsimp import radsimp, rad_rationalize
|
||||
if r <= 0 or d2 < 0 or not b or sqrt_depth(expr.base) < 2:
|
||||
return None
|
||||
for x in (a, b, r):
|
||||
for y in x.args:
|
||||
y2 = y**2
|
||||
if not y2.is_Integer or not y2.is_positive:
|
||||
return None
|
||||
sqd = _mexpand(sqrtdenest(sqrt(radsimp(d2))))
|
||||
if sqrt_depth(sqd) > 1:
|
||||
return None
|
||||
x1, x2 = [a/2 + sqd/2, a/2 - sqd/2]
|
||||
# look for a solution A with depth 1
|
||||
for x in (x1, x2):
|
||||
A = sqrtdenest(sqrt(x))
|
||||
if sqrt_depth(A) > 1:
|
||||
continue
|
||||
Bn, Bd = rad_rationalize(b, _mexpand(2*A))
|
||||
B = Bn/Bd
|
||||
z = A + B*sqrt(r)
|
||||
if z < 0:
|
||||
z = -z
|
||||
return _mexpand(z)
|
||||
return None
|
||||
|
||||
|
||||
def _denester(nested, av0, h, max_depth_level):
|
||||
"""Denests a list of expressions that contain nested square roots.
|
||||
|
||||
Explanation
|
||||
===========
|
||||
|
||||
Algorithm based on <http://www.almaden.ibm.com/cs/people/fagin/symb85.pdf>.
|
||||
|
||||
It is assumed that all of the elements of 'nested' share the same
|
||||
bottom-level radicand. (This is stated in the paper, on page 177, in
|
||||
the paragraph immediately preceding the algorithm.)
|
||||
|
||||
When evaluating all of the arguments in parallel, the bottom-level
|
||||
radicand only needs to be denested once. This means that calling
|
||||
_denester with x arguments results in a recursive invocation with x+1
|
||||
arguments; hence _denester has polynomial complexity.
|
||||
|
||||
However, if the arguments were evaluated separately, each call would
|
||||
result in two recursive invocations, and the algorithm would have
|
||||
exponential complexity.
|
||||
|
||||
This is discussed in the paper in the middle paragraph of page 179.
|
||||
"""
|
||||
from sympy.simplify.simplify import radsimp
|
||||
if h > max_depth_level:
|
||||
return None, None
|
||||
if av0[1] is None:
|
||||
return None, None
|
||||
if (av0[0] is None and
|
||||
all(n.is_Number for n in nested)): # no arguments are nested
|
||||
for f in _subsets(len(nested)): # test subset 'f' of nested
|
||||
p = _mexpand(Mul(*[nested[i] for i in range(len(f)) if f[i]]))
|
||||
if f.count(1) > 1 and f[-1]:
|
||||
p = -p
|
||||
sqp = sqrt(p)
|
||||
if sqp.is_Rational:
|
||||
return sqp, f # got a perfect square so return its square root.
|
||||
# Otherwise, return the radicand from the previous invocation.
|
||||
return sqrt(nested[-1]), [0]*len(nested)
|
||||
else:
|
||||
R = None
|
||||
if av0[0] is not None:
|
||||
values = [av0[:2]]
|
||||
R = av0[2]
|
||||
nested2 = [av0[3], R]
|
||||
av0[0] = None
|
||||
else:
|
||||
values = list(filter(None, [_sqrt_match(expr) for expr in nested]))
|
||||
for v in values:
|
||||
if v[2]: # Since if b=0, r is not defined
|
||||
if R is not None:
|
||||
if R != v[2]:
|
||||
av0[1] = None
|
||||
return None, None
|
||||
else:
|
||||
R = v[2]
|
||||
if R is None:
|
||||
# return the radicand from the previous invocation
|
||||
return sqrt(nested[-1]), [0]*len(nested)
|
||||
nested2 = [_mexpand(v[0]**2) -
|
||||
_mexpand(R*v[1]**2) for v in values] + [R]
|
||||
d, f = _denester(nested2, av0, h + 1, max_depth_level)
|
||||
if not f:
|
||||
return None, None
|
||||
if not any(f[i] for i in range(len(nested))):
|
||||
v = values[-1]
|
||||
return sqrt(v[0] + _mexpand(v[1]*d)), f
|
||||
else:
|
||||
p = Mul(*[nested[i] for i in range(len(nested)) if f[i]])
|
||||
v = _sqrt_match(p)
|
||||
if 1 in f and f.index(1) < len(nested) - 1 and f[len(nested) - 1]:
|
||||
v[0] = -v[0]
|
||||
v[1] = -v[1]
|
||||
if not f[len(nested)]: # Solution denests with square roots
|
||||
vad = _mexpand(v[0] + d)
|
||||
if vad <= 0:
|
||||
# return the radicand from the previous invocation.
|
||||
return sqrt(nested[-1]), [0]*len(nested)
|
||||
if not(sqrt_depth(vad) <= sqrt_depth(R) + 1 or
|
||||
(vad**2).is_Number):
|
||||
av0[1] = None
|
||||
return None, None
|
||||
|
||||
sqvad = _sqrtdenest1(sqrt(vad), denester=False)
|
||||
if not (sqrt_depth(sqvad) <= sqrt_depth(R) + 1):
|
||||
av0[1] = None
|
||||
return None, None
|
||||
sqvad1 = radsimp(1/sqvad)
|
||||
res = _mexpand(sqvad/sqrt(2) + (v[1]*sqrt(R)*sqvad1/sqrt(2)))
|
||||
return res, f
|
||||
|
||||
# sign(v[1])*sqrt(_mexpand(v[1]**2*R*vad1/2))), f
|
||||
else: # Solution requires a fourth root
|
||||
s2 = _mexpand(v[1]*R) + d
|
||||
if s2 <= 0:
|
||||
return sqrt(nested[-1]), [0]*len(nested)
|
||||
FR, s = root(_mexpand(R), 4), sqrt(s2)
|
||||
return _mexpand(s/(sqrt(2)*FR) + v[0]*FR/(sqrt(2)*s)), f
|
||||
|
||||
|
||||
def _sqrt_ratcomb(cs, args):
|
||||
"""Denest rational combinations of radicals.
|
||||
|
||||
Based on section 5 of [1].
|
||||
|
||||
Examples
|
||||
========
|
||||
|
||||
>>> from sympy import sqrt
|
||||
>>> from sympy.simplify.sqrtdenest import sqrtdenest
|
||||
>>> z = sqrt(1+sqrt(3)) + sqrt(3+3*sqrt(3)) - sqrt(10+6*sqrt(3))
|
||||
>>> sqrtdenest(z)
|
||||
0
|
||||
"""
|
||||
from sympy.simplify.radsimp import radsimp
|
||||
|
||||
# check if there exists a pair of sqrt that can be denested
|
||||
def find(a):
|
||||
n = len(a)
|
||||
for i in range(n - 1):
|
||||
for j in range(i + 1, n):
|
||||
s1 = a[i].base
|
||||
s2 = a[j].base
|
||||
p = _mexpand(s1 * s2)
|
||||
s = sqrtdenest(sqrt(p))
|
||||
if s != sqrt(p):
|
||||
return s, i, j
|
||||
|
||||
indices = find(args)
|
||||
if indices is None:
|
||||
return Add(*[c * arg for c, arg in zip(cs, args)])
|
||||
|
||||
s, i1, i2 = indices
|
||||
|
||||
c2 = cs.pop(i2)
|
||||
args.pop(i2)
|
||||
a1 = args[i1]
|
||||
|
||||
# replace a2 by s/a1
|
||||
cs[i1] += radsimp(c2 * s / a1.base)
|
||||
|
||||
return _sqrt_ratcomb(cs, args)
|
||||
@@ -0,0 +1,75 @@
|
||||
from sympy.core.numbers import Rational
|
||||
from sympy.core.symbol import symbols
|
||||
from sympy.functions.combinatorial.factorials import (FallingFactorial, RisingFactorial, binomial, factorial)
|
||||
from sympy.functions.special.gamma_functions import gamma
|
||||
from sympy.simplify.combsimp import combsimp
|
||||
from sympy.abc import x
|
||||
|
||||
|
||||
def test_combsimp():
|
||||
k, m, n = symbols('k m n', integer = True)
|
||||
|
||||
assert combsimp(factorial(n)) == factorial(n)
|
||||
assert combsimp(binomial(n, k)) == binomial(n, k)
|
||||
|
||||
assert combsimp(factorial(n)/factorial(n - 3)) == n*(-1 + n)*(-2 + n)
|
||||
assert combsimp(binomial(n + 1, k + 1)/binomial(n, k)) == (1 + n)/(1 + k)
|
||||
|
||||
assert combsimp(binomial(3*n + 4, n + 1)/binomial(3*n + 1, n)) == \
|
||||
Rational(3, 2)*((3*n + 2)*(3*n + 4)/((n + 1)*(2*n + 3)))
|
||||
|
||||
assert combsimp(factorial(n)**2/factorial(n - 3)) == \
|
||||
factorial(n)*n*(-1 + n)*(-2 + n)
|
||||
assert combsimp(factorial(n)*binomial(n + 1, k + 1)/binomial(n, k)) == \
|
||||
factorial(n + 1)/(1 + k)
|
||||
|
||||
assert combsimp(gamma(n + 3)) == factorial(n + 2)
|
||||
|
||||
assert combsimp(factorial(x)) == gamma(x + 1)
|
||||
|
||||
# issue 9699
|
||||
assert combsimp((n + 1)*factorial(n)) == factorial(n + 1)
|
||||
assert combsimp(factorial(n)/n) == factorial(n-1)
|
||||
|
||||
# issue 6658
|
||||
assert combsimp(binomial(n, n - k)) == binomial(n, k)
|
||||
|
||||
# issue 6341, 7135
|
||||
assert combsimp(factorial(n)/(factorial(k)*factorial(n - k))) == \
|
||||
binomial(n, k)
|
||||
assert combsimp(factorial(k)*factorial(n - k)/factorial(n)) == \
|
||||
1/binomial(n, k)
|
||||
assert combsimp(factorial(2*n)/factorial(n)**2) == binomial(2*n, n)
|
||||
assert combsimp(factorial(2*n)*factorial(k)*factorial(n - k)/
|
||||
factorial(n)**3) == binomial(2*n, n)/binomial(n, k)
|
||||
|
||||
assert combsimp(factorial(n*(1 + n) - n**2 - n)) == 1
|
||||
|
||||
assert combsimp(6*FallingFactorial(-4, n)/factorial(n)) == \
|
||||
(-1)**n*(n + 1)*(n + 2)*(n + 3)
|
||||
assert combsimp(6*FallingFactorial(-4, n - 1)/factorial(n - 1)) == \
|
||||
(-1)**(n - 1)*n*(n + 1)*(n + 2)
|
||||
assert combsimp(6*FallingFactorial(-4, n - 3)/factorial(n - 3)) == \
|
||||
(-1)**(n - 3)*n*(n - 1)*(n - 2)
|
||||
assert combsimp(6*FallingFactorial(-4, -n - 1)/factorial(-n - 1)) == \
|
||||
-(-1)**(-n - 1)*n*(n - 1)*(n - 2)
|
||||
|
||||
assert combsimp(6*RisingFactorial(4, n)/factorial(n)) == \
|
||||
(n + 1)*(n + 2)*(n + 3)
|
||||
assert combsimp(6*RisingFactorial(4, n - 1)/factorial(n - 1)) == \
|
||||
n*(n + 1)*(n + 2)
|
||||
assert combsimp(6*RisingFactorial(4, n - 3)/factorial(n - 3)) == \
|
||||
n*(n - 1)*(n - 2)
|
||||
assert combsimp(6*RisingFactorial(4, -n - 1)/factorial(-n - 1)) == \
|
||||
-n*(n - 1)*(n - 2)
|
||||
|
||||
|
||||
def test_issue_6878():
|
||||
n = symbols('n', integer=True)
|
||||
assert combsimp(RisingFactorial(-10, n)) == 3628800*(-1)**n/factorial(10 - n)
|
||||
|
||||
|
||||
def test_issue_14528():
|
||||
p = symbols("p", integer=True, positive=True)
|
||||
assert combsimp(binomial(1,p)) == 1/(factorial(p)*factorial(1-p))
|
||||
assert combsimp(factorial(2-p)) == factorial(2-p)
|
||||
@@ -0,0 +1,761 @@
|
||||
from functools import reduce
|
||||
import itertools
|
||||
from operator import add
|
||||
|
||||
from sympy.codegen.matrix_nodes import MatrixSolve
|
||||
from sympy.core.add import Add
|
||||
from sympy.core.containers import Tuple
|
||||
from sympy.core.expr import UnevaluatedExpr
|
||||
from sympy.core.function import Function
|
||||
from sympy.core.mul import Mul
|
||||
from sympy.core.power import Pow
|
||||
from sympy.core.relational import Eq
|
||||
from sympy.core.singleton import S
|
||||
from sympy.core.symbol import (Symbol, symbols)
|
||||
from sympy.core.sympify import sympify
|
||||
from sympy.functions.elementary.exponential import exp
|
||||
from sympy.functions.elementary.miscellaneous import sqrt
|
||||
from sympy.functions.elementary.piecewise import Piecewise
|
||||
from sympy.functions.elementary.trigonometric import (cos, sin)
|
||||
from sympy.matrices.dense import Matrix
|
||||
from sympy.matrices.expressions import Inverse, MatAdd, MatMul, Transpose
|
||||
from sympy.polys.rootoftools import CRootOf
|
||||
from sympy.series.order import O
|
||||
from sympy.simplify.cse_main import cse
|
||||
from sympy.simplify.simplify import signsimp
|
||||
from sympy.tensor.indexed import (Idx, IndexedBase)
|
||||
|
||||
from sympy.core.function import count_ops
|
||||
from sympy.simplify.cse_opts import sub_pre, sub_post
|
||||
from sympy.functions.special.hyper import meijerg
|
||||
from sympy.simplify import cse_main, cse_opts
|
||||
from sympy.utilities.iterables import subsets
|
||||
from sympy.testing.pytest import XFAIL, raises
|
||||
from sympy.matrices import (MutableDenseMatrix, MutableSparseMatrix,
|
||||
ImmutableDenseMatrix, ImmutableSparseMatrix)
|
||||
from sympy.matrices.expressions import MatrixSymbol
|
||||
|
||||
|
||||
w, x, y, z = symbols('w,x,y,z')
|
||||
x0, x1, x2, x3, x4, x5, x6, x7, x8, x9, x10, x11, x12 = symbols('x:13')
|
||||
|
||||
|
||||
def test_numbered_symbols():
|
||||
ns = cse_main.numbered_symbols(prefix='y')
|
||||
assert list(itertools.islice(
|
||||
ns, 0, 10)) == [Symbol('y%s' % i) for i in range(0, 10)]
|
||||
ns = cse_main.numbered_symbols(prefix='y')
|
||||
assert list(itertools.islice(
|
||||
ns, 10, 20)) == [Symbol('y%s' % i) for i in range(10, 20)]
|
||||
ns = cse_main.numbered_symbols()
|
||||
assert list(itertools.islice(
|
||||
ns, 0, 10)) == [Symbol('x%s' % i) for i in range(0, 10)]
|
||||
|
||||
# Dummy "optimization" functions for testing.
|
||||
|
||||
|
||||
def opt1(expr):
|
||||
return expr + y
|
||||
|
||||
|
||||
def opt2(expr):
|
||||
return expr*z
|
||||
|
||||
|
||||
def test_preprocess_for_cse():
|
||||
assert cse_main.preprocess_for_cse(x, [(opt1, None)]) == x + y
|
||||
assert cse_main.preprocess_for_cse(x, [(None, opt1)]) == x
|
||||
assert cse_main.preprocess_for_cse(x, [(None, None)]) == x
|
||||
assert cse_main.preprocess_for_cse(x, [(opt1, opt2)]) == x + y
|
||||
assert cse_main.preprocess_for_cse(
|
||||
x, [(opt1, None), (opt2, None)]) == (x + y)*z
|
||||
|
||||
|
||||
def test_postprocess_for_cse():
|
||||
assert cse_main.postprocess_for_cse(x, [(opt1, None)]) == x
|
||||
assert cse_main.postprocess_for_cse(x, [(None, opt1)]) == x + y
|
||||
assert cse_main.postprocess_for_cse(x, [(None, None)]) == x
|
||||
assert cse_main.postprocess_for_cse(x, [(opt1, opt2)]) == x*z
|
||||
# Note the reverse order of application.
|
||||
assert cse_main.postprocess_for_cse(
|
||||
x, [(None, opt1), (None, opt2)]) == x*z + y
|
||||
|
||||
|
||||
def test_cse_single():
|
||||
# Simple substitution.
|
||||
e = Add(Pow(x + y, 2), sqrt(x + y))
|
||||
substs, reduced = cse([e])
|
||||
assert substs == [(x0, x + y)]
|
||||
assert reduced == [sqrt(x0) + x0**2]
|
||||
|
||||
subst42, (red42,) = cse([42]) # issue_15082
|
||||
assert len(subst42) == 0 and red42 == 42
|
||||
subst_half, (red_half,) = cse([0.5])
|
||||
assert len(subst_half) == 0 and red_half == 0.5
|
||||
|
||||
|
||||
def test_cse_single2():
|
||||
# Simple substitution, test for being able to pass the expression directly
|
||||
e = Add(Pow(x + y, 2), sqrt(x + y))
|
||||
substs, reduced = cse(e)
|
||||
assert substs == [(x0, x + y)]
|
||||
assert reduced == [sqrt(x0) + x0**2]
|
||||
substs, reduced = cse(Matrix([[1]]))
|
||||
assert isinstance(reduced[0], Matrix)
|
||||
|
||||
subst42, (red42,) = cse(42) # issue 15082
|
||||
assert len(subst42) == 0 and red42 == 42
|
||||
subst_half, (red_half,) = cse(0.5) # issue 15082
|
||||
assert len(subst_half) == 0 and red_half == 0.5
|
||||
|
||||
|
||||
def test_cse_not_possible():
|
||||
# No substitution possible.
|
||||
e = Add(x, y)
|
||||
substs, reduced = cse([e])
|
||||
assert substs == []
|
||||
assert reduced == [x + y]
|
||||
# issue 6329
|
||||
eq = (meijerg((1, 2), (y, 4), (5,), [], x) +
|
||||
meijerg((1, 3), (y, 4), (5,), [], x))
|
||||
assert cse(eq) == ([], [eq])
|
||||
|
||||
|
||||
def test_nested_substitution():
|
||||
# Substitution within a substitution.
|
||||
e = Add(Pow(w*x + y, 2), sqrt(w*x + y))
|
||||
substs, reduced = cse([e])
|
||||
assert substs == [(x0, w*x + y)]
|
||||
assert reduced == [sqrt(x0) + x0**2]
|
||||
|
||||
|
||||
def test_subtraction_opt():
|
||||
# Make sure subtraction is optimized.
|
||||
e = (x - y)*(z - y) + exp((x - y)*(z - y))
|
||||
substs, reduced = cse(
|
||||
[e], optimizations=[(cse_opts.sub_pre, cse_opts.sub_post)])
|
||||
assert substs == [(x0, (x - y)*(y - z))]
|
||||
assert reduced == [-x0 + exp(-x0)]
|
||||
e = -(x - y)*(z - y) + exp(-(x - y)*(z - y))
|
||||
substs, reduced = cse(
|
||||
[e], optimizations=[(cse_opts.sub_pre, cse_opts.sub_post)])
|
||||
assert substs == [(x0, (x - y)*(y - z))]
|
||||
assert reduced == [x0 + exp(x0)]
|
||||
# issue 4077
|
||||
n = -1 + 1/x
|
||||
e = n/x/(-n)**2 - 1/n/x
|
||||
assert cse(e, optimizations=[(cse_opts.sub_pre, cse_opts.sub_post)]) == \
|
||||
([], [0])
|
||||
assert cse(((w + x + y + z)*(w - y - z))/(w + x)**3) == \
|
||||
([(x0, w + x), (x1, y + z)], [(w - x1)*(x0 + x1)/x0**3])
|
||||
|
||||
|
||||
def test_multiple_expressions():
|
||||
e1 = (x + y)*z
|
||||
e2 = (x + y)*w
|
||||
substs, reduced = cse([e1, e2])
|
||||
assert substs == [(x0, x + y)]
|
||||
assert reduced == [x0*z, x0*w]
|
||||
l = [w*x*y + z, w*y]
|
||||
substs, reduced = cse(l)
|
||||
rsubsts, _ = cse(reversed(l))
|
||||
assert substs == rsubsts
|
||||
assert reduced == [z + x*x0, x0]
|
||||
l = [w*x*y, w*x*y + z, w*y]
|
||||
substs, reduced = cse(l)
|
||||
rsubsts, _ = cse(reversed(l))
|
||||
assert substs == rsubsts
|
||||
assert reduced == [x1, x1 + z, x0]
|
||||
l = [(x - z)*(y - z), x - z, y - z]
|
||||
substs, reduced = cse(l)
|
||||
rsubsts, _ = cse(reversed(l))
|
||||
assert substs == [(x0, -z), (x1, x + x0), (x2, x0 + y)]
|
||||
assert rsubsts == [(x0, -z), (x1, x0 + y), (x2, x + x0)]
|
||||
assert reduced == [x1*x2, x1, x2]
|
||||
l = [w*y + w + x + y + z, w*x*y]
|
||||
assert cse(l) == ([(x0, w*y)], [w + x + x0 + y + z, x*x0])
|
||||
assert cse([x + y, x + y + z]) == ([(x0, x + y)], [x0, z + x0])
|
||||
assert cse([x + y, x + z]) == ([], [x + y, x + z])
|
||||
assert cse([x*y, z + x*y, x*y*z + 3]) == \
|
||||
([(x0, x*y)], [x0, z + x0, 3 + x0*z])
|
||||
|
||||
|
||||
@XFAIL # CSE of non-commutative Mul terms is disabled
|
||||
def test_non_commutative_cse():
|
||||
A, B, C = symbols('A B C', commutative=False)
|
||||
l = [A*B*C, A*C]
|
||||
assert cse(l) == ([], l)
|
||||
l = [A*B*C, A*B]
|
||||
assert cse(l) == ([(x0, A*B)], [x0*C, x0])
|
||||
|
||||
|
||||
# Test if CSE of non-commutative Mul terms is disabled
|
||||
def test_bypass_non_commutatives():
|
||||
A, B, C = symbols('A B C', commutative=False)
|
||||
l = [A*B*C, A*C]
|
||||
assert cse(l) == ([], l)
|
||||
l = [A*B*C, A*B]
|
||||
assert cse(l) == ([], l)
|
||||
l = [B*C, A*B*C]
|
||||
assert cse(l) == ([], l)
|
||||
|
||||
|
||||
@XFAIL # CSE fails when replacing non-commutative sub-expressions
|
||||
def test_non_commutative_order():
|
||||
A, B, C = symbols('A B C', commutative=False)
|
||||
x0 = symbols('x0', commutative=False)
|
||||
l = [B+C, A*(B+C)]
|
||||
assert cse(l) == ([(x0, B+C)], [x0, A*x0])
|
||||
|
||||
|
||||
@XFAIL # Worked in gh-11232, but was reverted due to performance considerations
|
||||
def test_issue_10228():
|
||||
assert cse([x*y**2 + x*y]) == ([(x0, x*y)], [x0*y + x0])
|
||||
assert cse([x + y, 2*x + y]) == ([(x0, x + y)], [x0, x + x0])
|
||||
assert cse((w + 2*x + y + z, w + x + 1)) == (
|
||||
[(x0, w + x)], [x0 + x + y + z, x0 + 1])
|
||||
assert cse(((w + x + y + z)*(w - x))/(w + x)) == (
|
||||
[(x0, w + x)], [(x0 + y + z)*(w - x)/x0])
|
||||
a, b, c, d, f, g, j, m = symbols('a, b, c, d, f, g, j, m')
|
||||
exprs = (d*g**2*j*m, 4*a*f*g*m, a*b*c*f**2)
|
||||
assert cse(exprs) == (
|
||||
[(x0, g*m), (x1, a*f)], [d*g*j*x0, 4*x0*x1, b*c*f*x1]
|
||||
)
|
||||
|
||||
@XFAIL
|
||||
def test_powers():
|
||||
assert cse(x*y**2 + x*y) == ([(x0, x*y)], [x0*y + x0])
|
||||
|
||||
|
||||
def test_issue_4498():
|
||||
assert cse(w/(x - y) + z/(y - x), optimizations='basic') == \
|
||||
([], [(w - z)/(x - y)])
|
||||
|
||||
|
||||
def test_issue_4020():
|
||||
assert cse(x**5 + x**4 + x**3 + x**2, optimizations='basic') \
|
||||
== ([(x0, x**2)], [x0*(x**3 + x + x0 + 1)])
|
||||
|
||||
|
||||
def test_issue_4203():
|
||||
assert cse(sin(x**x)/x**x) == ([(x0, x**x)], [sin(x0)/x0])
|
||||
|
||||
|
||||
def test_issue_6263():
|
||||
e = Eq(x*(-x + 1) + x*(x - 1), 0)
|
||||
assert cse(e, optimizations='basic') == ([], [True])
|
||||
|
||||
|
||||
def test_issue_25043():
|
||||
c = symbols("c")
|
||||
x = symbols("x0", real=True)
|
||||
cse_expr = cse(c*x**2 + c*(x**4 - x**2))[-1][-1]
|
||||
free = cse_expr.free_symbols
|
||||
assert len(free) == len({i.name for i in free})
|
||||
|
||||
|
||||
def test_dont_cse_tuples():
|
||||
from sympy.core.function import Subs
|
||||
f = Function("f")
|
||||
g = Function("g")
|
||||
|
||||
name_val, (expr,) = cse(
|
||||
Subs(f(x, y), (x, y), (0, 1))
|
||||
+ Subs(g(x, y), (x, y), (0, 1)))
|
||||
|
||||
assert name_val == []
|
||||
assert expr == (Subs(f(x, y), (x, y), (0, 1))
|
||||
+ Subs(g(x, y), (x, y), (0, 1)))
|
||||
|
||||
name_val, (expr,) = cse(
|
||||
Subs(f(x, y), (x, y), (0, x + y))
|
||||
+ Subs(g(x, y), (x, y), (0, x + y)))
|
||||
|
||||
assert name_val == [(x0, x + y)]
|
||||
assert expr == Subs(f(x, y), (x, y), (0, x0)) + \
|
||||
Subs(g(x, y), (x, y), (0, x0))
|
||||
|
||||
|
||||
def test_pow_invpow():
|
||||
assert cse(1/x**2 + x**2) == \
|
||||
([(x0, x**2)], [x0 + 1/x0])
|
||||
assert cse(x**2 + (1 + 1/x**2)/x**2) == \
|
||||
([(x0, x**2), (x1, 1/x0)], [x0 + x1*(x1 + 1)])
|
||||
assert cse(1/x**2 + (1 + 1/x**2)*x**2) == \
|
||||
([(x0, x**2), (x1, 1/x0)], [x0*(x1 + 1) + x1])
|
||||
assert cse(cos(1/x**2) + sin(1/x**2)) == \
|
||||
([(x0, x**(-2))], [sin(x0) + cos(x0)])
|
||||
assert cse(cos(x**2) + sin(x**2)) == \
|
||||
([(x0, x**2)], [sin(x0) + cos(x0)])
|
||||
assert cse(y/(2 + x**2) + z/x**2/y) == \
|
||||
([(x0, x**2)], [y/(x0 + 2) + z/(x0*y)])
|
||||
assert cse(exp(x**2) + x**2*cos(1/x**2)) == \
|
||||
([(x0, x**2)], [x0*cos(1/x0) + exp(x0)])
|
||||
assert cse((1 + 1/x**2)/x**2) == \
|
||||
([(x0, x**(-2))], [x0*(x0 + 1)])
|
||||
assert cse(x**(2*y) + x**(-2*y)) == \
|
||||
([(x0, x**(2*y))], [x0 + 1/x0])
|
||||
|
||||
|
||||
def test_postprocess():
|
||||
eq = (x + 1 + exp((x + 1)/(y + 1)) + cos(y + 1))
|
||||
assert cse([eq, Eq(x, z + 1), z - 2, (z + 1)*(x + 1)],
|
||||
postprocess=cse_main.cse_separate) == \
|
||||
[[(x0, y + 1), (x2, z + 1), (x, x2), (x1, x + 1)],
|
||||
[x1 + exp(x1/x0) + cos(x0), z - 2, x1*x2]]
|
||||
|
||||
|
||||
def test_issue_4499():
|
||||
# previously, this gave 16 constants
|
||||
from sympy.abc import a, b
|
||||
B = Function('B')
|
||||
G = Function('G')
|
||||
t = Tuple(*
|
||||
(a, a + S.Half, 2*a, b, 2*a - b + 1, (sqrt(z)/2)**(-2*a + 1)*B(2*a -
|
||||
b, sqrt(z))*B(b - 1, sqrt(z))*G(b)*G(2*a - b + 1),
|
||||
sqrt(z)*(sqrt(z)/2)**(-2*a + 1)*B(b, sqrt(z))*B(2*a - b,
|
||||
sqrt(z))*G(b)*G(2*a - b + 1), sqrt(z)*(sqrt(z)/2)**(-2*a + 1)*B(b - 1,
|
||||
sqrt(z))*B(2*a - b + 1, sqrt(z))*G(b)*G(2*a - b + 1),
|
||||
(sqrt(z)/2)**(-2*a + 1)*B(b, sqrt(z))*B(2*a - b + 1,
|
||||
sqrt(z))*G(b)*G(2*a - b + 1), 1, 0, S.Half, z/2, -b + 1, -2*a + b,
|
||||
-2*a))
|
||||
c = cse(t)
|
||||
ans = (
|
||||
[(x0, 2*a), (x1, -b + x0), (x2, x1 + 1), (x3, b - 1), (x4, sqrt(z)),
|
||||
(x5, B(x3, x4)), (x6, (x4/2)**(1 - x0)*G(b)*G(x2)), (x7, x6*B(x1, x4)),
|
||||
(x8, B(b, x4)), (x9, x6*B(x2, x4))],
|
||||
[(a, a + S.Half, x0, b, x2, x5*x7, x4*x7*x8, x4*x5*x9, x8*x9,
|
||||
1, 0, S.Half, z/2, -x3, -x1, -x0)])
|
||||
assert ans == c
|
||||
|
||||
|
||||
def test_issue_6169():
|
||||
r = CRootOf(x**6 - 4*x**5 - 2, 1)
|
||||
assert cse(r) == ([], [r])
|
||||
# and a check that the right thing is done with the new
|
||||
# mechanism
|
||||
assert sub_post(sub_pre((-x - y)*z - x - y)) == -z*(x + y) - x - y
|
||||
|
||||
|
||||
def test_cse_Indexed():
|
||||
len_y = 5
|
||||
y = IndexedBase('y', shape=(len_y,))
|
||||
x = IndexedBase('x', shape=(len_y,))
|
||||
i = Idx('i', len_y-1)
|
||||
|
||||
expr1 = (y[i+1]-y[i])/(x[i+1]-x[i])
|
||||
expr2 = 1/(x[i+1]-x[i])
|
||||
replacements, reduced_exprs = cse([expr1, expr2])
|
||||
assert len(replacements) > 0
|
||||
|
||||
|
||||
def test_cse_MatrixSymbol():
|
||||
# MatrixSymbols have non-Basic args, so make sure that works
|
||||
A = MatrixSymbol("A", 3, 3)
|
||||
assert cse(A) == ([], [A])
|
||||
|
||||
n = symbols('n', integer=True)
|
||||
B = MatrixSymbol("B", n, n)
|
||||
assert cse(B) == ([], [B])
|
||||
|
||||
assert cse(A[0] * A[0]) == ([], [A[0]*A[0]])
|
||||
|
||||
assert cse(A[0,0]*A[0,1] + A[0,0]*A[0,1]*A[0,2]) == ([(x0, A[0, 0]*A[0, 1])], [x0*A[0, 2] + x0])
|
||||
|
||||
def test_cse_MatrixExpr():
|
||||
A = MatrixSymbol('A', 3, 3)
|
||||
y = MatrixSymbol('y', 3, 1)
|
||||
|
||||
expr1 = (A.T*A).I * A * y
|
||||
expr2 = (A.T*A) * A * y
|
||||
replacements, reduced_exprs = cse([expr1, expr2])
|
||||
assert len(replacements) > 0
|
||||
|
||||
replacements, reduced_exprs = cse([expr1 + expr2, expr1])
|
||||
assert replacements
|
||||
|
||||
replacements, reduced_exprs = cse([A**2, A + A**2])
|
||||
assert replacements
|
||||
|
||||
|
||||
def test_Piecewise():
|
||||
f = Piecewise((-z + x*y, Eq(y, 0)), (-z - x*y, True))
|
||||
ans = cse(f)
|
||||
actual_ans = ([(x0, x*y)],
|
||||
[Piecewise((x0 - z, Eq(y, 0)), (-z - x0, True))])
|
||||
assert ans == actual_ans
|
||||
|
||||
|
||||
def test_ignore_order_terms():
|
||||
eq = exp(x).series(x,0,3) + sin(y+x**3) - 1
|
||||
assert cse(eq) == ([], [sin(x**3 + y) + x + x**2/2 + O(x**3)])
|
||||
|
||||
|
||||
def test_name_conflict():
|
||||
z1 = x0 + y
|
||||
z2 = x2 + x3
|
||||
l = [cos(z1) + z1, cos(z2) + z2, x0 + x2]
|
||||
substs, reduced = cse(l)
|
||||
assert [e.subs(reversed(substs)) for e in reduced] == l
|
||||
|
||||
|
||||
def test_name_conflict_cust_symbols():
|
||||
z1 = x0 + y
|
||||
z2 = x2 + x3
|
||||
l = [cos(z1) + z1, cos(z2) + z2, x0 + x2]
|
||||
substs, reduced = cse(l, symbols("x:10"))
|
||||
assert [e.subs(reversed(substs)) for e in reduced] == l
|
||||
|
||||
|
||||
def test_symbols_exhausted_error():
|
||||
l = cos(x+y)+x+y+cos(w+y)+sin(w+y)
|
||||
sym = [x, y, z]
|
||||
with raises(ValueError):
|
||||
cse(l, symbols=sym)
|
||||
|
||||
|
||||
def test_issue_7840():
|
||||
# daveknippers' example
|
||||
C393 = sympify( \
|
||||
'Piecewise((C391 - 1.65, C390 < 0.5), (Piecewise((C391 - 1.65, \
|
||||
C391 > 2.35), (C392, True)), True))'
|
||||
)
|
||||
C391 = sympify( \
|
||||
'Piecewise((2.05*C390**(-1.03), C390 < 0.5), (2.5*C390**(-0.625), True))'
|
||||
)
|
||||
C393 = C393.subs('C391',C391)
|
||||
# simple substitution
|
||||
sub = {}
|
||||
sub['C390'] = 0.703451854
|
||||
sub['C392'] = 1.01417794
|
||||
ss_answer = C393.subs(sub)
|
||||
# cse
|
||||
substitutions,new_eqn = cse(C393)
|
||||
for pair in substitutions:
|
||||
sub[pair[0].name] = pair[1].subs(sub)
|
||||
cse_answer = new_eqn[0].subs(sub)
|
||||
# both methods should be the same
|
||||
assert ss_answer == cse_answer
|
||||
|
||||
# GitRay's example
|
||||
expr = sympify(
|
||||
"Piecewise((Symbol('ON'), Equality(Symbol('mode'), Symbol('ON'))), \
|
||||
(Piecewise((Piecewise((Symbol('OFF'), StrictLessThan(Symbol('x'), \
|
||||
Symbol('threshold'))), (Symbol('ON'), true)), Equality(Symbol('mode'), \
|
||||
Symbol('AUTO'))), (Symbol('OFF'), true)), true))"
|
||||
)
|
||||
substitutions, new_eqn = cse(expr)
|
||||
# this Piecewise should be exactly the same
|
||||
assert new_eqn[0] == expr
|
||||
# there should not be any replacements
|
||||
assert len(substitutions) < 1
|
||||
|
||||
|
||||
def test_issue_8891():
|
||||
for cls in (MutableDenseMatrix, MutableSparseMatrix,
|
||||
ImmutableDenseMatrix, ImmutableSparseMatrix):
|
||||
m = cls(2, 2, [x + y, 0, 0, 0])
|
||||
res = cse([x + y, m])
|
||||
ans = ([(x0, x + y)], [x0, cls([[x0, 0], [0, 0]])])
|
||||
assert res == ans
|
||||
assert isinstance(res[1][-1], cls)
|
||||
|
||||
|
||||
def test_issue_11230():
|
||||
# a specific test that always failed
|
||||
a, b, f, k, l, i = symbols('a b f k l i')
|
||||
p = [a*b*f*k*l, a*i*k**2*l, f*i*k**2*l]
|
||||
R, C = cse(p)
|
||||
assert not any(i.is_Mul for a in C for i in a.args)
|
||||
|
||||
# random tests for the issue
|
||||
from sympy.core.random import choice
|
||||
from sympy.core.function import expand_mul
|
||||
s = symbols('a:m')
|
||||
# 35 Mul tests, none of which should ever fail
|
||||
ex = [Mul(*[choice(s) for i in range(5)]) for i in range(7)]
|
||||
for p in subsets(ex, 3):
|
||||
p = list(p)
|
||||
R, C = cse(p)
|
||||
assert not any(i.is_Mul for a in C for i in a.args)
|
||||
for ri in reversed(R):
|
||||
for i in range(len(C)):
|
||||
C[i] = C[i].subs(*ri)
|
||||
assert p == C
|
||||
# 35 Add tests, none of which should ever fail
|
||||
ex = [Add(*[choice(s[:7]) for i in range(5)]) for i in range(7)]
|
||||
for p in subsets(ex, 3):
|
||||
p = list(p)
|
||||
R, C = cse(p)
|
||||
assert not any(i.is_Add for a in C for i in a.args)
|
||||
for ri in reversed(R):
|
||||
for i in range(len(C)):
|
||||
C[i] = C[i].subs(*ri)
|
||||
# use expand_mul to handle cases like this:
|
||||
# p = [a + 2*b + 2*e, 2*b + c + 2*e, b + 2*c + 2*g]
|
||||
# x0 = 2*(b + e) is identified giving a rebuilt p that
|
||||
# is now `[a + 2*(b + e), c + 2*(b + e), b + 2*c + 2*g]`
|
||||
assert p == [expand_mul(i) for i in C]
|
||||
|
||||
|
||||
@XFAIL
|
||||
def test_issue_11577():
|
||||
def check(eq):
|
||||
r, c = cse(eq)
|
||||
assert eq.count_ops() >= \
|
||||
len(r) + sum(i[1].count_ops() for i in r) + \
|
||||
count_ops(c)
|
||||
|
||||
eq = x**5*y**2 + x**5*y + x**5
|
||||
assert cse(eq) == (
|
||||
[(x0, x**4), (x1, x*y)], [x**5 + x0*x1*y + x0*x1])
|
||||
# ([(x0, x**5*y)], [x0*y + x0 + x**5]) or
|
||||
# ([(x0, x**5)], [x0*y**2 + x0*y + x0])
|
||||
check(eq)
|
||||
|
||||
eq = x**2/(y + 1)**2 + x/(y + 1)
|
||||
assert cse(eq) == (
|
||||
[(x0, y + 1)], [x**2/x0**2 + x/x0])
|
||||
# ([(x0, x/(y + 1))], [x0**2 + x0])
|
||||
check(eq)
|
||||
|
||||
|
||||
def test_hollow_rejection():
|
||||
eq = [x + 3, x + 4]
|
||||
assert cse(eq) == ([], eq)
|
||||
|
||||
|
||||
def test_cse_ignore():
|
||||
exprs = [exp(y)*(3*y + 3*sqrt(x+1)), exp(y)*(5*y + 5*sqrt(x+1))]
|
||||
subst1, red1 = cse(exprs)
|
||||
assert any(y in sub.free_symbols for _, sub in subst1), "cse failed to identify any term with y"
|
||||
|
||||
subst2, red2 = cse(exprs, ignore=(y,)) # y is not allowed in substitutions
|
||||
assert not any(y in sub.free_symbols for _, sub in subst2), "Sub-expressions containing y must be ignored"
|
||||
assert any(sub - sqrt(x + 1) == 0 for _, sub in subst2), "cse failed to identify sqrt(x + 1) as sub-expression"
|
||||
|
||||
|
||||
def test_cse_ignore_issue_15002():
|
||||
l = [
|
||||
w*exp(x)*exp(-z),
|
||||
exp(y)*exp(x)*exp(-z)
|
||||
]
|
||||
substs, reduced = cse(l, ignore=(x,))
|
||||
rl = [e.subs(reversed(substs)) for e in reduced]
|
||||
assert rl == l
|
||||
|
||||
|
||||
def test_cse_unevaluated():
|
||||
xp1 = UnevaluatedExpr(x + 1)
|
||||
# This used to cause RecursionError
|
||||
[(x0, ue)], [red] = cse([(-1 - xp1) / (1 - xp1)])
|
||||
if ue == xp1:
|
||||
assert red == (-1 - x0) / (1 - x0)
|
||||
elif ue == -xp1:
|
||||
assert red == (-1 + x0) / (1 + x0)
|
||||
else:
|
||||
msg = f'Expected common subexpression {xp1} or {-xp1}, instead got {ue}'
|
||||
assert False, msg
|
||||
|
||||
|
||||
def test_cse__performance():
|
||||
nexprs, nterms = 3, 20
|
||||
x = symbols('x:%d' % nterms)
|
||||
exprs = [
|
||||
reduce(add, [x[j]*(-1)**(i+j) for j in range(nterms)])
|
||||
for i in range(nexprs)
|
||||
]
|
||||
assert (exprs[0] + exprs[1]).simplify() == 0
|
||||
subst, red = cse(exprs)
|
||||
assert len(subst) > 0, "exprs[0] == -exprs[2], i.e. a CSE"
|
||||
for i, e in enumerate(red):
|
||||
assert (e.subs(reversed(subst)) - exprs[i]).simplify() == 0
|
||||
|
||||
|
||||
def test_issue_12070():
|
||||
exprs = [x + y, 2 + x + y, x + y + z, 3 + x + y + z]
|
||||
subst, red = cse(exprs)
|
||||
assert 6 >= (len(subst) + sum(v.count_ops() for k, v in subst) +
|
||||
count_ops(red))
|
||||
|
||||
|
||||
def test_issue_13000():
|
||||
eq = x/(-4*x**2 + y**2)
|
||||
cse_eq = cse(eq)[1][0]
|
||||
assert cse_eq == eq
|
||||
|
||||
|
||||
def test_issue_18203():
|
||||
eq = CRootOf(x**5 + 11*x - 2, 0) + CRootOf(x**5 + 11*x - 2, 1)
|
||||
assert cse(eq) == ([], [eq])
|
||||
|
||||
|
||||
def test_unevaluated_mul():
|
||||
eq = Mul(x + y, x + y, evaluate=False)
|
||||
assert cse(eq) == ([(x0, x + y)], [x0**2])
|
||||
|
||||
|
||||
def test_cse_release_variables():
|
||||
from sympy.simplify.cse_main import cse_release_variables
|
||||
_0, _1, _2, _3, _4 = symbols('_:5')
|
||||
eqs = [(x + y - 1)**2, x,
|
||||
x + y, (x + y)/(2*x + 1) + (x + y - 1)**2,
|
||||
(2*x + 1)**(x + y)]
|
||||
r, e = cse(eqs, postprocess=cse_release_variables)
|
||||
# this can change in keeping with the intention of the function
|
||||
assert r, e == ([
|
||||
(x0, x + y), (x1, (x0 - 1)**2), (x2, 2*x + 1),
|
||||
(_3, x0/x2 + x1), (_4, x2**x0), (x2, None), (_0, x1),
|
||||
(x1, None), (_2, x0), (x0, None), (_1, x)], (_0, _1, _2, _3, _4))
|
||||
r.reverse()
|
||||
r = [(s, v) for s, v in r if v is not None]
|
||||
assert eqs == [i.subs(r) for i in e]
|
||||
|
||||
|
||||
def test_cse_list():
|
||||
_cse = lambda x: cse(x, list=False)
|
||||
assert _cse(x) == ([], x)
|
||||
assert _cse('x') == ([], 'x')
|
||||
it = [x]
|
||||
for c in (list, tuple, set):
|
||||
assert _cse(c(it)) == ([], c(it))
|
||||
#Tuple works different from tuple:
|
||||
assert _cse(Tuple(*it)) == ([], Tuple(*it))
|
||||
d = {x: 1}
|
||||
assert _cse(d) == ([], d)
|
||||
|
||||
def test_issue_18991():
|
||||
A = MatrixSymbol('A', 2, 2)
|
||||
assert signsimp(-A * A - A) == -A * A - A
|
||||
|
||||
|
||||
def test_unevaluated_Mul():
|
||||
m = [Mul(1, 2, evaluate=False)]
|
||||
assert cse(m) == ([], m)
|
||||
|
||||
|
||||
def test_cse_matrix_expression_inverse():
|
||||
A = ImmutableDenseMatrix(symbols('A:4')).reshape(2, 2)
|
||||
x = Inverse(A)
|
||||
cse_expr = cse(x)
|
||||
assert cse_expr == ([], [Inverse(A)])
|
||||
|
||||
|
||||
def test_cse_matrix_expression_matmul_inverse():
|
||||
A = ImmutableDenseMatrix(symbols('A:4')).reshape(2, 2)
|
||||
b = ImmutableDenseMatrix(symbols('b:2'))
|
||||
x = MatMul(Inverse(A), b)
|
||||
cse_expr = cse(x)
|
||||
assert cse_expr == ([], [x])
|
||||
|
||||
|
||||
def test_cse_matrix_negate_matrix():
|
||||
A = ImmutableDenseMatrix(symbols('A:4')).reshape(2, 2)
|
||||
x = MatMul(S.NegativeOne, A)
|
||||
cse_expr = cse(x)
|
||||
assert cse_expr == ([], [x])
|
||||
|
||||
|
||||
def test_cse_matrix_negate_matmul_not_extracted():
|
||||
A = ImmutableDenseMatrix(symbols('A:4')).reshape(2, 2)
|
||||
B = ImmutableDenseMatrix(symbols('B:4')).reshape(2, 2)
|
||||
x = MatMul(S.NegativeOne, A, B)
|
||||
cse_expr = cse(x)
|
||||
assert cse_expr == ([], [x])
|
||||
|
||||
|
||||
@XFAIL # No simplification rule for nested associative operations
|
||||
def test_cse_matrix_nested_matmul_collapsed():
|
||||
A = ImmutableDenseMatrix(symbols('A:4')).reshape(2, 2)
|
||||
B = ImmutableDenseMatrix(symbols('B:4')).reshape(2, 2)
|
||||
x = MatMul(S.NegativeOne, MatMul(A, B))
|
||||
cse_expr = cse(x)
|
||||
assert cse_expr == ([], [MatMul(S.NegativeOne, A, B)])
|
||||
|
||||
|
||||
def test_cse_matrix_optimize_out_single_argument_mul():
|
||||
A = ImmutableDenseMatrix(symbols('A:4')).reshape(2, 2)
|
||||
x = MatMul(MatMul(MatMul(A)))
|
||||
cse_expr = cse(x)
|
||||
assert cse_expr == ([], [A])
|
||||
|
||||
|
||||
@XFAIL # Multiple simplification passed not supported in CSE
|
||||
def test_cse_matrix_optimize_out_single_argument_mul_combined():
|
||||
A = ImmutableDenseMatrix(symbols('A:4')).reshape(2, 2)
|
||||
x = MatAdd(MatMul(MatMul(MatMul(A))), MatMul(MatMul(A)), MatMul(A), A)
|
||||
cse_expr = cse(x)
|
||||
assert cse_expr == ([], [MatMul(4, A)])
|
||||
|
||||
|
||||
def test_cse_matrix_optimize_out_single_argument_add():
|
||||
A = ImmutableDenseMatrix(symbols('A:4')).reshape(2, 2)
|
||||
x = MatAdd(MatAdd(MatAdd(MatAdd(A))))
|
||||
cse_expr = cse(x)
|
||||
assert cse_expr == ([], [A])
|
||||
|
||||
|
||||
@XFAIL # Multiple simplification passed not supported in CSE
|
||||
def test_cse_matrix_optimize_out_single_argument_add_combined():
|
||||
A = ImmutableDenseMatrix(symbols('A:4')).reshape(2, 2)
|
||||
x = MatMul(MatAdd(MatAdd(MatAdd(A))), MatAdd(MatAdd(A)), MatAdd(A), A)
|
||||
cse_expr = cse(x)
|
||||
assert cse_expr == ([], [MatMul(4, A)])
|
||||
|
||||
|
||||
def test_cse_matrix_expression_matrix_solve():
|
||||
A = ImmutableDenseMatrix(symbols('A:4')).reshape(2, 2)
|
||||
b = ImmutableDenseMatrix(symbols('b:2'))
|
||||
x = MatrixSolve(A, b)
|
||||
cse_expr = cse(x)
|
||||
assert cse_expr == ([], [x])
|
||||
|
||||
|
||||
def test_cse_matrix_matrix_expression():
|
||||
X = ImmutableDenseMatrix(symbols('X:4')).reshape(2, 2)
|
||||
y = ImmutableDenseMatrix(symbols('y:2'))
|
||||
b = MatMul(Inverse(MatMul(Transpose(X), X)), Transpose(X), y)
|
||||
cse_expr = cse(b)
|
||||
x0 = MatrixSymbol('x0', 2, 2)
|
||||
reduced_expr_expected = MatMul(Inverse(MatMul(x0, X)), x0, y)
|
||||
assert cse_expr == ([(x0, Transpose(X))], [reduced_expr_expected])
|
||||
|
||||
|
||||
def test_cse_matrix_kalman_filter():
|
||||
"""Kalman Filter example from Matthew Rocklin's SciPy 2013 talk.
|
||||
|
||||
Talk titled: "Matrix Expressions and BLAS/LAPACK; SciPy 2013 Presentation"
|
||||
|
||||
Video: https://pyvideo.org/scipy-2013/matrix-expressions-and-blaslapack-scipy-2013-pr.html
|
||||
|
||||
Notes
|
||||
=====
|
||||
|
||||
Equations are:
|
||||
|
||||
new_mu = mu + Sigma*H.T * (R + H*Sigma*H.T).I * (H*mu - data)
|
||||
= MatAdd(mu, MatMul(Sigma, Transpose(H), Inverse(MatAdd(R, MatMul(H, Sigma, Transpose(H)))), MatAdd(MatMul(H, mu), MatMul(S.NegativeOne, data))))
|
||||
new_Sigma = Sigma - Sigma*H.T * (R + H*Sigma*H.T).I * H * Sigma
|
||||
= MatAdd(Sigma, MatMul(S.NegativeOne, Sigma, Transpose(H)), Inverse(MatAdd(R, MatMul(H*Sigma*Transpose(H)))), H, Sigma))
|
||||
|
||||
"""
|
||||
N = 2
|
||||
mu = ImmutableDenseMatrix(symbols(f'mu:{N}'))
|
||||
Sigma = ImmutableDenseMatrix(symbols(f'Sigma:{N * N}')).reshape(N, N)
|
||||
H = ImmutableDenseMatrix(symbols(f'H:{N * N}')).reshape(N, N)
|
||||
R = ImmutableDenseMatrix(symbols(f'R:{N * N}')).reshape(N, N)
|
||||
data = ImmutableDenseMatrix(symbols(f'data:{N}'))
|
||||
new_mu = MatAdd(mu, MatMul(Sigma, Transpose(H), Inverse(MatAdd(R, MatMul(H, Sigma, Transpose(H)))), MatAdd(MatMul(H, mu), MatMul(S.NegativeOne, data))))
|
||||
new_Sigma = MatAdd(Sigma, MatMul(S.NegativeOne, Sigma, Transpose(H), Inverse(MatAdd(R, MatMul(H, Sigma, Transpose(H)))), H, Sigma))
|
||||
cse_expr = cse([new_mu, new_Sigma])
|
||||
x0 = MatrixSymbol('x0', N, N)
|
||||
x1 = MatrixSymbol('x1', N, N)
|
||||
replacements_expected = [
|
||||
(x0, Transpose(H)),
|
||||
(x1, Inverse(MatAdd(R, MatMul(H, Sigma, x0)))),
|
||||
]
|
||||
reduced_exprs_expected = [
|
||||
MatAdd(mu, MatMul(Sigma, x0, x1, MatAdd(MatMul(H, mu), MatMul(S.NegativeOne, data)))),
|
||||
MatAdd(Sigma, MatMul(S.NegativeOne, Sigma, x0, x1, H, Sigma)),
|
||||
]
|
||||
assert cse_expr == (replacements_expected, reduced_exprs_expected)
|
||||
@@ -0,0 +1,206 @@
|
||||
"""Tests for the ``sympy.simplify._cse_diff.py`` module."""
|
||||
|
||||
import pytest
|
||||
|
||||
from sympy.core.symbol import (Symbol, symbols)
|
||||
from sympy.core.numbers import Integer
|
||||
from sympy.core.function import Function
|
||||
from sympy.core import Derivative
|
||||
from sympy.functions.elementary.exponential import exp
|
||||
from sympy.matrices.immutable import ImmutableDenseMatrix
|
||||
from sympy.physics.mechanics import dynamicsymbols
|
||||
from sympy.simplify._cse_diff import (_forward_jacobian,
|
||||
_remove_cse_from_derivative,
|
||||
_forward_jacobian_cse,
|
||||
_forward_jacobian_norm_in_cse_out)
|
||||
from sympy.simplify.simplify import simplify
|
||||
from sympy.matrices import Matrix, eye
|
||||
|
||||
from sympy.testing.pytest import raises
|
||||
from sympy.functions.elementary.trigonometric import (cos, sin, tan)
|
||||
from sympy.simplify.trigsimp import trigsimp
|
||||
|
||||
from sympy import cse
|
||||
|
||||
|
||||
w = Symbol('w')
|
||||
x = Symbol('x')
|
||||
y = Symbol('y')
|
||||
z = Symbol('z')
|
||||
|
||||
q1, q2, q3 = dynamicsymbols('q1 q2 q3')
|
||||
|
||||
# Define the custom functions
|
||||
k = Function('k')(x, y)
|
||||
f = Function('f')(k, z)
|
||||
|
||||
zero = Integer(0)
|
||||
one = Integer(1)
|
||||
two = Integer(2)
|
||||
neg_one = Integer(-1)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
'expr, wrt',
|
||||
[
|
||||
([zero], [x]),
|
||||
([one], [x]),
|
||||
([two], [x]),
|
||||
([neg_one], [x]),
|
||||
([x], [x]),
|
||||
([y], [x]),
|
||||
([x + y], [x]),
|
||||
([x*y], [x]),
|
||||
([x**2], [x]),
|
||||
([x**y], [x]),
|
||||
([exp(x)], [x]),
|
||||
([sin(x)], [x]),
|
||||
([tan(x)], [x]),
|
||||
([zero, one, x, y, x*y, x + y], [x, y]),
|
||||
([((x/y) + sin(x/y) - exp(y))*((x/y) - exp(y))], [x, y]),
|
||||
([w*tan(y*z)/(x - tan(y*z)), w*x*tan(y*z)/(x - tan(y*z))], [w, x, y, z]),
|
||||
([q1**2 + q2, q2**2 + q3, q3**2 + q1], [q1, q2, q3]),
|
||||
([f + Derivative(f, x) + k + 2*x], [x])
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def test_forward_jacobian(expr, wrt):
|
||||
expr = ImmutableDenseMatrix([expr]).T
|
||||
wrt = ImmutableDenseMatrix([wrt]).T
|
||||
jacobian = _forward_jacobian(expr, wrt)
|
||||
zeros = ImmutableDenseMatrix.zeros(*jacobian.shape)
|
||||
assert simplify(jacobian - expr.jacobian(wrt)) == zeros
|
||||
|
||||
|
||||
def test_process_cse():
|
||||
x, y, z = symbols('x y z')
|
||||
f = Function('f')
|
||||
k = Function('k')
|
||||
expr = Matrix([f(k(x,y), z) + Derivative(f(k(x,y), z), x) + k(x,y) + 2*x])
|
||||
repl, reduced = cse(expr)
|
||||
p_repl, p_reduced = _remove_cse_from_derivative(repl, reduced)
|
||||
|
||||
x0 = symbols('x0')
|
||||
x1 = symbols('x1')
|
||||
|
||||
expected_output = (
|
||||
[(x0, k(x, y)), (x1, f(x0, z))],
|
||||
[Matrix([2 * x + x0 + x1 + Derivative(f(k(x, y), z), x)])]
|
||||
)
|
||||
|
||||
assert p_repl == expected_output[0], f"Expected {expected_output[0]}, but got {p_repl}"
|
||||
assert p_reduced == expected_output[1], f"Expected {expected_output[1]}, but got {p_reduced}"
|
||||
|
||||
|
||||
def test_io_matrix_type():
|
||||
x, y, z = symbols('x y z')
|
||||
expr = ImmutableDenseMatrix([
|
||||
x * y + y * z + x * y * z,
|
||||
x ** 2 + y ** 2 + z ** 2,
|
||||
x * y + x * z + y * z
|
||||
])
|
||||
wrt = ImmutableDenseMatrix([x, y, z])
|
||||
|
||||
replacements, reduced_expr = cse(expr)
|
||||
|
||||
# Test _forward_jacobian_core
|
||||
replacements_core, jacobian_core, precomputed_fs_core = _forward_jacobian_cse(replacements, reduced_expr, wrt)
|
||||
assert isinstance(jacobian_core[0], type(reduced_expr[0])), "Jacobian should be a Matrix of the same type as the input"
|
||||
|
||||
# Test _forward_jacobian_norm_in_dag_out
|
||||
replacements_norm, jacobian_norm, precomputed_fs_norm = _forward_jacobian_norm_in_cse_out(
|
||||
expr, wrt)
|
||||
assert isinstance(jacobian_norm[0], type(reduced_expr[0])), "Jacobian should be a Matrix of the same type as the input"
|
||||
|
||||
# Test _forward_jacobian
|
||||
jacobian = _forward_jacobian(expr, wrt)
|
||||
assert isinstance(jacobian, type(expr)), "Jacobian should be a Matrix of the same type as the input"
|
||||
|
||||
|
||||
def test_forward_jacobian_input_output():
|
||||
x, y, z = symbols('x y z')
|
||||
expr = Matrix([
|
||||
x * y + y * z + x * y * z,
|
||||
x ** 2 + y ** 2 + z ** 2,
|
||||
x * y + x * z + y * z
|
||||
])
|
||||
wrt = Matrix([x, y, z])
|
||||
|
||||
replacements, reduced_expr = cse(expr)
|
||||
|
||||
# Test _forward_jacobian_core
|
||||
replacements_core, jacobian_core, precomputed_fs_core = _forward_jacobian_cse(replacements, reduced_expr, wrt)
|
||||
assert isinstance(replacements_core, type(replacements)), "Replacements should be a list"
|
||||
assert isinstance(jacobian_core, type(reduced_expr)), "Jacobian should be a list"
|
||||
assert isinstance(precomputed_fs_core, list), "Precomputed free symbols should be a list"
|
||||
assert len(replacements_core) == len(replacements), "Length of replacements does not match"
|
||||
assert len(jacobian_core) == 1, "Jacobian should have one element"
|
||||
assert len(precomputed_fs_core) == len(replacements), "Length of precomputed free symbols does not match"
|
||||
|
||||
# Test _forward_jacobian_norm_in_dag_out
|
||||
replacements_norm, jacobian_norm, precomputed_fs_norm = _forward_jacobian_norm_in_cse_out(expr, wrt)
|
||||
assert isinstance(replacements_norm, type(replacements)), "Replacements should be a list"
|
||||
assert isinstance(jacobian_norm, type(reduced_expr)), "Jacobian should be a list"
|
||||
assert isinstance(precomputed_fs_norm, list), "Precomputed free symbols should be a list"
|
||||
assert len(replacements_norm) == len(replacements), "Length of replacements does not match"
|
||||
assert len(jacobian_norm) == 1, "Jacobian should have one element"
|
||||
assert len(precomputed_fs_norm) == len(replacements), "Length of precomputed free symbols does not match"
|
||||
|
||||
|
||||
def test_jacobian_hessian():
|
||||
L = Matrix(1, 2, [x**2*y, 2*y**2 + x*y])
|
||||
syms = [x, y]
|
||||
assert _forward_jacobian(L, syms) == Matrix([[2*x*y, x**2], [y, 4*y + x]])
|
||||
|
||||
L = Matrix(1, 2, [x, x**2*y**3])
|
||||
assert _forward_jacobian(L, syms) == Matrix([[1, 0], [2*x*y**3, x**2*3*y**2]])
|
||||
|
||||
|
||||
def test_jacobian_metrics():
|
||||
rho, phi = symbols("rho,phi")
|
||||
X = Matrix([rho * cos(phi), rho * sin(phi)])
|
||||
Y = Matrix([rho, phi])
|
||||
J = _forward_jacobian(X, Y)
|
||||
assert J == X.jacobian(Y.T)
|
||||
assert J == (X.T).jacobian(Y)
|
||||
assert J == (X.T).jacobian(Y.T)
|
||||
g = J.T * eye(J.shape[0]) * J
|
||||
g = g.applyfunc(trigsimp)
|
||||
assert g == Matrix([[1, 0], [0, rho ** 2]])
|
||||
|
||||
|
||||
def test_jacobian2():
|
||||
rho, phi = symbols("rho,phi")
|
||||
X = Matrix([rho * cos(phi), rho * sin(phi), rho ** 2])
|
||||
Y = Matrix([rho, phi])
|
||||
J = Matrix([
|
||||
[cos(phi), -rho * sin(phi)],
|
||||
[sin(phi), rho * cos(phi)],
|
||||
[2 * rho, 0],
|
||||
])
|
||||
assert _forward_jacobian(X, Y) == J
|
||||
|
||||
|
||||
def test_issue_4564():
|
||||
X = Matrix([exp(x + y + z), exp(x + y + z), exp(x + y + z)])
|
||||
Y = Matrix([x, y, z])
|
||||
for i in range(1, 3):
|
||||
for j in range(1, 3):
|
||||
X_slice = X[:i, :]
|
||||
Y_slice = Y[:j, :]
|
||||
J = _forward_jacobian(X_slice, Y_slice)
|
||||
assert J.rows == i
|
||||
assert J.cols == j
|
||||
for k in range(j):
|
||||
assert J[:, k] == X_slice
|
||||
|
||||
|
||||
def test_nonvectorJacobian():
|
||||
X = Matrix([[exp(x + y + z), exp(x + y + z)],
|
||||
[exp(x + y + z), exp(x + y + z)]])
|
||||
raises(TypeError, lambda: _forward_jacobian(X, Matrix([x, y, z])))
|
||||
X = X[0, :]
|
||||
Y = Matrix([[x, y], [x, z]])
|
||||
raises(TypeError, lambda: _forward_jacobian(X, Y))
|
||||
raises(TypeError, lambda: _forward_jacobian(X, Matrix([[x, y], [x, z]])))
|
||||
@@ -0,0 +1,90 @@
|
||||
"""Tests for tools for manipulation of expressions using paths. """
|
||||
|
||||
from sympy.simplify.epathtools import epath, EPath
|
||||
from sympy.testing.pytest import raises
|
||||
|
||||
from sympy.core.numbers import E
|
||||
from sympy.functions.elementary.trigonometric import (cos, sin)
|
||||
from sympy.abc import x, y, z, t
|
||||
|
||||
|
||||
def test_epath_select():
|
||||
expr = [((x, 1, t), 2), ((3, y, 4), z)]
|
||||
|
||||
assert epath("/*", expr) == [((x, 1, t), 2), ((3, y, 4), z)]
|
||||
assert epath("/*/*", expr) == [(x, 1, t), 2, (3, y, 4), z]
|
||||
assert epath("/*/*/*", expr) == [x, 1, t, 3, y, 4]
|
||||
assert epath("/*/*/*/*", expr) == []
|
||||
|
||||
assert epath("/[:]", expr) == [((x, 1, t), 2), ((3, y, 4), z)]
|
||||
assert epath("/[:]/[:]", expr) == [(x, 1, t), 2, (3, y, 4), z]
|
||||
assert epath("/[:]/[:]/[:]", expr) == [x, 1, t, 3, y, 4]
|
||||
assert epath("/[:]/[:]/[:]/[:]", expr) == []
|
||||
|
||||
assert epath("/*/[:]", expr) == [(x, 1, t), 2, (3, y, 4), z]
|
||||
|
||||
assert epath("/*/[0]", expr) == [(x, 1, t), (3, y, 4)]
|
||||
assert epath("/*/[1]", expr) == [2, z]
|
||||
assert epath("/*/[2]", expr) == []
|
||||
|
||||
assert epath("/*/int", expr) == [2]
|
||||
assert epath("/*/Symbol", expr) == [z]
|
||||
assert epath("/*/tuple", expr) == [(x, 1, t), (3, y, 4)]
|
||||
assert epath("/*/__iter__?", expr) == [(x, 1, t), (3, y, 4)]
|
||||
|
||||
assert epath("/*/int|tuple", expr) == [(x, 1, t), 2, (3, y, 4)]
|
||||
assert epath("/*/Symbol|tuple", expr) == [(x, 1, t), (3, y, 4), z]
|
||||
assert epath("/*/int|Symbol|tuple", expr) == [(x, 1, t), 2, (3, y, 4), z]
|
||||
|
||||
assert epath("/*/int|__iter__?", expr) == [(x, 1, t), 2, (3, y, 4)]
|
||||
assert epath("/*/Symbol|__iter__?", expr) == [(x, 1, t), (3, y, 4), z]
|
||||
assert epath(
|
||||
"/*/int|Symbol|__iter__?", expr) == [(x, 1, t), 2, (3, y, 4), z]
|
||||
|
||||
assert epath("/*/[0]/int", expr) == [1, 3, 4]
|
||||
assert epath("/*/[0]/Symbol", expr) == [x, t, y]
|
||||
|
||||
assert epath("/*/[0]/int[1:]", expr) == [1, 4]
|
||||
assert epath("/*/[0]/Symbol[1:]", expr) == [t, y]
|
||||
|
||||
assert epath("/Symbol", x + y + z + 1) == [x, y, z]
|
||||
assert epath("/*/*/Symbol", t + sin(x + 1) + cos(x + y + E)) == [x, x, y]
|
||||
|
||||
|
||||
def test_epath_apply():
|
||||
expr = [((x, 1, t), 2), ((3, y, 4), z)]
|
||||
func = lambda expr: expr**2
|
||||
|
||||
assert epath("/*", expr, list) == [[(x, 1, t), 2], [(3, y, 4), z]]
|
||||
|
||||
assert epath("/*/[0]", expr, list) == [([x, 1, t], 2), ([3, y, 4], z)]
|
||||
assert epath("/*/[1]", expr, func) == [((x, 1, t), 4), ((3, y, 4), z**2)]
|
||||
assert epath("/*/[2]", expr, list) == expr
|
||||
|
||||
assert epath("/*/[0]/int", expr, func) == [((x, 1, t), 2), ((9, y, 16), z)]
|
||||
assert epath("/*/[0]/Symbol", expr, func) == [((x**2, 1, t**2), 2),
|
||||
((3, y**2, 4), z)]
|
||||
assert epath(
|
||||
"/*/[0]/int[1:]", expr, func) == [((x, 1, t), 2), ((3, y, 16), z)]
|
||||
assert epath("/*/[0]/Symbol[1:]", expr, func) == [((x, 1, t**2),
|
||||
2), ((3, y**2, 4), z)]
|
||||
|
||||
assert epath("/Symbol", x + y + z + 1, func) == x**2 + y**2 + z**2 + 1
|
||||
assert epath("/*/*/Symbol", t + sin(x + 1) + cos(x + y + E), func) == \
|
||||
t + sin(x**2 + 1) + cos(x**2 + y**2 + E)
|
||||
|
||||
|
||||
def test_EPath():
|
||||
assert EPath("/*/[0]")._path == "/*/[0]"
|
||||
assert EPath(EPath("/*/[0]"))._path == "/*/[0]"
|
||||
assert isinstance(epath("/*/[0]"), EPath) is True
|
||||
|
||||
assert repr(EPath("/*/[0]")) == "EPath('/*/[0]')"
|
||||
|
||||
raises(ValueError, lambda: EPath(""))
|
||||
raises(ValueError, lambda: EPath("/"))
|
||||
raises(ValueError, lambda: EPath("/|x"))
|
||||
raises(ValueError, lambda: EPath("/["))
|
||||
raises(ValueError, lambda: EPath("/[0]%"))
|
||||
|
||||
raises(NotImplementedError, lambda: EPath("Symbol"))
|
||||
@@ -0,0 +1,492 @@
|
||||
from sympy.core.add import Add
|
||||
from sympy.core.mul import Mul
|
||||
from sympy.core.numbers import (I, Rational, pi)
|
||||
from sympy.core.parameters import evaluate
|
||||
from sympy.core.singleton import S
|
||||
from sympy.core.symbol import (Dummy, Symbol, symbols)
|
||||
from sympy.functions.elementary.hyperbolic import (cosh, coth, csch, sech, sinh, tanh)
|
||||
from sympy.functions.elementary.miscellaneous import (root, sqrt)
|
||||
from sympy.functions.elementary.trigonometric import (cos, cot, csc, sec, sin, tan)
|
||||
from sympy.simplify.powsimp import powsimp
|
||||
from sympy.simplify.fu import (
|
||||
L, TR1, TR10, TR10i, TR11, _TR11, TR12, TR12i, TR13, TR14, TR15, TR16,
|
||||
TR111, TR2, TR2i, TR3, TR4, TR5, TR6, TR7, TR8, TR9, TRmorrie, _TR56 as T,
|
||||
TRpower, hyper_as_trig, fu, process_common_addends, trig_split,
|
||||
as_f_sign_1)
|
||||
from sympy.core.random import verify_numerically
|
||||
from sympy.abc import a, b, c, x, y, z
|
||||
|
||||
|
||||
def test_TR1():
|
||||
assert TR1(2*csc(x) + sec(x)) == 1/cos(x) + 2/sin(x)
|
||||
|
||||
|
||||
def test_TR2():
|
||||
assert TR2(tan(x)) == sin(x)/cos(x)
|
||||
assert TR2(cot(x)) == cos(x)/sin(x)
|
||||
assert TR2(tan(tan(x) - sin(x)/cos(x))) == 0
|
||||
|
||||
|
||||
def test_TR2i():
|
||||
# just a reminder that ratios of powers only simplify if both
|
||||
# numerator and denominator satisfy the condition that each
|
||||
# has a positive base or an integer exponent; e.g. the following,
|
||||
# at y=-1, x=1/2 gives sqrt(2)*I != -sqrt(2)*I
|
||||
assert powsimp(2**x/y**x) != (2/y)**x
|
||||
|
||||
assert TR2i(sin(x)/cos(x)) == tan(x)
|
||||
assert TR2i(sin(x)*sin(y)/cos(x)) == tan(x)*sin(y)
|
||||
assert TR2i(1/(sin(x)/cos(x))) == 1/tan(x)
|
||||
assert TR2i(1/(sin(x)*sin(y)/cos(x))) == 1/tan(x)/sin(y)
|
||||
assert TR2i(sin(x)/2/(cos(x) + 1)) == sin(x)/(cos(x) + 1)/2
|
||||
|
||||
assert TR2i(sin(x)/2/(cos(x) + 1), half=True) == tan(x/2)/2
|
||||
assert TR2i(sin(1)/(cos(1) + 1), half=True) == tan(S.Half)
|
||||
assert TR2i(sin(2)/(cos(2) + 1), half=True) == tan(1)
|
||||
assert TR2i(sin(4)/(cos(4) + 1), half=True) == tan(2)
|
||||
assert TR2i(sin(5)/(cos(5) + 1), half=True) == tan(5*S.Half)
|
||||
assert TR2i((cos(1) + 1)/sin(1), half=True) == 1/tan(S.Half)
|
||||
assert TR2i((cos(2) + 1)/sin(2), half=True) == 1/tan(1)
|
||||
assert TR2i((cos(4) + 1)/sin(4), half=True) == 1/tan(2)
|
||||
assert TR2i((cos(5) + 1)/sin(5), half=True) == 1/tan(5*S.Half)
|
||||
assert TR2i((cos(1) + 1)**(-a)*sin(1)**a, half=True) == tan(S.Half)**a
|
||||
assert TR2i((cos(2) + 1)**(-a)*sin(2)**a, half=True) == tan(1)**a
|
||||
assert TR2i((cos(4) + 1)**(-a)*sin(4)**a, half=True) == (cos(4) + 1)**(-a)*sin(4)**a
|
||||
assert TR2i((cos(5) + 1)**(-a)*sin(5)**a, half=True) == (cos(5) + 1)**(-a)*sin(5)**a
|
||||
assert TR2i((cos(1) + 1)**a*sin(1)**(-a), half=True) == tan(S.Half)**(-a)
|
||||
assert TR2i((cos(2) + 1)**a*sin(2)**(-a), half=True) == tan(1)**(-a)
|
||||
assert TR2i((cos(4) + 1)**a*sin(4)**(-a), half=True) == (cos(4) + 1)**a*sin(4)**(-a)
|
||||
assert TR2i((cos(5) + 1)**a*sin(5)**(-a), half=True) == (cos(5) + 1)**a*sin(5)**(-a)
|
||||
|
||||
i = symbols('i', integer=True)
|
||||
assert TR2i(((cos(5) + 1)**i*sin(5)**(-i)), half=True) == tan(5*S.Half)**(-i)
|
||||
assert TR2i(1/((cos(5) + 1)**i*sin(5)**(-i)), half=True) == tan(5*S.Half)**i
|
||||
|
||||
|
||||
def test_TR3():
|
||||
assert TR3(cos(y - x*(y - x))) == cos(x*(x - y) + y)
|
||||
assert cos(pi/2 + x) == -sin(x)
|
||||
assert cos(30*pi/2 + x) == -cos(x)
|
||||
|
||||
for f in (cos, sin, tan, cot, csc, sec):
|
||||
i = f(pi*Rational(3, 7))
|
||||
j = TR3(i)
|
||||
assert verify_numerically(i, j) and i.func != j.func
|
||||
|
||||
with evaluate(False):
|
||||
eq = cos(9*pi/22)
|
||||
assert eq.has(9*pi) and TR3(eq) == sin(pi/11)
|
||||
|
||||
|
||||
def test_TR4():
|
||||
for i in [0, pi/6, pi/4, pi/3, pi/2]:
|
||||
with evaluate(False):
|
||||
eq = cos(i)
|
||||
assert isinstance(eq, cos) and TR4(eq) == cos(i)
|
||||
|
||||
|
||||
def test__TR56():
|
||||
h = lambda x: 1 - x
|
||||
assert T(sin(x)**3, sin, cos, h, 4, False) == sin(x)*(-cos(x)**2 + 1)
|
||||
assert T(sin(x)**10, sin, cos, h, 4, False) == sin(x)**10
|
||||
assert T(sin(x)**6, sin, cos, h, 6, False) == (-cos(x)**2 + 1)**3
|
||||
assert T(sin(x)**6, sin, cos, h, 6, True) == sin(x)**6
|
||||
assert T(sin(x)**8, sin, cos, h, 10, True) == (-cos(x)**2 + 1)**4
|
||||
|
||||
# issue 17137
|
||||
assert T(sin(x)**I, sin, cos, h, 4, True) == sin(x)**I
|
||||
assert T(sin(x)**(2*I + 1), sin, cos, h, 4, True) == sin(x)**(2*I + 1)
|
||||
|
||||
|
||||
def test_TR5():
|
||||
assert TR5(sin(x)**2) == -cos(x)**2 + 1
|
||||
assert TR5(sin(x)**-2) == sin(x)**(-2)
|
||||
assert TR5(sin(x)**4) == (-cos(x)**2 + 1)**2
|
||||
|
||||
|
||||
def test_TR6():
|
||||
assert TR6(cos(x)**2) == -sin(x)**2 + 1
|
||||
assert TR6(cos(x)**-2) == cos(x)**(-2)
|
||||
assert TR6(cos(x)**4) == (-sin(x)**2 + 1)**2
|
||||
|
||||
|
||||
def test_TR7():
|
||||
assert TR7(cos(x)**2) == cos(2*x)/2 + S.Half
|
||||
assert TR7(cos(x)**2 + 1) == cos(2*x)/2 + Rational(3, 2)
|
||||
|
||||
|
||||
def test_TR8():
|
||||
assert TR8(cos(2)*cos(3)) == cos(5)/2 + cos(1)/2
|
||||
assert TR8(cos(2)*sin(3)) == sin(5)/2 + sin(1)/2
|
||||
assert TR8(sin(2)*sin(3)) == -cos(5)/2 + cos(1)/2
|
||||
assert TR8(sin(1)*sin(2)*sin(3)) == sin(4)/4 - sin(6)/4 + sin(2)/4
|
||||
assert TR8(cos(2)*cos(3)*cos(4)*cos(5)) == \
|
||||
cos(4)/4 + cos(10)/8 + cos(2)/8 + cos(8)/8 + cos(14)/8 + \
|
||||
cos(6)/8 + Rational(1, 8)
|
||||
assert TR8(cos(2)*cos(3)*cos(4)*cos(5)*cos(6)) == \
|
||||
cos(10)/8 + cos(4)/8 + 3*cos(2)/16 + cos(16)/16 + cos(8)/8 + \
|
||||
cos(14)/16 + cos(20)/16 + cos(12)/16 + Rational(1, 16) + cos(6)/8
|
||||
assert TR8(sin(pi*Rational(3, 7))**2*cos(pi*Rational(3, 7))**2/(16*sin(pi/7)**2)) == Rational(1, 64)
|
||||
|
||||
def test_TR9():
|
||||
a = S.Half
|
||||
b = 3*a
|
||||
assert TR9(a) == a
|
||||
assert TR9(cos(1) + cos(2)) == 2*cos(a)*cos(b)
|
||||
assert TR9(cos(1) - cos(2)) == 2*sin(a)*sin(b)
|
||||
assert TR9(sin(1) - sin(2)) == -2*sin(a)*cos(b)
|
||||
assert TR9(sin(1) + sin(2)) == 2*sin(b)*cos(a)
|
||||
assert TR9(cos(1) + 2*sin(1) + 2*sin(2)) == cos(1) + 4*sin(b)*cos(a)
|
||||
assert TR9(cos(4) + cos(2) + 2*cos(1)*cos(3)) == 4*cos(1)*cos(3)
|
||||
assert TR9((cos(4) + cos(2))/cos(3)/2 + cos(3)) == 2*cos(1)*cos(2)
|
||||
assert TR9(cos(3) + cos(4) + cos(5) + cos(6)) == \
|
||||
4*cos(S.Half)*cos(1)*cos(Rational(9, 2))
|
||||
assert TR9(cos(3) + cos(3)*cos(2)) == cos(3) + cos(2)*cos(3)
|
||||
assert TR9(-cos(y) + cos(x*y)) == -2*sin(x*y/2 - y/2)*sin(x*y/2 + y/2)
|
||||
assert TR9(-sin(y) + sin(x*y)) == 2*sin(x*y/2 - y/2)*cos(x*y/2 + y/2)
|
||||
c = cos(x)
|
||||
s = sin(x)
|
||||
for si in ((1, 1), (1, -1), (-1, 1), (-1, -1)):
|
||||
for a in ((c, s), (s, c), (cos(x), cos(x*y)), (sin(x), sin(x*y))):
|
||||
args = zip(si, a)
|
||||
ex = Add(*[Mul(*ai) for ai in args])
|
||||
t = TR9(ex)
|
||||
assert not (a[0].func == a[1].func and (
|
||||
not verify_numerically(ex, t.expand(trig=True)) or t.is_Add)
|
||||
or a[1].func != a[0].func and ex != t)
|
||||
|
||||
|
||||
def test_TR10():
|
||||
assert TR10(cos(a + b)) == -sin(a)*sin(b) + cos(a)*cos(b)
|
||||
assert TR10(sin(a + b)) == sin(a)*cos(b) + sin(b)*cos(a)
|
||||
assert TR10(sin(a + b + c)) == \
|
||||
(-sin(a)*sin(b) + cos(a)*cos(b))*sin(c) + \
|
||||
(sin(a)*cos(b) + sin(b)*cos(a))*cos(c)
|
||||
assert TR10(cos(a + b + c)) == \
|
||||
(-sin(a)*sin(b) + cos(a)*cos(b))*cos(c) - \
|
||||
(sin(a)*cos(b) + sin(b)*cos(a))*sin(c)
|
||||
|
||||
|
||||
def test_TR10i():
|
||||
assert TR10i(cos(1)*cos(3) + sin(1)*sin(3)) == cos(2)
|
||||
assert TR10i(cos(1)*cos(3) - sin(1)*sin(3)) == cos(4)
|
||||
assert TR10i(cos(1)*sin(3) - sin(1)*cos(3)) == sin(2)
|
||||
assert TR10i(cos(1)*sin(3) + sin(1)*cos(3)) == sin(4)
|
||||
assert TR10i(cos(1)*sin(3) + sin(1)*cos(3) + 7) == sin(4) + 7
|
||||
assert TR10i(cos(1)*sin(3) + sin(1)*cos(3) + cos(3)) == cos(3) + sin(4)
|
||||
assert TR10i(2*cos(1)*sin(3) + 2*sin(1)*cos(3) + cos(3)) == \
|
||||
2*sin(4) + cos(3)
|
||||
assert TR10i(cos(2)*cos(3) + sin(2)*(cos(1)*sin(2) + cos(2)*sin(1))) == \
|
||||
cos(1)
|
||||
eq = (cos(2)*cos(3) + sin(2)*(
|
||||
cos(1)*sin(2) + cos(2)*sin(1)))*cos(5) + sin(1)*sin(5)
|
||||
assert TR10i(eq) == TR10i(eq.expand()) == cos(4)
|
||||
assert TR10i(sqrt(2)*cos(x)*x + sqrt(6)*sin(x)*x) == \
|
||||
2*sqrt(2)*x*sin(x + pi/6)
|
||||
assert TR10i(cos(x)/sqrt(6) + sin(x)/sqrt(2) +
|
||||
cos(x)/sqrt(6)/3 + sin(x)/sqrt(2)/3) == 4*sqrt(6)*sin(x + pi/6)/9
|
||||
assert TR10i(cos(x)/sqrt(6) + sin(x)/sqrt(2) +
|
||||
cos(y)/sqrt(6)/3 + sin(y)/sqrt(2)/3) == \
|
||||
sqrt(6)*sin(x + pi/6)/3 + sqrt(6)*sin(y + pi/6)/9
|
||||
assert TR10i(cos(x) + sqrt(3)*sin(x) + 2*sqrt(3)*cos(x + pi/6)) == 4*cos(x)
|
||||
assert TR10i(cos(x) + sqrt(3)*sin(x) +
|
||||
2*sqrt(3)*cos(x + pi/6) + 4*sin(x)) == 4*sqrt(2)*sin(x + pi/4)
|
||||
assert TR10i(cos(2)*sin(3) + sin(2)*cos(4)) == \
|
||||
sin(2)*cos(4) + sin(3)*cos(2)
|
||||
|
||||
A = Symbol('A', commutative=False)
|
||||
assert TR10i(sqrt(2)*cos(x)*A + sqrt(6)*sin(x)*A) == \
|
||||
2*sqrt(2)*sin(x + pi/6)*A
|
||||
|
||||
|
||||
c = cos(x)
|
||||
s = sin(x)
|
||||
h = sin(y)
|
||||
r = cos(y)
|
||||
for si in ((1, 1), (1, -1), (-1, 1), (-1, -1)):
|
||||
for argsi in ((c*r, s*h), (c*h, s*r)): # explicit 2-args
|
||||
args = zip(si, argsi)
|
||||
ex = Add(*[Mul(*ai) for ai in args])
|
||||
t = TR10i(ex)
|
||||
assert not (ex - t.expand(trig=True) or t.is_Add)
|
||||
|
||||
c = cos(x)
|
||||
s = sin(x)
|
||||
h = sin(pi/6)
|
||||
r = cos(pi/6)
|
||||
for si in ((1, 1), (1, -1), (-1, 1), (-1, -1)):
|
||||
for argsi in ((c*r, s*h), (c*h, s*r)): # induced
|
||||
args = zip(si, argsi)
|
||||
ex = Add(*[Mul(*ai) for ai in args])
|
||||
t = TR10i(ex)
|
||||
assert not (ex - t.expand(trig=True) or t.is_Add)
|
||||
|
||||
|
||||
def test_TR11():
|
||||
|
||||
assert TR11(sin(2*x)) == 2*sin(x)*cos(x)
|
||||
assert TR11(sin(4*x)) == 4*((-sin(x)**2 + cos(x)**2)*sin(x)*cos(x))
|
||||
assert TR11(sin(x*Rational(4, 3))) == \
|
||||
4*((-sin(x/3)**2 + cos(x/3)**2)*sin(x/3)*cos(x/3))
|
||||
|
||||
assert TR11(cos(2*x)) == -sin(x)**2 + cos(x)**2
|
||||
assert TR11(cos(4*x)) == \
|
||||
(-sin(x)**2 + cos(x)**2)**2 - 4*sin(x)**2*cos(x)**2
|
||||
|
||||
assert TR11(cos(2)) == cos(2)
|
||||
|
||||
assert TR11(cos(pi*Rational(3, 7)), pi*Rational(2, 7)) == -cos(pi*Rational(2, 7))**2 + sin(pi*Rational(2, 7))**2
|
||||
assert TR11(cos(4), 2) == -sin(2)**2 + cos(2)**2
|
||||
assert TR11(cos(6), 2) == cos(6)
|
||||
assert TR11(sin(x)/cos(x/2), x/2) == 2*sin(x/2)
|
||||
|
||||
def test__TR11():
|
||||
|
||||
assert _TR11(sin(x/3)*sin(2*x)*sin(x/4)/(cos(x/6)*cos(x/8))) == \
|
||||
4*sin(x/8)*sin(x/6)*sin(2*x),_TR11(sin(x/3)*sin(2*x)*sin(x/4)/(cos(x/6)*cos(x/8)))
|
||||
assert _TR11(sin(x/3)/cos(x/6)) == 2*sin(x/6)
|
||||
|
||||
assert _TR11(cos(x/6)/sin(x/3)) == 1/(2*sin(x/6))
|
||||
assert _TR11(sin(2*x)*cos(x/8)/sin(x/4)) == sin(2*x)/(2*sin(x/8)), _TR11(sin(2*x)*cos(x/8)/sin(x/4))
|
||||
assert _TR11(sin(x)/sin(x/2)) == 2*cos(x/2)
|
||||
|
||||
|
||||
def test_TR12():
|
||||
assert TR12(tan(x + y)) == (tan(x) + tan(y))/(-tan(x)*tan(y) + 1)
|
||||
assert TR12(tan(x + y + z)) ==\
|
||||
(tan(z) + (tan(x) + tan(y))/(-tan(x)*tan(y) + 1))/(
|
||||
1 - (tan(x) + tan(y))*tan(z)/(-tan(x)*tan(y) + 1))
|
||||
assert TR12(tan(x*y)) == tan(x*y)
|
||||
|
||||
|
||||
def test_TR13():
|
||||
assert TR13(tan(3)*tan(2)) == -tan(2)/tan(5) - tan(3)/tan(5) + 1
|
||||
assert TR13(cot(3)*cot(2)) == 1 + cot(3)*cot(5) + cot(2)*cot(5)
|
||||
assert TR13(tan(1)*tan(2)*tan(3)) == \
|
||||
(-tan(2)/tan(5) - tan(3)/tan(5) + 1)*tan(1)
|
||||
assert TR13(tan(1)*tan(2)*cot(3)) == \
|
||||
(-tan(2)/tan(3) + 1 - tan(1)/tan(3))*cot(3)
|
||||
|
||||
|
||||
def test_L():
|
||||
assert L(cos(x) + sin(x)) == 2
|
||||
|
||||
|
||||
def test_fu():
|
||||
|
||||
assert fu(sin(50)**2 + cos(50)**2 + sin(pi/6)) == Rational(3, 2)
|
||||
assert fu(sqrt(6)*cos(x) + sqrt(2)*sin(x)) == 2*sqrt(2)*sin(x + pi/3)
|
||||
|
||||
|
||||
eq = sin(x)**4 - cos(y)**2 + sin(y)**2 + 2*cos(x)**2
|
||||
assert fu(eq) == cos(x)**4 - 2*cos(y)**2 + 2
|
||||
|
||||
assert fu(S.Half - cos(2*x)/2) == sin(x)**2
|
||||
|
||||
assert fu(sin(a)*(cos(b) - sin(b)) + cos(a)*(sin(b) + cos(b))) == \
|
||||
sqrt(2)*sin(a + b + pi/4)
|
||||
|
||||
assert fu(sqrt(3)*cos(x)/2 + sin(x)/2) == sin(x + pi/3)
|
||||
|
||||
assert fu(1 - sin(2*x)**2/4 - sin(y)**2 - cos(x)**4) == \
|
||||
-cos(x)**2 + cos(y)**2
|
||||
|
||||
assert fu(cos(pi*Rational(4, 9))) == sin(pi/18)
|
||||
assert fu(cos(pi/9)*cos(pi*Rational(2, 9))*cos(pi*Rational(3, 9))*cos(pi*Rational(4, 9))) == Rational(1, 16)
|
||||
|
||||
assert fu(
|
||||
tan(pi*Rational(7, 18)) + tan(pi*Rational(5, 18)) - sqrt(3)*tan(pi*Rational(5, 18))*tan(pi*Rational(7, 18))) == \
|
||||
-sqrt(3)
|
||||
|
||||
assert fu(tan(1)*tan(2)) == tan(1)*tan(2)
|
||||
|
||||
expr = Mul(*[cos(2**i) for i in range(10)])
|
||||
assert fu(expr) == sin(1024)/(1024*sin(1))
|
||||
|
||||
# issue #18059:
|
||||
assert fu(cos(x) + sqrt(sin(x)**2)) == cos(x) + sqrt(sin(x)**2)
|
||||
|
||||
assert fu((-14*sin(x)**3 + 35*sin(x) + 6*sqrt(3)*cos(x)**3 + 9*sqrt(3)*cos(x))/((cos(2*x) + 4))) == \
|
||||
7*sin(x) + 3*sqrt(3)*cos(x)
|
||||
|
||||
|
||||
def test_objective():
|
||||
assert fu(sin(x)/cos(x), measure=lambda x: x.count_ops()) == \
|
||||
tan(x)
|
||||
assert fu(sin(x)/cos(x), measure=lambda x: -x.count_ops()) == \
|
||||
sin(x)/cos(x)
|
||||
|
||||
|
||||
def test_process_common_addends():
|
||||
# this tests that the args are not evaluated as they are given to do
|
||||
# and that key2 works when key1 is False
|
||||
do = lambda x: Add(*[i**(i%2) for i in x.args])
|
||||
assert process_common_addends(Add(*[1, 2, 3, 4], evaluate=False), do,
|
||||
key2=lambda x: x%2, key1=False) == 1**1 + 3**1 + 2**0 + 4**0
|
||||
|
||||
|
||||
def test_trig_split():
|
||||
assert trig_split(cos(x), cos(y)) == (1, 1, 1, x, y, True)
|
||||
assert trig_split(2*cos(x), -2*cos(y)) == (2, 1, -1, x, y, True)
|
||||
assert trig_split(cos(x)*sin(y), cos(y)*sin(y)) == \
|
||||
(sin(y), 1, 1, x, y, True)
|
||||
|
||||
assert trig_split(cos(x), -sqrt(3)*sin(x), two=True) == \
|
||||
(2, 1, -1, x, pi/6, False)
|
||||
assert trig_split(cos(x), sin(x), two=True) == \
|
||||
(sqrt(2), 1, 1, x, pi/4, False)
|
||||
assert trig_split(cos(x), -sin(x), two=True) == \
|
||||
(sqrt(2), 1, -1, x, pi/4, False)
|
||||
assert trig_split(sqrt(2)*cos(x), -sqrt(6)*sin(x), two=True) == \
|
||||
(2*sqrt(2), 1, -1, x, pi/6, False)
|
||||
assert trig_split(-sqrt(6)*cos(x), -sqrt(2)*sin(x), two=True) == \
|
||||
(-2*sqrt(2), 1, 1, x, pi/3, False)
|
||||
assert trig_split(cos(x)/sqrt(6), sin(x)/sqrt(2), two=True) == \
|
||||
(sqrt(6)/3, 1, 1, x, pi/6, False)
|
||||
assert trig_split(-sqrt(6)*cos(x)*sin(y),
|
||||
-sqrt(2)*sin(x)*sin(y), two=True) == \
|
||||
(-2*sqrt(2)*sin(y), 1, 1, x, pi/3, False)
|
||||
|
||||
assert trig_split(cos(x), sin(x)) is None
|
||||
assert trig_split(cos(x), sin(z)) is None
|
||||
assert trig_split(2*cos(x), -sin(x)) is None
|
||||
assert trig_split(cos(x), -sqrt(3)*sin(x)) is None
|
||||
assert trig_split(cos(x)*cos(y), sin(x)*sin(z)) is None
|
||||
assert trig_split(cos(x)*cos(y), sin(x)*sin(y)) is None
|
||||
assert trig_split(-sqrt(6)*cos(x), sqrt(2)*sin(x)*sin(y), two=True) is \
|
||||
None
|
||||
|
||||
assert trig_split(sqrt(3)*sqrt(x), cos(3), two=True) is None
|
||||
assert trig_split(sqrt(3)*root(x, 3), sin(3)*cos(2), two=True) is None
|
||||
assert trig_split(cos(5)*cos(6), cos(7)*sin(5), two=True) is None
|
||||
|
||||
|
||||
def test_TRmorrie():
|
||||
assert TRmorrie(7*Mul(*[cos(i) for i in range(10)])) == \
|
||||
7*sin(12)*sin(16)*cos(5)*cos(7)*cos(9)/(64*sin(1)*sin(3))
|
||||
assert TRmorrie(x) == x
|
||||
assert TRmorrie(2*x) == 2*x
|
||||
e = cos(pi/7)*cos(pi*Rational(2, 7))*cos(pi*Rational(4, 7))
|
||||
assert TR8(TRmorrie(e)) == Rational(-1, 8)
|
||||
e = Mul(*[cos(2**i*pi/17) for i in range(1, 17)])
|
||||
assert TR8(TR3(TRmorrie(e))) == Rational(1, 65536)
|
||||
# issue 17063
|
||||
eq = cos(x)/cos(x/2)
|
||||
assert TRmorrie(eq) == eq
|
||||
# issue #20430
|
||||
eq = cos(x/2)*sin(x/2)*cos(x)**3
|
||||
assert TRmorrie(eq) == sin(2*x)*cos(x)**2/4
|
||||
|
||||
|
||||
def test_TRpower():
|
||||
assert TRpower(1/sin(x)**2) == 1/sin(x)**2
|
||||
assert TRpower(cos(x)**3*sin(x/2)**4) == \
|
||||
(3*cos(x)/4 + cos(3*x)/4)*(-cos(x)/2 + cos(2*x)/8 + Rational(3, 8))
|
||||
for k in range(2, 8):
|
||||
assert verify_numerically(sin(x)**k, TRpower(sin(x)**k))
|
||||
assert verify_numerically(cos(x)**k, TRpower(cos(x)**k))
|
||||
|
||||
|
||||
def test_hyper_as_trig():
|
||||
from sympy.simplify.fu import _osborne, _osbornei
|
||||
|
||||
eq = sinh(x)**2 + cosh(x)**2
|
||||
t, f = hyper_as_trig(eq)
|
||||
assert f(fu(t)) == cosh(2*x)
|
||||
e, f = hyper_as_trig(tanh(x + y))
|
||||
assert f(TR12(e)) == (tanh(x) + tanh(y))/(tanh(x)*tanh(y) + 1)
|
||||
|
||||
d = Dummy()
|
||||
assert _osborne(sinh(x), d) == I*sin(x*d)
|
||||
assert _osborne(tanh(x), d) == I*tan(x*d)
|
||||
assert _osborne(coth(x), d) == cot(x*d)/I
|
||||
assert _osborne(cosh(x), d) == cos(x*d)
|
||||
assert _osborne(sech(x), d) == sec(x*d)
|
||||
assert _osborne(csch(x), d) == csc(x*d)/I
|
||||
for func in (sinh, cosh, tanh, coth, sech, csch):
|
||||
h = func(pi)
|
||||
assert _osbornei(_osborne(h, d), d) == h
|
||||
# /!\ the _osborne functions are not meant to work
|
||||
# in the o(i(trig, d), d) direction so we just check
|
||||
# that they work as they are supposed to work
|
||||
assert _osbornei(cos(x*y + z), y) == cosh(x + z*I)
|
||||
assert _osbornei(sin(x*y + z), y) == sinh(x + z*I)/I
|
||||
assert _osbornei(tan(x*y + z), y) == tanh(x + z*I)/I
|
||||
assert _osbornei(cot(x*y + z), y) == coth(x + z*I)*I
|
||||
assert _osbornei(sec(x*y + z), y) == sech(x + z*I)
|
||||
assert _osbornei(csc(x*y + z), y) == csch(x + z*I)*I
|
||||
|
||||
|
||||
def test_TR12i():
|
||||
ta, tb, tc = [tan(i) for i in (a, b, c)]
|
||||
assert TR12i((ta + tb)/(-ta*tb + 1)) == tan(a + b)
|
||||
assert TR12i((ta + tb)/(ta*tb - 1)) == -tan(a + b)
|
||||
assert TR12i((-ta - tb)/(ta*tb - 1)) == tan(a + b)
|
||||
eq = (ta + tb)/(-ta*tb + 1)**2*(-3*ta - 3*tc)/(2*(ta*tc - 1))
|
||||
assert TR12i(eq.expand()) == \
|
||||
-3*tan(a + b)*tan(a + c)/(tan(a) + tan(b) - 1)/2
|
||||
assert TR12i(tan(x)/sin(x)) == tan(x)/sin(x)
|
||||
eq = (ta + cos(2))/(-ta*tb + 1)
|
||||
assert TR12i(eq) == eq
|
||||
eq = (ta + tb + 2)**2/(-ta*tb + 1)
|
||||
assert TR12i(eq) == eq
|
||||
eq = ta/(-ta*tb + 1)
|
||||
assert TR12i(eq) == eq
|
||||
eq = (((ta + tb)*(a + 1)).expand())**2/(ta*tb - 1)
|
||||
assert TR12i(eq) == -(a + 1)**2*tan(a + b)
|
||||
|
||||
|
||||
def test_TR14():
|
||||
eq = (cos(x) - 1)*(cos(x) + 1)
|
||||
ans = -sin(x)**2
|
||||
assert TR14(eq) == ans
|
||||
assert TR14(1/eq) == 1/ans
|
||||
assert TR14((cos(x) - 1)**2*(cos(x) + 1)**2) == ans**2
|
||||
assert TR14((cos(x) - 1)**2*(cos(x) + 1)**3) == ans**2*(cos(x) + 1)
|
||||
assert TR14((cos(x) - 1)**3*(cos(x) + 1)**2) == ans**2*(cos(x) - 1)
|
||||
eq = (cos(x) - 1)**y*(cos(x) + 1)**y
|
||||
assert TR14(eq) == eq
|
||||
eq = (cos(x) - 2)**y*(cos(x) + 1)
|
||||
assert TR14(eq) == eq
|
||||
eq = (tan(x) - 2)**2*(cos(x) + 1)
|
||||
assert TR14(eq) == eq
|
||||
i = symbols('i', integer=True)
|
||||
assert TR14((cos(x) - 1)**i*(cos(x) + 1)**i) == ans**i
|
||||
assert TR14((sin(x) - 1)**i*(sin(x) + 1)**i) == (-cos(x)**2)**i
|
||||
# could use extraction in this case
|
||||
eq = (cos(x) - 1)**(i + 1)*(cos(x) + 1)**i
|
||||
assert TR14(eq) in [(cos(x) - 1)*ans**i, eq]
|
||||
|
||||
assert TR14((sin(x) - 1)*(sin(x) + 1)) == -cos(x)**2
|
||||
p1 = (cos(x) + 1)*(cos(x) - 1)
|
||||
p2 = (cos(y) - 1)*2*(cos(y) + 1)
|
||||
p3 = (3*(cos(y) - 1))*(3*(cos(y) + 1))
|
||||
assert TR14(p1*p2*p3*(x - 1)) == -18*((x - 1)*sin(x)**2*sin(y)**4)
|
||||
|
||||
|
||||
def test_TR15_16_17():
|
||||
assert TR15(1 - 1/sin(x)**2) == -cot(x)**2
|
||||
assert TR16(1 - 1/cos(x)**2) == -tan(x)**2
|
||||
assert TR111(1 - 1/tan(x)**2) == 1 - cot(x)**2
|
||||
|
||||
|
||||
def test_as_f_sign_1():
|
||||
assert as_f_sign_1(x + 1) == (1, x, 1)
|
||||
assert as_f_sign_1(x - 1) == (1, x, -1)
|
||||
assert as_f_sign_1(-x + 1) == (-1, x, -1)
|
||||
assert as_f_sign_1(-x - 1) == (-1, x, 1)
|
||||
assert as_f_sign_1(2*x + 2) == (2, x, 1)
|
||||
assert as_f_sign_1(x*y - y) == (y, x, -1)
|
||||
assert as_f_sign_1(-x*y + y) == (-y, x, -1)
|
||||
|
||||
|
||||
def test_issue_25590():
|
||||
A = Symbol('A', commutative=False)
|
||||
B = Symbol('B', commutative=False)
|
||||
|
||||
assert TR8(2*cos(x)*sin(x)*B*A) == sin(2*x)*B*A
|
||||
assert TR13(tan(2)*tan(3)*B*A) == (-tan(2)/tan(5) - tan(3)/tan(5) + 1)*B*A
|
||||
|
||||
# XXX The result may not be optimal than
|
||||
# sin(2*x)*B*A + cos(x)**2 and may change in the future
|
||||
assert (2*cos(x)*sin(x)*B*A + cos(x)**2).simplify() == sin(2*x)*B*A + cos(2*x)/2 + S.One/2
|
||||
@@ -0,0 +1,54 @@
|
||||
""" Unit tests for Hyper_Function"""
|
||||
from sympy.core import symbols, Dummy, Tuple, S, Rational
|
||||
from sympy.functions import hyper
|
||||
|
||||
from sympy.simplify.hyperexpand import Hyper_Function
|
||||
|
||||
def test_attrs():
|
||||
a, b = symbols('a, b', cls=Dummy)
|
||||
f = Hyper_Function([2, a], [b])
|
||||
assert f.ap == Tuple(2, a)
|
||||
assert f.bq == Tuple(b)
|
||||
assert f.args == (Tuple(2, a), Tuple(b))
|
||||
assert f.sizes == (2, 1)
|
||||
|
||||
def test_call():
|
||||
a, b, x = symbols('a, b, x', cls=Dummy)
|
||||
f = Hyper_Function([2, a], [b])
|
||||
assert f(x) == hyper([2, a], [b], x)
|
||||
|
||||
def test_has():
|
||||
a, b, c = symbols('a, b, c', cls=Dummy)
|
||||
f = Hyper_Function([2, -a], [b])
|
||||
assert f.has(a)
|
||||
assert f.has(Tuple(b))
|
||||
assert not f.has(c)
|
||||
|
||||
def test_eq():
|
||||
assert Hyper_Function([1], []) == Hyper_Function([1], [])
|
||||
assert (Hyper_Function([1], []) != Hyper_Function([1], [])) is False
|
||||
assert Hyper_Function([1], []) != Hyper_Function([2], [])
|
||||
assert Hyper_Function([1], []) != Hyper_Function([1, 2], [])
|
||||
assert Hyper_Function([1], []) != Hyper_Function([1], [2])
|
||||
|
||||
def test_gamma():
|
||||
assert Hyper_Function([2, 3], [-1]).gamma == 0
|
||||
assert Hyper_Function([-2, -3], [-1]).gamma == 2
|
||||
n = Dummy(integer=True)
|
||||
assert Hyper_Function([-1, n, 1], []).gamma == 1
|
||||
assert Hyper_Function([-1, -n, 1], []).gamma == 1
|
||||
p = Dummy(integer=True, positive=True)
|
||||
assert Hyper_Function([-1, p, 1], []).gamma == 1
|
||||
assert Hyper_Function([-1, -p, 1], []).gamma == 2
|
||||
|
||||
def test_suitable_origin():
|
||||
assert Hyper_Function((S.Half,), (Rational(3, 2),))._is_suitable_origin() is True
|
||||
assert Hyper_Function((S.Half,), (S.Half,))._is_suitable_origin() is False
|
||||
assert Hyper_Function((S.Half,), (Rational(-1, 2),))._is_suitable_origin() is False
|
||||
assert Hyper_Function((S.Half,), (0,))._is_suitable_origin() is False
|
||||
assert Hyper_Function((S.Half,), (-1, 1,))._is_suitable_origin() is False
|
||||
assert Hyper_Function((S.Half, 0), (1,))._is_suitable_origin() is False
|
||||
assert Hyper_Function((S.Half, 1),
|
||||
(2, Rational(-2, 3)))._is_suitable_origin() is True
|
||||
assert Hyper_Function((S.Half, 1),
|
||||
(2, Rational(-2, 3), Rational(3, 2)))._is_suitable_origin() is True
|
||||
@@ -0,0 +1,127 @@
|
||||
from sympy.core.function import Function
|
||||
from sympy.core.numbers import (Rational, pi)
|
||||
from sympy.core.singleton import S
|
||||
from sympy.core.symbol import symbols
|
||||
from sympy.functions.combinatorial.factorials import (rf, binomial, factorial)
|
||||
from sympy.functions.elementary.exponential import exp
|
||||
from sympy.functions.elementary.miscellaneous import sqrt
|
||||
from sympy.functions.elementary.piecewise import Piecewise
|
||||
from sympy.functions.elementary.trigonometric import (cos, sin)
|
||||
from sympy.functions.special.gamma_functions import gamma
|
||||
from sympy.simplify.gammasimp import gammasimp
|
||||
from sympy.simplify.powsimp import powsimp
|
||||
from sympy.simplify.simplify import simplify
|
||||
|
||||
from sympy.abc import x, y, n, k
|
||||
|
||||
|
||||
def test_gammasimp():
|
||||
R = Rational
|
||||
|
||||
# was part of test_combsimp_gamma() in test_combsimp.py
|
||||
assert gammasimp(gamma(x)) == gamma(x)
|
||||
assert gammasimp(gamma(x + 1)/x) == gamma(x)
|
||||
assert gammasimp(gamma(x)/(x - 1)) == gamma(x - 1)
|
||||
assert gammasimp(x*gamma(x)) == gamma(x + 1)
|
||||
assert gammasimp((x + 1)*gamma(x + 1)) == gamma(x + 2)
|
||||
assert gammasimp(gamma(x + y)*(x + y)) == gamma(x + y + 1)
|
||||
assert gammasimp(x/gamma(x + 1)) == 1/gamma(x)
|
||||
assert gammasimp((x + 1)**2/gamma(x + 2)) == (x + 1)/gamma(x + 1)
|
||||
assert gammasimp(x*gamma(x) + gamma(x + 3)/(x + 2)) == \
|
||||
(x + 2)*gamma(x + 1)
|
||||
|
||||
assert gammasimp(gamma(2*x)*x) == gamma(2*x + 1)/2
|
||||
assert gammasimp(gamma(2*x)/(x - S.Half)) == 2*gamma(2*x - 1)
|
||||
|
||||
assert gammasimp(gamma(x)*gamma(1 - x)) == pi/sin(pi*x)
|
||||
assert gammasimp(gamma(x)*gamma(-x)) == -pi/(x*sin(pi*x))
|
||||
assert gammasimp(1/gamma(x + 3)/gamma(1 - x)) == \
|
||||
sin(pi*x)/(pi*x*(x + 1)*(x + 2))
|
||||
|
||||
assert gammasimp(factorial(n + 2)) == gamma(n + 3)
|
||||
assert gammasimp(binomial(n, k)) == \
|
||||
gamma(n + 1)/(gamma(k + 1)*gamma(-k + n + 1))
|
||||
|
||||
assert powsimp(gammasimp(
|
||||
gamma(x)*gamma(x + S.Half)*gamma(y)/gamma(x + y))) == \
|
||||
2**(-2*x + 1)*sqrt(pi)*gamma(2*x)*gamma(y)/gamma(x + y)
|
||||
assert gammasimp(1/gamma(x)/gamma(x - Rational(1, 3))/gamma(x + Rational(1, 3))) == \
|
||||
3**(3*x - Rational(3, 2))/(2*pi*gamma(3*x - 1))
|
||||
assert simplify(
|
||||
gamma(S.Half + x/2)*gamma(1 + x/2)/gamma(1 + x)/sqrt(pi)*2**x) == 1
|
||||
assert gammasimp(gamma(Rational(-1, 4))*gamma(Rational(-3, 4))) == 16*sqrt(2)*pi/3
|
||||
|
||||
assert powsimp(gammasimp(gamma(2*x)/gamma(x))) == \
|
||||
2**(2*x - 1)*gamma(x + S.Half)/sqrt(pi)
|
||||
|
||||
# issue 6792
|
||||
e = (-gamma(k)*gamma(k + 2) + gamma(k + 1)**2)/gamma(k)**2
|
||||
assert gammasimp(e) == -k
|
||||
assert gammasimp(1/e) == -1/k
|
||||
e = (gamma(x) + gamma(x + 1))/gamma(x)
|
||||
assert gammasimp(e) == x + 1
|
||||
assert gammasimp(1/e) == 1/(x + 1)
|
||||
e = (gamma(x) + gamma(x + 2))*(gamma(x - 1) + gamma(x))/gamma(x)
|
||||
assert gammasimp(e) == (x**2 + x + 1)*gamma(x + 1)/(x - 1)
|
||||
e = (-gamma(k)*gamma(k + 2) + gamma(k + 1)**2)/gamma(k)**2
|
||||
assert gammasimp(e**2) == k**2
|
||||
assert gammasimp(e**2/gamma(k + 1)) == k/gamma(k)
|
||||
a = R(1, 2) + R(1, 3)
|
||||
b = a + R(1, 3)
|
||||
assert gammasimp(gamma(2*k)/gamma(k)*gamma(k + a)*gamma(k + b)
|
||||
) == 3*2**(2*k + 1)*3**(-3*k - 2)*sqrt(pi)*gamma(3*k + R(3, 2))/2
|
||||
|
||||
# issue 9699
|
||||
assert gammasimp((x + 1)*factorial(x)/gamma(y)) == gamma(x + 2)/gamma(y)
|
||||
assert gammasimp(rf(x + n, k)*binomial(n, k)).simplify() == Piecewise(
|
||||
(gamma(n + 1)*gamma(k + n + x)/(gamma(k + 1)*gamma(n + x)*gamma(-k + n + 1)), n > -x),
|
||||
((-1)**k*gamma(n + 1)*gamma(-n - x + 1)/(gamma(k + 1)*gamma(-k + n + 1)*gamma(-k - n - x + 1)), True))
|
||||
|
||||
A, B = symbols('A B', commutative=False)
|
||||
assert gammasimp(e*B*A) == gammasimp(e)*B*A
|
||||
|
||||
# check iteration
|
||||
assert gammasimp(gamma(2*k)/gamma(k)*gamma(-k - R(1, 2))) == (
|
||||
-2**(2*k + 1)*sqrt(pi)/(2*((2*k + 1)*cos(pi*k))))
|
||||
assert gammasimp(
|
||||
gamma(k)*gamma(k + R(1, 3))*gamma(k + R(2, 3))/gamma(k*R(3, 2))) == (
|
||||
3*2**(3*k + 1)*3**(-3*k - S.Half)*sqrt(pi)*gamma(k*R(3, 2) + S.Half)/2)
|
||||
|
||||
# issue 6153
|
||||
assert gammasimp(gamma(Rational(1, 4))/gamma(Rational(5, 4))) == 4
|
||||
|
||||
# was part of test_combsimp() in test_combsimp.py
|
||||
assert gammasimp(binomial(n + 2, k + S.Half)) == gamma(n + 3)/ \
|
||||
(gamma(k + R(3, 2))*gamma(-k + n + R(5, 2)))
|
||||
assert gammasimp(binomial(n + 2, k + 2.0)) == \
|
||||
gamma(n + 3)/(gamma(k + 3.0)*gamma(-k + n + 1))
|
||||
|
||||
# issue 11548
|
||||
assert gammasimp(binomial(0, x)) == sin(pi*x)/(pi*x)
|
||||
|
||||
e = gamma(n + Rational(1, 3))*gamma(n + R(2, 3))
|
||||
assert gammasimp(e) == e
|
||||
assert gammasimp(gamma(4*n + S.Half)/gamma(2*n - R(3, 4))) == \
|
||||
2**(4*n - R(5, 2))*(8*n - 3)*gamma(2*n + R(3, 4))/sqrt(pi)
|
||||
|
||||
i, m = symbols('i m', integer = True)
|
||||
e = gamma(exp(i))
|
||||
assert gammasimp(e) == e
|
||||
e = gamma(m + 3)
|
||||
assert gammasimp(e) == e
|
||||
e = gamma(m + 1)/(gamma(i + 1)*gamma(-i + m + 1))
|
||||
assert gammasimp(e) == e
|
||||
|
||||
p = symbols("p", integer=True, positive=True)
|
||||
assert gammasimp(gamma(-p + 4)) == gamma(-p + 4)
|
||||
|
||||
|
||||
def test_issue_22606():
|
||||
fx = Function('f')(x)
|
||||
eq = x + gamma(y)
|
||||
# seems like ans should be `eq`, not `(x*y + gamma(y + 1))/y`
|
||||
ans = gammasimp(eq)
|
||||
assert gammasimp(eq.subs(x, fx)).subs(fx, x) == ans
|
||||
assert gammasimp(eq.subs(x, cos(x))).subs(cos(x), x) == ans
|
||||
assert 1/gammasimp(1/eq) == ans
|
||||
assert gammasimp(fx.subs(x, eq)).args[0] == ans
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,368 @@
|
||||
from sympy.core.function import Function
|
||||
from sympy.core.mul import Mul
|
||||
from sympy.core.numbers import (E, I, Rational, oo, pi)
|
||||
from sympy.core.singleton import S
|
||||
from sympy.core.symbol import (Dummy, Symbol, symbols)
|
||||
from sympy.functions.elementary.exponential import (exp, log)
|
||||
from sympy.functions.elementary.miscellaneous import (root, sqrt)
|
||||
from sympy.functions.elementary.trigonometric import sin
|
||||
from sympy.functions.special.gamma_functions import gamma
|
||||
from sympy.functions.special.hyper import hyper
|
||||
from sympy.matrices.expressions.matexpr import MatrixSymbol
|
||||
from sympy.simplify.powsimp import (powdenest, powsimp)
|
||||
from sympy.simplify.simplify import (signsimp, simplify)
|
||||
from sympy.core.symbol import Str
|
||||
|
||||
from sympy.abc import x, y, z, a, b
|
||||
|
||||
|
||||
def test_powsimp():
|
||||
x, y, z, n = symbols('x,y,z,n')
|
||||
f = Function('f')
|
||||
assert powsimp( 4**x * 2**(-x) * 2**(-x) ) == 1
|
||||
assert powsimp( (-4)**x * (-2)**(-x) * 2**(-x) ) == 1
|
||||
|
||||
assert powsimp(
|
||||
f(4**x * 2**(-x) * 2**(-x)) ) == f(4**x * 2**(-x) * 2**(-x))
|
||||
assert powsimp( f(4**x * 2**(-x) * 2**(-x)), deep=True ) == f(1)
|
||||
assert exp(x)*exp(y) == exp(x)*exp(y)
|
||||
assert powsimp(exp(x)*exp(y)) == exp(x + y)
|
||||
assert powsimp(exp(x)*exp(y)*2**x*2**y) == (2*E)**(x + y)
|
||||
assert powsimp(exp(x)*exp(y)*2**x*2**y, combine='exp') == \
|
||||
exp(x + y)*2**(x + y)
|
||||
assert powsimp(exp(x)*exp(y)*exp(2)*sin(x) + sin(y) + 2**x*2**y) == \
|
||||
exp(2 + x + y)*sin(x) + sin(y) + 2**(x + y)
|
||||
assert powsimp(sin(exp(x)*exp(y))) == sin(exp(x)*exp(y))
|
||||
assert powsimp(sin(exp(x)*exp(y)), deep=True) == sin(exp(x + y))
|
||||
assert powsimp(x**2*x**y) == x**(2 + y)
|
||||
# This should remain factored, because 'exp' with deep=True is supposed
|
||||
# to act like old automatic exponent combining.
|
||||
assert powsimp((1 + E*exp(E))*exp(-E), combine='exp', deep=True) == \
|
||||
(1 + exp(1 + E))*exp(-E)
|
||||
assert powsimp((1 + E*exp(E))*exp(-E), deep=True) == \
|
||||
(1 + exp(1 + E))*exp(-E)
|
||||
assert powsimp((1 + E*exp(E))*exp(-E)) == (1 + exp(1 + E))*exp(-E)
|
||||
assert powsimp((1 + E*exp(E))*exp(-E), combine='exp') == \
|
||||
(1 + exp(1 + E))*exp(-E)
|
||||
assert powsimp((1 + E*exp(E))*exp(-E), combine='base') == \
|
||||
(1 + E*exp(E))*exp(-E)
|
||||
x, y = symbols('x,y', nonnegative=True)
|
||||
n = Symbol('n', real=True)
|
||||
assert powsimp(y**n * (y/x)**(-n)) == x**n
|
||||
assert powsimp(x**(x**(x*y)*y**(x*y))*y**(x**(x*y)*y**(x*y)), deep=True) \
|
||||
== (x*y)**(x*y)**(x*y)
|
||||
assert powsimp(2**(2**(2*x)*x), deep=False) == 2**(2**(2*x)*x)
|
||||
assert powsimp(2**(2**(2*x)*x), deep=True) == 2**(x*4**x)
|
||||
assert powsimp(
|
||||
exp(-x + exp(-x)*exp(-x*log(x))), deep=False, combine='exp') == \
|
||||
exp(-x + exp(-x)*exp(-x*log(x)))
|
||||
assert powsimp(
|
||||
exp(-x + exp(-x)*exp(-x*log(x))), deep=False, combine='exp') == \
|
||||
exp(-x + exp(-x)*exp(-x*log(x)))
|
||||
assert powsimp((x + y)/(3*z), deep=False, combine='exp') == (x + y)/(3*z)
|
||||
assert powsimp((x/3 + y/3)/z, deep=True, combine='exp') == (x/3 + y/3)/z
|
||||
assert powsimp(exp(x)/(1 + exp(x)*exp(y)), deep=True) == \
|
||||
exp(x)/(1 + exp(x + y))
|
||||
assert powsimp(x*y**(z**x*z**y), deep=True) == x*y**(z**(x + y))
|
||||
assert powsimp((z**x*z**y)**x, deep=True) == (z**(x + y))**x
|
||||
assert powsimp(x*(z**x*z**y)**x, deep=True) == x*(z**(x + y))**x
|
||||
p = symbols('p', positive=True)
|
||||
assert powsimp((1/x)**log(2)/x) == (1/x)**(1 + log(2))
|
||||
assert powsimp((1/p)**log(2)/p) == p**(-1 - log(2))
|
||||
|
||||
# coefficient of exponent can only be simplified for positive bases
|
||||
assert powsimp(2**(2*x)) == 4**x
|
||||
assert powsimp((-1)**(2*x)) == (-1)**(2*x)
|
||||
i = symbols('i', integer=True)
|
||||
assert powsimp((-1)**(2*i)) == 1
|
||||
assert powsimp((-1)**(-x)) != (-1)**x # could be 1/((-1)**x), but is not
|
||||
# force=True overrides assumptions
|
||||
assert powsimp((-1)**(2*x), force=True) == 1
|
||||
|
||||
# rational exponents allow combining of negative terms
|
||||
w, n, m = symbols('w n m', negative=True)
|
||||
e = i/a # not a rational exponent if `a` is unknown
|
||||
ex = w**e*n**e*m**e
|
||||
assert powsimp(ex) == m**(i/a)*n**(i/a)*w**(i/a)
|
||||
e = i/3
|
||||
ex = w**e*n**e*m**e
|
||||
assert powsimp(ex) == (-1)**i*(-m*n*w)**(i/3)
|
||||
e = (3 + i)/i
|
||||
ex = w**e*n**e*m**e
|
||||
assert powsimp(ex) == (-1)**(3*e)*(-m*n*w)**e
|
||||
|
||||
eq = x**(a*Rational(2, 3))
|
||||
# eq != (x**a)**(2/3) (try x = -1 and a = 3 to see)
|
||||
assert powsimp(eq).exp == eq.exp == a*Rational(2, 3)
|
||||
# powdenest goes the other direction
|
||||
assert powsimp(2**(2*x)) == 4**x
|
||||
|
||||
assert powsimp(exp(p/2)) == exp(p/2)
|
||||
|
||||
# issue 6368
|
||||
eq = Mul(*[sqrt(Dummy(imaginary=True)) for i in range(3)])
|
||||
assert powsimp(eq) == eq and eq.is_Mul
|
||||
|
||||
assert all(powsimp(e) == e for e in (sqrt(x**a), sqrt(x**2)))
|
||||
|
||||
# issue 8836
|
||||
assert str( powsimp(exp(I*pi/3)*root(-1,3)) ) == '(-1)**(2/3)'
|
||||
|
||||
# issue 9183
|
||||
assert powsimp(-0.1**x) == -0.1**x
|
||||
|
||||
# issue 10095
|
||||
assert powsimp((1/(2*E))**oo) == (exp(-1)/2)**oo
|
||||
|
||||
# PR 13131
|
||||
eq = sin(2*x)**2*sin(2.0*x)**2
|
||||
assert powsimp(eq) == eq
|
||||
|
||||
# issue 14615
|
||||
assert powsimp(x**2*y**3*(x*y**2)**Rational(3, 2)
|
||||
) == x*y*(x*y**2)**Rational(5, 2)
|
||||
|
||||
#issue 27380
|
||||
assert powsimp(1.0**(x+1)/1.0**x) == 1.0
|
||||
|
||||
def test_powsimp_negated_base():
|
||||
assert powsimp((-x + y)/sqrt(x - y)) == -sqrt(x - y)
|
||||
assert powsimp((-x + y)*(-z + y)/sqrt(x - y)/sqrt(z - y)) == sqrt(x - y)*sqrt(z - y)
|
||||
p = symbols('p', positive=True)
|
||||
reps = {p: 2, a: S.Half}
|
||||
assert powsimp((-p)**a/p**a).subs(reps) == ((-1)**a).subs(reps)
|
||||
assert powsimp((-p)**a*p**a).subs(reps) == ((-p**2)**a).subs(reps)
|
||||
n = symbols('n', negative=True)
|
||||
reps = {p: -2, a: S.Half}
|
||||
assert powsimp((-n)**a/n**a).subs(reps) == (-1)**(-a).subs(a, S.Half)
|
||||
assert powsimp((-n)**a*n**a).subs(reps) == ((-n**2)**a).subs(reps)
|
||||
# if x is 0 then the lhs is 0**a*oo**a which is not (-1)**a
|
||||
eq = (-x)**a/x**a
|
||||
assert powsimp(eq) == eq
|
||||
|
||||
|
||||
def test_powsimp_nc():
|
||||
x, y, z = symbols('x,y,z')
|
||||
A, B, C = symbols('A B C', commutative=False)
|
||||
|
||||
assert powsimp(A**x*A**y, combine='all') == A**(x + y)
|
||||
assert powsimp(A**x*A**y, combine='base') == A**x*A**y
|
||||
assert powsimp(A**x*A**y, combine='exp') == A**(x + y)
|
||||
|
||||
assert powsimp(A**x*B**x, combine='all') == A**x*B**x
|
||||
assert powsimp(A**x*B**x, combine='base') == A**x*B**x
|
||||
assert powsimp(A**x*B**x, combine='exp') == A**x*B**x
|
||||
|
||||
assert powsimp(B**x*A**x, combine='all') == B**x*A**x
|
||||
assert powsimp(B**x*A**x, combine='base') == B**x*A**x
|
||||
assert powsimp(B**x*A**x, combine='exp') == B**x*A**x
|
||||
|
||||
assert powsimp(A**x*A**y*A**z, combine='all') == A**(x + y + z)
|
||||
assert powsimp(A**x*A**y*A**z, combine='base') == A**x*A**y*A**z
|
||||
assert powsimp(A**x*A**y*A**z, combine='exp') == A**(x + y + z)
|
||||
|
||||
assert powsimp(A**x*B**x*C**x, combine='all') == A**x*B**x*C**x
|
||||
assert powsimp(A**x*B**x*C**x, combine='base') == A**x*B**x*C**x
|
||||
assert powsimp(A**x*B**x*C**x, combine='exp') == A**x*B**x*C**x
|
||||
|
||||
assert powsimp(B**x*A**x*C**x, combine='all') == B**x*A**x*C**x
|
||||
assert powsimp(B**x*A**x*C**x, combine='base') == B**x*A**x*C**x
|
||||
assert powsimp(B**x*A**x*C**x, combine='exp') == B**x*A**x*C**x
|
||||
|
||||
|
||||
def test_issue_6440():
|
||||
assert powsimp(16*2**a*8**b) == 2**(a + 3*b + 4)
|
||||
|
||||
|
||||
def test_powdenest():
|
||||
x, y = symbols('x,y')
|
||||
p, q = symbols('p q', positive=True)
|
||||
i, j = symbols('i,j', integer=True)
|
||||
|
||||
assert powdenest(x) == x
|
||||
assert powdenest(x + 2*(x**(a*Rational(2, 3)))**(3*x)) == (x + 2*(x**(a*Rational(2, 3)))**(3*x))
|
||||
assert powdenest((exp(a*Rational(2, 3)))**(3*x)) # -X-> (exp(a/3))**(6*x)
|
||||
assert powdenest((x**(a*Rational(2, 3)))**(3*x)) == ((x**(a*Rational(2, 3)))**(3*x))
|
||||
assert powdenest(exp(3*x*log(2))) == 2**(3*x)
|
||||
assert powdenest(sqrt(p**2)) == p
|
||||
eq = p**(2*i)*q**(4*i)
|
||||
assert powdenest(eq) == (p*q**2)**(2*i)
|
||||
# -X-> (x**x)**i*(x**x)**j == x**(x*(i + j))
|
||||
assert powdenest((x**x)**(i + j))
|
||||
assert powdenest(exp(3*y*log(x))) == x**(3*y)
|
||||
assert powdenest(exp(y*(log(a) + log(b)))) == (a*b)**y
|
||||
assert powdenest(exp(3*(log(a) + log(b)))) == a**3*b**3
|
||||
assert powdenest(((x**(2*i))**(3*y))**x) == ((x**(2*i))**(3*y))**x
|
||||
assert powdenest(((x**(2*i))**(3*y))**x, force=True) == x**(6*i*x*y)
|
||||
assert powdenest(((x**(a*Rational(2, 3)))**(3*y/i))**x) == \
|
||||
(((x**(a*Rational(2, 3)))**(3*y/i))**x)
|
||||
assert powdenest((x**(2*i)*y**(4*i))**z, force=True) == (x*y**2)**(2*i*z)
|
||||
assert powdenest((p**(2*i)*q**(4*i))**j) == (p*q**2)**(2*i*j)
|
||||
e = ((p**(2*a))**(3*y))**x
|
||||
assert powdenest(e) == e
|
||||
e = ((x**2*y**4)**a)**(x*y)
|
||||
assert powdenest(e) == e
|
||||
e = (((x**2*y**4)**a)**(x*y))**3
|
||||
assert powdenest(e) == ((x**2*y**4)**a)**(3*x*y)
|
||||
assert powdenest((((x**2*y**4)**a)**(x*y)), force=True) == \
|
||||
(x*y**2)**(2*a*x*y)
|
||||
assert powdenest((((x**2*y**4)**a)**(x*y))**3, force=True) == \
|
||||
(x*y**2)**(6*a*x*y)
|
||||
assert powdenest((x**2*y**6)**i) != (x*y**3)**(2*i)
|
||||
x, y = symbols('x,y', positive=True)
|
||||
assert powdenest((x**2*y**6)**i) == (x*y**3)**(2*i)
|
||||
|
||||
assert powdenest((x**(i*Rational(2, 3))*y**(i/2))**(2*i)) == (x**Rational(4, 3)*y)**(i**2)
|
||||
assert powdenest(sqrt(x**(2*i)*y**(6*i))) == (x*y**3)**i
|
||||
|
||||
assert powdenest(4**x) == 2**(2*x)
|
||||
assert powdenest((4**x)**y) == 2**(2*x*y)
|
||||
assert powdenest(4**x*y) == 2**(2*x)*y
|
||||
|
||||
|
||||
def test_powdenest_polar():
|
||||
x, y, z = symbols('x y z', polar=True)
|
||||
a, b, c = symbols('a b c')
|
||||
assert powdenest((x*y*z)**a) == x**a*y**a*z**a
|
||||
assert powdenest((x**a*y**b)**c) == x**(a*c)*y**(b*c)
|
||||
assert powdenest(((x**a)**b*y**c)**c) == x**(a*b*c)*y**(c**2)
|
||||
|
||||
|
||||
def test_issue_5805():
|
||||
arg = ((gamma(x)*hyper((), (), x))*pi)**2
|
||||
assert powdenest(arg) == (pi*gamma(x)*hyper((), (), x))**2
|
||||
assert arg.is_positive is None
|
||||
|
||||
|
||||
def test_issue_9324_powsimp_on_matrix_symbol():
|
||||
M = MatrixSymbol('M', 10, 10)
|
||||
expr = powsimp(M, deep=True)
|
||||
assert expr == M
|
||||
assert expr.args[0] == Str('M')
|
||||
|
||||
|
||||
def test_issue_6367():
|
||||
z = -5*sqrt(2)/(2*sqrt(2*sqrt(29) + 29)) + sqrt(-sqrt(29)/29 + S.Half)
|
||||
assert Mul(*[powsimp(a) for a in Mul.make_args(z.normal())]) == 0
|
||||
assert powsimp(z.normal()) == 0
|
||||
assert simplify(z) == 0
|
||||
assert powsimp(sqrt(2 + sqrt(3))*sqrt(2 - sqrt(3)) + 1) == 2
|
||||
assert powsimp(z) != 0
|
||||
|
||||
|
||||
def test_powsimp_polar():
|
||||
from sympy.functions.elementary.complexes import polar_lift
|
||||
from sympy.functions.elementary.exponential import exp_polar
|
||||
x, y, z = symbols('x y z')
|
||||
p, q, r = symbols('p q r', polar=True)
|
||||
|
||||
assert (polar_lift(-1))**(2*x) == exp_polar(2*pi*I*x)
|
||||
assert powsimp(p**x * q**x) == (p*q)**x
|
||||
assert p**x * (1/p)**x == 1
|
||||
assert (1/p)**x == p**(-x)
|
||||
|
||||
assert exp_polar(x)*exp_polar(y) == exp_polar(x)*exp_polar(y)
|
||||
assert powsimp(exp_polar(x)*exp_polar(y)) == exp_polar(x + y)
|
||||
assert powsimp(exp_polar(x)*exp_polar(y)*p**x*p**y) == \
|
||||
(p*exp_polar(1))**(x + y)
|
||||
assert powsimp(exp_polar(x)*exp_polar(y)*p**x*p**y, combine='exp') == \
|
||||
exp_polar(x + y)*p**(x + y)
|
||||
assert powsimp(
|
||||
exp_polar(x)*exp_polar(y)*exp_polar(2)*sin(x) + sin(y) + p**x*p**y) \
|
||||
== p**(x + y) + sin(x)*exp_polar(2 + x + y) + sin(y)
|
||||
assert powsimp(sin(exp_polar(x)*exp_polar(y))) == \
|
||||
sin(exp_polar(x)*exp_polar(y))
|
||||
assert powsimp(sin(exp_polar(x)*exp_polar(y)), deep=True) == \
|
||||
sin(exp_polar(x + y))
|
||||
|
||||
|
||||
def test_issue_5728():
|
||||
b = x*sqrt(y)
|
||||
a = sqrt(b)
|
||||
c = sqrt(sqrt(x)*y)
|
||||
assert powsimp(a*b) == sqrt(b)**3
|
||||
assert powsimp(a*b**2*sqrt(y)) == sqrt(y)*a**5
|
||||
assert powsimp(a*x**2*c**3*y) == c**3*a**5
|
||||
assert powsimp(a*x*c**3*y**2) == c**7*a
|
||||
assert powsimp(x*c**3*y**2) == c**7
|
||||
assert powsimp(x*c**3*y) == x*y*c**3
|
||||
assert powsimp(sqrt(x)*c**3*y) == c**5
|
||||
assert powsimp(sqrt(x)*a**3*sqrt(y)) == sqrt(x)*sqrt(y)*a**3
|
||||
assert powsimp(Mul(sqrt(x)*c**3*sqrt(y), y, evaluate=False)) == \
|
||||
sqrt(x)*sqrt(y)**3*c**3
|
||||
assert powsimp(a**2*a*x**2*y) == a**7
|
||||
|
||||
# symbolic powers work, too
|
||||
b = x**y*y
|
||||
a = b*sqrt(b)
|
||||
assert a.is_Mul is True
|
||||
assert powsimp(a) == sqrt(b)**3
|
||||
|
||||
# as does exp
|
||||
a = x*exp(y*Rational(2, 3))
|
||||
assert powsimp(a*sqrt(a)) == sqrt(a)**3
|
||||
assert powsimp(a**2*sqrt(a)) == sqrt(a)**5
|
||||
assert powsimp(a**2*sqrt(sqrt(a))) == sqrt(sqrt(a))**9
|
||||
|
||||
|
||||
def test_issue_from_PR1599():
|
||||
n1, n2, n3, n4 = symbols('n1 n2 n3 n4', negative=True)
|
||||
assert (powsimp(sqrt(n1)*sqrt(n2)*sqrt(n3)) ==
|
||||
-I*sqrt(-n1)*sqrt(-n2)*sqrt(-n3))
|
||||
assert (powsimp(root(n1, 3)*root(n2, 3)*root(n3, 3)*root(n4, 3)) ==
|
||||
-(-1)**Rational(1, 3)*
|
||||
(-n1)**Rational(1, 3)*(-n2)**Rational(1, 3)*(-n3)**Rational(1, 3)*(-n4)**Rational(1, 3))
|
||||
|
||||
|
||||
def test_issue_10195():
|
||||
a = Symbol('a', integer=True)
|
||||
l = Symbol('l', even=True, nonzero=True)
|
||||
n = Symbol('n', odd=True)
|
||||
e_x = (-1)**(n/2 - S.Half) - (-1)**(n*Rational(3, 2) - S.Half)
|
||||
assert powsimp((-1)**(l/2)) == I**l
|
||||
assert powsimp((-1)**(n/2)) == I**n
|
||||
assert powsimp((-1)**(n*Rational(3, 2))) == -I**n
|
||||
assert powsimp(e_x) == (-1)**(n/2 - S.Half) + (-1)**(n*Rational(3, 2) +
|
||||
S.Half)
|
||||
assert powsimp((-1)**(a*Rational(3, 2))) == (-I)**a
|
||||
|
||||
def test_issue_15709():
|
||||
assert powsimp(3**x*Rational(2, 3)) == 2*3**(x-1)
|
||||
assert powsimp(2*3**x/3) == 2*3**(x-1)
|
||||
|
||||
|
||||
def test_issue_11981():
|
||||
x, y = symbols('x y', commutative=False)
|
||||
assert powsimp((x*y)**2 * (y*x)**2) == (x*y)**2 * (y*x)**2
|
||||
|
||||
|
||||
def test_issue_17524():
|
||||
a = symbols("a", real=True)
|
||||
e = (-1 - a**2)*sqrt(1 + a**2)
|
||||
assert signsimp(powsimp(e)) == signsimp(e) == -(a**2 + 1)**(S(3)/2)
|
||||
|
||||
|
||||
def test_issue_19627():
|
||||
# if you use force the user must verify
|
||||
assert powdenest(sqrt(sin(x)**2), force=True) == sin(x)
|
||||
assert powdenest((x**(S.Half/y))**(2*y), force=True) == x
|
||||
from sympy.core.function import expand_power_base
|
||||
e = 1 - a
|
||||
expr = (exp(z/e)*x**(b/e)*y**((1 - b)/e))**e
|
||||
assert powdenest(expand_power_base(expr, force=True), force=True
|
||||
) == x**b*y**(1 - b)*exp(z)
|
||||
|
||||
|
||||
def test_issue_22546():
|
||||
p1, p2 = symbols('p1, p2', positive=True)
|
||||
ref = powsimp(p1**z/p2**z)
|
||||
e = z + 1
|
||||
ans = ref.subs(z, e)
|
||||
assert ans.is_Pow
|
||||
assert powsimp(p1**e/p2**e) == ans
|
||||
i = symbols('i', integer=True)
|
||||
ref = powsimp(x**i/y**i)
|
||||
e = i + 1
|
||||
ans = ref.subs(i, e)
|
||||
assert ans.is_Pow
|
||||
assert powsimp(x**e/y**e) == ans
|
||||
@@ -0,0 +1,498 @@
|
||||
from sympy.core.add import Add
|
||||
from sympy.core.function import (Derivative, Function, diff)
|
||||
from sympy.core.mul import Mul
|
||||
from sympy.core.numbers import (I, Rational)
|
||||
from sympy.core.power import Pow
|
||||
from sympy.core.singleton import S
|
||||
from sympy.core.symbol import (Symbol, Wild, symbols)
|
||||
from sympy.functions.elementary.complexes import Abs
|
||||
from sympy.functions.elementary.exponential import (exp, log)
|
||||
from sympy.functions.elementary.miscellaneous import (root, sqrt)
|
||||
from sympy.functions.elementary.trigonometric import (cos, sin)
|
||||
from sympy.polys.polytools import factor
|
||||
from sympy.series.order import O
|
||||
from sympy.simplify.radsimp import (collect, collect_const, fraction, radsimp, rcollect)
|
||||
|
||||
from sympy.core.expr import unchanged
|
||||
from sympy.core.mul import _unevaluated_Mul as umul
|
||||
from sympy.simplify.radsimp import (_unevaluated_Add,
|
||||
collect_sqrt, fraction_expand, collect_abs)
|
||||
from sympy.testing.pytest import raises
|
||||
|
||||
from sympy.abc import x, y, z, a, b, c, d
|
||||
|
||||
|
||||
def test_radsimp():
|
||||
r2 = sqrt(2)
|
||||
r3 = sqrt(3)
|
||||
r5 = sqrt(5)
|
||||
r7 = sqrt(7)
|
||||
assert fraction(radsimp(1/r2)) == (sqrt(2), 2)
|
||||
assert radsimp(1/(1 + r2)) == \
|
||||
-1 + sqrt(2)
|
||||
assert radsimp(1/(r2 + r3)) == \
|
||||
-sqrt(2) + sqrt(3)
|
||||
assert fraction(radsimp(1/(1 + r2 + r3))) == \
|
||||
(-sqrt(6) + sqrt(2) + 2, 4)
|
||||
assert fraction(radsimp(1/(r2 + r3 + r5))) == \
|
||||
(-sqrt(30) + 2*sqrt(3) + 3*sqrt(2), 12)
|
||||
assert fraction(radsimp(1/(1 + r2 + r3 + r5))) == (
|
||||
(-34*sqrt(10) - 26*sqrt(15) - 55*sqrt(3) - 61*sqrt(2) + 14*sqrt(30) +
|
||||
93 + 46*sqrt(6) + 53*sqrt(5), 71))
|
||||
assert fraction(radsimp(1/(r2 + r3 + r5 + r7))) == (
|
||||
(-50*sqrt(42) - 133*sqrt(5) - 34*sqrt(70) - 145*sqrt(3) + 22*sqrt(105)
|
||||
+ 185*sqrt(2) + 62*sqrt(30) + 135*sqrt(7), 215))
|
||||
z = radsimp(1/(1 + r2/3 + r3/5 + r5 + r7))
|
||||
assert len((3616791619821680643598*z).args) == 16
|
||||
assert radsimp(1/z) == 1/z
|
||||
assert radsimp(1/z, max_terms=20).expand() == 1 + r2/3 + r3/5 + r5 + r7
|
||||
assert radsimp(1/(r2*3)) == \
|
||||
sqrt(2)/6
|
||||
assert radsimp(1/(r2*a + r3 + r5 + r7)) == (
|
||||
(8*sqrt(2)*a**7 - 8*sqrt(7)*a**6 - 8*sqrt(5)*a**6 - 8*sqrt(3)*a**6 -
|
||||
180*sqrt(2)*a**5 + 8*sqrt(30)*a**5 + 8*sqrt(42)*a**5 + 8*sqrt(70)*a**5
|
||||
- 24*sqrt(105)*a**4 + 84*sqrt(3)*a**4 + 100*sqrt(5)*a**4 +
|
||||
116*sqrt(7)*a**4 - 72*sqrt(70)*a**3 - 40*sqrt(42)*a**3 -
|
||||
8*sqrt(30)*a**3 + 782*sqrt(2)*a**3 - 462*sqrt(3)*a**2 -
|
||||
302*sqrt(7)*a**2 - 254*sqrt(5)*a**2 + 120*sqrt(105)*a**2 -
|
||||
795*sqrt(2)*a - 62*sqrt(30)*a + 82*sqrt(42)*a + 98*sqrt(70)*a -
|
||||
118*sqrt(105) + 59*sqrt(7) + 295*sqrt(5) + 531*sqrt(3))/(16*a**8 -
|
||||
480*a**6 + 3128*a**4 - 6360*a**2 + 3481))
|
||||
assert radsimp(1/(r2*a + r2*b + r3 + r7)) == (
|
||||
(sqrt(2)*a*(a + b)**2 - 5*sqrt(2)*a + sqrt(42)*a + sqrt(2)*b*(a +
|
||||
b)**2 - 5*sqrt(2)*b + sqrt(42)*b - sqrt(7)*(a + b)**2 - sqrt(3)*(a +
|
||||
b)**2 - 2*sqrt(3) + 2*sqrt(7))/(2*a**4 + 8*a**3*b + 12*a**2*b**2 -
|
||||
20*a**2 + 8*a*b**3 - 40*a*b + 2*b**4 - 20*b**2 + 8))
|
||||
assert radsimp(1/(r2*a + r2*b + r2*c + r2*d)) == \
|
||||
sqrt(2)/(2*a + 2*b + 2*c + 2*d)
|
||||
assert radsimp(1/(1 + r2*a + r2*b + r2*c + r2*d)) == (
|
||||
(sqrt(2)*a + sqrt(2)*b + sqrt(2)*c + sqrt(2)*d - 1)/(2*a**2 + 4*a*b +
|
||||
4*a*c + 4*a*d + 2*b**2 + 4*b*c + 4*b*d + 2*c**2 + 4*c*d + 2*d**2 - 1))
|
||||
assert radsimp((y**2 - x)/(y - sqrt(x))) == \
|
||||
sqrt(x) + y
|
||||
assert radsimp(-(y**2 - x)/(y - sqrt(x))) == \
|
||||
-(sqrt(x) + y)
|
||||
assert radsimp(1/(1 - I + a*I)) == \
|
||||
(-I*a + 1 + I)/(a**2 - 2*a + 2)
|
||||
assert radsimp(1/((-x + y)*(x - sqrt(y)))) == \
|
||||
(-x - sqrt(y))/((x - y)*(x**2 - y))
|
||||
e = (3 + 3*sqrt(2))*x*(3*x - 3*sqrt(y))
|
||||
assert radsimp(e) == x*(3 + 3*sqrt(2))*(3*x - 3*sqrt(y))
|
||||
assert radsimp(1/e) == (
|
||||
(-9*x + 9*sqrt(2)*x - 9*sqrt(y) + 9*sqrt(2)*sqrt(y))/(9*x*(9*x**2 -
|
||||
9*y)))
|
||||
assert radsimp(1 + 1/(1 + sqrt(3))) == \
|
||||
Mul(S.Half, -1 + sqrt(3), evaluate=False) + 1
|
||||
A = symbols("A", commutative=False)
|
||||
assert radsimp(x**2 + sqrt(2)*x**2 - sqrt(2)*x*A) == \
|
||||
x**2 + sqrt(2)*x**2 - sqrt(2)*x*A
|
||||
assert radsimp(1/sqrt(5 + 2 * sqrt(6))) == -sqrt(2) + sqrt(3)
|
||||
assert radsimp(1/sqrt(5 + 2 * sqrt(6))**3) == -(-sqrt(3) + sqrt(2))**3
|
||||
|
||||
# issue 6532
|
||||
assert fraction(radsimp(1/sqrt(x))) == (sqrt(x), x)
|
||||
assert fraction(radsimp(1/sqrt(2*x + 3))) == (sqrt(2*x + 3), 2*x + 3)
|
||||
assert fraction(radsimp(1/sqrt(2*(x + 3)))) == (sqrt(2*x + 6), 2*x + 6)
|
||||
|
||||
# issue 5994
|
||||
e = S('-(2 + 2*sqrt(2) + 4*2**(1/4))/'
|
||||
'(1 + 2**(3/4) + 3*2**(1/4) + 3*sqrt(2))')
|
||||
assert radsimp(e).expand() == -2*2**Rational(3, 4) - 2*2**Rational(1, 4) + 2 + 2*sqrt(2)
|
||||
|
||||
# issue 5986 (modifications to radimp didn't initially recognize this so
|
||||
# the test is included here)
|
||||
assert radsimp(1/(-sqrt(5)/2 - S.Half + (-sqrt(5)/2 - S.Half)**2)) == 1
|
||||
|
||||
# from issue 5934
|
||||
eq = (
|
||||
(-240*sqrt(2)*sqrt(sqrt(5) + 5)*sqrt(8*sqrt(5) + 40) -
|
||||
360*sqrt(2)*sqrt(-8*sqrt(5) + 40)*sqrt(-sqrt(5) + 5) -
|
||||
120*sqrt(10)*sqrt(-8*sqrt(5) + 40)*sqrt(-sqrt(5) + 5) +
|
||||
120*sqrt(2)*sqrt(-sqrt(5) + 5)*sqrt(8*sqrt(5) + 40) +
|
||||
120*sqrt(2)*sqrt(-8*sqrt(5) + 40)*sqrt(sqrt(5) + 5) +
|
||||
120*sqrt(10)*sqrt(-sqrt(5) + 5)*sqrt(8*sqrt(5) + 40) +
|
||||
120*sqrt(10)*sqrt(-8*sqrt(5) + 40)*sqrt(sqrt(5) + 5))/(-36000 -
|
||||
7200*sqrt(5) + (12*sqrt(10)*sqrt(sqrt(5) + 5) +
|
||||
24*sqrt(10)*sqrt(-sqrt(5) + 5))**2))
|
||||
assert radsimp(eq) is S.NaN # it's 0/0
|
||||
|
||||
# work with normal form
|
||||
e = 1/sqrt(sqrt(7)/7 + 2*sqrt(2) + 3*sqrt(3) + 5*sqrt(5)) + 3
|
||||
assert radsimp(e) == (
|
||||
-sqrt(sqrt(7) + 14*sqrt(2) + 21*sqrt(3) +
|
||||
35*sqrt(5))*(-11654899*sqrt(35) - 1577436*sqrt(210) - 1278438*sqrt(15)
|
||||
- 1346996*sqrt(10) + 1635060*sqrt(6) + 5709765 + 7539830*sqrt(14) +
|
||||
8291415*sqrt(21))/1300423175 + 3)
|
||||
|
||||
# obey power rules
|
||||
base = sqrt(3) - sqrt(2)
|
||||
assert radsimp(1/base**3) == (sqrt(3) + sqrt(2))**3
|
||||
assert radsimp(1/(-base)**3) == -(sqrt(2) + sqrt(3))**3
|
||||
assert radsimp(1/(-base)**x) == (-base)**(-x)
|
||||
assert radsimp(1/base**x) == (sqrt(2) + sqrt(3))**x
|
||||
assert radsimp(root(1/(-1 - sqrt(2)), -x)) == (-1)**(-1/x)*(1 + sqrt(2))**(1/x)
|
||||
|
||||
# recurse
|
||||
e = cos(1/(1 + sqrt(2)))
|
||||
assert radsimp(e) == cos(-sqrt(2) + 1)
|
||||
assert radsimp(e/2) == cos(-sqrt(2) + 1)/2
|
||||
assert radsimp(1/e) == 1/cos(-sqrt(2) + 1)
|
||||
assert radsimp(2/e) == 2/cos(-sqrt(2) + 1)
|
||||
assert fraction(radsimp(e/sqrt(x))) == (sqrt(x)*cos(-sqrt(2)+1), x)
|
||||
|
||||
# test that symbolic denominators are not processed
|
||||
r = 1 + sqrt(2)
|
||||
assert radsimp(x/r, symbolic=False) == -x*(-sqrt(2) + 1)
|
||||
assert radsimp(x/(y + r), symbolic=False) == x/(y + 1 + sqrt(2))
|
||||
assert radsimp(x/(y + r)/r, symbolic=False) == \
|
||||
-x*(-sqrt(2) + 1)/(y + 1 + sqrt(2))
|
||||
|
||||
# issue 7408
|
||||
eq = sqrt(x)/sqrt(y)
|
||||
assert radsimp(eq) == umul(sqrt(x), sqrt(y), 1/y)
|
||||
assert radsimp(eq, symbolic=False) == eq
|
||||
|
||||
# issue 7498
|
||||
assert radsimp(sqrt(x)/sqrt(y)**3) == umul(sqrt(x), sqrt(y**3), 1/y**3)
|
||||
|
||||
# for coverage
|
||||
eq = sqrt(x)/y**2
|
||||
assert radsimp(eq) == eq
|
||||
|
||||
# handle non-Expr args
|
||||
from sympy.integrals.integrals import Integral
|
||||
eq = Integral(x/(sqrt(2) - 1), (x, 0, 1/(sqrt(2) + 1)))
|
||||
assert radsimp(eq) == Integral((sqrt(2) + 1)*x , (x, 0, sqrt(2) - 1))
|
||||
|
||||
from sympy.sets import FiniteSet
|
||||
eq = FiniteSet(x/(sqrt(2) - 1))
|
||||
assert radsimp(eq) == FiniteSet((sqrt(2) + 1)*x)
|
||||
|
||||
def test_radsimp_issue_3214():
|
||||
c, p = symbols('c p', positive=True)
|
||||
s = sqrt(c**2 - p**2)
|
||||
b = (c + I*p - s)/(c + I*p + s)
|
||||
assert radsimp(b) == -I*(c + I*p - sqrt(c**2 - p**2))**2/(2*c*p)
|
||||
|
||||
|
||||
def test_collect_1():
|
||||
"""Collect with respect to Symbol"""
|
||||
x, y, z, n = symbols('x,y,z,n')
|
||||
assert collect(1, x) == 1
|
||||
assert collect( x + y*x, x ) == x * (1 + y)
|
||||
assert collect( x + x**2, x ) == x + x**2
|
||||
assert collect( x**2 + y*x**2, x ) == (x**2)*(1 + y)
|
||||
assert collect( x**2 + y*x, x ) == x*y + x**2
|
||||
assert collect( 2*x**2 + y*x**2 + 3*x*y, [x] ) == x**2*(2 + y) + 3*x*y
|
||||
assert collect( 2*x**2 + y*x**2 + 3*x*y, [y] ) == 2*x**2 + y*(x**2 + 3*x)
|
||||
|
||||
assert collect( ((1 + y + x)**4).expand(), x) == ((1 + y)**4).expand() + \
|
||||
x*(4*(1 + y)**3).expand() + x**2*(6*(1 + y)**2).expand() + \
|
||||
x**3*(4*(1 + y)).expand() + x**4
|
||||
# symbols can be given as any iterable
|
||||
expr = x + y
|
||||
assert collect(expr, expr.free_symbols) == expr
|
||||
assert collect(x*exp(x) + sin(x)*y + sin(x)*2 + 3*x, x, exact=None
|
||||
) == x*exp(x) + 3*x + (y + 2)*sin(x)
|
||||
assert collect(x*exp(x) + sin(x)*y + sin(x)*2 + 3*x + y*x +
|
||||
y*x*exp(x), x, exact=None
|
||||
) == x*exp(x)*(y + 1) + (3 + y)*x + (y + 2)*sin(x)
|
||||
|
||||
|
||||
def test_collect_2():
|
||||
"""Collect with respect to a sum"""
|
||||
a, b, x = symbols('a,b,x')
|
||||
assert collect(a*(cos(x) + sin(x)) + b*(cos(x) + sin(x)),
|
||||
sin(x) + cos(x)) == (a + b)*(cos(x) + sin(x))
|
||||
|
||||
|
||||
def test_collect_3():
|
||||
"""Collect with respect to a product"""
|
||||
a, b, c = symbols('a,b,c')
|
||||
f = Function('f')
|
||||
x, y, z, n = symbols('x,y,z,n')
|
||||
|
||||
assert collect(-x/8 + x*y, -x) == x*(y - Rational(1, 8))
|
||||
|
||||
assert collect( 1 + x*(y**2), x*y ) == 1 + x*(y**2)
|
||||
assert collect( x*y + a*x*y, x*y) == x*y*(1 + a)
|
||||
assert collect( 1 + x*y + a*x*y, x*y) == 1 + x*y*(1 + a)
|
||||
assert collect(a*x*f(x) + b*(x*f(x)), x*f(x)) == x*(a + b)*f(x)
|
||||
|
||||
assert collect(a*x*log(x) + b*(x*log(x)), x*log(x)) == x*(a + b)*log(x)
|
||||
assert collect(a*x**2*log(x)**2 + b*(x*log(x))**2, x*log(x)) == \
|
||||
x**2*log(x)**2*(a + b)
|
||||
|
||||
# with respect to a product of three symbols
|
||||
assert collect(y*x*z + a*x*y*z, x*y*z) == (1 + a)*x*y*z
|
||||
|
||||
|
||||
def test_collect_4():
|
||||
"""Collect with respect to a power"""
|
||||
a, b, c, x = symbols('a,b,c,x')
|
||||
|
||||
assert collect(a*x**c + b*x**c, x**c) == x**c*(a + b)
|
||||
# issue 6096: 2 stays with c (unless c is integer or x is positive0
|
||||
assert collect(a*x**(2*c) + b*x**(2*c), x**c) == x**(2*c)*(a + b)
|
||||
|
||||
|
||||
def test_collect_5():
|
||||
"""Collect with respect to a tuple"""
|
||||
a, x, y, z, n = symbols('a,x,y,z,n')
|
||||
assert collect(x**2*y**4 + z*(x*y**2)**2 + z + a*z, [x*y**2, z]) in [
|
||||
z*(1 + a + x**2*y**4) + x**2*y**4,
|
||||
z*(1 + a) + x**2*y**4*(1 + z) ]
|
||||
assert collect((1 + (x + y) + (x + y)**2).expand(),
|
||||
[x, y]) == 1 + y + x*(1 + 2*y) + x**2 + y**2
|
||||
|
||||
|
||||
def test_collect_pr19431():
|
||||
"""Unevaluated collect with respect to a product"""
|
||||
a = symbols('a')
|
||||
assert collect(a**2*(a**2 + 1), a**2, evaluate=False)[a**2] == (a**2 + 1)
|
||||
|
||||
|
||||
def test_collect_D():
|
||||
D = Derivative
|
||||
f = Function('f')
|
||||
x, a, b = symbols('x,a,b')
|
||||
fx = D(f(x), x)
|
||||
fxx = D(f(x), x, x)
|
||||
|
||||
assert collect(a*fx + b*fx, fx) == (a + b)*fx
|
||||
assert collect(a*D(fx, x) + b*D(fx, x), fx) == (a + b)*D(fx, x)
|
||||
assert collect(a*fxx + b*fxx, fx) == (a + b)*D(fx, x)
|
||||
# issue 4784
|
||||
assert collect(5*f(x) + 3*fx, fx) == 5*f(x) + 3*fx
|
||||
assert collect(f(x) + f(x)*diff(f(x), x) + x*diff(f(x), x)*f(x), f(x).diff(x)) == \
|
||||
(x*f(x) + f(x))*D(f(x), x) + f(x)
|
||||
assert collect(f(x) + f(x)*diff(f(x), x) + x*diff(f(x), x)*f(x), f(x).diff(x), exact=True) == \
|
||||
(x*f(x) + f(x))*D(f(x), x) + f(x)
|
||||
assert collect(1/f(x) + 1/f(x)*diff(f(x), x) + x*diff(f(x), x)/f(x), f(x).diff(x), exact=True) == \
|
||||
(1/f(x) + x/f(x))*D(f(x), x) + 1/f(x)
|
||||
e = (1 + x*fx + fx)/f(x)
|
||||
assert collect(e.expand(), fx) == fx*(x/f(x) + 1/f(x)) + 1/f(x)
|
||||
|
||||
|
||||
def test_collect_func():
|
||||
f = ((x + a + 1)**3).expand()
|
||||
|
||||
assert collect(f, x) == a**3 + 3*a**2 + 3*a + x**3 + x**2*(3*a + 3) + \
|
||||
x*(3*a**2 + 6*a + 3) + 1
|
||||
assert collect(f, x, factor) == x**3 + 3*x**2*(a + 1) + 3*x*(a + 1)**2 + \
|
||||
(a + 1)**3
|
||||
|
||||
assert collect(f, x, evaluate=False) == {
|
||||
S.One: a**3 + 3*a**2 + 3*a + 1,
|
||||
x: 3*a**2 + 6*a + 3, x**2: 3*a + 3,
|
||||
x**3: 1
|
||||
}
|
||||
|
||||
assert collect(f, x, factor, evaluate=False) == {
|
||||
S.One: (a + 1)**3, x: 3*(a + 1)**2,
|
||||
x**2: umul(S(3), a + 1), x**3: 1}
|
||||
|
||||
|
||||
def test_collect_order():
|
||||
a, b, x, t = symbols('a,b,x,t')
|
||||
|
||||
assert collect(t + t*x + t*x**2 + O(x**3), t) == t*(1 + x + x**2 + O(x**3))
|
||||
assert collect(t + t*x + x**2 + O(x**3), t) == \
|
||||
t*(1 + x + O(x**3)) + x**2 + O(x**3)
|
||||
|
||||
f = a*x + b*x + c*x**2 + d*x**2 + O(x**3)
|
||||
g = x*(a + b) + x**2*(c + d) + O(x**3)
|
||||
|
||||
assert collect(f, x) == g
|
||||
assert collect(f, x, distribute_order_term=False) == g
|
||||
|
||||
f = sin(a + b).series(b, 0, 10)
|
||||
|
||||
assert collect(f, [sin(a), cos(a)]) == \
|
||||
sin(a)*cos(b).series(b, 0, 10) + cos(a)*sin(b).series(b, 0, 10)
|
||||
assert collect(f, [sin(a), cos(a)], distribute_order_term=False) == \
|
||||
sin(a)*cos(b).series(b, 0, 10).removeO() + \
|
||||
cos(a)*sin(b).series(b, 0, 10).removeO() + O(b**10)
|
||||
|
||||
|
||||
def test_rcollect():
|
||||
assert rcollect((x**2*y + x*y + x + y)/(x + y), y) == \
|
||||
(x + y*(1 + x + x**2))/(x + y)
|
||||
assert rcollect(sqrt(-((x + 1)*(y + 1))), z) == sqrt(-((x + 1)*(y + 1)))
|
||||
|
||||
|
||||
def test_collect_D_0():
|
||||
D = Derivative
|
||||
f = Function('f')
|
||||
x, a, b = symbols('x,a,b')
|
||||
fxx = D(f(x), x, x)
|
||||
|
||||
assert collect(a*fxx + b*fxx, fxx) == (a + b)*fxx
|
||||
|
||||
|
||||
def test_collect_Wild():
|
||||
"""Collect with respect to functions with Wild argument"""
|
||||
a, b, x, y = symbols('a b x y')
|
||||
f = Function('f')
|
||||
w1 = Wild('.1')
|
||||
w2 = Wild('.2')
|
||||
assert collect(f(x) + a*f(x), f(w1)) == (1 + a)*f(x)
|
||||
assert collect(f(x, y) + a*f(x, y), f(w1)) == f(x, y) + a*f(x, y)
|
||||
assert collect(f(x, y) + a*f(x, y), f(w1, w2)) == (1 + a)*f(x, y)
|
||||
assert collect(f(x, y) + a*f(x, y), f(w1, w1)) == f(x, y) + a*f(x, y)
|
||||
assert collect(f(x, x) + a*f(x, x), f(w1, w1)) == (1 + a)*f(x, x)
|
||||
assert collect(a*(x + 1)**y + (x + 1)**y, w1**y) == (1 + a)*(x + 1)**y
|
||||
assert collect(a*(x + 1)**y + (x + 1)**y, w1**b) == \
|
||||
a*(x + 1)**y + (x + 1)**y
|
||||
assert collect(a*(x + 1)**y + (x + 1)**y, (x + 1)**w2) == \
|
||||
(1 + a)*(x + 1)**y
|
||||
assert collect(a*(x + 1)**y + (x + 1)**y, w1**w2) == (1 + a)*(x + 1)**y
|
||||
|
||||
|
||||
def test_collect_const():
|
||||
# coverage not provided by above tests
|
||||
assert collect_const(2*sqrt(3) + 4*a*sqrt(5)) == \
|
||||
2*(2*sqrt(5)*a + sqrt(3)) # let the primitive reabsorb
|
||||
assert collect_const(2*sqrt(3) + 4*a*sqrt(5), sqrt(3)) == \
|
||||
2*sqrt(3) + 4*a*sqrt(5)
|
||||
assert collect_const(sqrt(2)*(1 + sqrt(2)) + sqrt(3) + x*sqrt(2)) == \
|
||||
sqrt(2)*(x + 1 + sqrt(2)) + sqrt(3)
|
||||
|
||||
# issue 5290
|
||||
assert collect_const(2*x + 2*y + 1, 2) == \
|
||||
collect_const(2*x + 2*y + 1) == \
|
||||
Add(S.One, Mul(2, x + y, evaluate=False), evaluate=False)
|
||||
assert collect_const(-y - z) == Mul(-1, y + z, evaluate=False)
|
||||
assert collect_const(2*x - 2*y - 2*z, 2) == \
|
||||
Mul(2, x - y - z, evaluate=False)
|
||||
assert collect_const(2*x - 2*y - 2*z, -2) == \
|
||||
_unevaluated_Add(2*x, Mul(-2, y + z, evaluate=False))
|
||||
|
||||
# this is why the content_primitive is used
|
||||
eq = (sqrt(15 + 5*sqrt(2))*x + sqrt(3 + sqrt(2))*y)*2
|
||||
assert collect_sqrt(eq + 2) == \
|
||||
2*sqrt(sqrt(2) + 3)*(sqrt(5)*x + y) + 2
|
||||
|
||||
# issue 16296
|
||||
assert collect_const(a + b + x/2 + y/2) == a + b + Mul(S.Half, x + y, evaluate=False)
|
||||
|
||||
|
||||
def test_issue_13143():
|
||||
f = Function('f')
|
||||
fx = f(x).diff(x)
|
||||
e = f(x) + fx + f(x)*fx
|
||||
# collect function before derivative
|
||||
assert collect(e, Wild('w')) == f(x)*(fx + 1) + fx
|
||||
e = f(x) + f(x)*fx + x*fx*f(x)
|
||||
assert collect(e, fx) == (x*f(x) + f(x))*fx + f(x)
|
||||
assert collect(e, f(x)) == (x*fx + fx + 1)*f(x)
|
||||
e = f(x) + fx + f(x)*fx
|
||||
assert collect(e, [f(x), fx]) == f(x)*(1 + fx) + fx
|
||||
assert collect(e, [fx, f(x)]) == fx*(1 + f(x)) + f(x)
|
||||
|
||||
|
||||
def test_issue_6097():
|
||||
assert collect(a*y**(2.0*x) + b*y**(2.0*x), y**x) == (a + b)*(y**x)**2.0
|
||||
assert collect(a*2**(2.0*x) + b*2**(2.0*x), 2**x) == (a + b)*(2**x)**2.0
|
||||
|
||||
|
||||
def test_fraction_expand():
|
||||
eq = (x + y)*y/x
|
||||
assert eq.expand(frac=True) == fraction_expand(eq) == (x*y + y**2)/x
|
||||
assert eq.expand() == y + y**2/x
|
||||
|
||||
|
||||
def test_fraction():
|
||||
x, y, z = map(Symbol, 'xyz')
|
||||
A = Symbol('A', commutative=False)
|
||||
|
||||
assert fraction(S.Half) == (1, 2)
|
||||
|
||||
assert fraction(x) == (x, 1)
|
||||
assert fraction(1/x) == (1, x)
|
||||
assert fraction(x/y) == (x, y)
|
||||
assert fraction(x/2) == (x, 2)
|
||||
|
||||
assert fraction(x*y/z) == (x*y, z)
|
||||
assert fraction(x/(y*z)) == (x, y*z)
|
||||
|
||||
assert fraction(1/y**2) == (1, y**2)
|
||||
assert fraction(x/y**2) == (x, y**2)
|
||||
|
||||
assert fraction((x**2 + 1)/y) == (x**2 + 1, y)
|
||||
assert fraction(x*(y + 1)/y**7) == (x*(y + 1), y**7)
|
||||
|
||||
assert fraction(exp(-x), exact=True) == (exp(-x), 1)
|
||||
assert fraction((1/(x + y))/2, exact=True) == (1, Mul(2,(x + y), evaluate=False))
|
||||
|
||||
assert fraction(x*A/y) == (x*A, y)
|
||||
assert fraction(x*A**-1/y) == (x*A**-1, y)
|
||||
|
||||
n = symbols('n', negative=True)
|
||||
assert fraction(exp(n)) == (1, exp(-n))
|
||||
assert fraction(exp(-n)) == (exp(-n), 1)
|
||||
|
||||
p = symbols('p', positive=True)
|
||||
assert fraction(exp(-p)*log(p), exact=True) == (exp(-p)*log(p), 1)
|
||||
|
||||
m = Mul(1, 1, S.Half, evaluate=False)
|
||||
assert fraction(m) == (1, 2)
|
||||
assert fraction(m, exact=True) == (Mul(1, 1, evaluate=False), 2)
|
||||
|
||||
m = Mul(1, 1, S.Half, S.Half, Pow(1, -1, evaluate=False), evaluate=False)
|
||||
assert fraction(m) == (1, 4)
|
||||
assert fraction(m, exact=True) == \
|
||||
(Mul(1, 1, evaluate=False), Mul(2, 2, 1, evaluate=False))
|
||||
|
||||
|
||||
def test_issue_5615():
|
||||
aA, Re, a, b, D = symbols('aA Re a b D')
|
||||
e = ((D**3*a + b*aA**3)/Re).expand()
|
||||
assert collect(e, [aA**3/Re, a]) == e
|
||||
|
||||
|
||||
def test_issue_5933():
|
||||
from sympy.geometry.polygon import (Polygon, RegularPolygon)
|
||||
from sympy.simplify.radsimp import denom
|
||||
x = Polygon(*RegularPolygon((0, 0), 1, 5).vertices).centroid.x
|
||||
assert abs(denom(x).n()) > 1e-12
|
||||
assert abs(denom(radsimp(x))) > 1e-12 # in case simplify didn't handle it
|
||||
|
||||
|
||||
def test_issue_14608():
|
||||
a, b = symbols('a b', commutative=False)
|
||||
x, y = symbols('x y')
|
||||
raises(AttributeError, lambda: collect(a*b + b*a, a))
|
||||
assert collect(x*y + y*(x+1), a) == x*y + y*(x+1)
|
||||
assert collect(x*y + y*(x+1) + a*b + b*a, y) == y*(2*x + 1) + a*b + b*a
|
||||
|
||||
|
||||
def test_collect_abs():
|
||||
s = abs(x) + abs(y)
|
||||
assert collect_abs(s) == s
|
||||
assert unchanged(Mul, abs(x), abs(y))
|
||||
ans = Abs(x*y)
|
||||
assert isinstance(ans, Abs)
|
||||
assert collect_abs(abs(x)*abs(y)) == ans
|
||||
assert collect_abs(1 + exp(abs(x)*abs(y))) == 1 + exp(ans)
|
||||
|
||||
# See https://github.com/sympy/sympy/issues/12910
|
||||
p = Symbol('p', positive=True)
|
||||
assert collect_abs(p/abs(1-p)).is_commutative is True
|
||||
|
||||
|
||||
def test_issue_19149():
|
||||
eq = exp(3*x/4)
|
||||
assert collect(eq, exp(x)) == eq
|
||||
|
||||
def test_issue_19719():
|
||||
a, b = symbols('a, b')
|
||||
expr = a**2 * (b + 1) + (7 + 1/b)/a
|
||||
collected = collect(expr, (a**2, 1/a), evaluate=False)
|
||||
# Would return {_Dummy_20**(-2): b + 1, 1/a: 7 + 1/b} without xreplace
|
||||
assert collected == {a**2: b + 1, 1/a: 7 + 1/b}
|
||||
|
||||
|
||||
def test_issue_21355():
|
||||
assert radsimp(1/(x + sqrt(x**2))) == 1/(x + sqrt(x**2))
|
||||
assert radsimp(1/(x - sqrt(x**2))) == 1/(x - sqrt(x**2))
|
||||
@@ -0,0 +1,78 @@
|
||||
from sympy.core.numbers import (Rational, pi)
|
||||
from sympy.functions.elementary.exponential import log
|
||||
from sympy.functions.elementary.miscellaneous import sqrt
|
||||
from sympy.functions.special.error_functions import erf
|
||||
from sympy.polys.domains import GF
|
||||
from sympy.simplify.ratsimp import (ratsimp, ratsimpmodprime)
|
||||
|
||||
from sympy.abc import x, y, z, t, a, b, c, d, e
|
||||
|
||||
|
||||
def test_ratsimp():
|
||||
f, g = 1/x + 1/y, (x + y)/(x*y)
|
||||
|
||||
assert f != g and ratsimp(f) == g
|
||||
|
||||
f, g = 1/(1 + 1/x), 1 - 1/(x + 1)
|
||||
|
||||
assert f != g and ratsimp(f) == g
|
||||
|
||||
f, g = x/(x + y) + y/(x + y), 1
|
||||
|
||||
assert f != g and ratsimp(f) == g
|
||||
|
||||
f, g = -x - y - y**2/(x + y) + x**2/(x + y), -2*y
|
||||
|
||||
assert f != g and ratsimp(f) == g
|
||||
|
||||
f = (a*c*x*y + a*c*z - b*d*x*y - b*d*z - b*t*x*y - b*t*x - b*t*z +
|
||||
e*x)/(x*y + z)
|
||||
G = [a*c - b*d - b*t + (-b*t*x + e*x)/(x*y + z),
|
||||
a*c - b*d - b*t - ( b*t*x - e*x)/(x*y + z)]
|
||||
|
||||
assert f != g and ratsimp(f) in G
|
||||
|
||||
A = sqrt(pi)
|
||||
|
||||
B = log(erf(x) - 1)
|
||||
C = log(erf(x) + 1)
|
||||
|
||||
D = 8 - 8*erf(x)
|
||||
|
||||
f = A*B/D - A*C/D + A*C*erf(x)/D - A*B*erf(x)/D + 2*A/D
|
||||
|
||||
assert ratsimp(f) == A*B/8 - A*C/8 - A/(4*erf(x) - 4)
|
||||
|
||||
|
||||
def test_ratsimpmodprime():
|
||||
a = y**5 + x + y
|
||||
b = x - y
|
||||
F = [x*y**5 - x - y]
|
||||
assert ratsimpmodprime(a/b, F, x, y, order='lex') == \
|
||||
(-x**2 - x*y - x - y) / (-x**2 + x*y)
|
||||
|
||||
a = x + y**2 - 2
|
||||
b = x + y**2 - y - 1
|
||||
F = [x*y - 1]
|
||||
assert ratsimpmodprime(a/b, F, x, y, order='lex') == \
|
||||
(1 + y - x)/(y - x)
|
||||
|
||||
a = 5*x**3 + 21*x**2 + 4*x*y + 23*x + 12*y + 15
|
||||
b = 7*x**3 - y*x**2 + 31*x**2 + 2*x*y + 15*y + 37*x + 21
|
||||
F = [x**2 + y**2 - 1]
|
||||
assert ratsimpmodprime(a/b, F, x, y, order='lex') == \
|
||||
(1 + 5*y - 5*x)/(8*y - 6*x)
|
||||
|
||||
a = x*y - x - 2*y + 4
|
||||
b = x + y**2 - 2*y
|
||||
F = [x - 2, y - 3]
|
||||
assert ratsimpmodprime(a/b, F, x, y, order='lex') == \
|
||||
Rational(2, 5)
|
||||
|
||||
# Test a bug where denominators would be dropped
|
||||
assert ratsimpmodprime(x, [y - 2*x], order='lex') == \
|
||||
y/2
|
||||
|
||||
a = (x**5 + 2*x**4 + 2*x**3 + 2*x**2 + x + 2/x + x**(-2))
|
||||
assert ratsimpmodprime(a, [x + 1], domain=GF(2)) == 1
|
||||
assert ratsimpmodprime(a, [x + 1], domain=GF(3)) == -1
|
||||
@@ -0,0 +1,31 @@
|
||||
from sympy.core.numbers import I
|
||||
from sympy.core.symbol import symbols
|
||||
from sympy.functions.elementary.exponential import exp
|
||||
from sympy.functions.elementary.trigonometric import (cos, cot, sin)
|
||||
from sympy.testing.pytest import _both_exp_pow
|
||||
|
||||
x, y, z, n = symbols('x,y,z,n')
|
||||
|
||||
|
||||
@_both_exp_pow
|
||||
def test_has():
|
||||
assert cot(x).has(x)
|
||||
assert cot(x).has(cot)
|
||||
assert not cot(x).has(sin)
|
||||
assert sin(x).has(x)
|
||||
assert sin(x).has(sin)
|
||||
assert not sin(x).has(cot)
|
||||
assert exp(x).has(exp)
|
||||
|
||||
|
||||
@_both_exp_pow
|
||||
def test_sin_exp_rewrite():
|
||||
assert sin(x).rewrite(sin, exp) == -I/2*(exp(I*x) - exp(-I*x))
|
||||
assert sin(x).rewrite(sin, exp).rewrite(exp, sin) == sin(x)
|
||||
assert cos(x).rewrite(cos, exp).rewrite(exp, cos) == cos(x)
|
||||
assert (sin(5*y) - sin(
|
||||
2*x)).rewrite(sin, exp).rewrite(exp, sin) == sin(5*y) - sin(2*x)
|
||||
assert sin(x + y).rewrite(sin, exp).rewrite(exp, sin) == sin(x + y)
|
||||
assert cos(x + y).rewrite(cos, exp).rewrite(exp, cos) == cos(x + y)
|
||||
# This next test currently passes... not clear whether it should or not?
|
||||
assert cos(x).rewrite(cos, exp).rewrite(exp, sin) == cos(x)
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,204 @@
|
||||
from sympy.core.mul import Mul
|
||||
from sympy.core.numbers import (I, Integer, Rational)
|
||||
from sympy.core.symbol import Symbol
|
||||
from sympy.functions.elementary.miscellaneous import (root, sqrt)
|
||||
from sympy.functions.elementary.trigonometric import cos
|
||||
from sympy.integrals.integrals import Integral
|
||||
from sympy.simplify.sqrtdenest import sqrtdenest
|
||||
from sympy.simplify.sqrtdenest import (
|
||||
_subsets as subsets, _sqrt_numeric_denest)
|
||||
|
||||
r2, r3, r5, r6, r7, r10, r15, r29 = [sqrt(x) for x in (2, 3, 5, 6, 7, 10,
|
||||
15, 29)]
|
||||
|
||||
|
||||
def test_sqrtdenest():
|
||||
d = {sqrt(5 + 2 * r6): r2 + r3,
|
||||
sqrt(5. + 2 * r6): sqrt(5. + 2 * r6),
|
||||
sqrt(5. + 4*sqrt(5 + 2 * r6)): sqrt(5.0 + 4*r2 + 4*r3),
|
||||
sqrt(r2): sqrt(r2),
|
||||
sqrt(5 + r7): sqrt(5 + r7),
|
||||
sqrt(3 + sqrt(5 + 2*r7)):
|
||||
3*r2*(5 + 2*r7)**Rational(1, 4)/(2*sqrt(6 + 3*r7)) +
|
||||
r2*sqrt(6 + 3*r7)/(2*(5 + 2*r7)**Rational(1, 4)),
|
||||
sqrt(3 + 2*r3): 3**Rational(3, 4)*(r6/2 + 3*r2/2)/3}
|
||||
for i in d:
|
||||
assert sqrtdenest(i) == d[i], i
|
||||
|
||||
|
||||
def test_sqrtdenest2():
|
||||
assert sqrtdenest(sqrt(16 - 2*r29 + 2*sqrt(55 - 10*r29))) == \
|
||||
r5 + sqrt(11 - 2*r29)
|
||||
e = sqrt(-r5 + sqrt(-2*r29 + 2*sqrt(-10*r29 + 55) + 16))
|
||||
assert sqrtdenest(e) == root(-2*r29 + 11, 4)
|
||||
r = sqrt(1 + r7)
|
||||
assert sqrtdenest(sqrt(1 + r)) == sqrt(1 + r)
|
||||
e = sqrt(((1 + sqrt(1 + 2*sqrt(3 + r2 + r5)))**2).expand())
|
||||
assert sqrtdenest(e) == 1 + sqrt(1 + 2*sqrt(r2 + r5 + 3))
|
||||
|
||||
assert sqrtdenest(sqrt(5*r3 + 6*r2)) == \
|
||||
sqrt(2)*root(3, 4) + root(3, 4)**3
|
||||
|
||||
assert sqrtdenest(sqrt(((1 + r5 + sqrt(1 + r3))**2).expand())) == \
|
||||
1 + r5 + sqrt(1 + r3)
|
||||
|
||||
assert sqrtdenest(sqrt(((1 + r5 + r7 + sqrt(1 + r3))**2).expand())) == \
|
||||
1 + sqrt(1 + r3) + r5 + r7
|
||||
|
||||
e = sqrt(((1 + cos(2) + cos(3) + sqrt(1 + r3))**2).expand())
|
||||
assert sqrtdenest(e) == cos(3) + cos(2) + 1 + sqrt(1 + r3)
|
||||
|
||||
e = sqrt(-2*r10 + 2*r2*sqrt(-2*r10 + 11) + 14)
|
||||
assert sqrtdenest(e) == sqrt(-2*r10 - 2*r2 + 4*r5 + 14)
|
||||
|
||||
# check that the result is not more complicated than the input
|
||||
z = sqrt(-2*r29 + cos(2) + 2*sqrt(-10*r29 + 55) + 16)
|
||||
assert sqrtdenest(z) == z
|
||||
|
||||
assert sqrtdenest(sqrt(r6 + sqrt(15))) == sqrt(r6 + sqrt(15))
|
||||
|
||||
z = sqrt(15 - 2*sqrt(31) + 2*sqrt(55 - 10*r29))
|
||||
assert sqrtdenest(z) == z
|
||||
|
||||
|
||||
def test_sqrtdenest_rec():
|
||||
assert sqrtdenest(sqrt(-4*sqrt(14) - 2*r6 + 4*sqrt(21) + 33)) == \
|
||||
-r2 + r3 + 2*r7
|
||||
assert sqrtdenest(sqrt(-28*r7 - 14*r5 + 4*sqrt(35) + 82)) == \
|
||||
-7 + r5 + 2*r7
|
||||
assert sqrtdenest(sqrt(6*r2/11 + 2*sqrt(22)/11 + 6*sqrt(11)/11 + 2)) == \
|
||||
sqrt(11)*(r2 + 3 + sqrt(11))/11
|
||||
assert sqrtdenest(sqrt(468*r3 + 3024*r2 + 2912*r6 + 19735)) == \
|
||||
9*r3 + 26 + 56*r6
|
||||
z = sqrt(-490*r3 - 98*sqrt(115) - 98*sqrt(345) - 2107)
|
||||
assert sqrtdenest(z) == sqrt(-1)*(7*r5 + 7*r15 + 7*sqrt(23))
|
||||
z = sqrt(-4*sqrt(14) - 2*r6 + 4*sqrt(21) + 34)
|
||||
assert sqrtdenest(z) == z
|
||||
assert sqrtdenest(sqrt(-8*r2 - 2*r5 + 18)) == -r10 + 1 + r2 + r5
|
||||
assert sqrtdenest(sqrt(8*r2 + 2*r5 - 18)) == \
|
||||
sqrt(-1)*(-r10 + 1 + r2 + r5)
|
||||
assert sqrtdenest(sqrt(8*r2/3 + 14*r5/3 + Rational(154, 9))) == \
|
||||
-r10/3 + r2 + r5 + 3
|
||||
assert sqrtdenest(sqrt(sqrt(2*r6 + 5) + sqrt(2*r7 + 8))) == \
|
||||
sqrt(1 + r2 + r3 + r7)
|
||||
assert sqrtdenest(sqrt(4*r15 + 8*r5 + 12*r3 + 24)) == 1 + r3 + r5 + r15
|
||||
|
||||
w = 1 + r2 + r3 + r5 + r7
|
||||
assert sqrtdenest(sqrt((w**2).expand())) == w
|
||||
z = sqrt((w**2).expand() + 1)
|
||||
assert sqrtdenest(z) == z
|
||||
|
||||
z = sqrt(2*r10 + 6*r2 + 4*r5 + 12 + 10*r15 + 30*r3)
|
||||
assert sqrtdenest(z) == z
|
||||
|
||||
|
||||
def test_issue_6241():
|
||||
z = sqrt( -320 + 32*sqrt(5) + 64*r15)
|
||||
assert sqrtdenest(z) == z
|
||||
|
||||
|
||||
def test_sqrtdenest3():
|
||||
z = sqrt(13 - 2*r10 + 2*r2*sqrt(-2*r10 + 11))
|
||||
assert sqrtdenest(z) == -1 + r2 + r10
|
||||
assert sqrtdenest(z, max_iter=1) == -1 + sqrt(2) + sqrt(10)
|
||||
z = sqrt(sqrt(r2 + 2) + 2)
|
||||
assert sqrtdenest(z) == z
|
||||
assert sqrtdenest(sqrt(-2*r10 + 4*r2*sqrt(-2*r10 + 11) + 20)) == \
|
||||
sqrt(-2*r10 - 4*r2 + 8*r5 + 20)
|
||||
assert sqrtdenest(sqrt((112 + 70*r2) + (46 + 34*r2)*r5)) == \
|
||||
r10 + 5 + 4*r2 + 3*r5
|
||||
z = sqrt(5 + sqrt(2*r6 + 5)*sqrt(-2*r29 + 2*sqrt(-10*r29 + 55) + 16))
|
||||
r = sqrt(-2*r29 + 11)
|
||||
assert sqrtdenest(z) == sqrt(r2*r + r3*r + r10 + r15 + 5)
|
||||
|
||||
n = sqrt(2*r6/7 + 2*r7/7 + 2*sqrt(42)/7 + 2)
|
||||
d = sqrt(16 - 2*r29 + 2*sqrt(55 - 10*r29))
|
||||
assert sqrtdenest(n/d) == r7*(1 + r6 + r7)/(Mul(7, (sqrt(-2*r29 + 11) + r5),
|
||||
evaluate=False))
|
||||
|
||||
|
||||
def test_sqrtdenest4():
|
||||
# see Denest_en.pdf in https://github.com/sympy/sympy/issues/3192
|
||||
z = sqrt(8 - r2*sqrt(5 - r5) - sqrt(3)*(1 + r5))
|
||||
z1 = sqrtdenest(z)
|
||||
c = sqrt(-r5 + 5)
|
||||
z1 = ((-r15*c - r3*c + c + r5*c - r6 - r2 + r10 + sqrt(30))/4).expand()
|
||||
assert sqrtdenest(z) == z1
|
||||
|
||||
z = sqrt(2*r2*sqrt(r2 + 2) + 5*r2 + 4*sqrt(r2 + 2) + 8)
|
||||
assert sqrtdenest(z) == r2 + sqrt(r2 + 2) + 2
|
||||
|
||||
w = 2 + r2 + r3 + (1 + r3)*sqrt(2 + r2 + 5*r3)
|
||||
z = sqrt((w**2).expand())
|
||||
assert sqrtdenest(z) == w.expand()
|
||||
|
||||
|
||||
def test_sqrt_symbolic_denest():
|
||||
x = Symbol('x')
|
||||
z = sqrt(((1 + sqrt(sqrt(2 + x) + 3))**2).expand())
|
||||
assert sqrtdenest(z) == sqrt((1 + sqrt(sqrt(2 + x) + 3))**2)
|
||||
z = sqrt(((1 + sqrt(sqrt(2 + cos(1)) + 3))**2).expand())
|
||||
assert sqrtdenest(z) == 1 + sqrt(sqrt(2 + cos(1)) + 3)
|
||||
z = ((1 + cos(2))**4 + 1).expand()
|
||||
assert sqrtdenest(z) == z
|
||||
z = sqrt(((1 + sqrt(sqrt(2 + cos(3*x)) + 3))**2 + 1).expand())
|
||||
assert sqrtdenest(z) == z
|
||||
c = cos(3)
|
||||
c2 = c**2
|
||||
assert sqrtdenest(sqrt(2*sqrt(1 + r3)*c + c2 + 1 + r3*c2)) == \
|
||||
-1 - sqrt(1 + r3)*c
|
||||
ra = sqrt(1 + r3)
|
||||
z = sqrt(20*ra*sqrt(3 + 3*r3) + 12*r3*ra*sqrt(3 + 3*r3) + 64*r3 + 112)
|
||||
assert sqrtdenest(z) == z
|
||||
|
||||
|
||||
def test_issue_5857():
|
||||
from sympy.abc import x, y
|
||||
z = sqrt(1/(4*r3 + 7) + 1)
|
||||
ans = (r2 + r6)/(r3 + 2)
|
||||
assert sqrtdenest(z) == ans
|
||||
assert sqrtdenest(1 + z) == 1 + ans
|
||||
assert sqrtdenest(Integral(z + 1, (x, 1, 2))) == \
|
||||
Integral(1 + ans, (x, 1, 2))
|
||||
assert sqrtdenest(x + sqrt(y)) == x + sqrt(y)
|
||||
ans = (r2 + r6)/(r3 + 2)
|
||||
assert sqrtdenest(z) == ans
|
||||
assert sqrtdenest(1 + z) == 1 + ans
|
||||
assert sqrtdenest(Integral(z + 1, (x, 1, 2))) == \
|
||||
Integral(1 + ans, (x, 1, 2))
|
||||
assert sqrtdenest(x + sqrt(y)) == x + sqrt(y)
|
||||
|
||||
|
||||
def test_subsets():
|
||||
assert subsets(1) == [[1]]
|
||||
assert subsets(4) == [
|
||||
[1, 0, 0, 0], [0, 1, 0, 0], [1, 1, 0, 0], [0, 0, 1, 0], [1, 0, 1, 0],
|
||||
[0, 1, 1, 0], [1, 1, 1, 0], [0, 0, 0, 1], [1, 0, 0, 1], [0, 1, 0, 1],
|
||||
[1, 1, 0, 1], [0, 0, 1, 1], [1, 0, 1, 1], [0, 1, 1, 1], [1, 1, 1, 1]]
|
||||
|
||||
|
||||
def test_issue_5653():
|
||||
assert sqrtdenest(
|
||||
sqrt(2 + sqrt(2 + sqrt(2)))) == sqrt(2 + sqrt(2 + sqrt(2)))
|
||||
|
||||
def test_issue_12420():
|
||||
assert sqrtdenest((3 - sqrt(2)*sqrt(4 + 3*I) + 3*I)/2) == I
|
||||
e = 3 - sqrt(2)*sqrt(4 + I) + 3*I
|
||||
assert sqrtdenest(e) == e
|
||||
|
||||
def test_sqrt_ratcomb():
|
||||
assert sqrtdenest(sqrt(1 + r3) + sqrt(3 + 3*r3) - sqrt(10 + 6*r3)) == 0
|
||||
|
||||
def test_issue_18041():
|
||||
e = -sqrt(-2 + 2*sqrt(3)*I)
|
||||
assert sqrtdenest(e) == -1 - sqrt(3)*I
|
||||
|
||||
def test_issue_19914():
|
||||
a = Integer(-8)
|
||||
b = Integer(-1)
|
||||
r = Integer(63)
|
||||
d2 = a*a - b*b*r
|
||||
|
||||
assert _sqrt_numeric_denest(a, b, r, d2) == \
|
||||
sqrt(14)*I/2 + 3*sqrt(2)*I/2
|
||||
assert sqrtdenest(sqrt(-8-sqrt(63))) == sqrt(14)*I/2 + 3*sqrt(2)*I/2
|
||||
@@ -0,0 +1,520 @@
|
||||
from itertools import product
|
||||
from sympy.core.function import (Subs, count_ops, diff, expand)
|
||||
from sympy.core.numbers import (E, I, Rational, pi)
|
||||
from sympy.core.singleton import S
|
||||
from sympy.core.symbol import (Symbol, symbols)
|
||||
from sympy.functions.elementary.exponential import (exp, log)
|
||||
from sympy.functions.elementary.hyperbolic import (cosh, coth, sinh, tanh)
|
||||
from sympy.functions.elementary.miscellaneous import sqrt
|
||||
from sympy.functions.elementary.piecewise import Piecewise
|
||||
from sympy.functions.elementary.trigonometric import (cos, cot, sin, tan)
|
||||
from sympy.functions.elementary.trigonometric import (acos, asin, atan2)
|
||||
from sympy.functions.elementary.trigonometric import (asec, acsc)
|
||||
from sympy.functions.elementary.trigonometric import (acot, atan)
|
||||
from sympy.integrals.integrals import integrate
|
||||
from sympy.matrices.dense import Matrix
|
||||
from sympy.simplify.simplify import simplify
|
||||
from sympy.simplify.trigsimp import (exptrigsimp, trigsimp)
|
||||
|
||||
from sympy.testing.pytest import XFAIL
|
||||
|
||||
from sympy.abc import x, y
|
||||
|
||||
|
||||
|
||||
def test_trigsimp1():
|
||||
x, y = symbols('x,y')
|
||||
|
||||
assert trigsimp(1 - sin(x)**2) == cos(x)**2
|
||||
assert trigsimp(1 - cos(x)**2) == sin(x)**2
|
||||
assert trigsimp(sin(x)**2 + cos(x)**2) == 1
|
||||
assert trigsimp(1 + tan(x)**2) == 1/cos(x)**2
|
||||
assert trigsimp(1/cos(x)**2 - 1) == tan(x)**2
|
||||
assert trigsimp(1/cos(x)**2 - tan(x)**2) == 1
|
||||
assert trigsimp(1 + cot(x)**2) == 1/sin(x)**2
|
||||
assert trigsimp(1/sin(x)**2 - 1) == 1/tan(x)**2
|
||||
assert trigsimp(1/sin(x)**2 - cot(x)**2) == 1
|
||||
|
||||
assert trigsimp(5*cos(x)**2 + 5*sin(x)**2) == 5
|
||||
assert trigsimp(5*cos(x/2)**2 + 2*sin(x/2)**2) == 3*cos(x)/2 + Rational(7, 2)
|
||||
|
||||
assert trigsimp(sin(x)/cos(x)) == tan(x)
|
||||
assert trigsimp(2*tan(x)*cos(x)) == 2*sin(x)
|
||||
assert trigsimp(cot(x)**3*sin(x)**3) == cos(x)**3
|
||||
assert trigsimp(y*tan(x)**2/sin(x)**2) == y/cos(x)**2
|
||||
assert trigsimp(cot(x)/cos(x)) == 1/sin(x)
|
||||
|
||||
assert trigsimp(sin(x + y) + sin(x - y)) == 2*sin(x)*cos(y)
|
||||
assert trigsimp(sin(x + y) - sin(x - y)) == 2*sin(y)*cos(x)
|
||||
assert trigsimp(cos(x + y) + cos(x - y)) == 2*cos(x)*cos(y)
|
||||
assert trigsimp(cos(x + y) - cos(x - y)) == -2*sin(x)*sin(y)
|
||||
assert trigsimp(tan(x + y) - tan(x)/(1 - tan(x)*tan(y))) == \
|
||||
sin(y)/(-sin(y)*tan(x) + cos(y)) # -tan(y)/(tan(x)*tan(y) - 1)
|
||||
|
||||
assert trigsimp(sinh(x + y) + sinh(x - y)) == 2*sinh(x)*cosh(y)
|
||||
assert trigsimp(sinh(x + y) - sinh(x - y)) == 2*sinh(y)*cosh(x)
|
||||
assert trigsimp(cosh(x + y) + cosh(x - y)) == 2*cosh(x)*cosh(y)
|
||||
assert trigsimp(cosh(x + y) - cosh(x - y)) == 2*sinh(x)*sinh(y)
|
||||
assert trigsimp(tanh(x + y) - tanh(x)/(1 + tanh(x)*tanh(y))) == \
|
||||
sinh(y)/(sinh(y)*tanh(x) + cosh(y))
|
||||
|
||||
assert trigsimp(cos(0.12345)**2 + sin(0.12345)**2) == 1.0
|
||||
e = 2*sin(x)**2 + 2*cos(x)**2
|
||||
assert trigsimp(log(e)) == log(2)
|
||||
|
||||
|
||||
def test_trigsimp1a():
|
||||
assert trigsimp(sin(2)**2*cos(3)*exp(2)/cos(2)**2) == tan(2)**2*cos(3)*exp(2)
|
||||
assert trigsimp(tan(2)**2*cos(3)*exp(2)*cos(2)**2) == sin(2)**2*cos(3)*exp(2)
|
||||
assert trigsimp(cot(2)*cos(3)*exp(2)*sin(2)) == cos(3)*exp(2)*cos(2)
|
||||
assert trigsimp(tan(2)*cos(3)*exp(2)/sin(2)) == cos(3)*exp(2)/cos(2)
|
||||
assert trigsimp(cot(2)*cos(3)*exp(2)/cos(2)) == cos(3)*exp(2)/sin(2)
|
||||
assert trigsimp(cot(2)*cos(3)*exp(2)*tan(2)) == cos(3)*exp(2)
|
||||
assert trigsimp(sinh(2)*cos(3)*exp(2)/cosh(2)) == tanh(2)*cos(3)*exp(2)
|
||||
assert trigsimp(tanh(2)*cos(3)*exp(2)*cosh(2)) == sinh(2)*cos(3)*exp(2)
|
||||
assert trigsimp(coth(2)*cos(3)*exp(2)*sinh(2)) == cosh(2)*cos(3)*exp(2)
|
||||
assert trigsimp(tanh(2)*cos(3)*exp(2)/sinh(2)) == cos(3)*exp(2)/cosh(2)
|
||||
assert trigsimp(coth(2)*cos(3)*exp(2)/cosh(2)) == cos(3)*exp(2)/sinh(2)
|
||||
assert trigsimp(coth(2)*cos(3)*exp(2)*tanh(2)) == cos(3)*exp(2)
|
||||
|
||||
|
||||
def test_trigsimp2():
|
||||
x, y = symbols('x,y')
|
||||
assert trigsimp(cos(x)**2*sin(y)**2 + cos(x)**2*cos(y)**2 + sin(x)**2,
|
||||
recursive=True) == 1
|
||||
assert trigsimp(sin(x)**2*sin(y)**2 + sin(x)**2*cos(y)**2 + cos(x)**2,
|
||||
recursive=True) == 1
|
||||
assert trigsimp(
|
||||
Subs(x, x, sin(y)**2 + cos(y)**2)) == Subs(x, x, 1)
|
||||
|
||||
|
||||
def test_issue_4373():
|
||||
x = Symbol("x")
|
||||
assert abs(trigsimp(2.0*sin(x)**2 + 2.0*cos(x)**2) - 2.0) < 1e-10
|
||||
|
||||
|
||||
def test_trigsimp3():
|
||||
x, y = symbols('x,y')
|
||||
assert trigsimp(sin(x)/cos(x)) == tan(x)
|
||||
assert trigsimp(sin(x)**2/cos(x)**2) == tan(x)**2
|
||||
assert trigsimp(sin(x)**3/cos(x)**3) == tan(x)**3
|
||||
assert trigsimp(sin(x)**10/cos(x)**10) == tan(x)**10
|
||||
|
||||
assert trigsimp(cos(x)/sin(x)) == 1/tan(x)
|
||||
assert trigsimp(cos(x)**2/sin(x)**2) == 1/tan(x)**2
|
||||
assert trigsimp(cos(x)**10/sin(x)**10) == 1/tan(x)**10
|
||||
|
||||
assert trigsimp(tan(x)) == trigsimp(sin(x)/cos(x))
|
||||
|
||||
|
||||
def test_issue_4661():
|
||||
a, x, y = symbols('a x y')
|
||||
eq = -4*sin(x)**4 + 4*cos(x)**4 - 8*cos(x)**2
|
||||
assert trigsimp(eq) == -4
|
||||
n = sin(x)**6 + 4*sin(x)**4*cos(x)**2 + 5*sin(x)**2*cos(x)**4 + 2*cos(x)**6
|
||||
d = -sin(x)**2 - 2*cos(x)**2
|
||||
assert simplify(n/d) == -1
|
||||
assert trigsimp(-2*cos(x)**2 + cos(x)**4 - sin(x)**4) == -1
|
||||
eq = (- sin(x)**3/4)*cos(x) + (cos(x)**3/4)*sin(x) - sin(2*x)*cos(2*x)/8
|
||||
assert trigsimp(eq) == 0
|
||||
|
||||
|
||||
def test_issue_4494():
|
||||
a, b = symbols('a b')
|
||||
eq = sin(a)**2*sin(b)**2 + cos(a)**2*cos(b)**2*tan(a)**2 + cos(a)**2
|
||||
assert trigsimp(eq) == 1
|
||||
|
||||
|
||||
def test_issue_5948():
|
||||
a, x, y = symbols('a x y')
|
||||
assert trigsimp(diff(integrate(cos(x)/sin(x)**7, x), x)) == \
|
||||
cos(x)/sin(x)**7
|
||||
|
||||
|
||||
def test_issue_4775():
|
||||
a, x, y = symbols('a x y')
|
||||
assert trigsimp(sin(x)*cos(y)+cos(x)*sin(y)) == sin(x + y)
|
||||
assert trigsimp(sin(x)*cos(y)+cos(x)*sin(y)+3) == sin(x + y) + 3
|
||||
|
||||
|
||||
def test_issue_4280():
|
||||
a, x, y = symbols('a x y')
|
||||
assert trigsimp(cos(x)**2 + cos(y)**2*sin(x)**2 + sin(y)**2*sin(x)**2) == 1
|
||||
assert trigsimp(a**2*sin(x)**2 + a**2*cos(y)**2*cos(x)**2 + a**2*cos(x)**2*sin(y)**2) == a**2
|
||||
assert trigsimp(a**2*cos(y)**2*sin(x)**2 + a**2*sin(y)**2*sin(x)**2) == a**2*sin(x)**2
|
||||
|
||||
|
||||
def test_issue_3210():
|
||||
eqs = (sin(2)*cos(3) + sin(3)*cos(2),
|
||||
-sin(2)*sin(3) + cos(2)*cos(3),
|
||||
sin(2)*cos(3) - sin(3)*cos(2),
|
||||
sin(2)*sin(3) + cos(2)*cos(3),
|
||||
sin(2)*sin(3) + cos(2)*cos(3) + cos(2),
|
||||
sinh(2)*cosh(3) + sinh(3)*cosh(2),
|
||||
sinh(2)*sinh(3) + cosh(2)*cosh(3),
|
||||
)
|
||||
assert [trigsimp(e) for e in eqs] == [
|
||||
sin(5),
|
||||
cos(5),
|
||||
-sin(1),
|
||||
cos(1),
|
||||
cos(1) + cos(2),
|
||||
sinh(5),
|
||||
cosh(5),
|
||||
]
|
||||
|
||||
|
||||
def test_trigsimp_issues():
|
||||
a, x, y = symbols('a x y')
|
||||
|
||||
# issue 4625 - factor_terms works, too
|
||||
assert trigsimp(sin(x)**3 + cos(x)**2*sin(x)) == sin(x)
|
||||
|
||||
# issue 5948
|
||||
assert trigsimp(diff(integrate(cos(x)/sin(x)**3, x), x)) == \
|
||||
cos(x)/sin(x)**3
|
||||
assert trigsimp(diff(integrate(sin(x)/cos(x)**3, x), x)) == \
|
||||
sin(x)/cos(x)**3
|
||||
|
||||
# check integer exponents
|
||||
e = sin(x)**y/cos(x)**y
|
||||
assert trigsimp(e) == e
|
||||
assert trigsimp(e.subs(y, 2)) == tan(x)**2
|
||||
assert trigsimp(e.subs(x, 1)) == tan(1)**y
|
||||
|
||||
# check for multiple patterns
|
||||
assert (cos(x)**2/sin(x)**2*cos(y)**2/sin(y)**2).trigsimp() == \
|
||||
1/tan(x)**2/tan(y)**2
|
||||
assert trigsimp(cos(x)/sin(x)*cos(x+y)/sin(x+y)) == \
|
||||
1/(tan(x)*tan(x + y))
|
||||
|
||||
eq = cos(2)*(cos(3) + 1)**2/(cos(3) - 1)**2
|
||||
assert trigsimp(eq) == eq.factor() # factor makes denom (-1 + cos(3))**2
|
||||
assert trigsimp(cos(2)*(cos(3) + 1)**2*(cos(3) - 1)**2) == \
|
||||
cos(2)*sin(3)**4
|
||||
|
||||
# issue 6789; this generates an expression that formerly caused
|
||||
# trigsimp to hang
|
||||
assert cot(x).equals(tan(x)) is False
|
||||
|
||||
# nan or the unchanged expression is ok, but not sin(1)
|
||||
z = cos(x)**2 + sin(x)**2 - 1
|
||||
z1 = tan(x)**2 - 1/cot(x)**2
|
||||
n = (1 + z1/z)
|
||||
assert trigsimp(sin(n)) != sin(1)
|
||||
eq = x*(n - 1) - x*n
|
||||
assert trigsimp(eq) is S.NaN
|
||||
assert trigsimp(eq, recursive=True) is S.NaN
|
||||
assert trigsimp(1).is_Integer
|
||||
|
||||
assert trigsimp(-sin(x)**4 - 2*sin(x)**2*cos(x)**2 - cos(x)**4) == -1
|
||||
|
||||
|
||||
def test_trigsimp_issue_2515():
|
||||
x = Symbol('x')
|
||||
assert trigsimp(x*cos(x)*tan(x)) == x*sin(x)
|
||||
assert trigsimp(-sin(x) + cos(x)*tan(x)) == 0
|
||||
|
||||
|
||||
def test_trigsimp_issue_3826():
|
||||
assert trigsimp(tan(2*x).expand(trig=True)) == tan(2*x)
|
||||
|
||||
|
||||
def test_trigsimp_issue_4032():
|
||||
n = Symbol('n', integer=True, positive=True)
|
||||
assert trigsimp(2**(n/2)*cos(pi*n/4)/2 + 2**(n - 1)/2) == \
|
||||
2**(n/2)*cos(pi*n/4)/2 + 2**n/4
|
||||
|
||||
|
||||
def test_trigsimp_issue_7761():
|
||||
assert trigsimp(cosh(pi/4)) == cosh(pi/4)
|
||||
|
||||
|
||||
def test_trigsimp_noncommutative():
|
||||
x, y = symbols('x,y')
|
||||
A, B = symbols('A,B', commutative=False)
|
||||
|
||||
assert trigsimp(A - A*sin(x)**2) == A*cos(x)**2
|
||||
assert trigsimp(A - A*cos(x)**2) == A*sin(x)**2
|
||||
assert trigsimp(A*sin(x)**2 + A*cos(x)**2) == A
|
||||
assert trigsimp(A + A*tan(x)**2) == A/cos(x)**2
|
||||
assert trigsimp(A/cos(x)**2 - A) == A*tan(x)**2
|
||||
assert trigsimp(A/cos(x)**2 - A*tan(x)**2) == A
|
||||
assert trigsimp(A + A*cot(x)**2) == A/sin(x)**2
|
||||
assert trigsimp(A/sin(x)**2 - A) == A/tan(x)**2
|
||||
assert trigsimp(A/sin(x)**2 - A*cot(x)**2) == A
|
||||
|
||||
assert trigsimp(y*A*cos(x)**2 + y*A*sin(x)**2) == y*A
|
||||
|
||||
assert trigsimp(A*sin(x)/cos(x)) == A*tan(x)
|
||||
assert trigsimp(A*tan(x)*cos(x)) == A*sin(x)
|
||||
assert trigsimp(A*cot(x)**3*sin(x)**3) == A*cos(x)**3
|
||||
assert trigsimp(y*A*tan(x)**2/sin(x)**2) == y*A/cos(x)**2
|
||||
assert trigsimp(A*cot(x)/cos(x)) == A/sin(x)
|
||||
|
||||
assert trigsimp(A*sin(x + y) + A*sin(x - y)) == 2*A*sin(x)*cos(y)
|
||||
assert trigsimp(A*sin(x + y) - A*sin(x - y)) == 2*A*sin(y)*cos(x)
|
||||
assert trigsimp(A*cos(x + y) + A*cos(x - y)) == 2*A*cos(x)*cos(y)
|
||||
assert trigsimp(A*cos(x + y) - A*cos(x - y)) == -2*A*sin(x)*sin(y)
|
||||
|
||||
assert trigsimp(A*sinh(x + y) + A*sinh(x - y)) == 2*A*sinh(x)*cosh(y)
|
||||
assert trigsimp(A*sinh(x + y) - A*sinh(x - y)) == 2*A*sinh(y)*cosh(x)
|
||||
assert trigsimp(A*cosh(x + y) + A*cosh(x - y)) == 2*A*cosh(x)*cosh(y)
|
||||
assert trigsimp(A*cosh(x + y) - A*cosh(x - y)) == 2*A*sinh(x)*sinh(y)
|
||||
|
||||
assert trigsimp(A*cos(0.12345)**2 + A*sin(0.12345)**2) == 1.0*A
|
||||
|
||||
|
||||
def test_hyperbolic_simp():
|
||||
x, y = symbols('x,y')
|
||||
|
||||
assert trigsimp(sinh(x)**2 + 1) == cosh(x)**2
|
||||
assert trigsimp(cosh(x)**2 - 1) == sinh(x)**2
|
||||
assert trigsimp(cosh(x)**2 - sinh(x)**2) == 1
|
||||
assert trigsimp(1 - tanh(x)**2) == 1/cosh(x)**2
|
||||
assert trigsimp(1 - 1/cosh(x)**2) == tanh(x)**2
|
||||
assert trigsimp(tanh(x)**2 + 1/cosh(x)**2) == 1
|
||||
assert trigsimp(coth(x)**2 - 1) == 1/sinh(x)**2
|
||||
assert trigsimp(1/sinh(x)**2 + 1) == 1/tanh(x)**2
|
||||
assert trigsimp(coth(x)**2 - 1/sinh(x)**2) == 1
|
||||
|
||||
assert trigsimp(5*cosh(x)**2 - 5*sinh(x)**2) == 5
|
||||
assert trigsimp(5*cosh(x/2)**2 - 2*sinh(x/2)**2) == 3*cosh(x)/2 + Rational(7, 2)
|
||||
|
||||
assert trigsimp(sinh(x)/cosh(x)) == tanh(x)
|
||||
assert trigsimp(tanh(x)) == trigsimp(sinh(x)/cosh(x))
|
||||
assert trigsimp(cosh(x)/sinh(x)) == 1/tanh(x)
|
||||
assert trigsimp(2*tanh(x)*cosh(x)) == 2*sinh(x)
|
||||
assert trigsimp(coth(x)**3*sinh(x)**3) == cosh(x)**3
|
||||
assert trigsimp(y*tanh(x)**2/sinh(x)**2) == y/cosh(x)**2
|
||||
assert trigsimp(coth(x)/cosh(x)) == 1/sinh(x)
|
||||
|
||||
for a in (pi/6*I, pi/4*I, pi/3*I):
|
||||
assert trigsimp(sinh(a)*cosh(x) + cosh(a)*sinh(x)) == sinh(x + a)
|
||||
assert trigsimp(-sinh(a)*cosh(x) + cosh(a)*sinh(x)) == sinh(x - a)
|
||||
|
||||
e = 2*cosh(x)**2 - 2*sinh(x)**2
|
||||
assert trigsimp(log(e)) == log(2)
|
||||
|
||||
# issue 19535:
|
||||
assert trigsimp(sqrt(cosh(x)**2 - 1)) == sqrt(sinh(x)**2)
|
||||
|
||||
assert trigsimp(cosh(x)**2*cosh(y)**2 - cosh(x)**2*sinh(y)**2 - sinh(x)**2,
|
||||
recursive=True) == 1
|
||||
assert trigsimp(sinh(x)**2*sinh(y)**2 - sinh(x)**2*cosh(y)**2 + cosh(x)**2,
|
||||
recursive=True) == 1
|
||||
|
||||
assert abs(trigsimp(2.0*cosh(x)**2 - 2.0*sinh(x)**2) - 2.0) < 1e-10
|
||||
|
||||
assert trigsimp(sinh(x)**2/cosh(x)**2) == tanh(x)**2
|
||||
assert trigsimp(sinh(x)**3/cosh(x)**3) == tanh(x)**3
|
||||
assert trigsimp(sinh(x)**10/cosh(x)**10) == tanh(x)**10
|
||||
assert trigsimp(cosh(x)**3/sinh(x)**3) == 1/tanh(x)**3
|
||||
|
||||
assert trigsimp(cosh(x)/sinh(x)) == 1/tanh(x)
|
||||
assert trigsimp(cosh(x)**2/sinh(x)**2) == 1/tanh(x)**2
|
||||
assert trigsimp(cosh(x)**10/sinh(x)**10) == 1/tanh(x)**10
|
||||
|
||||
assert trigsimp(x*cosh(x)*tanh(x)) == x*sinh(x)
|
||||
assert trigsimp(-sinh(x) + cosh(x)*tanh(x)) == 0
|
||||
|
||||
assert tan(x) != 1/cot(x) # cot doesn't auto-simplify
|
||||
|
||||
assert trigsimp(tan(x) - 1/cot(x)) == 0
|
||||
assert trigsimp(3*tanh(x)**7 - 2/coth(x)**7) == tanh(x)**7
|
||||
|
||||
|
||||
def test_trigsimp_groebner():
|
||||
from sympy.simplify.trigsimp import trigsimp_groebner
|
||||
|
||||
c = cos(x)
|
||||
s = sin(x)
|
||||
ex = (4*s*c + 12*s + 5*c**3 + 21*c**2 + 23*c + 15)/(
|
||||
-s*c**2 + 2*s*c + 15*s + 7*c**3 + 31*c**2 + 37*c + 21)
|
||||
resnum = (5*s - 5*c + 1)
|
||||
resdenom = (8*s - 6*c)
|
||||
results = [resnum/resdenom, (-resnum)/(-resdenom)]
|
||||
assert trigsimp_groebner(ex) in results
|
||||
assert trigsimp_groebner(s/c, hints=[tan]) == tan(x)
|
||||
assert trigsimp_groebner(c*s) == c*s
|
||||
assert trigsimp((-s + 1)/c + c/(-s + 1),
|
||||
method='groebner') == 2/c
|
||||
assert trigsimp((-s + 1)/c + c/(-s + 1),
|
||||
method='groebner', polynomial=True) == 2/c
|
||||
|
||||
# Test quick=False works
|
||||
assert trigsimp_groebner(ex, hints=[2]) in results
|
||||
assert trigsimp_groebner(ex, hints=[int(2)]) in results
|
||||
|
||||
# test "I"
|
||||
assert trigsimp_groebner(sin(I*x)/cos(I*x), hints=[tanh]) == I*tanh(x)
|
||||
|
||||
# test hyperbolic / sums
|
||||
assert trigsimp_groebner((tanh(x)+tanh(y))/(1+tanh(x)*tanh(y)),
|
||||
hints=[(tanh, x, y)]) == tanh(x + y)
|
||||
|
||||
|
||||
def test_issue_2827_trigsimp_methods():
|
||||
measure1 = lambda expr: len(str(expr))
|
||||
measure2 = lambda expr: -count_ops(expr)
|
||||
# Return the most complicated result
|
||||
expr = (x + 1)/(x + sin(x)**2 + cos(x)**2)
|
||||
ans = Matrix([1])
|
||||
M = Matrix([expr])
|
||||
assert trigsimp(M, method='fu', measure=measure1) == ans
|
||||
assert trigsimp(M, method='fu', measure=measure2) != ans
|
||||
# all methods should work with Basic expressions even if they
|
||||
# aren't Expr
|
||||
M = Matrix.eye(1)
|
||||
assert all(trigsimp(M, method=m) == M for m in
|
||||
'fu matching groebner old'.split())
|
||||
# watch for E in exptrigsimp, not only exp()
|
||||
eq = 1/sqrt(E) + E
|
||||
assert exptrigsimp(eq) == eq
|
||||
|
||||
def test_issue_15129_trigsimp_methods():
|
||||
t1 = Matrix([sin(Rational(1, 50)), cos(Rational(1, 50)), 0])
|
||||
t2 = Matrix([sin(Rational(1, 25)), cos(Rational(1, 25)), 0])
|
||||
t3 = Matrix([cos(Rational(1, 25)), sin(Rational(1, 25)), 0])
|
||||
r1 = t1.dot(t2)
|
||||
r2 = t1.dot(t3)
|
||||
assert trigsimp(r1) == cos(Rational(1, 50))
|
||||
assert trigsimp(r2) == sin(Rational(3, 50))
|
||||
|
||||
def test_exptrigsimp():
|
||||
def valid(a, b):
|
||||
from sympy.core.random import verify_numerically as tn
|
||||
if not (tn(a, b) and a == b):
|
||||
return False
|
||||
return True
|
||||
|
||||
assert exptrigsimp(exp(x) + exp(-x)) == 2*cosh(x)
|
||||
assert exptrigsimp(exp(x) - exp(-x)) == 2*sinh(x)
|
||||
assert exptrigsimp((2*exp(x)-2*exp(-x))/(exp(x)+exp(-x))) == 2*tanh(x)
|
||||
assert exptrigsimp((2*exp(2*x)-2)/(exp(2*x)+1)) == 2*tanh(x)
|
||||
e = [cos(x) + I*sin(x), cos(x) - I*sin(x),
|
||||
cosh(x) - sinh(x), cosh(x) + sinh(x)]
|
||||
ok = [exp(I*x), exp(-I*x), exp(-x), exp(x)]
|
||||
assert all(valid(i, j) for i, j in zip(
|
||||
[exptrigsimp(ei) for ei in e], ok))
|
||||
|
||||
ue = [cos(x) + sin(x), cos(x) - sin(x),
|
||||
cosh(x) + I*sinh(x), cosh(x) - I*sinh(x)]
|
||||
assert [exptrigsimp(ei) == ei for ei in ue]
|
||||
|
||||
res = []
|
||||
ok = [y*tanh(1), 1/(y*tanh(1)), I*y*tan(1), -I/(y*tan(1)),
|
||||
y*tanh(x), 1/(y*tanh(x)), I*y*tan(x), -I/(y*tan(x)),
|
||||
y*tanh(1 + I), 1/(y*tanh(1 + I))]
|
||||
for a in (1, I, x, I*x, 1 + I):
|
||||
w = exp(a)
|
||||
eq = y*(w - 1/w)/(w + 1/w)
|
||||
res.append(simplify(eq))
|
||||
res.append(simplify(1/eq))
|
||||
assert all(valid(i, j) for i, j in zip(res, ok))
|
||||
|
||||
for a in range(1, 3):
|
||||
w = exp(a)
|
||||
e = w + 1/w
|
||||
s = simplify(e)
|
||||
assert s == exptrigsimp(e)
|
||||
assert valid(s, 2*cosh(a))
|
||||
e = w - 1/w
|
||||
s = simplify(e)
|
||||
assert s == exptrigsimp(e)
|
||||
assert valid(s, 2*sinh(a))
|
||||
|
||||
def test_exptrigsimp_noncommutative():
|
||||
a,b = symbols('a b', commutative=False)
|
||||
x = Symbol('x', commutative=True)
|
||||
assert exp(a + x) == exptrigsimp(exp(a)*exp(x))
|
||||
p = exp(a)*exp(b) - exp(b)*exp(a)
|
||||
assert p == exptrigsimp(p) != 0
|
||||
|
||||
def test_powsimp_on_numbers():
|
||||
assert 2**(Rational(1, 3) - 2) == 2**Rational(1, 3)/4
|
||||
|
||||
|
||||
@XFAIL
|
||||
def test_issue_6811_fail():
|
||||
# from doc/src/modules/physics/mechanics/examples.rst, the current `eq`
|
||||
# at Line 576 (in different variables) was formerly the equivalent and
|
||||
# shorter expression given below...it would be nice to get the short one
|
||||
# back again
|
||||
xp, y, x, z = symbols('xp, y, x, z')
|
||||
eq = 4*(-19*sin(x)*y + 5*sin(3*x)*y + 15*cos(2*x)*z - 21*z)*xp/(9*cos(x) - 5*cos(3*x))
|
||||
assert trigsimp(eq) == -2*(2*cos(x)*tan(x)*y + 3*z)*xp/cos(x)
|
||||
|
||||
|
||||
def test_Piecewise():
|
||||
e1 = x*(x + y) - y*(x + y)
|
||||
e2 = sin(x)**2 + cos(x)**2
|
||||
e3 = expand((x + y)*y/x)
|
||||
# s1 = simplify(e1)
|
||||
s2 = simplify(e2)
|
||||
# s3 = simplify(e3)
|
||||
|
||||
# trigsimp tries not to touch non-trig containing args
|
||||
assert trigsimp(Piecewise((e1, e3 < e2), (e3, True))) == \
|
||||
Piecewise((e1, e3 < s2), (e3, True))
|
||||
|
||||
|
||||
def test_issue_21594():
|
||||
assert simplify(exp(Rational(1,2)) + exp(Rational(-1,2))) == cosh(S.Half)*2
|
||||
|
||||
|
||||
def test_trigsimp_old():
|
||||
x, y = symbols('x,y')
|
||||
|
||||
assert trigsimp(1 - sin(x)**2, old=True) == cos(x)**2
|
||||
assert trigsimp(1 - cos(x)**2, old=True) == sin(x)**2
|
||||
assert trigsimp(sin(x)**2 + cos(x)**2, old=True) == 1
|
||||
assert trigsimp(1 + tan(x)**2, old=True) == 1/cos(x)**2
|
||||
assert trigsimp(1/cos(x)**2 - 1, old=True) == tan(x)**2
|
||||
assert trigsimp(1/cos(x)**2 - tan(x)**2, old=True) == 1
|
||||
assert trigsimp(1 + cot(x)**2, old=True) == 1/sin(x)**2
|
||||
assert trigsimp(1/sin(x)**2 - cot(x)**2, old=True) == 1
|
||||
|
||||
assert trigsimp(5*cos(x)**2 + 5*sin(x)**2, old=True) == 5
|
||||
|
||||
assert trigsimp(sin(x)/cos(x), old=True) == tan(x)
|
||||
assert trigsimp(2*tan(x)*cos(x), old=True) == 2*sin(x)
|
||||
assert trigsimp(cot(x)**3*sin(x)**3, old=True) == cos(x)**3
|
||||
assert trigsimp(y*tan(x)**2/sin(x)**2, old=True) == y/cos(x)**2
|
||||
assert trigsimp(cot(x)/cos(x), old=True) == 1/sin(x)
|
||||
|
||||
assert trigsimp(sin(x + y) + sin(x - y), old=True) == 2*sin(x)*cos(y)
|
||||
assert trigsimp(sin(x + y) - sin(x - y), old=True) == 2*sin(y)*cos(x)
|
||||
assert trigsimp(cos(x + y) + cos(x - y), old=True) == 2*cos(x)*cos(y)
|
||||
assert trigsimp(cos(x + y) - cos(x - y), old=True) == -2*sin(x)*sin(y)
|
||||
|
||||
assert trigsimp(sinh(x + y) + sinh(x - y), old=True) == 2*sinh(x)*cosh(y)
|
||||
assert trigsimp(sinh(x + y) - sinh(x - y), old=True) == 2*sinh(y)*cosh(x)
|
||||
assert trigsimp(cosh(x + y) + cosh(x - y), old=True) == 2*cosh(x)*cosh(y)
|
||||
assert trigsimp(cosh(x + y) - cosh(x - y), old=True) == 2*sinh(x)*sinh(y)
|
||||
|
||||
assert trigsimp(cos(0.12345)**2 + sin(0.12345)**2, old=True) == 1.0
|
||||
|
||||
assert trigsimp(sin(x)/cos(x), old=True, method='combined') == tan(x)
|
||||
assert trigsimp(sin(x)/cos(x), old=True, method='groebner') == sin(x)/cos(x)
|
||||
assert trigsimp(sin(x)/cos(x), old=True, method='groebner', hints=[tan]) == tan(x)
|
||||
|
||||
assert trigsimp(1-sin(sin(x)**2+cos(x)**2)**2, old=True, deep=True) == cos(1)**2
|
||||
|
||||
|
||||
def test_trigsimp_inverse():
|
||||
alpha = symbols('alpha')
|
||||
s, c = sin(alpha), cos(alpha)
|
||||
|
||||
for finv in [asin, acos, asec, acsc, atan, acot]:
|
||||
f = finv.inverse(None)
|
||||
assert alpha == trigsimp(finv(f(alpha)), inverse=True)
|
||||
|
||||
# test atan2(cos, sin), atan2(sin, cos), etc...
|
||||
for a, b in [[c, s], [s, c]]:
|
||||
for i, j in product([-1, 1], repeat=2):
|
||||
angle = atan2(i*b, j*a)
|
||||
angle_inverted = trigsimp(angle, inverse=True)
|
||||
assert angle_inverted != angle # assures simplification happened
|
||||
assert sin(angle_inverted) == trigsimp(sin(angle))
|
||||
assert cos(angle_inverted) == trigsimp(cos(angle))
|
||||
@@ -0,0 +1,15 @@
|
||||
from sympy.core.traversal import use as _use
|
||||
from sympy.utilities.decorator import deprecated
|
||||
|
||||
use = deprecated(
|
||||
"""
|
||||
Using use from the sympy.simplify.traversaltools submodule is
|
||||
deprecated.
|
||||
|
||||
Instead, use use from the top-level sympy namespace, like
|
||||
|
||||
sympy.use
|
||||
""",
|
||||
deprecated_since_version="1.10",
|
||||
active_deprecations_target="deprecated-traversal-functions-moved"
|
||||
)(_use)
|
||||
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user