switching to high quality piper tts and added label translations
This commit is contained in:
@@ -0,0 +1,75 @@
|
||||
"""A module for solving all kinds of equations.
|
||||
|
||||
Examples
|
||||
========
|
||||
|
||||
>>> from sympy.solvers import solve
|
||||
>>> from sympy.abc import x
|
||||
>>> solve(((x + 1)**5).expand(), x)
|
||||
[-1]
|
||||
"""
|
||||
from sympy.core.assumptions import check_assumptions, failing_assumptions
|
||||
|
||||
from .solvers import solve, solve_linear_system, solve_linear_system_LU, \
|
||||
solve_undetermined_coeffs, nsolve, solve_linear, checksol, \
|
||||
det_quick, inv_quick
|
||||
|
||||
from sympy.solvers.diophantine.diophantine import diophantine
|
||||
|
||||
from .recurr import rsolve, rsolve_poly, rsolve_ratio, rsolve_hyper
|
||||
|
||||
from .ode import checkodesol, classify_ode, dsolve, \
|
||||
homogeneous_order
|
||||
|
||||
from .polysys import solve_poly_system, solve_triangulated, factor_system
|
||||
|
||||
from .pde import pde_separate, pde_separate_add, pde_separate_mul, \
|
||||
pdsolve, classify_pde, checkpdesol
|
||||
|
||||
from .deutils import ode_order
|
||||
|
||||
from .inequalities import reduce_inequalities, reduce_abs_inequality, \
|
||||
reduce_abs_inequalities, solve_poly_inequality, solve_rational_inequalities, solve_univariate_inequality
|
||||
|
||||
from .decompogen import decompogen
|
||||
|
||||
from .solveset import solveset, linsolve, linear_eq_to_matrix, nonlinsolve, substitution
|
||||
|
||||
from .simplex import lpmin, lpmax, linprog
|
||||
|
||||
# This is here instead of sympy/sets/__init__.py to avoid circular import issues
|
||||
from ..core.singleton import S
|
||||
Complexes = S.Complexes
|
||||
|
||||
__all__ = [
|
||||
'solve', 'solve_linear_system', 'solve_linear_system_LU',
|
||||
'solve_undetermined_coeffs', 'nsolve', 'solve_linear', 'checksol',
|
||||
'det_quick', 'inv_quick', 'check_assumptions', 'failing_assumptions',
|
||||
|
||||
'diophantine',
|
||||
|
||||
'rsolve', 'rsolve_poly', 'rsolve_ratio', 'rsolve_hyper',
|
||||
|
||||
'checkodesol', 'classify_ode', 'dsolve', 'homogeneous_order',
|
||||
|
||||
'solve_poly_system', 'solve_triangulated', 'factor_system',
|
||||
|
||||
'pde_separate', 'pde_separate_add', 'pde_separate_mul', 'pdsolve',
|
||||
'classify_pde', 'checkpdesol',
|
||||
|
||||
'ode_order',
|
||||
|
||||
'reduce_inequalities', 'reduce_abs_inequality', 'reduce_abs_inequalities',
|
||||
'solve_poly_inequality', 'solve_rational_inequalities',
|
||||
'solve_univariate_inequality',
|
||||
|
||||
'decompogen',
|
||||
|
||||
'solveset', 'linsolve', 'linear_eq_to_matrix', 'nonlinsolve',
|
||||
'substitution',
|
||||
|
||||
# This is here instead of sympy/sets/__init__.py to avoid circular import issues
|
||||
'Complexes',
|
||||
|
||||
'lpmin', 'lpmax', 'linprog'
|
||||
]
|
||||
@@ -0,0 +1,12 @@
|
||||
from sympy.core.symbol import Symbol
|
||||
from sympy.matrices.dense import (eye, zeros)
|
||||
from sympy.solvers.solvers import solve_linear_system
|
||||
|
||||
N = 8
|
||||
M = zeros(N, N + 1)
|
||||
M[:, :N] = eye(N)
|
||||
S = [Symbol('A%i' % i) for i in range(N)]
|
||||
|
||||
|
||||
def timeit_linsolve_trivial():
|
||||
solve_linear_system(M, *S)
|
||||
@@ -0,0 +1,509 @@
|
||||
from sympy.core.add import Add
|
||||
from sympy.core.exprtools import factor_terms
|
||||
from sympy.core.function import expand_log, _mexpand
|
||||
from sympy.core.power import Pow
|
||||
from sympy.core.singleton import S
|
||||
from sympy.core.sorting import ordered
|
||||
from sympy.core.symbol import Dummy
|
||||
from sympy.functions.elementary.exponential import (LambertW, exp, log)
|
||||
from sympy.functions.elementary.miscellaneous import root
|
||||
from sympy.polys.polyroots import roots
|
||||
from sympy.polys.polytools import Poly, factor
|
||||
from sympy.simplify.simplify import separatevars
|
||||
from sympy.simplify.radsimp import collect
|
||||
from sympy.simplify.simplify import powsimp
|
||||
from sympy.solvers.solvers import solve, _invert
|
||||
from sympy.utilities.iterables import uniq
|
||||
|
||||
|
||||
def _filtered_gens(poly, symbol):
|
||||
"""process the generators of ``poly``, returning the set of generators that
|
||||
have ``symbol``. If there are two generators that are inverses of each other,
|
||||
prefer the one that has no denominator.
|
||||
|
||||
Examples
|
||||
========
|
||||
|
||||
>>> from sympy.solvers.bivariate import _filtered_gens
|
||||
>>> from sympy import Poly, exp
|
||||
>>> from sympy.abc import x
|
||||
>>> _filtered_gens(Poly(x + 1/x + exp(x)), x)
|
||||
{x, exp(x)}
|
||||
|
||||
"""
|
||||
# TODO it would be good to pick the smallest divisible power
|
||||
# instead of the base for something like x**4 + x**2 -->
|
||||
# return x**2 not x
|
||||
gens = {g for g in poly.gens if symbol in g.free_symbols}
|
||||
for g in list(gens):
|
||||
ag = 1/g
|
||||
if g in gens and ag in gens:
|
||||
if ag.as_numer_denom()[1] is not S.One:
|
||||
g = ag
|
||||
gens.remove(g)
|
||||
return gens
|
||||
|
||||
|
||||
def _mostfunc(lhs, func, X=None):
|
||||
"""Returns the term in lhs which contains the most of the
|
||||
func-type things e.g. log(log(x)) wins over log(x) if both terms appear.
|
||||
|
||||
``func`` can be a function (exp, log, etc...) or any other SymPy object,
|
||||
like Pow.
|
||||
|
||||
If ``X`` is not ``None``, then the function returns the term composed with the
|
||||
most ``func`` having the specified variable.
|
||||
|
||||
Examples
|
||||
========
|
||||
|
||||
>>> from sympy.solvers.bivariate import _mostfunc
|
||||
>>> from sympy import exp
|
||||
>>> from sympy.abc import x, y
|
||||
>>> _mostfunc(exp(x) + exp(exp(x) + 2), exp)
|
||||
exp(exp(x) + 2)
|
||||
>>> _mostfunc(exp(x) + exp(exp(y) + 2), exp)
|
||||
exp(exp(y) + 2)
|
||||
>>> _mostfunc(exp(x) + exp(exp(y) + 2), exp, x)
|
||||
exp(x)
|
||||
>>> _mostfunc(x, exp, x) is None
|
||||
True
|
||||
>>> _mostfunc(exp(x) + exp(x*y), exp, x)
|
||||
exp(x)
|
||||
"""
|
||||
fterms = [tmp for tmp in lhs.atoms(func) if (not X or
|
||||
X.is_Symbol and X in tmp.free_symbols or
|
||||
not X.is_Symbol and tmp.has(X))]
|
||||
if len(fterms) == 1:
|
||||
return fterms[0]
|
||||
elif fterms:
|
||||
return max(list(ordered(fterms)), key=lambda x: x.count(func))
|
||||
return None
|
||||
|
||||
|
||||
def _linab(arg, symbol):
|
||||
"""Return ``a, b, X`` assuming ``arg`` can be written as ``a*X + b``
|
||||
where ``X`` is a symbol-dependent factor and ``a`` and ``b`` are
|
||||
independent of ``symbol``.
|
||||
|
||||
Examples
|
||||
========
|
||||
|
||||
>>> from sympy.solvers.bivariate import _linab
|
||||
>>> from sympy.abc import x, y
|
||||
>>> from sympy import exp, S
|
||||
>>> _linab(S(2), x)
|
||||
(2, 0, 1)
|
||||
>>> _linab(2*x, x)
|
||||
(2, 0, x)
|
||||
>>> _linab(y + y*x + 2*x, x)
|
||||
(y + 2, y, x)
|
||||
>>> _linab(3 + 2*exp(x), x)
|
||||
(2, 3, exp(x))
|
||||
"""
|
||||
arg = factor_terms(arg.expand())
|
||||
ind, dep = arg.as_independent(symbol)
|
||||
if arg.is_Mul and dep.is_Add:
|
||||
a, b, x = _linab(dep, symbol)
|
||||
return ind*a, ind*b, x
|
||||
if not arg.is_Add:
|
||||
b = 0
|
||||
a, x = ind, dep
|
||||
else:
|
||||
b = ind
|
||||
a, x = separatevars(dep).as_independent(symbol, as_Add=False)
|
||||
if x.could_extract_minus_sign():
|
||||
a = -a
|
||||
x = -x
|
||||
return a, b, x
|
||||
|
||||
|
||||
def _lambert(eq, x):
|
||||
"""
|
||||
Given an expression assumed to be in the form
|
||||
``F(X, a..f) = a*log(b*X + c) + d*X + f = 0``
|
||||
where X = g(x) and x = g^-1(X), return the Lambert solution,
|
||||
``x = g^-1(-c/b + (a/d)*W(d/(a*b)*exp(c*d/a/b)*exp(-f/a)))``.
|
||||
"""
|
||||
eq = _mexpand(expand_log(eq))
|
||||
mainlog = _mostfunc(eq, log, x)
|
||||
if not mainlog:
|
||||
return [] # violated assumptions
|
||||
other = eq.subs(mainlog, 0)
|
||||
if isinstance(-other, log):
|
||||
eq = (eq - other).subs(mainlog, mainlog.args[0])
|
||||
mainlog = mainlog.args[0]
|
||||
if not isinstance(mainlog, log):
|
||||
return [] # violated assumptions
|
||||
other = -(-other).args[0]
|
||||
eq += other
|
||||
if x not in other.free_symbols:
|
||||
return [] # violated assumptions
|
||||
d, f, X2 = _linab(other, x)
|
||||
logterm = collect(eq - other, mainlog)
|
||||
a = logterm.as_coefficient(mainlog)
|
||||
if a is None or x in a.free_symbols:
|
||||
return [] # violated assumptions
|
||||
logarg = mainlog.args[0]
|
||||
b, c, X1 = _linab(logarg, x)
|
||||
if X1 != X2:
|
||||
return [] # violated assumptions
|
||||
|
||||
# invert the generator X1 so we have x(u)
|
||||
u = Dummy('rhs')
|
||||
xusolns = solve(X1 - u, x)
|
||||
|
||||
# There are infinitely many branches for LambertW
|
||||
# but only branches for k = -1 and 0 might be real. The k = 0
|
||||
# branch is real and the k = -1 branch is real if the LambertW argument
|
||||
# in in range [-1/e, 0]. Since `solve` does not return infinite
|
||||
# solutions we will only include the -1 branch if it tests as real.
|
||||
# Otherwise, inclusion of any LambertW in the solution indicates to
|
||||
# the user that there are imaginary solutions corresponding to
|
||||
# different k values.
|
||||
lambert_real_branches = [-1, 0]
|
||||
sol = []
|
||||
|
||||
# solution of the given Lambert equation is like
|
||||
# sol = -c/b + (a/d)*LambertW(arg, k),
|
||||
# where arg = d/(a*b)*exp((c*d-b*f)/a/b) and k in lambert_real_branches.
|
||||
# Instead of considering the single arg, `d/(a*b)*exp((c*d-b*f)/a/b)`,
|
||||
# the individual `p` roots obtained when writing `exp((c*d-b*f)/a/b)`
|
||||
# as `exp(A/p) = exp(A)**(1/p)`, where `p` is an Integer, are used.
|
||||
|
||||
# calculating args for LambertW
|
||||
num, den = ((c*d-b*f)/a/b).as_numer_denom()
|
||||
p, den = den.as_coeff_Mul()
|
||||
e = exp(num/den)
|
||||
t = Dummy('t')
|
||||
args = [d/(a*b)*t for t in roots(t**p - e, t).keys()]
|
||||
|
||||
# calculating solutions from args
|
||||
for arg in args:
|
||||
for k in lambert_real_branches:
|
||||
w = LambertW(arg, k)
|
||||
if k and not w.is_real:
|
||||
continue
|
||||
rhs = -c/b + (a/d)*w
|
||||
|
||||
sol.extend(xu.subs(u, rhs) for xu in xusolns)
|
||||
return sol
|
||||
|
||||
|
||||
def _solve_lambert(f, symbol, gens):
|
||||
"""Return solution to ``f`` if it is a Lambert-type expression
|
||||
else raise NotImplementedError.
|
||||
|
||||
For ``f(X, a..f) = a*log(b*X + c) + d*X - f = 0`` the solution
|
||||
for ``X`` is ``X = -c/b + (a/d)*W(d/(a*b)*exp(c*d/a/b)*exp(f/a))``.
|
||||
There are a variety of forms for `f(X, a..f)` as enumerated below:
|
||||
|
||||
1a1)
|
||||
if B**B = R for R not in [0, 1] (since those cases would already
|
||||
be solved before getting here) then log of both sides gives
|
||||
log(B) + log(log(B)) = log(log(R)) and
|
||||
X = log(B), a = 1, b = 1, c = 0, d = 1, f = log(log(R))
|
||||
1a2)
|
||||
if B*(b*log(B) + c)**a = R then log of both sides gives
|
||||
log(B) + a*log(b*log(B) + c) = log(R) and
|
||||
X = log(B), d=1, f=log(R)
|
||||
1b)
|
||||
if a*log(b*B + c) + d*B = R and
|
||||
X = B, f = R
|
||||
2a)
|
||||
if (b*B + c)*exp(d*B + g) = R then log of both sides gives
|
||||
log(b*B + c) + d*B + g = log(R) and
|
||||
X = B, a = 1, f = log(R) - g
|
||||
2b)
|
||||
if g*exp(d*B + h) - b*B = c then the log form is
|
||||
log(g) + d*B + h - log(b*B + c) = 0 and
|
||||
X = B, a = -1, f = -h - log(g)
|
||||
3)
|
||||
if d*p**(a*B + g) - b*B = c then the log form is
|
||||
log(d) + (a*B + g)*log(p) - log(b*B + c) = 0 and
|
||||
X = B, a = -1, d = a*log(p), f = -log(d) - g*log(p)
|
||||
"""
|
||||
|
||||
def _solve_even_degree_expr(expr, t, symbol):
|
||||
"""Return the unique solutions of equations derived from
|
||||
``expr`` by replacing ``t`` with ``+/- symbol``.
|
||||
|
||||
Parameters
|
||||
==========
|
||||
|
||||
expr : Expr
|
||||
The expression which includes a dummy variable t to be
|
||||
replaced with +symbol and -symbol.
|
||||
|
||||
symbol : Symbol
|
||||
The symbol for which a solution is being sought.
|
||||
|
||||
Returns
|
||||
=======
|
||||
|
||||
List of unique solution of the two equations generated by
|
||||
replacing ``t`` with positive and negative ``symbol``.
|
||||
|
||||
Notes
|
||||
=====
|
||||
|
||||
If ``expr = 2*log(t) + x/2` then solutions for
|
||||
``2*log(x) + x/2 = 0`` and ``2*log(-x) + x/2 = 0`` are
|
||||
returned by this function. Though this may seem
|
||||
counter-intuitive, one must note that the ``expr`` being
|
||||
solved here has been derived from a different expression. For
|
||||
an expression like ``eq = x**2*g(x) = 1``, if we take the
|
||||
log of both sides we obtain ``log(x**2) + log(g(x)) = 0``. If
|
||||
x is positive then this simplifies to
|
||||
``2*log(x) + log(g(x)) = 0``; the Lambert-solving routines will
|
||||
return solutions for this, but we must also consider the
|
||||
solutions for ``2*log(-x) + log(g(x))`` since those must also
|
||||
be a solution of ``eq`` which has the same value when the ``x``
|
||||
in ``x**2`` is negated. If `g(x)` does not have even powers of
|
||||
symbol then we do not want to replace the ``x`` there with
|
||||
``-x``. So the role of the ``t`` in the expression received by
|
||||
this function is to mark where ``+/-x`` should be inserted
|
||||
before obtaining the Lambert solutions.
|
||||
|
||||
"""
|
||||
nlhs, plhs = [
|
||||
expr.xreplace({t: sgn*symbol}) for sgn in (-1, 1)]
|
||||
sols = _solve_lambert(nlhs, symbol, gens)
|
||||
if plhs != nlhs:
|
||||
sols.extend(_solve_lambert(plhs, symbol, gens))
|
||||
# uniq is needed for a case like
|
||||
# 2*log(t) - log(-z**2) + log(z + log(x) + log(z))
|
||||
# where substituting t with +/-x gives all the same solution;
|
||||
# uniq, rather than list(set()), is used to maintain canonical
|
||||
# order
|
||||
return list(uniq(sols))
|
||||
|
||||
nrhs, lhs = f.as_independent(symbol, as_Add=True)
|
||||
rhs = -nrhs
|
||||
|
||||
lamcheck = [tmp for tmp in gens
|
||||
if (tmp.func in [exp, log] or
|
||||
(tmp.is_Pow and symbol in tmp.exp.free_symbols))]
|
||||
if not lamcheck:
|
||||
raise NotImplementedError()
|
||||
|
||||
if lhs.is_Add or lhs.is_Mul:
|
||||
# replacing all even_degrees of symbol with dummy variable t
|
||||
# since these will need special handling; non-Add/Mul do not
|
||||
# need this handling
|
||||
t = Dummy('t', **symbol.assumptions0)
|
||||
lhs = lhs.replace(
|
||||
lambda i: # find symbol**even
|
||||
i.is_Pow and i.base == symbol and i.exp.is_even,
|
||||
lambda i: # replace t**even
|
||||
t**i.exp)
|
||||
|
||||
if lhs.is_Add and lhs.has(t):
|
||||
t_indep = lhs.subs(t, 0)
|
||||
t_term = lhs - t_indep
|
||||
_rhs = rhs - t_indep
|
||||
if not t_term.is_Add and _rhs and not (
|
||||
t_term.has(S.ComplexInfinity, S.NaN)):
|
||||
eq = expand_log(log(t_term) - log(_rhs))
|
||||
return _solve_even_degree_expr(eq, t, symbol)
|
||||
elif lhs.is_Mul and rhs:
|
||||
# this needs to happen whether t is present or not
|
||||
lhs = expand_log(log(lhs), force=True)
|
||||
rhs = log(rhs)
|
||||
if lhs.has(t) and lhs.is_Add:
|
||||
# it expanded from Mul to Add
|
||||
eq = lhs - rhs
|
||||
return _solve_even_degree_expr(eq, t, symbol)
|
||||
|
||||
# restore symbol in lhs
|
||||
lhs = lhs.xreplace({t: symbol})
|
||||
|
||||
lhs = powsimp(factor(lhs, deep=True))
|
||||
|
||||
# make sure we have inverted as completely as possible
|
||||
r = Dummy()
|
||||
i, lhs = _invert(lhs - r, symbol)
|
||||
rhs = i.xreplace({r: rhs})
|
||||
|
||||
# For the first forms:
|
||||
#
|
||||
# 1a1) B**B = R will arrive here as B*log(B) = log(R)
|
||||
# lhs is Mul so take log of both sides:
|
||||
# log(B) + log(log(B)) = log(log(R))
|
||||
# 1a2) B*(b*log(B) + c)**a = R will arrive unchanged so
|
||||
# lhs is Mul, so take log of both sides:
|
||||
# log(B) + a*log(b*log(B) + c) = log(R)
|
||||
# 1b) d*log(a*B + b) + c*B = R will arrive unchanged so
|
||||
# lhs is Add, so isolate c*B and expand log of both sides:
|
||||
# log(c) + log(B) = log(R - d*log(a*B + b))
|
||||
|
||||
soln = []
|
||||
if not soln:
|
||||
mainlog = _mostfunc(lhs, log, symbol)
|
||||
if mainlog:
|
||||
if lhs.is_Mul and rhs != 0:
|
||||
soln = _lambert(log(lhs) - log(rhs), symbol)
|
||||
elif lhs.is_Add:
|
||||
other = lhs.subs(mainlog, 0)
|
||||
if other and not other.is_Add and [
|
||||
tmp for tmp in other.atoms(Pow)
|
||||
if symbol in tmp.free_symbols]:
|
||||
if not rhs:
|
||||
diff = log(other) - log(other - lhs)
|
||||
else:
|
||||
diff = log(lhs - other) - log(rhs - other)
|
||||
soln = _lambert(expand_log(diff), symbol)
|
||||
else:
|
||||
#it's ready to go
|
||||
soln = _lambert(lhs - rhs, symbol)
|
||||
|
||||
# For the next forms,
|
||||
#
|
||||
# collect on main exp
|
||||
# 2a) (b*B + c)*exp(d*B + g) = R
|
||||
# lhs is mul, so take log of both sides:
|
||||
# log(b*B + c) + d*B = log(R) - g
|
||||
# 2b) g*exp(d*B + h) - b*B = R
|
||||
# lhs is add, so add b*B to both sides,
|
||||
# take the log of both sides and rearrange to give
|
||||
# log(R + b*B) - d*B = log(g) + h
|
||||
|
||||
if not soln:
|
||||
mainexp = _mostfunc(lhs, exp, symbol)
|
||||
if mainexp:
|
||||
lhs = collect(lhs, mainexp)
|
||||
if lhs.is_Mul and rhs != 0:
|
||||
soln = _lambert(expand_log(log(lhs) - log(rhs)), symbol)
|
||||
elif lhs.is_Add:
|
||||
# move all but mainexp-containing term to rhs
|
||||
other = lhs.subs(mainexp, 0)
|
||||
mainterm = lhs - other
|
||||
rhs = rhs - other
|
||||
if (mainterm.could_extract_minus_sign() and
|
||||
rhs.could_extract_minus_sign()):
|
||||
mainterm *= -1
|
||||
rhs *= -1
|
||||
diff = log(mainterm) - log(rhs)
|
||||
soln = _lambert(expand_log(diff), symbol)
|
||||
|
||||
# For the last form:
|
||||
#
|
||||
# 3) d*p**(a*B + g) - b*B = c
|
||||
# collect on main pow, add b*B to both sides,
|
||||
# take log of both sides and rearrange to give
|
||||
# a*B*log(p) - log(b*B + c) = -log(d) - g*log(p)
|
||||
if not soln:
|
||||
mainpow = _mostfunc(lhs, Pow, symbol)
|
||||
if mainpow and symbol in mainpow.exp.free_symbols:
|
||||
lhs = collect(lhs, mainpow)
|
||||
if lhs.is_Mul and rhs != 0:
|
||||
# b*B = 0
|
||||
soln = _lambert(expand_log(log(lhs) - log(rhs)), symbol)
|
||||
elif lhs.is_Add:
|
||||
# move all but mainpow-containing term to rhs
|
||||
other = lhs.subs(mainpow, 0)
|
||||
mainterm = lhs - other
|
||||
rhs = rhs - other
|
||||
diff = log(mainterm) - log(rhs)
|
||||
soln = _lambert(expand_log(diff), symbol)
|
||||
|
||||
if not soln:
|
||||
raise NotImplementedError('%s does not appear to have a solution in '
|
||||
'terms of LambertW' % f)
|
||||
|
||||
return list(ordered(soln))
|
||||
|
||||
|
||||
def bivariate_type(f, x, y, *, first=True):
|
||||
"""Given an expression, f, 3 tests will be done to see what type
|
||||
of composite bivariate it might be, options for u(x, y) are::
|
||||
|
||||
x*y
|
||||
x+y
|
||||
x*y+x
|
||||
x*y+y
|
||||
|
||||
If it matches one of these types, ``u(x, y)``, ``P(u)`` and dummy
|
||||
variable ``u`` will be returned. Solving ``P(u)`` for ``u`` and
|
||||
equating the solutions to ``u(x, y)`` and then solving for ``x`` or
|
||||
``y`` is equivalent to solving the original expression for ``x`` or
|
||||
``y``. If ``x`` and ``y`` represent two functions in the same
|
||||
variable, e.g. ``x = g(t)`` and ``y = h(t)``, then if ``u(x, y) - p``
|
||||
can be solved for ``t`` then these represent the solutions to
|
||||
``P(u) = 0`` when ``p`` are the solutions of ``P(u) = 0``.
|
||||
|
||||
Only positive values of ``u`` are considered.
|
||||
|
||||
Examples
|
||||
========
|
||||
|
||||
>>> from sympy import solve
|
||||
>>> from sympy.solvers.bivariate import bivariate_type
|
||||
>>> from sympy.abc import x, y
|
||||
>>> eq = (x**2 - 3).subs(x, x + y)
|
||||
>>> bivariate_type(eq, x, y)
|
||||
(x + y, _u**2 - 3, _u)
|
||||
>>> uxy, pu, u = _
|
||||
>>> usol = solve(pu, u); usol
|
||||
[sqrt(3)]
|
||||
>>> [solve(uxy - s) for s in solve(pu, u)]
|
||||
[[{x: -y + sqrt(3)}]]
|
||||
>>> all(eq.subs(s).equals(0) for sol in _ for s in sol)
|
||||
True
|
||||
|
||||
"""
|
||||
|
||||
u = Dummy('u', positive=True)
|
||||
|
||||
if first:
|
||||
p = Poly(f, x, y)
|
||||
f = p.as_expr()
|
||||
_x = Dummy()
|
||||
_y = Dummy()
|
||||
rv = bivariate_type(Poly(f.subs({x: _x, y: _y}), _x, _y), _x, _y, first=False)
|
||||
if rv:
|
||||
reps = {_x: x, _y: y}
|
||||
return rv[0].xreplace(reps), rv[1].xreplace(reps), rv[2]
|
||||
return
|
||||
|
||||
p = f
|
||||
f = p.as_expr()
|
||||
|
||||
# f(x*y)
|
||||
args = Add.make_args(p.as_expr())
|
||||
new = []
|
||||
for a in args:
|
||||
a = _mexpand(a.subs(x, u/y))
|
||||
free = a.free_symbols
|
||||
if x in free or y in free:
|
||||
break
|
||||
new.append(a)
|
||||
else:
|
||||
return x*y, Add(*new), u
|
||||
|
||||
def ok(f, v, c):
|
||||
new = _mexpand(f.subs(v, c))
|
||||
free = new.free_symbols
|
||||
return None if (x in free or y in free) else new
|
||||
|
||||
# f(a*x + b*y)
|
||||
new = []
|
||||
d = p.degree(x)
|
||||
if p.degree(y) == d:
|
||||
a = root(p.coeff_monomial(x**d), d)
|
||||
b = root(p.coeff_monomial(y**d), d)
|
||||
new = ok(f, x, (u - b*y)/a)
|
||||
if new is not None:
|
||||
return a*x + b*y, new, u
|
||||
|
||||
# f(a*x*y + b*y)
|
||||
new = []
|
||||
d = p.degree(x)
|
||||
if p.degree(y) == d:
|
||||
for itry in range(2):
|
||||
a = root(p.coeff_monomial(x**d*y**d), d)
|
||||
b = root(p.coeff_monomial(y**d), d)
|
||||
new = ok(f, x, (u - b*y)/a/y)
|
||||
if new is not None:
|
||||
return a*x*y + b*y, new, u
|
||||
x, y = y, x
|
||||
@@ -0,0 +1,126 @@
|
||||
from sympy.core import (Function, Pow, sympify, Expr)
|
||||
from sympy.core.relational import Relational
|
||||
from sympy.core.singleton import S
|
||||
from sympy.polys import Poly, decompose
|
||||
from sympy.utilities.misc import func_name
|
||||
from sympy.functions.elementary.miscellaneous import Min, Max
|
||||
|
||||
|
||||
def decompogen(f, symbol):
|
||||
"""
|
||||
Computes General functional decomposition of ``f``.
|
||||
Given an expression ``f``, returns a list ``[f_1, f_2, ..., f_n]``,
|
||||
where::
|
||||
f = f_1 o f_2 o ... f_n = f_1(f_2(... f_n))
|
||||
|
||||
Note: This is a General decomposition function. It also decomposes
|
||||
Polynomials. For only Polynomial decomposition see ``decompose`` in polys.
|
||||
|
||||
Examples
|
||||
========
|
||||
|
||||
>>> from sympy.abc import x
|
||||
>>> from sympy import decompogen, sqrt, sin, cos
|
||||
>>> decompogen(sin(cos(x)), x)
|
||||
[sin(x), cos(x)]
|
||||
>>> decompogen(sin(x)**2 + sin(x) + 1, x)
|
||||
[x**2 + x + 1, sin(x)]
|
||||
>>> decompogen(sqrt(6*x**2 - 5), x)
|
||||
[sqrt(x), 6*x**2 - 5]
|
||||
>>> decompogen(sin(sqrt(cos(x**2 + 1))), x)
|
||||
[sin(x), sqrt(x), cos(x), x**2 + 1]
|
||||
>>> decompogen(x**4 + 2*x**3 - x - 1, x)
|
||||
[x**2 - x - 1, x**2 + x]
|
||||
|
||||
"""
|
||||
f = sympify(f)
|
||||
if not isinstance(f, Expr) or isinstance(f, Relational):
|
||||
raise TypeError('expecting Expr but got: `%s`' % func_name(f))
|
||||
if symbol not in f.free_symbols:
|
||||
return [f]
|
||||
|
||||
|
||||
# ===== Simple Functions ===== #
|
||||
if isinstance(f, (Function, Pow)):
|
||||
if f.is_Pow and f.base == S.Exp1:
|
||||
arg = f.exp
|
||||
else:
|
||||
arg = f.args[0]
|
||||
if arg == symbol:
|
||||
return [f]
|
||||
return [f.subs(arg, symbol)] + decompogen(arg, symbol)
|
||||
|
||||
# ===== Min/Max Functions ===== #
|
||||
if isinstance(f, (Min, Max)):
|
||||
args = list(f.args)
|
||||
d0 = None
|
||||
for i, a in enumerate(args):
|
||||
if not a.has_free(symbol):
|
||||
continue
|
||||
d = decompogen(a, symbol)
|
||||
if len(d) == 1:
|
||||
d = [symbol] + d
|
||||
if d0 is None:
|
||||
d0 = d[1:]
|
||||
elif d[1:] != d0:
|
||||
# decomposition is not the same for each arg:
|
||||
# mark as having no decomposition
|
||||
d = [symbol]
|
||||
break
|
||||
args[i] = d[0]
|
||||
if d[0] == symbol:
|
||||
return [f]
|
||||
return [f.func(*args)] + d0
|
||||
|
||||
# ===== Convert to Polynomial ===== #
|
||||
fp = Poly(f)
|
||||
gens = list(filter(lambda x: symbol in x.free_symbols, fp.gens))
|
||||
|
||||
if len(gens) == 1 and gens[0] != symbol:
|
||||
f1 = f.subs(gens[0], symbol)
|
||||
f2 = gens[0]
|
||||
return [f1] + decompogen(f2, symbol)
|
||||
|
||||
# ===== Polynomial decompose() ====== #
|
||||
try:
|
||||
return decompose(f)
|
||||
except ValueError:
|
||||
return [f]
|
||||
|
||||
|
||||
def compogen(g_s, symbol):
|
||||
"""
|
||||
Returns the composition of functions.
|
||||
Given a list of functions ``g_s``, returns their composition ``f``,
|
||||
where:
|
||||
f = g_1 o g_2 o .. o g_n
|
||||
|
||||
Note: This is a General composition function. It also composes Polynomials.
|
||||
For only Polynomial composition see ``compose`` in polys.
|
||||
|
||||
Examples
|
||||
========
|
||||
|
||||
>>> from sympy.solvers.decompogen import compogen
|
||||
>>> from sympy.abc import x
|
||||
>>> from sympy import sqrt, sin, cos
|
||||
>>> compogen([sin(x), cos(x)], x)
|
||||
sin(cos(x))
|
||||
>>> compogen([x**2 + x + 1, sin(x)], x)
|
||||
sin(x)**2 + sin(x) + 1
|
||||
>>> compogen([sqrt(x), 6*x**2 - 5], x)
|
||||
sqrt(6*x**2 - 5)
|
||||
>>> compogen([sin(x), sqrt(x), cos(x), x**2 + 1], x)
|
||||
sin(sqrt(cos(x**2 + 1)))
|
||||
>>> compogen([x**2 - x - 1, x**2 + x], x)
|
||||
-x**2 - x + (x**2 + x)**2 - 1
|
||||
"""
|
||||
if len(g_s) == 1:
|
||||
return g_s[0]
|
||||
|
||||
foo = g_s[0].subs(symbol, g_s[1])
|
||||
|
||||
if len(g_s) == 2:
|
||||
return foo
|
||||
|
||||
return compogen([foo] + g_s[2:], symbol)
|
||||
@@ -0,0 +1,273 @@
|
||||
"""Utility functions for classifying and solving
|
||||
ordinary and partial differential equations.
|
||||
|
||||
Contains
|
||||
========
|
||||
_preprocess
|
||||
ode_order
|
||||
_desolve
|
||||
|
||||
"""
|
||||
from sympy.core import Pow
|
||||
from sympy.core.function import Derivative, AppliedUndef
|
||||
from sympy.core.relational import Equality
|
||||
from sympy.core.symbol import Wild
|
||||
|
||||
def _preprocess(expr, func=None, hint='_Integral'):
|
||||
"""Prepare expr for solving by making sure that differentiation
|
||||
is done so that only func remains in unevaluated derivatives and
|
||||
(if hint does not end with _Integral) that doit is applied to all
|
||||
other derivatives. If hint is None, do not do any differentiation.
|
||||
(Currently this may cause some simple differential equations to
|
||||
fail.)
|
||||
|
||||
In case func is None, an attempt will be made to autodetect the
|
||||
function to be solved for.
|
||||
|
||||
>>> from sympy.solvers.deutils import _preprocess
|
||||
>>> from sympy import Derivative, Function
|
||||
>>> from sympy.abc import x, y, z
|
||||
>>> f, g = map(Function, 'fg')
|
||||
|
||||
If f(x)**p == 0 and p>0 then we can solve for f(x)=0
|
||||
>>> _preprocess((f(x).diff(x)-4)**5, f(x))
|
||||
(Derivative(f(x), x) - 4, f(x))
|
||||
|
||||
Apply doit to derivatives that contain more than the function
|
||||
of interest:
|
||||
|
||||
>>> _preprocess(Derivative(f(x) + x, x))
|
||||
(Derivative(f(x), x) + 1, f(x))
|
||||
|
||||
Do others if the differentiation variable(s) intersect with those
|
||||
of the function of interest or contain the function of interest:
|
||||
|
||||
>>> _preprocess(Derivative(g(x), y, z), f(y))
|
||||
(0, f(y))
|
||||
>>> _preprocess(Derivative(f(y), z), f(y))
|
||||
(0, f(y))
|
||||
|
||||
Do others if the hint does not end in '_Integral' (the default
|
||||
assumes that it does):
|
||||
|
||||
>>> _preprocess(Derivative(g(x), y), f(x))
|
||||
(Derivative(g(x), y), f(x))
|
||||
>>> _preprocess(Derivative(f(x), y), f(x), hint='')
|
||||
(0, f(x))
|
||||
|
||||
Do not do any derivatives if hint is None:
|
||||
|
||||
>>> eq = Derivative(f(x) + 1, x) + Derivative(f(x), y)
|
||||
>>> _preprocess(eq, f(x), hint=None)
|
||||
(Derivative(f(x) + 1, x) + Derivative(f(x), y), f(x))
|
||||
|
||||
If it's not clear what the function of interest is, it must be given:
|
||||
|
||||
>>> eq = Derivative(f(x) + g(x), x)
|
||||
>>> _preprocess(eq, g(x))
|
||||
(Derivative(f(x), x) + Derivative(g(x), x), g(x))
|
||||
>>> try: _preprocess(eq)
|
||||
... except ValueError: print("A ValueError was raised.")
|
||||
A ValueError was raised.
|
||||
|
||||
"""
|
||||
if isinstance(expr, Pow):
|
||||
# if f(x)**p=0 then f(x)=0 (p>0)
|
||||
if (expr.exp).is_positive:
|
||||
expr = expr.base
|
||||
derivs = expr.atoms(Derivative)
|
||||
if not func:
|
||||
funcs = set().union(*[d.atoms(AppliedUndef) for d in derivs])
|
||||
if len(funcs) != 1:
|
||||
raise ValueError('The function cannot be '
|
||||
'automatically detected for %s.' % expr)
|
||||
func = funcs.pop()
|
||||
fvars = set(func.args)
|
||||
if hint is None:
|
||||
return expr, func
|
||||
reps = [(d, d.doit()) for d in derivs if not hint.endswith('_Integral') or
|
||||
d.has(func) or set(d.variables) & fvars]
|
||||
eq = expr.subs(reps)
|
||||
return eq, func
|
||||
|
||||
|
||||
def ode_order(expr, func):
|
||||
"""
|
||||
Returns the order of a given differential
|
||||
equation with respect to func.
|
||||
|
||||
This function is implemented recursively.
|
||||
|
||||
Examples
|
||||
========
|
||||
|
||||
>>> from sympy import Function
|
||||
>>> from sympy.solvers.deutils import ode_order
|
||||
>>> from sympy.abc import x
|
||||
>>> f, g = map(Function, ['f', 'g'])
|
||||
>>> ode_order(f(x).diff(x, 2) + f(x).diff(x)**2 +
|
||||
... f(x).diff(x), f(x))
|
||||
2
|
||||
>>> ode_order(f(x).diff(x, 2) + g(x).diff(x, 3), f(x))
|
||||
2
|
||||
>>> ode_order(f(x).diff(x, 2) + g(x).diff(x, 3), g(x))
|
||||
3
|
||||
|
||||
"""
|
||||
a = Wild('a', exclude=[func])
|
||||
if expr.match(a):
|
||||
return 0
|
||||
|
||||
if isinstance(expr, Derivative):
|
||||
if expr.args[0] == func:
|
||||
return len(expr.variables)
|
||||
else:
|
||||
args = expr.args[0].args
|
||||
rv = len(expr.variables)
|
||||
if args:
|
||||
rv += max(ode_order(_, func) for _ in args)
|
||||
return rv
|
||||
else:
|
||||
return max(ode_order(_, func) for _ in expr.args) if expr.args else 0
|
||||
|
||||
|
||||
def _desolve(eq, func=None, hint="default", ics=None, simplify=True, *, prep=True, **kwargs):
|
||||
"""This is a helper function to dsolve and pdsolve in the ode
|
||||
and pde modules.
|
||||
|
||||
If the hint provided to the function is "default", then a dict with
|
||||
the following keys are returned
|
||||
|
||||
'func' - It provides the function for which the differential equation
|
||||
has to be solved. This is useful when the expression has
|
||||
more than one function in it.
|
||||
|
||||
'default' - The default key as returned by classifier functions in ode
|
||||
and pde.py
|
||||
|
||||
'hint' - The hint given by the user for which the differential equation
|
||||
is to be solved. If the hint given by the user is 'default',
|
||||
then the value of 'hint' and 'default' is the same.
|
||||
|
||||
'order' - The order of the function as returned by ode_order
|
||||
|
||||
'match' - It returns the match as given by the classifier functions, for
|
||||
the default hint.
|
||||
|
||||
If the hint provided to the function is not "default" and is not in
|
||||
('all', 'all_Integral', 'best'), then a dict with the above mentioned keys
|
||||
is returned along with the keys which are returned when dict in
|
||||
classify_ode or classify_pde is set True
|
||||
|
||||
If the hint given is in ('all', 'all_Integral', 'best'), then this function
|
||||
returns a nested dict, with the keys, being the set of classified hints
|
||||
returned by classifier functions, and the values being the dict of form
|
||||
as mentioned above.
|
||||
|
||||
Key 'eq' is a common key to all the above mentioned hints which returns an
|
||||
expression if eq given by user is an Equality.
|
||||
|
||||
See Also
|
||||
========
|
||||
classify_ode(ode.py)
|
||||
classify_pde(pde.py)
|
||||
"""
|
||||
if isinstance(eq, Equality):
|
||||
eq = eq.lhs - eq.rhs
|
||||
|
||||
# preprocess the equation and find func if not given
|
||||
if prep or func is None:
|
||||
eq, func = _preprocess(eq, func)
|
||||
prep = False
|
||||
|
||||
# type is an argument passed by the solve functions in ode and pde.py
|
||||
# that identifies whether the function caller is an ordinary
|
||||
# or partial differential equation. Accordingly corresponding
|
||||
# changes are made in the function.
|
||||
type = kwargs.get('type', None)
|
||||
xi = kwargs.get('xi')
|
||||
eta = kwargs.get('eta')
|
||||
x0 = kwargs.get('x0', 0)
|
||||
terms = kwargs.get('n')
|
||||
|
||||
if type == 'ode':
|
||||
from sympy.solvers.ode import classify_ode, allhints
|
||||
classifier = classify_ode
|
||||
string = 'ODE '
|
||||
dummy = ''
|
||||
|
||||
elif type == 'pde':
|
||||
from sympy.solvers.pde import classify_pde, allhints
|
||||
classifier = classify_pde
|
||||
string = 'PDE '
|
||||
dummy = 'p'
|
||||
|
||||
# Magic that should only be used internally. Prevents classify_ode from
|
||||
# being called more than it needs to be by passing its results through
|
||||
# recursive calls.
|
||||
if kwargs.get('classify', True):
|
||||
hints = classifier(eq, func, dict=True, ics=ics, xi=xi, eta=eta,
|
||||
n=terms, x0=x0, hint=hint, prep=prep)
|
||||
|
||||
else:
|
||||
# Here is what all this means:
|
||||
#
|
||||
# hint: The hint method given to _desolve() by the user.
|
||||
# hints: The dictionary of hints that match the DE, along with other
|
||||
# information (including the internal pass-through magic).
|
||||
# default: The default hint to return, the first hint from allhints
|
||||
# that matches the hint; obtained from classify_ode().
|
||||
# match: Dictionary containing the match dictionary for each hint
|
||||
# (the parts of the DE for solving). When going through the
|
||||
# hints in "all", this holds the match string for the current
|
||||
# hint.
|
||||
# order: The order of the DE, as determined by ode_order().
|
||||
hints = kwargs.get('hint',
|
||||
{'default': hint,
|
||||
hint: kwargs['match'],
|
||||
'order': kwargs['order']})
|
||||
if not hints['default']:
|
||||
# classify_ode will set hints['default'] to None if no hints match.
|
||||
if hint not in allhints and hint != 'default':
|
||||
raise ValueError("Hint not recognized: " + hint)
|
||||
elif hint not in hints['ordered_hints'] and hint != 'default':
|
||||
raise ValueError(string + str(eq) + " does not match hint " + hint)
|
||||
# If dsolve can't solve the purely algebraic equation then dsolve will raise
|
||||
# ValueError
|
||||
elif hints['order'] == 0:
|
||||
raise ValueError(
|
||||
str(eq) + " is not a solvable differential equation in " + str(func))
|
||||
else:
|
||||
raise NotImplementedError(dummy + "solve" + ": Cannot solve " + str(eq))
|
||||
if hint == 'default':
|
||||
return _desolve(eq, func, ics=ics, hint=hints['default'], simplify=simplify,
|
||||
prep=prep, x0=x0, classify=False, order=hints['order'],
|
||||
match=hints[hints['default']], xi=xi, eta=eta, n=terms, type=type)
|
||||
elif hint in ('all', 'all_Integral', 'best'):
|
||||
retdict = {}
|
||||
gethints = set(hints) - {'order', 'default', 'ordered_hints'}
|
||||
if hint == 'all_Integral':
|
||||
for i in hints:
|
||||
if i.endswith('_Integral'):
|
||||
gethints.remove(i.removesuffix('_Integral'))
|
||||
# special cases
|
||||
for k in ["1st_homogeneous_coeff_best", "1st_power_series",
|
||||
"lie_group", "2nd_power_series_ordinary", "2nd_power_series_regular"]:
|
||||
if k in gethints:
|
||||
gethints.remove(k)
|
||||
for i in gethints:
|
||||
sol = _desolve(eq, func, ics=ics, hint=i, x0=x0, simplify=simplify, prep=prep,
|
||||
classify=False, n=terms, order=hints['order'], match=hints[i], type=type)
|
||||
retdict[i] = sol
|
||||
retdict['all'] = True
|
||||
retdict['eq'] = eq
|
||||
return retdict
|
||||
elif hint not in allhints: # and hint not in ('default', 'ordered_hints'):
|
||||
raise ValueError("Hint not recognized: " + hint)
|
||||
elif hint not in hints:
|
||||
raise ValueError(string + str(eq) + " does not match hint " + hint)
|
||||
else:
|
||||
# Key added to identify the hint needed to solve the equation
|
||||
hints['hint'] = hint
|
||||
hints.update({'func': func, 'eq': eq})
|
||||
return hints
|
||||
@@ -0,0 +1,5 @@
|
||||
from .diophantine import diophantine, classify_diop, diop_solve
|
||||
|
||||
__all__ = [
|
||||
'diophantine', 'classify_diop', 'diop_solve'
|
||||
]
|
||||
File diff suppressed because it is too large
Load Diff
+1071
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,986 @@
|
||||
"""Tools for solving inequalities and systems of inequalities. """
|
||||
import itertools
|
||||
|
||||
from sympy.calculus.util import (continuous_domain, periodicity,
|
||||
function_range)
|
||||
from sympy.core import sympify
|
||||
from sympy.core.exprtools import factor_terms
|
||||
from sympy.core.relational import Relational, Lt, Ge, Eq
|
||||
from sympy.core.symbol import Symbol, Dummy
|
||||
from sympy.sets.sets import Interval, FiniteSet, Union, Intersection
|
||||
from sympy.core.singleton import S
|
||||
from sympy.core.function import expand_mul
|
||||
from sympy.functions.elementary.complexes import Abs
|
||||
from sympy.logic import And
|
||||
from sympy.polys import Poly, PolynomialError, parallel_poly_from_expr
|
||||
from sympy.polys.polyutils import _nsort
|
||||
from sympy.solvers.solveset import solvify, solveset
|
||||
from sympy.utilities.iterables import sift, iterable
|
||||
from sympy.utilities.misc import filldedent
|
||||
|
||||
|
||||
def solve_poly_inequality(poly, rel):
|
||||
"""Solve a polynomial inequality with rational coefficients.
|
||||
|
||||
Examples
|
||||
========
|
||||
|
||||
>>> from sympy import solve_poly_inequality, Poly
|
||||
>>> from sympy.abc import x
|
||||
|
||||
>>> solve_poly_inequality(Poly(x, x, domain='ZZ'), '==')
|
||||
[{0}]
|
||||
|
||||
>>> solve_poly_inequality(Poly(x**2 - 1, x, domain='ZZ'), '!=')
|
||||
[Interval.open(-oo, -1), Interval.open(-1, 1), Interval.open(1, oo)]
|
||||
|
||||
>>> solve_poly_inequality(Poly(x**2 - 1, x, domain='ZZ'), '==')
|
||||
[{-1}, {1}]
|
||||
|
||||
See Also
|
||||
========
|
||||
solve_poly_inequalities
|
||||
"""
|
||||
if not isinstance(poly, Poly):
|
||||
raise ValueError(
|
||||
'For efficiency reasons, `poly` should be a Poly instance')
|
||||
if poly.as_expr().is_number:
|
||||
t = Relational(poly.as_expr(), 0, rel)
|
||||
if t is S.true:
|
||||
return [S.Reals]
|
||||
elif t is S.false:
|
||||
return [S.EmptySet]
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"could not determine truth value of %s" % t)
|
||||
|
||||
reals, intervals = poly.real_roots(multiple=False), []
|
||||
|
||||
if rel == '==':
|
||||
for root, _ in reals:
|
||||
interval = Interval(root, root)
|
||||
intervals.append(interval)
|
||||
elif rel == '!=':
|
||||
left = S.NegativeInfinity
|
||||
|
||||
for right, _ in reals + [(S.Infinity, 1)]:
|
||||
interval = Interval(left, right, True, True)
|
||||
intervals.append(interval)
|
||||
left = right
|
||||
else:
|
||||
if poly.LC() > 0:
|
||||
sign = +1
|
||||
else:
|
||||
sign = -1
|
||||
|
||||
eq_sign, equal = None, False
|
||||
|
||||
if rel == '>':
|
||||
eq_sign = +1
|
||||
elif rel == '<':
|
||||
eq_sign = -1
|
||||
elif rel == '>=':
|
||||
eq_sign, equal = +1, True
|
||||
elif rel == '<=':
|
||||
eq_sign, equal = -1, True
|
||||
else:
|
||||
raise ValueError("'%s' is not a valid relation" % rel)
|
||||
|
||||
right, right_open = S.Infinity, True
|
||||
|
||||
for left, multiplicity in reversed(reals):
|
||||
if multiplicity % 2:
|
||||
if sign == eq_sign:
|
||||
intervals.insert(
|
||||
0, Interval(left, right, not equal, right_open))
|
||||
|
||||
sign, right, right_open = -sign, left, not equal
|
||||
else:
|
||||
if sign == eq_sign and not equal:
|
||||
intervals.insert(
|
||||
0, Interval(left, right, True, right_open))
|
||||
right, right_open = left, True
|
||||
elif sign != eq_sign and equal:
|
||||
intervals.insert(0, Interval(left, left))
|
||||
|
||||
if sign == eq_sign:
|
||||
intervals.insert(
|
||||
0, Interval(S.NegativeInfinity, right, True, right_open))
|
||||
|
||||
return intervals
|
||||
|
||||
|
||||
def solve_poly_inequalities(polys):
|
||||
"""Solve polynomial inequalities with rational coefficients.
|
||||
|
||||
Examples
|
||||
========
|
||||
|
||||
>>> from sympy import Poly
|
||||
>>> from sympy.solvers.inequalities import solve_poly_inequalities
|
||||
>>> from sympy.abc import x
|
||||
>>> solve_poly_inequalities(((
|
||||
... Poly(x**2 - 3), ">"), (
|
||||
... Poly(-x**2 + 1), ">")))
|
||||
Union(Interval.open(-oo, -sqrt(3)), Interval.open(-1, 1), Interval.open(sqrt(3), oo))
|
||||
"""
|
||||
return Union(*[s for p in polys for s in solve_poly_inequality(*p)])
|
||||
|
||||
|
||||
def solve_rational_inequalities(eqs):
|
||||
"""Solve a system of rational inequalities with rational coefficients.
|
||||
|
||||
Examples
|
||||
========
|
||||
|
||||
>>> from sympy.abc import x
|
||||
>>> from sympy import solve_rational_inequalities, Poly
|
||||
|
||||
>>> solve_rational_inequalities([[
|
||||
... ((Poly(-x + 1), Poly(1, x)), '>='),
|
||||
... ((Poly(-x + 1), Poly(1, x)), '<=')]])
|
||||
{1}
|
||||
|
||||
>>> solve_rational_inequalities([[
|
||||
... ((Poly(x), Poly(1, x)), '!='),
|
||||
... ((Poly(-x + 1), Poly(1, x)), '>=')]])
|
||||
Union(Interval.open(-oo, 0), Interval.Lopen(0, 1))
|
||||
|
||||
See Also
|
||||
========
|
||||
solve_poly_inequality
|
||||
"""
|
||||
result = S.EmptySet
|
||||
|
||||
for _eqs in eqs:
|
||||
if not _eqs:
|
||||
continue
|
||||
|
||||
global_intervals = [Interval(S.NegativeInfinity, S.Infinity)]
|
||||
|
||||
for (numer, denom), rel in _eqs:
|
||||
numer_intervals = solve_poly_inequality(numer*denom, rel)
|
||||
denom_intervals = solve_poly_inequality(denom, '==')
|
||||
|
||||
intervals = []
|
||||
|
||||
for numer_interval, global_interval in itertools.product(
|
||||
numer_intervals, global_intervals):
|
||||
interval = numer_interval.intersect(global_interval)
|
||||
|
||||
if interval is not S.EmptySet:
|
||||
intervals.append(interval)
|
||||
|
||||
global_intervals = intervals
|
||||
|
||||
intervals = []
|
||||
|
||||
for global_interval in global_intervals:
|
||||
for denom_interval in denom_intervals:
|
||||
global_interval -= denom_interval
|
||||
|
||||
if global_interval is not S.EmptySet:
|
||||
intervals.append(global_interval)
|
||||
|
||||
global_intervals = intervals
|
||||
|
||||
if not global_intervals:
|
||||
break
|
||||
|
||||
for interval in global_intervals:
|
||||
result = result.union(interval)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def reduce_rational_inequalities(exprs, gen, relational=True):
|
||||
"""Reduce a system of rational inequalities with rational coefficients.
|
||||
|
||||
Examples
|
||||
========
|
||||
|
||||
>>> from sympy import Symbol
|
||||
>>> from sympy.solvers.inequalities import reduce_rational_inequalities
|
||||
|
||||
>>> x = Symbol('x', real=True)
|
||||
|
||||
>>> reduce_rational_inequalities([[x**2 <= 0]], x)
|
||||
Eq(x, 0)
|
||||
|
||||
>>> reduce_rational_inequalities([[x + 2 > 0]], x)
|
||||
-2 < x
|
||||
>>> reduce_rational_inequalities([[(x + 2, ">")]], x)
|
||||
-2 < x
|
||||
>>> reduce_rational_inequalities([[x + 2]], x)
|
||||
Eq(x, -2)
|
||||
|
||||
This function find the non-infinite solution set so if the unknown symbol
|
||||
is declared as extended real rather than real then the result may include
|
||||
finiteness conditions:
|
||||
|
||||
>>> y = Symbol('y', extended_real=True)
|
||||
>>> reduce_rational_inequalities([[y + 2 > 0]], y)
|
||||
(-2 < y) & (y < oo)
|
||||
"""
|
||||
exact = True
|
||||
eqs = []
|
||||
solution = S.EmptySet # add pieces for each group
|
||||
for _exprs in exprs:
|
||||
if not _exprs:
|
||||
continue
|
||||
_eqs = []
|
||||
_sol = S.Reals
|
||||
for expr in _exprs:
|
||||
if isinstance(expr, tuple):
|
||||
expr, rel = expr
|
||||
else:
|
||||
if expr.is_Relational:
|
||||
expr, rel = expr.lhs - expr.rhs, expr.rel_op
|
||||
else:
|
||||
rel = '=='
|
||||
|
||||
if expr is S.true:
|
||||
numer, denom, rel = S.Zero, S.One, '=='
|
||||
elif expr is S.false:
|
||||
numer, denom, rel = S.One, S.One, '=='
|
||||
else:
|
||||
numer, denom = expr.together().as_numer_denom()
|
||||
|
||||
try:
|
||||
(numer, denom), opt = parallel_poly_from_expr(
|
||||
(numer, denom), gen)
|
||||
except PolynomialError:
|
||||
raise PolynomialError(filldedent('''
|
||||
only polynomials and rational functions are
|
||||
supported in this context.
|
||||
'''))
|
||||
|
||||
if not opt.domain.is_Exact:
|
||||
numer, denom, exact = numer.to_exact(), denom.to_exact(), False
|
||||
|
||||
domain = opt.domain.get_exact()
|
||||
|
||||
if not (domain.is_ZZ or domain.is_QQ):
|
||||
expr = numer/denom
|
||||
expr = Relational(expr, 0, rel)
|
||||
_sol &= solve_univariate_inequality(expr, gen, relational=False)
|
||||
else:
|
||||
_eqs.append(((numer, denom), rel))
|
||||
|
||||
if _eqs:
|
||||
_sol &= solve_rational_inequalities([_eqs])
|
||||
exclude = solve_rational_inequalities([[((d, d.one), '==')
|
||||
for i in eqs for ((n, d), _) in i if d.has(gen)]])
|
||||
_sol -= exclude
|
||||
|
||||
solution |= _sol
|
||||
|
||||
if not exact and solution:
|
||||
solution = solution.evalf()
|
||||
|
||||
if relational:
|
||||
solution = solution.as_relational(gen)
|
||||
|
||||
return solution
|
||||
|
||||
|
||||
def reduce_abs_inequality(expr, rel, gen):
|
||||
"""Reduce an inequality with nested absolute values.
|
||||
|
||||
Examples
|
||||
========
|
||||
|
||||
>>> from sympy import reduce_abs_inequality, Abs, Symbol
|
||||
>>> x = Symbol('x', real=True)
|
||||
|
||||
>>> reduce_abs_inequality(Abs(x - 5) - 3, '<', x)
|
||||
(2 < x) & (x < 8)
|
||||
|
||||
>>> reduce_abs_inequality(Abs(x + 2)*3 - 13, '<', x)
|
||||
(-19/3 < x) & (x < 7/3)
|
||||
|
||||
See Also
|
||||
========
|
||||
|
||||
reduce_abs_inequalities
|
||||
"""
|
||||
if gen.is_extended_real is False:
|
||||
raise TypeError(filldedent('''
|
||||
Cannot solve inequalities with absolute values containing
|
||||
non-real variables.
|
||||
'''))
|
||||
|
||||
def _bottom_up_scan(expr):
|
||||
exprs = []
|
||||
|
||||
if expr.is_Add or expr.is_Mul:
|
||||
op = expr.func
|
||||
|
||||
for arg in expr.args:
|
||||
_exprs = _bottom_up_scan(arg)
|
||||
|
||||
if not exprs:
|
||||
exprs = _exprs
|
||||
else:
|
||||
exprs = [(op(expr, _expr), conds + _conds) for (expr, conds), (_expr, _conds) in
|
||||
itertools.product(exprs, _exprs)]
|
||||
elif expr.is_Pow:
|
||||
n = expr.exp
|
||||
if not n.is_Integer:
|
||||
raise ValueError("Only Integer Powers are allowed on Abs.")
|
||||
|
||||
exprs.extend((expr**n, conds) for expr, conds in _bottom_up_scan(expr.base))
|
||||
elif isinstance(expr, Abs):
|
||||
_exprs = _bottom_up_scan(expr.args[0])
|
||||
|
||||
for expr, conds in _exprs:
|
||||
exprs.append(( expr, conds + [Ge(expr, 0)]))
|
||||
exprs.append((-expr, conds + [Lt(expr, 0)]))
|
||||
else:
|
||||
exprs = [(expr, [])]
|
||||
|
||||
return exprs
|
||||
|
||||
mapping = {'<': '>', '<=': '>='}
|
||||
inequalities = []
|
||||
|
||||
for expr, conds in _bottom_up_scan(expr):
|
||||
if rel not in mapping.keys():
|
||||
expr = Relational( expr, 0, rel)
|
||||
else:
|
||||
expr = Relational(-expr, 0, mapping[rel])
|
||||
|
||||
inequalities.append([expr] + conds)
|
||||
|
||||
return reduce_rational_inequalities(inequalities, gen)
|
||||
|
||||
|
||||
def reduce_abs_inequalities(exprs, gen):
|
||||
"""Reduce a system of inequalities with nested absolute values.
|
||||
|
||||
Examples
|
||||
========
|
||||
|
||||
>>> from sympy import reduce_abs_inequalities, Abs, Symbol
|
||||
>>> x = Symbol('x', extended_real=True)
|
||||
|
||||
>>> reduce_abs_inequalities([(Abs(3*x - 5) - 7, '<'),
|
||||
... (Abs(x + 25) - 13, '>')], x)
|
||||
(-2/3 < x) & (x < 4) & (((-oo < x) & (x < -38)) | ((-12 < x) & (x < oo)))
|
||||
|
||||
>>> reduce_abs_inequalities([(Abs(x - 4) + Abs(3*x - 5) - 7, '<')], x)
|
||||
(1/2 < x) & (x < 4)
|
||||
|
||||
See Also
|
||||
========
|
||||
|
||||
reduce_abs_inequality
|
||||
"""
|
||||
return And(*[ reduce_abs_inequality(expr, rel, gen)
|
||||
for expr, rel in exprs ])
|
||||
|
||||
|
||||
def solve_univariate_inequality(expr, gen, relational=True, domain=S.Reals, continuous=False):
|
||||
"""Solves a real univariate inequality.
|
||||
|
||||
Parameters
|
||||
==========
|
||||
|
||||
expr : Relational
|
||||
The target inequality
|
||||
gen : Symbol
|
||||
The variable for which the inequality is solved
|
||||
relational : bool
|
||||
A Relational type output is expected or not
|
||||
domain : Set
|
||||
The domain over which the equation is solved
|
||||
continuous: bool
|
||||
True if expr is known to be continuous over the given domain
|
||||
(and so continuous_domain() does not need to be called on it)
|
||||
|
||||
Raises
|
||||
======
|
||||
|
||||
NotImplementedError
|
||||
The solution of the inequality cannot be determined due to limitation
|
||||
in :func:`sympy.solvers.solveset.solvify`.
|
||||
|
||||
Notes
|
||||
=====
|
||||
|
||||
Currently, we cannot solve all the inequalities due to limitations in
|
||||
:func:`sympy.solvers.solveset.solvify`. Also, the solution returned for trigonometric inequalities
|
||||
are restricted in its periodic interval.
|
||||
|
||||
See Also
|
||||
========
|
||||
|
||||
sympy.solvers.solveset.solvify: solver returning solveset solutions with solve's output API
|
||||
|
||||
Examples
|
||||
========
|
||||
|
||||
>>> from sympy import solve_univariate_inequality, Symbol, sin, Interval, S
|
||||
>>> x = Symbol('x')
|
||||
|
||||
>>> solve_univariate_inequality(x**2 >= 4, x)
|
||||
((2 <= x) & (x < oo)) | ((-oo < x) & (x <= -2))
|
||||
|
||||
>>> solve_univariate_inequality(x**2 >= 4, x, relational=False)
|
||||
Union(Interval(-oo, -2), Interval(2, oo))
|
||||
|
||||
>>> domain = Interval(0, S.Infinity)
|
||||
>>> solve_univariate_inequality(x**2 >= 4, x, False, domain)
|
||||
Interval(2, oo)
|
||||
|
||||
>>> solve_univariate_inequality(sin(x) > 0, x, relational=False)
|
||||
Interval.open(0, pi)
|
||||
|
||||
"""
|
||||
from sympy.solvers.solvers import denoms
|
||||
|
||||
if domain.is_subset(S.Reals) is False:
|
||||
raise NotImplementedError(filldedent('''
|
||||
Inequalities in the complex domain are
|
||||
not supported. Try the real domain by
|
||||
setting domain=S.Reals'''))
|
||||
elif domain is not S.Reals:
|
||||
rv = solve_univariate_inequality(
|
||||
expr, gen, relational=False, continuous=continuous).intersection(domain)
|
||||
if relational:
|
||||
rv = rv.as_relational(gen)
|
||||
return rv
|
||||
else:
|
||||
pass # continue with attempt to solve in Real domain
|
||||
|
||||
# This keeps the function independent of the assumptions about `gen`.
|
||||
# `solveset` makes sure this function is called only when the domain is
|
||||
# real.
|
||||
_gen = gen
|
||||
_domain = domain
|
||||
if gen.is_extended_real is False:
|
||||
rv = S.EmptySet
|
||||
return rv if not relational else rv.as_relational(_gen)
|
||||
elif gen.is_extended_real is None:
|
||||
gen = Dummy('gen', extended_real=True)
|
||||
try:
|
||||
expr = expr.xreplace({_gen: gen})
|
||||
except TypeError:
|
||||
raise TypeError(filldedent('''
|
||||
When gen is real, the relational has a complex part
|
||||
which leads to an invalid comparison like I < 0.
|
||||
'''))
|
||||
|
||||
rv = None
|
||||
|
||||
if expr is S.true:
|
||||
rv = domain
|
||||
|
||||
elif expr is S.false:
|
||||
rv = S.EmptySet
|
||||
|
||||
else:
|
||||
e = expr.lhs - expr.rhs
|
||||
period = periodicity(e, gen)
|
||||
if period == S.Zero:
|
||||
e = expand_mul(e)
|
||||
const = expr.func(e, 0)
|
||||
if const is S.true:
|
||||
rv = domain
|
||||
elif const is S.false:
|
||||
rv = S.EmptySet
|
||||
elif period is not None:
|
||||
frange = function_range(e, gen, domain)
|
||||
|
||||
rel = expr.rel_op
|
||||
if rel in ('<', '<='):
|
||||
if expr.func(frange.sup, 0):
|
||||
rv = domain
|
||||
elif not expr.func(frange.inf, 0):
|
||||
rv = S.EmptySet
|
||||
|
||||
elif rel in ('>', '>='):
|
||||
if expr.func(frange.inf, 0):
|
||||
rv = domain
|
||||
elif not expr.func(frange.sup, 0):
|
||||
rv = S.EmptySet
|
||||
|
||||
inf, sup = domain.inf, domain.sup
|
||||
if sup - inf is S.Infinity:
|
||||
domain = Interval(0, period, False, True).intersect(_domain)
|
||||
_domain = domain
|
||||
|
||||
if rv is None:
|
||||
n, d = e.as_numer_denom()
|
||||
try:
|
||||
if gen not in n.free_symbols and len(e.free_symbols) > 1:
|
||||
raise ValueError
|
||||
# this might raise ValueError on its own
|
||||
# or it might give None...
|
||||
solns = solvify(e, gen, domain)
|
||||
if solns is None:
|
||||
# in which case we raise ValueError
|
||||
raise ValueError
|
||||
except (ValueError, NotImplementedError):
|
||||
# replace gen with generic x since it's
|
||||
# univariate anyway
|
||||
raise NotImplementedError(filldedent('''
|
||||
The inequality, %s, cannot be solved using
|
||||
solve_univariate_inequality.
|
||||
''' % expr.subs(gen, Symbol('x'))))
|
||||
|
||||
expanded_e = expand_mul(e)
|
||||
def valid(x):
|
||||
# this is used to see if gen=x satisfies the
|
||||
# relational by substituting it into the
|
||||
# expanded form and testing against 0, e.g.
|
||||
# if expr = x*(x + 1) < 2 then e = x*(x + 1) - 2
|
||||
# and expanded_e = x**2 + x - 2; the test is
|
||||
# whether a given value of x satisfies
|
||||
# x**2 + x - 2 < 0
|
||||
#
|
||||
# expanded_e, expr and gen used from enclosing scope
|
||||
v = expanded_e.subs(gen, expand_mul(x))
|
||||
try:
|
||||
r = expr.func(v, 0)
|
||||
except TypeError:
|
||||
r = S.false
|
||||
if r in (S.true, S.false):
|
||||
return r
|
||||
if v.is_extended_real is False:
|
||||
return S.false
|
||||
else:
|
||||
v = v.n(2)
|
||||
if v.is_comparable:
|
||||
return expr.func(v, 0)
|
||||
# not comparable or couldn't be evaluated
|
||||
raise NotImplementedError(
|
||||
'relationship did not evaluate: %s' % r)
|
||||
|
||||
singularities = []
|
||||
for d in denoms(expr, gen):
|
||||
singularities.extend(solvify(d, gen, domain))
|
||||
if not continuous:
|
||||
domain = continuous_domain(expanded_e, gen, domain)
|
||||
|
||||
include_x = '=' in expr.rel_op and expr.rel_op != '!='
|
||||
|
||||
try:
|
||||
discontinuities = set(domain.boundary -
|
||||
FiniteSet(domain.inf, domain.sup))
|
||||
# remove points that are not between inf and sup of domain
|
||||
critical_points = FiniteSet(*(solns + singularities + list(
|
||||
discontinuities))).intersection(
|
||||
Interval(domain.inf, domain.sup,
|
||||
domain.inf not in domain, domain.sup not in domain))
|
||||
if all(r.is_number for r in critical_points):
|
||||
reals = _nsort(critical_points, separated=True)[0]
|
||||
else:
|
||||
sifted = sift(critical_points, lambda x: x.is_extended_real)
|
||||
if sifted[None]:
|
||||
# there were some roots that weren't known
|
||||
# to be real
|
||||
raise NotImplementedError
|
||||
try:
|
||||
reals = sifted[True]
|
||||
if len(reals) > 1:
|
||||
reals = sorted(reals)
|
||||
except TypeError:
|
||||
raise NotImplementedError
|
||||
except NotImplementedError:
|
||||
raise NotImplementedError('sorting of these roots is not supported')
|
||||
|
||||
# If expr contains imaginary coefficients, only take real
|
||||
# values of x for which the imaginary part is 0
|
||||
make_real = S.Reals
|
||||
if (coeffI := expanded_e.coeff(S.ImaginaryUnit)) != S.Zero:
|
||||
check = True
|
||||
im_sol = FiniteSet()
|
||||
try:
|
||||
a = solveset(coeffI, gen, domain)
|
||||
if not isinstance(a, Interval):
|
||||
for z in a:
|
||||
if z not in singularities and valid(z) and z.is_extended_real:
|
||||
im_sol += FiniteSet(z)
|
||||
else:
|
||||
start, end = a.inf, a.sup
|
||||
for z in _nsort(critical_points + FiniteSet(end)):
|
||||
valid_start = valid(start)
|
||||
if start != end:
|
||||
valid_z = valid(z)
|
||||
pt = _pt(start, z)
|
||||
if pt not in singularities and pt.is_extended_real and valid(pt):
|
||||
if valid_start and valid_z:
|
||||
im_sol += Interval(start, z)
|
||||
elif valid_start:
|
||||
im_sol += Interval.Ropen(start, z)
|
||||
elif valid_z:
|
||||
im_sol += Interval.Lopen(start, z)
|
||||
else:
|
||||
im_sol += Interval.open(start, z)
|
||||
start = z
|
||||
for s in singularities:
|
||||
im_sol -= FiniteSet(s)
|
||||
except (TypeError):
|
||||
im_sol = S.Reals
|
||||
check = False
|
||||
|
||||
if im_sol is S.EmptySet:
|
||||
raise ValueError(filldedent('''
|
||||
%s contains imaginary parts which cannot be
|
||||
made 0 for any value of %s satisfying the
|
||||
inequality, leading to relations like I < 0.
|
||||
''' % (expr.subs(gen, _gen), _gen)))
|
||||
|
||||
make_real = make_real.intersect(im_sol)
|
||||
|
||||
sol_sets = [S.EmptySet]
|
||||
|
||||
start = domain.inf
|
||||
if start in domain and valid(start) and start.is_finite:
|
||||
sol_sets.append(FiniteSet(start))
|
||||
|
||||
for x in reals:
|
||||
end = x
|
||||
|
||||
if valid(_pt(start, end)):
|
||||
sol_sets.append(Interval(start, end, True, True))
|
||||
|
||||
if x in singularities:
|
||||
singularities.remove(x)
|
||||
else:
|
||||
if x in discontinuities:
|
||||
discontinuities.remove(x)
|
||||
_valid = valid(x)
|
||||
else: # it's a solution
|
||||
_valid = include_x
|
||||
if _valid:
|
||||
sol_sets.append(FiniteSet(x))
|
||||
|
||||
start = end
|
||||
|
||||
end = domain.sup
|
||||
if end in domain and valid(end) and end.is_finite:
|
||||
sol_sets.append(FiniteSet(end))
|
||||
|
||||
if valid(_pt(start, end)):
|
||||
sol_sets.append(Interval.open(start, end))
|
||||
|
||||
if coeffI != S.Zero and check:
|
||||
rv = (make_real).intersect(_domain)
|
||||
else:
|
||||
rv = Intersection(
|
||||
(Union(*sol_sets)), make_real, _domain).subs(gen, _gen)
|
||||
|
||||
return rv if not relational else rv.as_relational(_gen)
|
||||
|
||||
|
||||
def _pt(start, end):
|
||||
"""Return a point between start and end"""
|
||||
if not start.is_infinite and not end.is_infinite:
|
||||
pt = (start + end)/2
|
||||
elif start.is_infinite and end.is_infinite:
|
||||
pt = S.Zero
|
||||
else:
|
||||
if (start.is_infinite and start.is_extended_positive is None or
|
||||
end.is_infinite and end.is_extended_positive is None):
|
||||
raise ValueError('cannot proceed with unsigned infinite values')
|
||||
if (end.is_infinite and end.is_extended_negative or
|
||||
start.is_infinite and start.is_extended_positive):
|
||||
start, end = end, start
|
||||
# if possible, use a multiple of self which has
|
||||
# better behavior when checking assumptions than
|
||||
# an expression obtained by adding or subtracting 1
|
||||
if end.is_infinite:
|
||||
if start.is_extended_positive:
|
||||
pt = start*2
|
||||
elif start.is_extended_negative:
|
||||
pt = start*S.Half
|
||||
else:
|
||||
pt = start + 1
|
||||
elif start.is_infinite:
|
||||
if end.is_extended_positive:
|
||||
pt = end*S.Half
|
||||
elif end.is_extended_negative:
|
||||
pt = end*2
|
||||
else:
|
||||
pt = end - 1
|
||||
return pt
|
||||
|
||||
|
||||
def _solve_inequality(ie, s, linear=False):
|
||||
"""Return the inequality with s isolated on the left, if possible.
|
||||
If the relationship is non-linear, a solution involving And or Or
|
||||
may be returned. False or True are returned if the relationship
|
||||
is never True or always True, respectively.
|
||||
|
||||
If `linear` is True (default is False) an `s`-dependent expression
|
||||
will be isolated on the left, if possible
|
||||
but it will not be solved for `s` unless the expression is linear
|
||||
in `s`. Furthermore, only "safe" operations which do not change the
|
||||
sense of the relationship are applied: no division by an unsigned
|
||||
value is attempted unless the relationship involves Eq or Ne and
|
||||
no division by a value not known to be nonzero is ever attempted.
|
||||
|
||||
Examples
|
||||
========
|
||||
|
||||
>>> from sympy import Eq, Symbol
|
||||
>>> from sympy.solvers.inequalities import _solve_inequality as f
|
||||
>>> from sympy.abc import x, y
|
||||
|
||||
For linear expressions, the symbol can be isolated:
|
||||
|
||||
>>> f(x - 2 < 0, x)
|
||||
x < 2
|
||||
>>> f(-x - 6 < x, x)
|
||||
x > -3
|
||||
|
||||
Sometimes nonlinear relationships will be False
|
||||
|
||||
>>> f(x**2 + 4 < 0, x)
|
||||
False
|
||||
|
||||
Or they may involve more than one region of values:
|
||||
|
||||
>>> f(x**2 - 4 < 0, x)
|
||||
(-2 < x) & (x < 2)
|
||||
|
||||
To restrict the solution to a relational, set linear=True
|
||||
and only the x-dependent portion will be isolated on the left:
|
||||
|
||||
>>> f(x**2 - 4 < 0, x, linear=True)
|
||||
x**2 < 4
|
||||
|
||||
Division of only nonzero quantities is allowed, so x cannot
|
||||
be isolated by dividing by y:
|
||||
|
||||
>>> y.is_nonzero is None # it is unknown whether it is 0 or not
|
||||
True
|
||||
>>> f(x*y < 1, x)
|
||||
x*y < 1
|
||||
|
||||
And while an equality (or inequality) still holds after dividing by a
|
||||
non-zero quantity
|
||||
|
||||
>>> nz = Symbol('nz', nonzero=True)
|
||||
>>> f(Eq(x*nz, 1), x)
|
||||
Eq(x, 1/nz)
|
||||
|
||||
the sign must be known for other inequalities involving > or <:
|
||||
|
||||
>>> f(x*nz <= 1, x)
|
||||
nz*x <= 1
|
||||
>>> p = Symbol('p', positive=True)
|
||||
>>> f(x*p <= 1, x)
|
||||
x <= 1/p
|
||||
|
||||
When there are denominators in the original expression that
|
||||
are removed by expansion, conditions for them will be returned
|
||||
as part of the result:
|
||||
|
||||
>>> f(x < x*(2/x - 1), x)
|
||||
(x < 1) & Ne(x, 0)
|
||||
"""
|
||||
from sympy.solvers.solvers import denoms
|
||||
if s not in ie.free_symbols:
|
||||
return ie
|
||||
if ie.rhs == s:
|
||||
ie = ie.reversed
|
||||
if ie.lhs == s and s not in ie.rhs.free_symbols:
|
||||
return ie
|
||||
|
||||
def classify(ie, s, i):
|
||||
# return True or False if ie evaluates when substituting s with
|
||||
# i else None (if unevaluated) or NaN (when there is an error
|
||||
# in evaluating)
|
||||
try:
|
||||
v = ie.subs(s, i)
|
||||
if v is S.NaN:
|
||||
return v
|
||||
elif v not in (True, False):
|
||||
return
|
||||
return v
|
||||
except TypeError:
|
||||
return S.NaN
|
||||
|
||||
rv = None
|
||||
oo = S.Infinity
|
||||
expr = ie.lhs - ie.rhs
|
||||
try:
|
||||
p = Poly(expr, s)
|
||||
if p.degree() == 0:
|
||||
rv = ie.func(p.as_expr(), 0)
|
||||
elif not linear and p.degree() > 1:
|
||||
# handle in except clause
|
||||
raise NotImplementedError
|
||||
except (PolynomialError, NotImplementedError):
|
||||
if not linear:
|
||||
try:
|
||||
rv = reduce_rational_inequalities([[ie]], s)
|
||||
except PolynomialError:
|
||||
rv = solve_univariate_inequality(ie, s)
|
||||
# remove restrictions wrt +/-oo that may have been
|
||||
# applied when using sets to simplify the relationship
|
||||
okoo = classify(ie, s, oo)
|
||||
if okoo is S.true and classify(rv, s, oo) is S.false:
|
||||
rv = rv.subs(s < oo, True)
|
||||
oknoo = classify(ie, s, -oo)
|
||||
if (oknoo is S.true and
|
||||
classify(rv, s, -oo) is S.false):
|
||||
rv = rv.subs(-oo < s, True)
|
||||
rv = rv.subs(s > -oo, True)
|
||||
if rv is S.true:
|
||||
rv = (s <= oo) if okoo is S.true else (s < oo)
|
||||
if oknoo is not S.true:
|
||||
rv = And(-oo < s, rv)
|
||||
else:
|
||||
p = Poly(expr)
|
||||
|
||||
conds = []
|
||||
if rv is None:
|
||||
e = p.as_expr() # this is in expanded form
|
||||
# Do a safe inversion of e, moving non-s terms
|
||||
# to the rhs and dividing by a nonzero factor if
|
||||
# the relational is Eq/Ne; for other relationals
|
||||
# the sign must also be positive or negative
|
||||
rhs = 0
|
||||
b, ax = e.as_independent(s, as_Add=True)
|
||||
e -= b
|
||||
rhs -= b
|
||||
ef = factor_terms(e)
|
||||
a, e = ef.as_independent(s, as_Add=False)
|
||||
if (a.is_zero != False or # don't divide by potential 0
|
||||
a.is_negative ==
|
||||
a.is_positive is None and # if sign is not known then
|
||||
ie.rel_op not in ('!=', '==')): # reject if not Eq/Ne
|
||||
e = ef
|
||||
a = S.One
|
||||
rhs /= a
|
||||
if a.is_positive:
|
||||
rv = ie.func(e, rhs)
|
||||
else:
|
||||
rv = ie.reversed.func(e, rhs)
|
||||
|
||||
# return conditions under which the value is
|
||||
# valid, too.
|
||||
beginning_denoms = denoms(ie.lhs) | denoms(ie.rhs)
|
||||
current_denoms = denoms(rv)
|
||||
for d in beginning_denoms - current_denoms:
|
||||
c = _solve_inequality(Eq(d, 0), s, linear=linear)
|
||||
if isinstance(c, Eq) and c.lhs == s:
|
||||
if classify(rv, s, c.rhs) is S.true:
|
||||
# rv is permitting this value but it shouldn't
|
||||
conds.append(~c)
|
||||
for i in (-oo, oo):
|
||||
if (classify(rv, s, i) is S.true and
|
||||
classify(ie, s, i) is not S.true):
|
||||
conds.append(s < i if i is oo else i < s)
|
||||
|
||||
conds.append(rv)
|
||||
return And(*conds)
|
||||
|
||||
|
||||
def _reduce_inequalities(inequalities, symbols):
|
||||
# helper for reduce_inequalities
|
||||
|
||||
poly_part, abs_part = {}, {}
|
||||
other = []
|
||||
|
||||
for inequality in inequalities:
|
||||
|
||||
expr, rel = inequality.lhs, inequality.rel_op # rhs is 0
|
||||
|
||||
# check for gens using atoms which is more strict than free_symbols to
|
||||
# guard against EX domain which won't be handled by
|
||||
# reduce_rational_inequalities
|
||||
gens = expr.atoms(Symbol)
|
||||
|
||||
if len(gens) == 1:
|
||||
gen = gens.pop()
|
||||
else:
|
||||
common = expr.free_symbols & symbols
|
||||
if len(common) == 1:
|
||||
gen = common.pop()
|
||||
other.append(_solve_inequality(Relational(expr, 0, rel), gen))
|
||||
continue
|
||||
else:
|
||||
raise NotImplementedError(filldedent('''
|
||||
inequality has more than one symbol of interest.
|
||||
'''))
|
||||
|
||||
if expr.is_polynomial(gen):
|
||||
poly_part.setdefault(gen, []).append((expr, rel))
|
||||
else:
|
||||
components = expr.find(lambda u:
|
||||
u.has(gen) and (
|
||||
u.is_Function or u.is_Pow and not u.exp.is_Integer))
|
||||
if components and all(isinstance(i, Abs) for i in components):
|
||||
abs_part.setdefault(gen, []).append((expr, rel))
|
||||
else:
|
||||
other.append(_solve_inequality(Relational(expr, 0, rel), gen))
|
||||
|
||||
poly_reduced = [reduce_rational_inequalities([exprs], gen) for gen, exprs in poly_part.items()]
|
||||
abs_reduced = [reduce_abs_inequalities(exprs, gen) for gen, exprs in abs_part.items()]
|
||||
|
||||
return And(*(poly_reduced + abs_reduced + other))
|
||||
|
||||
|
||||
def reduce_inequalities(inequalities, symbols=[]):
|
||||
"""Reduce a system of inequalities with rational coefficients.
|
||||
|
||||
Examples
|
||||
========
|
||||
|
||||
>>> from sympy.abc import x, y
|
||||
>>> from sympy import reduce_inequalities
|
||||
|
||||
>>> reduce_inequalities(0 <= x + 3, [])
|
||||
(-3 <= x) & (x < oo)
|
||||
|
||||
>>> reduce_inequalities(0 <= x + y*2 - 1, [x])
|
||||
(x < oo) & (x >= 1 - 2*y)
|
||||
"""
|
||||
if not iterable(inequalities):
|
||||
inequalities = [inequalities]
|
||||
inequalities = [sympify(i) for i in inequalities]
|
||||
|
||||
gens = set().union(*[i.free_symbols for i in inequalities])
|
||||
|
||||
if not iterable(symbols):
|
||||
symbols = [symbols]
|
||||
symbols = (set(symbols) or gens) & gens
|
||||
if any(i.is_extended_real is False for i in symbols):
|
||||
raise TypeError(filldedent('''
|
||||
inequalities cannot contain symbols that are not real.
|
||||
'''))
|
||||
|
||||
# make vanilla symbol real
|
||||
recast = {i: Dummy(i.name, extended_real=True)
|
||||
for i in gens if i.is_extended_real is None}
|
||||
inequalities = [i.xreplace(recast) for i in inequalities]
|
||||
symbols = {i.xreplace(recast) for i in symbols}
|
||||
|
||||
# prefilter
|
||||
keep = []
|
||||
for i in inequalities:
|
||||
if isinstance(i, Relational):
|
||||
i = i.func(i.lhs.as_expr() - i.rhs.as_expr(), 0)
|
||||
elif i not in (True, False):
|
||||
i = Eq(i, 0)
|
||||
if i == True:
|
||||
continue
|
||||
elif i == False:
|
||||
return S.false
|
||||
if i.lhs.is_number:
|
||||
raise NotImplementedError(
|
||||
"could not determine truth value of %s" % i)
|
||||
keep.append(i)
|
||||
inequalities = keep
|
||||
del keep
|
||||
|
||||
# solve system
|
||||
rv = _reduce_inequalities(inequalities, symbols)
|
||||
|
||||
# restore original symbols and return
|
||||
return rv.xreplace({v: k for k, v in recast.items()})
|
||||
@@ -0,0 +1,16 @@
|
||||
from .ode import (allhints, checkinfsol, classify_ode,
|
||||
constantsimp, dsolve, homogeneous_order)
|
||||
|
||||
from .lie_group import infinitesimals
|
||||
|
||||
from .subscheck import checkodesol
|
||||
|
||||
from .systems import (canonical_odes, linear_ode_to_matrix,
|
||||
linodesolve)
|
||||
|
||||
|
||||
__all__ = [
|
||||
'allhints', 'checkinfsol', 'checkodesol', 'classify_ode', 'constantsimp',
|
||||
'dsolve', 'homogeneous_order', 'infinitesimals', 'canonical_odes', 'linear_ode_to_matrix',
|
||||
'linodesolve'
|
||||
]
|
||||
@@ -0,0 +1,272 @@
|
||||
r'''
|
||||
This module contains the implementation of the 2nd_hypergeometric hint for
|
||||
dsolve. This is an incomplete implementation of the algorithm described in [1].
|
||||
The algorithm solves 2nd order linear ODEs of the form
|
||||
|
||||
.. math:: y'' + A(x) y' + B(x) y = 0\text{,}
|
||||
|
||||
where `A` and `B` are rational functions. The algorithm should find any
|
||||
solution of the form
|
||||
|
||||
.. math:: y = P(x) _pF_q(..; ..;\frac{\alpha x^k + \beta}{\gamma x^k + \delta})\text{,}
|
||||
|
||||
where pFq is any of 2F1, 1F1 or 0F1 and `P` is an "arbitrary function".
|
||||
Currently only the 2F1 case is implemented in SymPy but the other cases are
|
||||
described in the paper and could be implemented in future (contributions
|
||||
welcome!).
|
||||
|
||||
References
|
||||
==========
|
||||
|
||||
.. [1] L. Chan, E.S. Cheb-Terrab, Non-Liouvillian solutions for second order
|
||||
linear ODEs, (2004).
|
||||
https://arxiv.org/abs/math-ph/0402063
|
||||
'''
|
||||
|
||||
from sympy.core import S, Pow
|
||||
from sympy.core.function import expand
|
||||
from sympy.core.relational import Eq
|
||||
from sympy.core.symbol import Symbol, Wild
|
||||
from sympy.functions import exp, sqrt, hyper
|
||||
from sympy.integrals import Integral
|
||||
from sympy.polys import roots, gcd
|
||||
from sympy.polys.polytools import cancel, factor
|
||||
from sympy.simplify import collect, simplify, logcombine # type: ignore
|
||||
from sympy.simplify.powsimp import powdenest
|
||||
from sympy.solvers.ode.ode import get_numbered_constants
|
||||
|
||||
|
||||
def match_2nd_hypergeometric(eq, func):
|
||||
x = func.args[0]
|
||||
df = func.diff(x)
|
||||
a3 = Wild('a3', exclude=[func, func.diff(x), func.diff(x, 2)])
|
||||
b3 = Wild('b3', exclude=[func, func.diff(x), func.diff(x, 2)])
|
||||
c3 = Wild('c3', exclude=[func, func.diff(x), func.diff(x, 2)])
|
||||
deq = a3*(func.diff(x, 2)) + b3*df + c3*func
|
||||
r = collect(eq,
|
||||
[func.diff(x, 2), func.diff(x), func]).match(deq)
|
||||
if r:
|
||||
if not all(val.is_polynomial() for val in r.values()):
|
||||
n, d = eq.as_numer_denom()
|
||||
eq = expand(n)
|
||||
r = collect(eq, [func.diff(x, 2), func.diff(x), func]).match(deq)
|
||||
|
||||
if r and r[a3]!=0:
|
||||
A = cancel(r[b3]/r[a3])
|
||||
B = cancel(r[c3]/r[a3])
|
||||
return [A, B]
|
||||
else:
|
||||
return []
|
||||
|
||||
|
||||
def equivalence_hypergeometric(A, B, func):
|
||||
# This method for finding the equivalence is only for 2F1 type.
|
||||
# We can extend it for 1F1 and 0F1 type also.
|
||||
x = func.args[0]
|
||||
|
||||
# making given equation in normal form
|
||||
I1 = factor(cancel(A.diff(x)/2 + A**2/4 - B))
|
||||
|
||||
# computing shifted invariant(J1) of the equation
|
||||
J1 = factor(cancel(x**2*I1 + S(1)/4))
|
||||
num, dem = J1.as_numer_denom()
|
||||
num = powdenest(expand(num))
|
||||
dem = powdenest(expand(dem))
|
||||
# this function will compute the different powers of variable(x) in J1.
|
||||
# then it will help in finding value of k. k is power of x such that we can express
|
||||
# J1 = x**k * J0(x**k) then all the powers in J0 become integers.
|
||||
def _power_counting(num):
|
||||
_pow = {0}
|
||||
for val in num:
|
||||
if val.has(x):
|
||||
if isinstance(val, Pow) and val.as_base_exp()[0] == x:
|
||||
_pow.add(val.as_base_exp()[1])
|
||||
elif val == x:
|
||||
_pow.add(val.as_base_exp()[1])
|
||||
else:
|
||||
_pow.update(_power_counting(val.args))
|
||||
return _pow
|
||||
|
||||
pow_num = _power_counting((num, ))
|
||||
pow_dem = _power_counting((dem, ))
|
||||
pow_dem.update(pow_num)
|
||||
|
||||
_pow = pow_dem
|
||||
k = gcd(_pow)
|
||||
|
||||
# computing I0 of the given equation
|
||||
I0 = powdenest(simplify(factor(((J1/k**2) - S(1)/4)/((x**k)**2))), force=True)
|
||||
I0 = factor(cancel(powdenest(I0.subs(x, x**(S(1)/k)), force=True)))
|
||||
|
||||
# Before this point I0, J1 might be functions of e.g. sqrt(x) but replacing
|
||||
# x with x**(1/k) should result in I0 being a rational function of x or
|
||||
# otherwise the hypergeometric solver cannot be used. Note that k can be a
|
||||
# non-integer rational such as 2/7.
|
||||
if not I0.is_rational_function(x):
|
||||
return None
|
||||
|
||||
num, dem = I0.as_numer_denom()
|
||||
|
||||
max_num_pow = max(_power_counting((num, )))
|
||||
dem_args = dem.args
|
||||
sing_point = []
|
||||
dem_pow = []
|
||||
# calculating singular point of I0.
|
||||
for arg in dem_args:
|
||||
if arg.has(x):
|
||||
if isinstance(arg, Pow):
|
||||
# (x-a)**n
|
||||
dem_pow.append(arg.as_base_exp()[1])
|
||||
sing_point.append(list(roots(arg.as_base_exp()[0], x).keys())[0])
|
||||
else:
|
||||
# (x-a) type
|
||||
dem_pow.append(arg.as_base_exp()[1])
|
||||
sing_point.append(list(roots(arg, x).keys())[0])
|
||||
|
||||
dem_pow.sort()
|
||||
# checking if equivalence is exists or not.
|
||||
|
||||
if equivalence(max_num_pow, dem_pow) == "2F1":
|
||||
return {'I0':I0, 'k':k, 'sing_point':sing_point, 'type':"2F1"}
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
def match_2nd_2F1_hypergeometric(I, k, sing_point, func):
|
||||
x = func.args[0]
|
||||
a = Wild("a")
|
||||
b = Wild("b")
|
||||
c = Wild("c")
|
||||
t = Wild("t")
|
||||
s = Wild("s")
|
||||
r = Wild("r")
|
||||
alpha = Wild("alpha")
|
||||
beta = Wild("beta")
|
||||
gamma = Wild("gamma")
|
||||
delta = Wild("delta")
|
||||
# I0 of the standard 2F1 equation.
|
||||
I0 = ((a-b+1)*(a-b-1)*x**2 + 2*((1-a-b)*c + 2*a*b)*x + c*(c-2))/(4*x**2*(x-1)**2)
|
||||
if sing_point != [0, 1]:
|
||||
# If singular point is [0, 1] then we have standard equation.
|
||||
eqs = []
|
||||
sing_eqs = [-beta/alpha, -delta/gamma, (delta-beta)/(alpha-gamma)]
|
||||
# making equations for the finding the mobius transformation
|
||||
for i in range(3):
|
||||
if i<len(sing_point):
|
||||
eqs.append(Eq(sing_eqs[i], sing_point[i]))
|
||||
else:
|
||||
eqs.append(Eq(1/sing_eqs[i], 0))
|
||||
# solving above equations for the mobius transformation
|
||||
_beta = -alpha*sing_point[0]
|
||||
_delta = -gamma*sing_point[1]
|
||||
_gamma = alpha
|
||||
if len(sing_point) == 3:
|
||||
_gamma = (_beta + sing_point[2]*alpha)/(sing_point[2] - sing_point[1])
|
||||
mob = (alpha*x + beta)/(gamma*x + delta)
|
||||
mob = mob.subs(beta, _beta)
|
||||
mob = mob.subs(delta, _delta)
|
||||
mob = mob.subs(gamma, _gamma)
|
||||
mob = cancel(mob)
|
||||
t = (beta - delta*x)/(gamma*x - alpha)
|
||||
t = cancel(((t.subs(beta, _beta)).subs(delta, _delta)).subs(gamma, _gamma))
|
||||
else:
|
||||
mob = x
|
||||
t = x
|
||||
|
||||
# applying mobius transformation in I to make it into I0.
|
||||
I = I.subs(x, t)
|
||||
I = I*(t.diff(x))**2
|
||||
I = factor(I)
|
||||
dict_I = {x**2:0, x:0, 1:0}
|
||||
I0_num, I0_dem = I0.as_numer_denom()
|
||||
# collecting coeff of (x**2, x), of the standard equation.
|
||||
# substituting (a-b) = s, (a+b) = r
|
||||
dict_I0 = {x**2:s**2 - 1, x:(2*(1-r)*c + (r+s)*(r-s)), 1:c*(c-2)}
|
||||
# collecting coeff of (x**2, x) from I0 of the given equation.
|
||||
dict_I.update(collect(expand(cancel(I*I0_dem)), [x**2, x], evaluate=False))
|
||||
eqs = []
|
||||
# We are comparing the coeff of powers of different x, for finding the values of
|
||||
# parameters of standard equation.
|
||||
for key in [x**2, x, 1]:
|
||||
eqs.append(Eq(dict_I[key], dict_I0[key]))
|
||||
|
||||
# We can have many possible roots for the equation.
|
||||
# I am selecting the root on the basis that when we have
|
||||
# standard equation eq = x*(x-1)*f(x).diff(x, 2) + ((a+b+1)*x-c)*f(x).diff(x) + a*b*f(x)
|
||||
# then root should be a, b, c.
|
||||
|
||||
_c = 1 - factor(sqrt(1+eqs[2].lhs))
|
||||
if not _c.has(Symbol):
|
||||
_c = min(list(roots(eqs[2], c)))
|
||||
_s = factor(sqrt(eqs[0].lhs + 1))
|
||||
_r = _c - factor(sqrt(_c**2 + _s**2 + eqs[1].lhs - 2*_c))
|
||||
_a = (_r + _s)/2
|
||||
_b = (_r - _s)/2
|
||||
|
||||
rn = {'a':simplify(_a), 'b':simplify(_b), 'c':simplify(_c), 'k':k, 'mobius':mob, 'type':"2F1"}
|
||||
return rn
|
||||
|
||||
|
||||
def equivalence(max_num_pow, dem_pow):
|
||||
# this function is made for checking the equivalence with 2F1 type of equation.
|
||||
# max_num_pow is the value of maximum power of x in numerator
|
||||
# and dem_pow is list of powers of different factor of form (a*x b).
|
||||
# reference from table 1 in paper - "Non-Liouvillian solutions for second order
|
||||
# linear ODEs" by L. Chan, E.S. Cheb-Terrab.
|
||||
# We can extend it for 1F1 and 0F1 type also.
|
||||
|
||||
if max_num_pow == 2:
|
||||
if dem_pow in [[2, 2], [2, 2, 2]]:
|
||||
return "2F1"
|
||||
elif max_num_pow == 1:
|
||||
if dem_pow in [[1, 2, 2], [2, 2, 2], [1, 2], [2, 2]]:
|
||||
return "2F1"
|
||||
elif max_num_pow == 0:
|
||||
if dem_pow in [[1, 1, 2], [2, 2], [1, 2, 2], [1, 1], [2], [1, 2], [2, 2]]:
|
||||
return "2F1"
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def get_sol_2F1_hypergeometric(eq, func, match_object):
|
||||
x = func.args[0]
|
||||
from sympy.simplify.hyperexpand import hyperexpand
|
||||
from sympy.polys.polytools import factor
|
||||
C0, C1 = get_numbered_constants(eq, num=2)
|
||||
a = match_object['a']
|
||||
b = match_object['b']
|
||||
c = match_object['c']
|
||||
A = match_object['A']
|
||||
|
||||
sol = None
|
||||
|
||||
if c.is_integer == False:
|
||||
sol = C0*hyper([a, b], [c], x) + C1*hyper([a-c+1, b-c+1], [2-c], x)*x**(1-c)
|
||||
elif c == 1:
|
||||
y2 = Integral(exp(Integral((-(a+b+1)*x + c)/(x**2-x), x))/(hyperexpand(hyper([a, b], [c], x))**2), x)*hyper([a, b], [c], x)
|
||||
sol = C0*hyper([a, b], [c], x) + C1*y2
|
||||
elif (c-a-b).is_integer == False:
|
||||
sol = C0*hyper([a, b], [1+a+b-c], 1-x) + C1*hyper([c-a, c-b], [1+c-a-b], 1-x)*(1-x)**(c-a-b)
|
||||
|
||||
if sol:
|
||||
# applying transformation in the solution
|
||||
subs = match_object['mobius']
|
||||
dtdx = simplify(1/(subs.diff(x)))
|
||||
_B = ((a + b + 1)*x - c).subs(x, subs)*dtdx
|
||||
_B = factor(_B + ((x**2 -x).subs(x, subs))*(dtdx.diff(x)*dtdx))
|
||||
_A = factor((x**2 - x).subs(x, subs)*(dtdx**2))
|
||||
e = exp(logcombine(Integral(cancel(_B/(2*_A)), x), force=True))
|
||||
sol = sol.subs(x, match_object['mobius'])
|
||||
sol = sol.subs(x, x**match_object['k'])
|
||||
e = e.subs(x, x**match_object['k'])
|
||||
|
||||
if not A.is_zero:
|
||||
e1 = Integral(A/2, x)
|
||||
e1 = exp(logcombine(e1, force=True))
|
||||
sol = cancel((e/e1)*x**((-match_object['k']+1)/2))*sol
|
||||
sol = Eq(func, sol)
|
||||
return sol
|
||||
|
||||
sol = cancel((e)*x**((-match_object['k']+1)/2))*sol
|
||||
sol = Eq(func, sol)
|
||||
return sol
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,484 @@
|
||||
r"""
|
||||
This File contains helper functions for nth_linear_constant_coeff_undetermined_coefficients,
|
||||
nth_linear_euler_eq_nonhomogeneous_undetermined_coefficients,
|
||||
nth_linear_constant_coeff_variation_of_parameters,
|
||||
and nth_linear_euler_eq_nonhomogeneous_variation_of_parameters.
|
||||
|
||||
All the functions in this file are used by more than one solvers so, instead of creating
|
||||
instances in other classes for using them it is better to keep it here as separate helpers.
|
||||
|
||||
"""
|
||||
from collections import Counter
|
||||
from sympy.core import Add, S
|
||||
from sympy.core.function import diff, expand, _mexpand, expand_mul
|
||||
from sympy.core.relational import Eq
|
||||
from sympy.core.sorting import default_sort_key
|
||||
from sympy.core.symbol import Dummy, Wild
|
||||
from sympy.functions import exp, cos, cosh, im, log, re, sin, sinh, \
|
||||
atan2, conjugate
|
||||
from sympy.integrals import Integral
|
||||
from sympy.polys import (Poly, RootOf, rootof, roots)
|
||||
from sympy.simplify import collect, simplify, separatevars, powsimp, trigsimp # type: ignore
|
||||
from sympy.utilities import numbered_symbols
|
||||
from sympy.solvers.solvers import solve
|
||||
from sympy.matrices import wronskian
|
||||
from .subscheck import sub_func_doit
|
||||
from sympy.solvers.ode.ode import get_numbered_constants
|
||||
|
||||
|
||||
def _test_term(coeff, func, order):
|
||||
r"""
|
||||
Linear Euler ODEs have the form K*x**order*diff(y(x), x, order) = F(x),
|
||||
where K is independent of x and y(x), order>= 0.
|
||||
So we need to check that for each term, coeff == K*x**order from
|
||||
some K. We have a few cases, since coeff may have several
|
||||
different types.
|
||||
"""
|
||||
x = func.args[0]
|
||||
f = func.func
|
||||
if order < 0:
|
||||
raise ValueError("order should be greater than 0")
|
||||
if coeff == 0:
|
||||
return True
|
||||
if order == 0:
|
||||
if x in coeff.free_symbols:
|
||||
return False
|
||||
return True
|
||||
if coeff.is_Mul:
|
||||
if coeff.has(f(x)):
|
||||
return False
|
||||
return x**order in coeff.args
|
||||
elif coeff.is_Pow:
|
||||
return coeff.as_base_exp() == (x, order)
|
||||
elif order == 1:
|
||||
return x == coeff
|
||||
return False
|
||||
|
||||
|
||||
def _get_euler_characteristic_eq_sols(eq, func, match_obj):
|
||||
r"""
|
||||
Returns the solution of homogeneous part of the linear euler ODE and
|
||||
the list of roots of characteristic equation.
|
||||
|
||||
The parameter ``match_obj`` is a dict of order:coeff terms, where order is the order
|
||||
of the derivative on each term, and coeff is the coefficient of that derivative.
|
||||
|
||||
"""
|
||||
x = func.args[0]
|
||||
f = func.func
|
||||
|
||||
# First, set up characteristic equation.
|
||||
chareq, symbol = S.Zero, Dummy('x')
|
||||
|
||||
for i in match_obj:
|
||||
if i >= 0:
|
||||
chareq += (match_obj[i]*diff(x**symbol, x, i)*x**-symbol).expand()
|
||||
|
||||
chareq = Poly(chareq, symbol)
|
||||
chareqroots = [rootof(chareq, k) for k in range(chareq.degree())]
|
||||
collectterms = []
|
||||
|
||||
# A generator of constants
|
||||
constants = list(get_numbered_constants(eq, num=chareq.degree()*2))
|
||||
constants.reverse()
|
||||
|
||||
# Create a dict root: multiplicity or charroots
|
||||
charroots = Counter(chareqroots)
|
||||
gsol = S.Zero
|
||||
ln = log
|
||||
for root, multiplicity in charroots.items():
|
||||
for i in range(multiplicity):
|
||||
if isinstance(root, RootOf):
|
||||
gsol += (x**root) * constants.pop()
|
||||
if multiplicity != 1:
|
||||
raise ValueError("Value should be 1")
|
||||
collectterms = [(0, root, 0)] + collectterms
|
||||
elif root.is_real:
|
||||
gsol += ln(x)**i*(x**root) * constants.pop()
|
||||
collectterms = [(i, root, 0)] + collectterms
|
||||
else:
|
||||
reroot = re(root)
|
||||
imroot = im(root)
|
||||
gsol += ln(x)**i * (x**reroot) * (
|
||||
constants.pop() * sin(abs(imroot)*ln(x))
|
||||
+ constants.pop() * cos(imroot*ln(x)))
|
||||
collectterms = [(i, reroot, imroot)] + collectterms
|
||||
|
||||
gsol = Eq(f(x), gsol)
|
||||
|
||||
gensols = []
|
||||
# Keep track of when to use sin or cos for nonzero imroot
|
||||
for i, reroot, imroot in collectterms:
|
||||
if imroot == 0:
|
||||
gensols.append(ln(x)**i*x**reroot)
|
||||
else:
|
||||
sin_form = ln(x)**i*x**reroot*sin(abs(imroot)*ln(x))
|
||||
if sin_form in gensols:
|
||||
cos_form = ln(x)**i*x**reroot*cos(imroot*ln(x))
|
||||
gensols.append(cos_form)
|
||||
else:
|
||||
gensols.append(sin_form)
|
||||
return gsol, gensols
|
||||
|
||||
|
||||
def _solve_variation_of_parameters(eq, func, roots, homogen_sol, order, match_obj, simplify_flag=True):
|
||||
r"""
|
||||
Helper function for the method of variation of parameters and nonhomogeneous euler eq.
|
||||
|
||||
See the
|
||||
:py:meth:`~sympy.solvers.ode.single.NthLinearConstantCoeffVariationOfParameters`
|
||||
docstring for more information on this method.
|
||||
|
||||
The parameter are ``match_obj`` should be a dictionary that has the following
|
||||
keys:
|
||||
|
||||
``list``
|
||||
A list of solutions to the homogeneous equation.
|
||||
|
||||
``sol``
|
||||
The general solution.
|
||||
|
||||
"""
|
||||
f = func.func
|
||||
x = func.args[0]
|
||||
r = match_obj
|
||||
psol = 0
|
||||
wr = wronskian(roots, x)
|
||||
|
||||
if simplify_flag:
|
||||
wr = simplify(wr) # We need much better simplification for
|
||||
# some ODEs. See issue 4662, for example.
|
||||
# To reduce commonly occurring sin(x)**2 + cos(x)**2 to 1
|
||||
wr = trigsimp(wr, deep=True, recursive=True)
|
||||
if not wr:
|
||||
# The wronskian will be 0 iff the solutions are not linearly
|
||||
# independent.
|
||||
raise NotImplementedError("Cannot find " + str(order) +
|
||||
" solutions to the homogeneous equation necessary to apply " +
|
||||
"variation of parameters to " + str(eq) + " (Wronskian == 0)")
|
||||
if len(roots) != order:
|
||||
raise NotImplementedError("Cannot find " + str(order) +
|
||||
" solutions to the homogeneous equation necessary to apply " +
|
||||
"variation of parameters to " +
|
||||
str(eq) + " (number of terms != order)")
|
||||
negoneterm = S.NegativeOne**(order)
|
||||
for i in roots:
|
||||
psol += negoneterm*Integral(wronskian([sol for sol in roots if sol != i], x)*r[-1]/wr, x)*i/r[order]
|
||||
negoneterm *= -1
|
||||
|
||||
if simplify_flag:
|
||||
psol = simplify(psol)
|
||||
psol = trigsimp(psol, deep=True)
|
||||
return Eq(f(x), homogen_sol.rhs + psol)
|
||||
|
||||
|
||||
def _get_const_characteristic_eq_sols(r, func, order):
|
||||
r"""
|
||||
Returns the roots of characteristic equation of constant coefficient
|
||||
linear ODE and list of collectterms which is later on used by simplification
|
||||
to use collect on solution.
|
||||
|
||||
The parameter `r` is a dict of order:coeff terms, where order is the order of the
|
||||
derivative on each term, and coeff is the coefficient of that derivative.
|
||||
|
||||
"""
|
||||
x = func.args[0]
|
||||
# First, set up characteristic equation.
|
||||
chareq, symbol = S.Zero, Dummy('x')
|
||||
|
||||
for i in r.keys():
|
||||
if isinstance(i, str) or i < 0:
|
||||
pass
|
||||
else:
|
||||
chareq += r[i]*symbol**i
|
||||
|
||||
chareq = Poly(chareq, symbol)
|
||||
# Can't just call roots because it doesn't return rootof for unsolveable
|
||||
# polynomials.
|
||||
chareqroots = roots(chareq, multiple=True)
|
||||
if len(chareqroots) != order:
|
||||
chareqroots = [rootof(chareq, k) for k in range(chareq.degree())]
|
||||
|
||||
chareq_is_complex = not all(i.is_real for i in chareq.all_coeffs())
|
||||
|
||||
# Create a dict root: multiplicity or charroots
|
||||
charroots = Counter(chareqroots)
|
||||
# We need to keep track of terms so we can run collect() at the end.
|
||||
# This is necessary for constantsimp to work properly.
|
||||
collectterms = []
|
||||
gensols = []
|
||||
conjugate_roots = [] # used to prevent double-use of conjugate roots
|
||||
# Loop over roots in theorder provided by roots/rootof...
|
||||
for root in chareqroots:
|
||||
# but don't repoeat multiple roots.
|
||||
if root not in charroots:
|
||||
continue
|
||||
multiplicity = charroots.pop(root)
|
||||
for i in range(multiplicity):
|
||||
if chareq_is_complex:
|
||||
gensols.append(x**i*exp(root*x))
|
||||
collectterms = [(i, root, 0)] + collectterms
|
||||
continue
|
||||
reroot = re(root)
|
||||
imroot = im(root)
|
||||
if imroot.has(atan2) and reroot.has(atan2):
|
||||
# Remove this condition when re and im stop returning
|
||||
# circular atan2 usages.
|
||||
gensols.append(x**i*exp(root*x))
|
||||
collectterms = [(i, root, 0)] + collectterms
|
||||
else:
|
||||
if root in conjugate_roots:
|
||||
collectterms = [(i, reroot, imroot)] + collectterms
|
||||
continue
|
||||
if imroot == 0:
|
||||
gensols.append(x**i*exp(reroot*x))
|
||||
collectterms = [(i, reroot, 0)] + collectterms
|
||||
continue
|
||||
conjugate_roots.append(conjugate(root))
|
||||
gensols.append(x**i*exp(reroot*x) * sin(abs(imroot) * x))
|
||||
gensols.append(x**i*exp(reroot*x) * cos( imroot * x))
|
||||
|
||||
# This ordering is important
|
||||
collectterms = [(i, reroot, imroot)] + collectterms
|
||||
return gensols, collectterms
|
||||
|
||||
|
||||
# Ideally these kind of simplification functions shouldn't be part of solvers.
|
||||
# odesimp should be improved to handle these kind of specific simplifications.
|
||||
def _get_simplified_sol(sol, func, collectterms):
|
||||
r"""
|
||||
Helper function which collects the solution on
|
||||
collectterms. Ideally this should be handled by odesimp.It is used
|
||||
only when the simplify is set to True in dsolve.
|
||||
|
||||
The parameter ``collectterms`` is a list of tuple (i, reroot, imroot) where `i` is
|
||||
the multiplicity of the root, reroot is real part and imroot being the imaginary part.
|
||||
|
||||
"""
|
||||
f = func.func
|
||||
x = func.args[0]
|
||||
collectterms.sort(key=default_sort_key)
|
||||
collectterms.reverse()
|
||||
assert len(sol) == 1 and sol[0].lhs == f(x)
|
||||
sol = sol[0].rhs
|
||||
sol = expand_mul(sol)
|
||||
for i, reroot, imroot in collectterms:
|
||||
sol = collect(sol, x**i*exp(reroot*x)*sin(abs(imroot)*x))
|
||||
sol = collect(sol, x**i*exp(reroot*x)*cos(imroot*x))
|
||||
for i, reroot, imroot in collectterms:
|
||||
sol = collect(sol, x**i*exp(reroot*x))
|
||||
sol = powsimp(sol)
|
||||
return Eq(f(x), sol)
|
||||
|
||||
|
||||
def _undetermined_coefficients_match(expr, x, func=None, eq_homogeneous=S.Zero):
|
||||
r"""
|
||||
Returns a trial function match if undetermined coefficients can be applied
|
||||
to ``expr``, and ``None`` otherwise.
|
||||
|
||||
A trial expression can be found for an expression for use with the method
|
||||
of undetermined coefficients if the expression is an
|
||||
additive/multiplicative combination of constants, polynomials in `x` (the
|
||||
independent variable of expr), `\sin(a x + b)`, `\cos(a x + b)`, and
|
||||
`e^{a x}` terms (in other words, it has a finite number of linearly
|
||||
independent derivatives).
|
||||
|
||||
Note that you may still need to multiply each term returned here by
|
||||
sufficient `x` to make it linearly independent with the solutions to the
|
||||
homogeneous equation.
|
||||
|
||||
This is intended for internal use by ``undetermined_coefficients`` hints.
|
||||
|
||||
SymPy currently has no way to convert `\sin^n(x) \cos^m(y)` into a sum of
|
||||
only `\sin(a x)` and `\cos(b x)` terms, so these are not implemented. So,
|
||||
for example, you will need to manually convert `\sin^2(x)` into `[1 +
|
||||
\cos(2 x)]/2` to properly apply the method of undetermined coefficients on
|
||||
it.
|
||||
|
||||
Examples
|
||||
========
|
||||
|
||||
>>> from sympy import log, exp
|
||||
>>> from sympy.solvers.ode.nonhomogeneous import _undetermined_coefficients_match
|
||||
>>> from sympy.abc import x
|
||||
>>> _undetermined_coefficients_match(9*x*exp(x) + exp(-x), x)
|
||||
{'test': True, 'trialset': {x*exp(x), exp(-x), exp(x)}}
|
||||
>>> _undetermined_coefficients_match(log(x), x)
|
||||
{'test': False}
|
||||
|
||||
"""
|
||||
a = Wild('a', exclude=[x])
|
||||
b = Wild('b', exclude=[x])
|
||||
expr = powsimp(expr, combine='exp') # exp(x)*exp(2*x + 1) => exp(3*x + 1)
|
||||
retdict = {}
|
||||
|
||||
def _test_term(expr, x) -> bool:
|
||||
r"""
|
||||
Test if ``expr`` fits the proper form for undetermined coefficients.
|
||||
"""
|
||||
if not expr.has(x):
|
||||
return True
|
||||
if expr.is_Add:
|
||||
return all(_test_term(i, x) for i in expr.args)
|
||||
if expr.is_Mul:
|
||||
if expr.has(sin, cos):
|
||||
foundtrig = False
|
||||
# Make sure that there is only one trig function in the args.
|
||||
# See the docstring.
|
||||
for i in expr.args:
|
||||
if i.has(sin, cos):
|
||||
if foundtrig:
|
||||
return False
|
||||
else:
|
||||
foundtrig = True
|
||||
return all(_test_term(i, x) for i in expr.args)
|
||||
if expr.is_Function:
|
||||
return expr.func in (sin, cos, exp, sinh, cosh) and \
|
||||
bool(expr.args[0].match(a*x + b))
|
||||
if expr.is_Pow and expr.base.is_Symbol and expr.exp.is_Integer and \
|
||||
expr.exp >= 0:
|
||||
return True
|
||||
if expr.is_Pow and expr.base.is_number:
|
||||
return bool(expr.exp.match(a*x + b))
|
||||
return expr.is_Symbol or bool(expr.is_number)
|
||||
|
||||
def _get_trial_set(expr, x, exprs=set()):
|
||||
r"""
|
||||
Returns a set of trial terms for undetermined coefficients.
|
||||
|
||||
The idea behind undetermined coefficients is that the terms expression
|
||||
repeat themselves after a finite number of derivatives, except for the
|
||||
coefficients (they are linearly dependent). So if we collect these,
|
||||
we should have the terms of our trial function.
|
||||
"""
|
||||
def _remove_coefficient(expr, x):
|
||||
r"""
|
||||
Returns the expression without a coefficient.
|
||||
|
||||
Similar to expr.as_independent(x)[1], except it only works
|
||||
multiplicatively.
|
||||
"""
|
||||
term = S.One
|
||||
if expr.is_Mul:
|
||||
for i in expr.args:
|
||||
if i.has(x):
|
||||
term *= i
|
||||
elif expr.has(x):
|
||||
term = expr
|
||||
return term
|
||||
|
||||
expr = expand_mul(expr)
|
||||
if expr.is_Add:
|
||||
for term in expr.args:
|
||||
if _remove_coefficient(term, x) in exprs:
|
||||
pass
|
||||
else:
|
||||
exprs.add(_remove_coefficient(term, x))
|
||||
exprs = exprs.union(_get_trial_set(term, x, exprs))
|
||||
else:
|
||||
term = _remove_coefficient(expr, x)
|
||||
tmpset = exprs.union({term})
|
||||
oldset = set()
|
||||
while tmpset != oldset:
|
||||
# If you get stuck in this loop, then _test_term is probably
|
||||
# broken
|
||||
oldset = tmpset.copy()
|
||||
expr = expr.diff(x)
|
||||
term = _remove_coefficient(expr, x)
|
||||
if term.is_Add:
|
||||
tmpset = tmpset.union(_get_trial_set(term, x, tmpset))
|
||||
else:
|
||||
tmpset.add(term)
|
||||
exprs = tmpset
|
||||
return exprs
|
||||
|
||||
def is_homogeneous_solution(term):
|
||||
r""" This function checks whether the given trialset contains any root
|
||||
of homogeneous equation"""
|
||||
return expand(sub_func_doit(eq_homogeneous, func, term)).is_zero
|
||||
|
||||
retdict['test'] = _test_term(expr, x)
|
||||
if retdict['test']:
|
||||
# Try to generate a list of trial solutions that will have the
|
||||
# undetermined coefficients. Note that if any of these are not linearly
|
||||
# independent with any of the solutions to the homogeneous equation,
|
||||
# then they will need to be multiplied by sufficient x to make them so.
|
||||
# This function DOES NOT do that (it doesn't even look at the
|
||||
# homogeneous equation).
|
||||
temp_set = set()
|
||||
for i in Add.make_args(expr):
|
||||
act = _get_trial_set(i, x)
|
||||
if eq_homogeneous is not S.Zero:
|
||||
while any(is_homogeneous_solution(ts) for ts in act):
|
||||
act = {x*ts for ts in act}
|
||||
temp_set = temp_set.union(act)
|
||||
|
||||
retdict['trialset'] = temp_set
|
||||
return retdict
|
||||
|
||||
|
||||
def _solve_undetermined_coefficients(eq, func, order, match, trialset):
|
||||
r"""
|
||||
Helper function for the method of undetermined coefficients.
|
||||
|
||||
See the
|
||||
:py:meth:`~sympy.solvers.ode.single.NthLinearConstantCoeffUndeterminedCoefficients`
|
||||
docstring for more information on this method.
|
||||
|
||||
The parameter ``trialset`` is the set of trial functions as returned by
|
||||
``_undetermined_coefficients_match()['trialset']``.
|
||||
|
||||
The parameter ``match`` should be a dictionary that has the following
|
||||
keys:
|
||||
|
||||
``list``
|
||||
A list of solutions to the homogeneous equation.
|
||||
|
||||
``sol``
|
||||
The general solution.
|
||||
|
||||
"""
|
||||
r = match
|
||||
coeffs = numbered_symbols('a', cls=Dummy)
|
||||
coefflist = []
|
||||
gensols = r['list']
|
||||
gsol = r['sol']
|
||||
f = func.func
|
||||
x = func.args[0]
|
||||
|
||||
if len(gensols) != order:
|
||||
raise NotImplementedError("Cannot find " + str(order) +
|
||||
" solutions to the homogeneous equation necessary to apply" +
|
||||
" undetermined coefficients to " + str(eq) +
|
||||
" (number of terms != order)")
|
||||
|
||||
trialfunc = 0
|
||||
for i in trialset:
|
||||
c = next(coeffs)
|
||||
coefflist.append(c)
|
||||
trialfunc += c*i
|
||||
|
||||
eqs = sub_func_doit(eq, f(x), trialfunc)
|
||||
|
||||
coeffsdict = dict(list(zip(trialset, [0]*(len(trialset) + 1))))
|
||||
|
||||
eqs = _mexpand(eqs)
|
||||
|
||||
for i in Add.make_args(eqs):
|
||||
s = separatevars(i, dict=True, symbols=[x])
|
||||
if coeffsdict.get(s[x]):
|
||||
coeffsdict[s[x]] += s['coeff']
|
||||
else:
|
||||
coeffsdict[s[x]] = s['coeff']
|
||||
|
||||
coeffvals = solve(list(coeffsdict.values()), coefflist)
|
||||
|
||||
if not coeffvals:
|
||||
raise NotImplementedError(
|
||||
"Could not solve `%s` using the "
|
||||
"method of undetermined coefficients "
|
||||
"(unable to solve for coefficients)." % eq)
|
||||
|
||||
psol = trialfunc.subs(coeffvals)
|
||||
|
||||
return Eq(f(x), gsol.rhs + psol)
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,893 @@
|
||||
r"""
|
||||
This module contains :py:meth:`~sympy.solvers.ode.riccati.solve_riccati`,
|
||||
a function which gives all rational particular solutions to first order
|
||||
Riccati ODEs. A general first order Riccati ODE is given by -
|
||||
|
||||
.. math:: y' = b_0(x) + b_1(x)w + b_2(x)w^2
|
||||
|
||||
where `b_0, b_1` and `b_2` can be arbitrary rational functions of `x`
|
||||
with `b_2 \ne 0`. When `b_2 = 0`, the equation is not a Riccati ODE
|
||||
anymore and becomes a Linear ODE. Similarly, when `b_0 = 0`, the equation
|
||||
is a Bernoulli ODE. The algorithm presented below can find rational
|
||||
solution(s) to all ODEs with `b_2 \ne 0` that have a rational solution,
|
||||
or prove that no rational solution exists for the equation.
|
||||
|
||||
Background
|
||||
==========
|
||||
|
||||
A Riccati equation can be transformed to its normal form
|
||||
|
||||
.. math:: y' + y^2 = a(x)
|
||||
|
||||
using the transformation
|
||||
|
||||
.. math:: y = -b_2(x) - \frac{b'_2(x)}{2 b_2(x)} - \frac{b_1(x)}{2}
|
||||
|
||||
where `a(x)` is given by
|
||||
|
||||
.. math:: a(x) = \frac{1}{4}\left(\frac{b_2'}{b_2} + b_1\right)^2 - \frac{1}{2}\left(\frac{b_2'}{b_2} + b_1\right)' - b_0 b_2
|
||||
|
||||
Thus, we can develop an algorithm to solve for the Riccati equation
|
||||
in its normal form, which would in turn give us the solution for
|
||||
the original Riccati equation.
|
||||
|
||||
Algorithm
|
||||
=========
|
||||
|
||||
The algorithm implemented here is presented in the Ph.D thesis
|
||||
"Rational and Algebraic Solutions of First-Order Algebraic ODEs"
|
||||
by N. Thieu Vo. The entire thesis can be found here -
|
||||
https://www3.risc.jku.at/publications/download/risc_5387/PhDThesisThieu.pdf
|
||||
|
||||
We have only implemented the Rational Riccati solver (Algorithm 11,
|
||||
Pg 78-82 in Thesis). Before we proceed towards the implementation
|
||||
of the algorithm, a few definitions to understand are -
|
||||
|
||||
1. Valuation of a Rational Function at `\infty`:
|
||||
The valuation of a rational function `p(x)` at `\infty` is equal
|
||||
to the difference between the degree of the denominator and the
|
||||
numerator of `p(x)`.
|
||||
|
||||
NOTE: A general definition of valuation of a rational function
|
||||
at any value of `x` can be found in Pg 63 of the thesis, but
|
||||
is not of any interest for this algorithm.
|
||||
|
||||
2. Zeros and Poles of a Rational Function:
|
||||
Let `a(x) = \frac{S(x)}{T(x)}, T \ne 0` be a rational function
|
||||
of `x`. Then -
|
||||
|
||||
a. The Zeros of `a(x)` are the roots of `S(x)`.
|
||||
b. The Poles of `a(x)` are the roots of `T(x)`. However, `\infty`
|
||||
can also be a pole of a(x). We say that `a(x)` has a pole at
|
||||
`\infty` if `a(\frac{1}{x})` has a pole at 0.
|
||||
|
||||
Every pole is associated with an order that is equal to the multiplicity
|
||||
of its appearance as a root of `T(x)`. A pole is called a simple pole if
|
||||
it has an order 1. Similarly, a pole is called a multiple pole if it has
|
||||
an order `\ge` 2.
|
||||
|
||||
Necessary Conditions
|
||||
====================
|
||||
|
||||
For a Riccati equation in its normal form,
|
||||
|
||||
.. math:: y' + y^2 = a(x)
|
||||
|
||||
we can define
|
||||
|
||||
a. A pole is called a movable pole if it is a pole of `y(x)` and is not
|
||||
a pole of `a(x)`.
|
||||
b. Similarly, a pole is called a non-movable pole if it is a pole of both
|
||||
`y(x)` and `a(x)`.
|
||||
|
||||
Then, the algorithm states that a rational solution exists only if -
|
||||
|
||||
a. Every pole of `a(x)` must be either a simple pole or a multiple pole
|
||||
of even order.
|
||||
b. The valuation of `a(x)` at `\infty` must be even or be `\ge` 2.
|
||||
|
||||
This algorithm finds all possible rational solutions for the Riccati ODE.
|
||||
If no rational solutions are found, it means that no rational solutions
|
||||
exist.
|
||||
|
||||
The algorithm works for Riccati ODEs where the coefficients are rational
|
||||
functions in the independent variable `x` with rational number coefficients
|
||||
i.e. in `Q(x)`. The coefficients in the rational function cannot be floats,
|
||||
irrational numbers, symbols or any other kind of expression. The reasons
|
||||
for this are -
|
||||
|
||||
1. When using symbols, different symbols could take the same value and this
|
||||
would affect the multiplicity of poles if symbols are present here.
|
||||
|
||||
2. An integer degree bound is required to calculate a polynomial solution
|
||||
to an auxiliary differential equation, which in turn gives the particular
|
||||
solution for the original ODE. If symbols/floats/irrational numbers are
|
||||
present, we cannot determine if the expression for the degree bound is an
|
||||
integer or not.
|
||||
|
||||
Solution
|
||||
========
|
||||
|
||||
With these definitions, we can state a general form for the solution of
|
||||
the equation. `y(x)` must have the form -
|
||||
|
||||
.. math:: y(x) = \sum_{i=1}^{n} \sum_{j=1}^{r_i} \frac{c_{ij}}{(x - x_i)^j} + \sum_{i=1}^{m} \frac{1}{x - \chi_i} + \sum_{i=0}^{N} d_i x^i
|
||||
|
||||
where `x_1, x_2, \dots, x_n` are non-movable poles of `a(x)`,
|
||||
`\chi_1, \chi_2, \dots, \chi_m` are movable poles of `a(x)`, and the values
|
||||
of `N, n, r_1, r_2, \dots, r_n` can be determined from `a(x)`. The
|
||||
coefficient vectors `(d_0, d_1, \dots, d_N)` and `(c_{i1}, c_{i2}, \dots, c_{i r_i})`
|
||||
can be determined from `a(x)`. We will have 2 choices each of these vectors
|
||||
and part of the procedure is figuring out which of the 2 should be used
|
||||
to get the solution correctly.
|
||||
|
||||
Implementation
|
||||
==============
|
||||
|
||||
In this implementation, we use ``Poly`` to represent a rational function
|
||||
rather than using ``Expr`` since ``Poly`` is much faster. Since we cannot
|
||||
represent rational functions directly using ``Poly``, we instead represent
|
||||
a rational function with 2 ``Poly`` objects - one for its numerator and
|
||||
the other for its denominator.
|
||||
|
||||
The code is written to match the steps given in the thesis (Pg 82)
|
||||
|
||||
Step 0 : Match the equation -
|
||||
Find `b_0, b_1` and `b_2`. If `b_2 = 0` or no such functions exist, raise
|
||||
an error
|
||||
|
||||
Step 1 : Transform the equation to its normal form as explained in the
|
||||
theory section.
|
||||
|
||||
Step 2 : Initialize an empty set of solutions, ``sol``.
|
||||
|
||||
Step 3 : If `a(x) = 0`, append `\frac{1}/{(x - C1)}` to ``sol``.
|
||||
|
||||
Step 4 : If `a(x)` is a rational non-zero number, append `\pm \sqrt{a}`
|
||||
to ``sol``.
|
||||
|
||||
Step 5 : Find the poles and their multiplicities of `a(x)`. Let
|
||||
the number of poles be `n`. Also find the valuation of `a(x)` at
|
||||
`\infty` using ``val_at_inf``.
|
||||
|
||||
NOTE: Although the algorithm considers `\infty` as a pole, it is
|
||||
not mentioned if it a part of the set of finite poles. `\infty`
|
||||
is NOT a part of the set of finite poles. If a pole exists at
|
||||
`\infty`, we use its multiplicity to find the laurent series of
|
||||
`a(x)` about `\infty`.
|
||||
|
||||
Step 6 : Find `n` c-vectors (one for each pole) and 1 d-vector using
|
||||
``construct_c`` and ``construct_d``. Now, determine all the ``2**(n + 1)``
|
||||
combinations of choosing between 2 choices for each of the `n` c-vectors
|
||||
and 1 d-vector.
|
||||
|
||||
NOTE: The equation for `d_{-1}` in Case 4 (Pg 80) has a printinig
|
||||
mistake. The term `- d_N` must be replaced with `-N d_N`. The same
|
||||
has been explained in the code as well.
|
||||
|
||||
For each of these above combinations, do
|
||||
|
||||
Step 8 : Compute `m` in ``compute_m_ybar``. `m` is the degree bound of
|
||||
the polynomial solution we must find for the auxiliary equation.
|
||||
|
||||
Step 9 : In ``compute_m_ybar``, compute ybar as well where ``ybar`` is
|
||||
one part of y(x) -
|
||||
|
||||
.. math:: \overline{y}(x) = \sum_{i=1}^{n} \sum_{j=1}^{r_i} \frac{c_{ij}}{(x - x_i)^j} + \sum_{i=0}^{N} d_i x^i
|
||||
|
||||
Step 10 : If `m` is a non-negative integer -
|
||||
|
||||
Step 11: Find a polynomial solution of degree `m` for the auxiliary equation.
|
||||
|
||||
There are 2 cases possible -
|
||||
|
||||
a. `m` is a non-negative integer: We can solve for the coefficients
|
||||
in `p(x)` using Undetermined Coefficients.
|
||||
|
||||
b. `m` is not a non-negative integer: In this case, we cannot find
|
||||
a polynomial solution to the auxiliary equation, and hence, we ignore
|
||||
this value of `m`.
|
||||
|
||||
Step 12 : For each `p(x)` that exists, append `ybar + \frac{p'(x)}{p(x)}`
|
||||
to ``sol``.
|
||||
|
||||
Step 13 : For each solution in ``sol``, apply an inverse transformation,
|
||||
so that the solutions of the original equation are found using the
|
||||
solutions of the equation in its normal form.
|
||||
"""
|
||||
|
||||
|
||||
from itertools import product
|
||||
from sympy.core import S
|
||||
from sympy.core.add import Add
|
||||
from sympy.core.numbers import oo, Float
|
||||
from sympy.core.function import count_ops
|
||||
from sympy.core.relational import Eq
|
||||
from sympy.core.symbol import symbols, Symbol, Dummy
|
||||
from sympy.functions import sqrt, exp
|
||||
from sympy.functions.elementary.complexes import sign
|
||||
from sympy.integrals.integrals import Integral
|
||||
from sympy.polys.domains import ZZ
|
||||
from sympy.polys.polytools import Poly
|
||||
from sympy.polys.polyroots import roots
|
||||
from sympy.solvers.solveset import linsolve
|
||||
|
||||
|
||||
def riccati_normal(w, x, b1, b2):
|
||||
"""
|
||||
Given a solution `w(x)` to the equation
|
||||
|
||||
.. math:: w'(x) = b_0(x) + b_1(x)*w(x) + b_2(x)*w(x)^2
|
||||
|
||||
and rational function coefficients `b_1(x)` and
|
||||
`b_2(x)`, this function transforms the solution to
|
||||
give a solution `y(x)` for its corresponding normal
|
||||
Riccati ODE
|
||||
|
||||
.. math:: y'(x) + y(x)^2 = a(x)
|
||||
|
||||
using the transformation
|
||||
|
||||
.. math:: y(x) = -b_2(x)*w(x) - b'_2(x)/(2*b_2(x)) - b_1(x)/2
|
||||
"""
|
||||
return -b2*w - b2.diff(x)/(2*b2) - b1/2
|
||||
|
||||
|
||||
def riccati_inverse_normal(y, x, b1, b2, bp=None):
|
||||
"""
|
||||
Inverse transforming the solution to the normal
|
||||
Riccati ODE to get the solution to the Riccati ODE.
|
||||
"""
|
||||
# bp is the expression which is independent of the solution
|
||||
# and hence, it need not be computed again
|
||||
if bp is None:
|
||||
bp = -b2.diff(x)/(2*b2**2) - b1/(2*b2)
|
||||
# w(x) = -y(x)/b2(x) - b2'(x)/(2*b2(x)^2) - b1(x)/(2*b2(x))
|
||||
return -y/b2 + bp
|
||||
|
||||
|
||||
def riccati_reduced(eq, f, x):
|
||||
"""
|
||||
Convert a Riccati ODE into its corresponding
|
||||
normal Riccati ODE.
|
||||
"""
|
||||
match, funcs = match_riccati(eq, f, x)
|
||||
# If equation is not a Riccati ODE, exit
|
||||
if not match:
|
||||
return False
|
||||
# Using the rational functions, find the expression for a(x)
|
||||
b0, b1, b2 = funcs
|
||||
a = -b0*b2 + b1**2/4 - b1.diff(x)/2 + 3*b2.diff(x)**2/(4*b2**2) + b1*b2.diff(x)/(2*b2) - \
|
||||
b2.diff(x, 2)/(2*b2)
|
||||
# Normal form of Riccati ODE is f'(x) + f(x)^2 = a(x)
|
||||
return f(x).diff(x) + f(x)**2 - a
|
||||
|
||||
def linsolve_dict(eq, syms):
|
||||
"""
|
||||
Get the output of linsolve as a dict
|
||||
"""
|
||||
# Convert tuple type return value of linsolve
|
||||
# to a dictionary for ease of use
|
||||
sol = linsolve(eq, syms)
|
||||
if not sol:
|
||||
return {}
|
||||
return dict(zip(syms, list(sol)[0]))
|
||||
|
||||
|
||||
def match_riccati(eq, f, x):
|
||||
"""
|
||||
A function that matches and returns the coefficients
|
||||
if an equation is a Riccati ODE
|
||||
|
||||
Parameters
|
||||
==========
|
||||
|
||||
eq: Equation to be matched
|
||||
f: Dependent variable
|
||||
x: Independent variable
|
||||
|
||||
Returns
|
||||
=======
|
||||
|
||||
match: True if equation is a Riccati ODE, False otherwise
|
||||
funcs: [b0, b1, b2] if match is True, [] otherwise. Here,
|
||||
b0, b1 and b2 are rational functions which match the equation.
|
||||
"""
|
||||
# Group terms based on f(x)
|
||||
if isinstance(eq, Eq):
|
||||
eq = eq.lhs - eq.rhs
|
||||
eq = eq.expand().collect(f(x))
|
||||
cf = eq.coeff(f(x).diff(x))
|
||||
|
||||
# There must be an f(x).diff(x) term.
|
||||
# eq must be an Add object since we are using the expanded
|
||||
# equation and it must have atleast 2 terms (b2 != 0)
|
||||
if cf != 0 and isinstance(eq, Add):
|
||||
|
||||
# Divide all coefficients by the coefficient of f(x).diff(x)
|
||||
# and add the terms again to get the same equation
|
||||
eq = Add(*((x/cf).cancel() for x in eq.args)).collect(f(x))
|
||||
|
||||
# Match the equation with the pattern
|
||||
b1 = -eq.coeff(f(x))
|
||||
b2 = -eq.coeff(f(x)**2)
|
||||
b0 = (f(x).diff(x) - b1*f(x) - b2*f(x)**2 - eq).expand()
|
||||
funcs = [b0, b1, b2]
|
||||
|
||||
# Check if coefficients are not symbols and floats
|
||||
if any(len(x.atoms(Symbol)) > 1 or len(x.atoms(Float)) for x in funcs):
|
||||
return False, []
|
||||
|
||||
# If b_0(x) contains f(x), it is not a Riccati ODE
|
||||
if len(b0.atoms(f)) or not all((b2 != 0, b0.is_rational_function(x),
|
||||
b1.is_rational_function(x), b2.is_rational_function(x))):
|
||||
return False, []
|
||||
return True, funcs
|
||||
return False, []
|
||||
|
||||
|
||||
def val_at_inf(num, den, x):
|
||||
# Valuation of a rational function at oo = deg(denom) - deg(numer)
|
||||
return den.degree(x) - num.degree(x)
|
||||
|
||||
|
||||
def check_necessary_conds(val_inf, muls):
|
||||
"""
|
||||
The necessary conditions for a rational solution
|
||||
to exist are as follows -
|
||||
|
||||
i) Every pole of a(x) must be either a simple pole
|
||||
or a multiple pole of even order.
|
||||
|
||||
ii) The valuation of a(x) at infinity must be even
|
||||
or be greater than or equal to 2.
|
||||
|
||||
Here, a simple pole is a pole with multiplicity 1
|
||||
and a multiple pole is a pole with multiplicity
|
||||
greater than 1.
|
||||
"""
|
||||
return (val_inf >= 2 or (val_inf <= 0 and val_inf%2 == 0)) and \
|
||||
all(mul == 1 or (mul%2 == 0 and mul >= 2) for mul in muls)
|
||||
|
||||
|
||||
def inverse_transform_poly(num, den, x):
|
||||
"""
|
||||
A function to make the substitution
|
||||
x -> 1/x in a rational function that
|
||||
is represented using Poly objects for
|
||||
numerator and denominator.
|
||||
"""
|
||||
# Declare for reuse
|
||||
one = Poly(1, x)
|
||||
xpoly = Poly(x, x)
|
||||
|
||||
# Check if degree of numerator is same as denominator
|
||||
pwr = val_at_inf(num, den, x)
|
||||
if pwr >= 0:
|
||||
# Denominator has greater degree. Substituting x with
|
||||
# 1/x would make the extra power go to the numerator
|
||||
if num.expr != 0:
|
||||
num = num.transform(one, xpoly) * x**pwr
|
||||
den = den.transform(one, xpoly)
|
||||
else:
|
||||
# Numerator has greater degree. Substituting x with
|
||||
# 1/x would make the extra power go to the denominator
|
||||
num = num.transform(one, xpoly)
|
||||
den = den.transform(one, xpoly) * x**(-pwr)
|
||||
return num.cancel(den, include=True)
|
||||
|
||||
|
||||
def limit_at_inf(num, den, x):
|
||||
"""
|
||||
Find the limit of a rational function
|
||||
at oo
|
||||
"""
|
||||
# pwr = degree(num) - degree(den)
|
||||
pwr = -val_at_inf(num, den, x)
|
||||
# Numerator has a greater degree than denominator
|
||||
# Limit at infinity would depend on the sign of the
|
||||
# leading coefficients of numerator and denominator
|
||||
if pwr > 0:
|
||||
return oo*sign(num.LC()/den.LC())
|
||||
# Degree of numerator is equal to that of denominator
|
||||
# Limit at infinity is just the ratio of leading coeffs
|
||||
elif pwr == 0:
|
||||
return num.LC()/den.LC()
|
||||
# Degree of numerator is less than that of denominator
|
||||
# Limit at infinity is just 0
|
||||
else:
|
||||
return 0
|
||||
|
||||
|
||||
def construct_c_case_1(num, den, x, pole):
|
||||
# Find the coefficient of 1/(x - pole)**2 in the
|
||||
# Laurent series expansion of a(x) about pole.
|
||||
num1, den1 = (num*Poly((x - pole)**2, x, extension=True)).cancel(den, include=True)
|
||||
r = (num1.subs(x, pole))/(den1.subs(x, pole))
|
||||
|
||||
# If multiplicity is 2, the coefficient to be added
|
||||
# in the c-vector is c = (1 +- sqrt(1 + 4*r))/2
|
||||
if r != -S(1)/4:
|
||||
return [[(1 + sqrt(1 + 4*r))/2], [(1 - sqrt(1 + 4*r))/2]]
|
||||
return [[S.Half]]
|
||||
|
||||
|
||||
def construct_c_case_2(num, den, x, pole, mul):
|
||||
# Generate the coefficients using the recurrence
|
||||
# relation mentioned in (5.14) in the thesis (Pg 80)
|
||||
|
||||
# r_i = mul/2
|
||||
ri = mul//2
|
||||
|
||||
# Find the Laurent series coefficients about the pole
|
||||
ser = rational_laurent_series(num, den, x, pole, mul, 6)
|
||||
|
||||
# Start with an empty memo to store the coefficients
|
||||
# This is for the plus case
|
||||
cplus = [0 for i in range(ri)]
|
||||
|
||||
# Base Case
|
||||
cplus[ri-1] = sqrt(ser[2*ri])
|
||||
|
||||
# Iterate backwards to find all coefficients
|
||||
s = ri - 1
|
||||
sm = 0
|
||||
for s in range(ri-1, 0, -1):
|
||||
sm = 0
|
||||
for j in range(s+1, ri):
|
||||
sm += cplus[j-1]*cplus[ri+s-j-1]
|
||||
if s!= 1:
|
||||
cplus[s-1] = (ser[ri+s] - sm)/(2*cplus[ri-1])
|
||||
|
||||
# Memo for the minus case
|
||||
cminus = [-x for x in cplus]
|
||||
|
||||
# Find the 0th coefficient in the recurrence
|
||||
cplus[0] = (ser[ri+s] - sm - ri*cplus[ri-1])/(2*cplus[ri-1])
|
||||
cminus[0] = (ser[ri+s] - sm - ri*cminus[ri-1])/(2*cminus[ri-1])
|
||||
|
||||
# Add both the plus and minus cases' coefficients
|
||||
if cplus != cminus:
|
||||
return [cplus, cminus]
|
||||
return cplus
|
||||
|
||||
|
||||
def construct_c_case_3():
|
||||
# If multiplicity is 1, the coefficient to be added
|
||||
# in the c-vector is 1 (no choice)
|
||||
return [[1]]
|
||||
|
||||
|
||||
def construct_c(num, den, x, poles, muls):
|
||||
"""
|
||||
Helper function to calculate the coefficients
|
||||
in the c-vector for each pole.
|
||||
"""
|
||||
c = []
|
||||
for pole, mul in zip(poles, muls):
|
||||
c.append([])
|
||||
|
||||
# Case 3
|
||||
if mul == 1:
|
||||
# Add the coefficients from Case 3
|
||||
c[-1].extend(construct_c_case_3())
|
||||
|
||||
# Case 1
|
||||
elif mul == 2:
|
||||
# Add the coefficients from Case 1
|
||||
c[-1].extend(construct_c_case_1(num, den, x, pole))
|
||||
|
||||
# Case 2
|
||||
else:
|
||||
# Add the coefficients from Case 2
|
||||
c[-1].extend(construct_c_case_2(num, den, x, pole, mul))
|
||||
|
||||
return c
|
||||
|
||||
|
||||
def construct_d_case_4(ser, N):
|
||||
# Initialize an empty vector
|
||||
dplus = [0 for i in range(N+2)]
|
||||
# d_N = sqrt(a_{2*N})
|
||||
dplus[N] = sqrt(ser[2*N])
|
||||
|
||||
# Use the recurrence relations to find
|
||||
# the value of d_s
|
||||
for s in range(N-1, -2, -1):
|
||||
sm = 0
|
||||
for j in range(s+1, N):
|
||||
sm += dplus[j]*dplus[N+s-j]
|
||||
if s != -1:
|
||||
dplus[s] = (ser[N+s] - sm)/(2*dplus[N])
|
||||
|
||||
# Coefficients for the case of d_N = -sqrt(a_{2*N})
|
||||
dminus = [-x for x in dplus]
|
||||
|
||||
# The third equation in Eq 5.15 of the thesis is WRONG!
|
||||
# d_N must be replaced with N*d_N in that equation.
|
||||
dplus[-1] = (ser[N+s] - N*dplus[N] - sm)/(2*dplus[N])
|
||||
dminus[-1] = (ser[N+s] - N*dminus[N] - sm)/(2*dminus[N])
|
||||
|
||||
if dplus != dminus:
|
||||
return [dplus, dminus]
|
||||
return dplus
|
||||
|
||||
|
||||
def construct_d_case_5(ser):
|
||||
# List to store coefficients for plus case
|
||||
dplus = [0, 0]
|
||||
|
||||
# d_0 = sqrt(a_0)
|
||||
dplus[0] = sqrt(ser[0])
|
||||
|
||||
# d_(-1) = a_(-1)/(2*d_0)
|
||||
dplus[-1] = ser[-1]/(2*dplus[0])
|
||||
|
||||
# Coefficients for the minus case are just the negative
|
||||
# of the coefficients for the positive case.
|
||||
dminus = [-x for x in dplus]
|
||||
|
||||
if dplus != dminus:
|
||||
return [dplus, dminus]
|
||||
return dplus
|
||||
|
||||
|
||||
def construct_d_case_6(num, den, x):
|
||||
# s_oo = lim x->0 1/x**2 * a(1/x) which is equivalent to
|
||||
# s_oo = lim x->oo x**2 * a(x)
|
||||
s_inf = limit_at_inf(Poly(x**2, x)*num, den, x)
|
||||
|
||||
# d_(-1) = (1 +- sqrt(1 + 4*s_oo))/2
|
||||
if s_inf != -S(1)/4:
|
||||
return [[(1 + sqrt(1 + 4*s_inf))/2], [(1 - sqrt(1 + 4*s_inf))/2]]
|
||||
return [[S.Half]]
|
||||
|
||||
|
||||
def construct_d(num, den, x, val_inf):
|
||||
"""
|
||||
Helper function to calculate the coefficients
|
||||
in the d-vector based on the valuation of the
|
||||
function at oo.
|
||||
"""
|
||||
N = -val_inf//2
|
||||
# Multiplicity of oo as a pole
|
||||
mul = -val_inf if val_inf < 0 else 0
|
||||
ser = rational_laurent_series(num, den, x, oo, mul, 1)
|
||||
|
||||
# Case 4
|
||||
if val_inf < 0:
|
||||
d = construct_d_case_4(ser, N)
|
||||
|
||||
# Case 5
|
||||
elif val_inf == 0:
|
||||
d = construct_d_case_5(ser)
|
||||
|
||||
# Case 6
|
||||
else:
|
||||
d = construct_d_case_6(num, den, x)
|
||||
|
||||
return d
|
||||
|
||||
|
||||
def rational_laurent_series(num, den, x, r, m, n):
|
||||
r"""
|
||||
The function computes the Laurent series coefficients
|
||||
of a rational function.
|
||||
|
||||
Parameters
|
||||
==========
|
||||
|
||||
num: A Poly object that is the numerator of `f(x)`.
|
||||
den: A Poly object that is the denominator of `f(x)`.
|
||||
x: The variable of expansion of the series.
|
||||
r: The point of expansion of the series.
|
||||
m: Multiplicity of r if r is a pole of `f(x)`. Should
|
||||
be zero otherwise.
|
||||
n: Order of the term upto which the series is expanded.
|
||||
|
||||
Returns
|
||||
=======
|
||||
|
||||
series: A dictionary that has power of the term as key
|
||||
and coefficient of that term as value.
|
||||
|
||||
Below is a basic outline of how the Laurent series of a
|
||||
rational function `f(x)` about `x_0` is being calculated -
|
||||
|
||||
1. Substitute `x + x_0` in place of `x`. If `x_0`
|
||||
is a pole of `f(x)`, multiply the expression by `x^m`
|
||||
where `m` is the multiplicity of `x_0`. Denote the
|
||||
the resulting expression as g(x). We do this substitution
|
||||
so that we can now find the Laurent series of g(x) about
|
||||
`x = 0`.
|
||||
|
||||
2. We can then assume that the Laurent series of `g(x)`
|
||||
takes the following form -
|
||||
|
||||
.. math:: g(x) = \frac{num(x)}{den(x)} = \sum_{m = 0}^{\infty} a_m x^m
|
||||
|
||||
where `a_m` denotes the Laurent series coefficients.
|
||||
|
||||
3. Multiply the denominator to the RHS of the equation
|
||||
and form a recurrence relation for the coefficients `a_m`.
|
||||
"""
|
||||
one = Poly(1, x, extension=True)
|
||||
|
||||
if r == oo:
|
||||
# Series at x = oo is equal to first transforming
|
||||
# the function from x -> 1/x and finding the
|
||||
# series at x = 0
|
||||
num, den = inverse_transform_poly(num, den, x)
|
||||
r = S(0)
|
||||
|
||||
if r:
|
||||
# For an expansion about a non-zero point, a
|
||||
# transformation from x -> x + r must be made
|
||||
num = num.transform(Poly(x + r, x, extension=True), one)
|
||||
den = den.transform(Poly(x + r, x, extension=True), one)
|
||||
|
||||
# Remove the pole from the denominator if the series
|
||||
# expansion is about one of the poles
|
||||
num, den = (num*x**m).cancel(den, include=True)
|
||||
|
||||
# Equate coefficients for the first terms (base case)
|
||||
maxdegree = 1 + max(num.degree(), den.degree())
|
||||
syms = symbols(f'a:{maxdegree}', cls=Dummy)
|
||||
diff = num - den * Poly(syms[::-1], x)
|
||||
coeff_diffs = diff.all_coeffs()[::-1][:maxdegree]
|
||||
(coeffs, ) = linsolve(coeff_diffs, syms)
|
||||
|
||||
# Use the recursion relation for the rest
|
||||
recursion = den.all_coeffs()[::-1]
|
||||
div, rec_rhs = recursion[0], recursion[1:]
|
||||
series = list(coeffs)
|
||||
while len(series) < n:
|
||||
next_coeff = Add(*(c*series[-1-n] for n, c in enumerate(rec_rhs))) / div
|
||||
series.append(-next_coeff)
|
||||
series = {m - i: val for i, val in enumerate(series)}
|
||||
return series
|
||||
|
||||
def compute_m_ybar(x, poles, choice, N):
|
||||
"""
|
||||
Helper function to calculate -
|
||||
|
||||
1. m - The degree bound for the polynomial
|
||||
solution that must be found for the auxiliary
|
||||
differential equation.
|
||||
|
||||
2. ybar - Part of the solution which can be
|
||||
computed using the poles, c and d vectors.
|
||||
"""
|
||||
ybar = 0
|
||||
m = Poly(choice[-1][-1], x, extension=True)
|
||||
|
||||
# Calculate the first (nested) summation for ybar
|
||||
# as given in Step 9 of the Thesis (Pg 82)
|
||||
dybar = []
|
||||
for i, polei in enumerate(poles):
|
||||
for j, cij in enumerate(choice[i]):
|
||||
dybar.append(cij/(x - polei)**(j + 1))
|
||||
m -=Poly(choice[i][0], x, extension=True) # can't accumulate Poly and use with Add
|
||||
ybar += Add(*dybar)
|
||||
|
||||
# Calculate the second summation for ybar
|
||||
for i in range(N+1):
|
||||
ybar += choice[-1][i]*x**i
|
||||
return (m.expr, ybar)
|
||||
|
||||
|
||||
def solve_aux_eq(numa, dena, numy, deny, x, m):
|
||||
"""
|
||||
Helper function to find a polynomial solution
|
||||
of degree m for the auxiliary differential
|
||||
equation.
|
||||
"""
|
||||
# Assume that the solution is of the type
|
||||
# p(x) = C_0 + C_1*x + ... + C_{m-1}*x**(m-1) + x**m
|
||||
psyms = symbols(f'C0:{m}', cls=Dummy)
|
||||
K = ZZ[psyms]
|
||||
psol = Poly(K.gens, x, domain=K) + Poly(x**m, x, domain=K)
|
||||
|
||||
# Eq (5.16) in Thesis - Pg 81
|
||||
auxeq = (dena*(numy.diff(x)*deny - numy*deny.diff(x) + numy**2) - numa*deny**2)*psol
|
||||
if m >= 1:
|
||||
px = psol.diff(x)
|
||||
auxeq += px*(2*numy*deny*dena)
|
||||
if m >= 2:
|
||||
auxeq += px.diff(x)*(deny**2*dena)
|
||||
if m != 0:
|
||||
# m is a non-zero integer. Find the constant terms using undetermined coefficients
|
||||
return psol, linsolve_dict(auxeq.all_coeffs(), psyms), True
|
||||
else:
|
||||
# m == 0 . Check if 1 (x**0) is a solution to the auxiliary equation
|
||||
return S.One, auxeq, auxeq == 0
|
||||
|
||||
|
||||
def remove_redundant_sols(sol1, sol2, x):
|
||||
"""
|
||||
Helper function to remove redundant
|
||||
solutions to the differential equation.
|
||||
"""
|
||||
# If y1 and y2 are redundant solutions, there is
|
||||
# some value of the arbitrary constant for which
|
||||
# they will be equal
|
||||
|
||||
syms1 = sol1.atoms(Symbol, Dummy)
|
||||
syms2 = sol2.atoms(Symbol, Dummy)
|
||||
num1, den1 = [Poly(e, x, extension=True) for e in sol1.together().as_numer_denom()]
|
||||
num2, den2 = [Poly(e, x, extension=True) for e in sol2.together().as_numer_denom()]
|
||||
# Cross multiply
|
||||
e = num1*den2 - den1*num2
|
||||
# Check if there are any constants
|
||||
syms = list(e.atoms(Symbol, Dummy))
|
||||
if len(syms):
|
||||
# Find values of constants for which solutions are equal
|
||||
redn = linsolve(e.all_coeffs(), syms)
|
||||
if len(redn):
|
||||
# Return the general solution over a particular solution
|
||||
if len(syms1) > len(syms2):
|
||||
return sol2
|
||||
# If both have constants, return the lesser complex solution
|
||||
elif len(syms1) == len(syms2):
|
||||
return sol1 if count_ops(syms1) >= count_ops(syms2) else sol2
|
||||
else:
|
||||
return sol1
|
||||
|
||||
|
||||
def get_gen_sol_from_part_sol(part_sols, a, x):
|
||||
""""
|
||||
Helper function which computes the general
|
||||
solution for a Riccati ODE from its particular
|
||||
solutions.
|
||||
|
||||
There are 3 cases to find the general solution
|
||||
from the particular solutions for a Riccati ODE
|
||||
depending on the number of particular solution(s)
|
||||
we have - 1, 2 or 3.
|
||||
|
||||
For more information, see Section 6 of
|
||||
"Methods of Solution of the Riccati Differential Equation"
|
||||
by D. R. Haaheim and F. M. Stein
|
||||
"""
|
||||
|
||||
# If no particular solutions are found, a general
|
||||
# solution cannot be found
|
||||
if len(part_sols) == 0:
|
||||
return []
|
||||
|
||||
# In case of a single particular solution, the general
|
||||
# solution can be found by using the substitution
|
||||
# y = y1 + 1/z and solving a Bernoulli ODE to find z.
|
||||
elif len(part_sols) == 1:
|
||||
y1 = part_sols[0]
|
||||
i = exp(Integral(2*y1, x))
|
||||
z = i * Integral(a/i, x)
|
||||
z = z.doit()
|
||||
if a == 0 or z == 0:
|
||||
return y1
|
||||
return y1 + 1/z
|
||||
|
||||
# In case of 2 particular solutions, the general solution
|
||||
# can be found by solving a separable equation. This is
|
||||
# the most common case, i.e. most Riccati ODEs have 2
|
||||
# rational particular solutions.
|
||||
elif len(part_sols) == 2:
|
||||
y1, y2 = part_sols
|
||||
# One of them already has a constant
|
||||
if len(y1.atoms(Dummy)) + len(y2.atoms(Dummy)) > 0:
|
||||
u = exp(Integral(y2 - y1, x)).doit()
|
||||
# Introduce a constant
|
||||
else:
|
||||
C1 = Dummy('C1')
|
||||
u = C1*exp(Integral(y2 - y1, x)).doit()
|
||||
if u == 1:
|
||||
return y2
|
||||
return (y2*u - y1)/(u - 1)
|
||||
|
||||
# In case of 3 particular solutions, a closed form
|
||||
# of the general solution can be obtained directly
|
||||
else:
|
||||
y1, y2, y3 = part_sols[:3]
|
||||
C1 = Dummy('C1')
|
||||
return (C1 + 1)*y2*(y1 - y3)/(C1*y1 + y2 - (C1 + 1)*y3)
|
||||
|
||||
|
||||
def solve_riccati(fx, x, b0, b1, b2, gensol=False):
|
||||
"""
|
||||
The main function that gives particular/general
|
||||
solutions to Riccati ODEs that have atleast 1
|
||||
rational particular solution.
|
||||
"""
|
||||
# Step 1 : Convert to Normal Form
|
||||
a = -b0*b2 + b1**2/4 - b1.diff(x)/2 + 3*b2.diff(x)**2/(4*b2**2) + b1*b2.diff(x)/(2*b2) - \
|
||||
b2.diff(x, 2)/(2*b2)
|
||||
a_t = a.together()
|
||||
num, den = [Poly(e, x, extension=True) for e in a_t.as_numer_denom()]
|
||||
num, den = num.cancel(den, include=True)
|
||||
|
||||
# Step 2
|
||||
presol = []
|
||||
|
||||
# Step 3 : a(x) is 0
|
||||
if num == 0:
|
||||
presol.append(1/(x + Dummy('C1')))
|
||||
|
||||
# Step 4 : a(x) is a non-zero constant
|
||||
elif x not in num.free_symbols.union(den.free_symbols):
|
||||
presol.extend([sqrt(a), -sqrt(a)])
|
||||
|
||||
# Step 5 : Find poles and valuation at infinity
|
||||
poles = roots(den, x)
|
||||
poles, muls = list(poles.keys()), list(poles.values())
|
||||
val_inf = val_at_inf(num, den, x)
|
||||
|
||||
if len(poles):
|
||||
# Check necessary conditions (outlined in the module docstring)
|
||||
if not check_necessary_conds(val_inf, muls):
|
||||
raise ValueError("Rational Solution doesn't exist")
|
||||
|
||||
# Step 6
|
||||
# Construct c-vectors for each singular point
|
||||
c = construct_c(num, den, x, poles, muls)
|
||||
|
||||
# Construct d vectors for each singular point
|
||||
d = construct_d(num, den, x, val_inf)
|
||||
|
||||
# Step 7 : Iterate over all possible combinations and return solutions
|
||||
# For each possible combination, generate an array of 0's and 1's
|
||||
# where 0 means pick 1st choice and 1 means pick the second choice.
|
||||
|
||||
# NOTE: We could exit from the loop if we find 3 particular solutions,
|
||||
# but it is not implemented here as -
|
||||
# a. Finding 3 particular solutions is very rare. Most of the time,
|
||||
# only 2 particular solutions are found.
|
||||
# b. In case we exit after finding 3 particular solutions, it might
|
||||
# happen that 1 or 2 of them are redundant solutions. So, instead of
|
||||
# spending some more time in computing the particular solutions,
|
||||
# we will end up computing the general solution from a single
|
||||
# particular solution which is usually slower than computing the
|
||||
# general solution from 2 or 3 particular solutions.
|
||||
c.append(d)
|
||||
choices = product(*c)
|
||||
for choice in choices:
|
||||
m, ybar = compute_m_ybar(x, poles, choice, -val_inf//2)
|
||||
numy, deny = [Poly(e, x, extension=True) for e in ybar.together().as_numer_denom()]
|
||||
# Step 10 : Check if a valid solution exists. If yes, also check
|
||||
# if m is a non-negative integer
|
||||
if m.is_nonnegative == True and m.is_integer == True:
|
||||
|
||||
# Step 11 : Find polynomial solutions of degree m for the auxiliary equation
|
||||
psol, coeffs, exists = solve_aux_eq(num, den, numy, deny, x, m)
|
||||
|
||||
# Step 12 : If valid polynomial solution exists, append solution.
|
||||
if exists:
|
||||
# m == 0 case
|
||||
if psol == 1 and coeffs == 0:
|
||||
# p(x) = 1, so p'(x)/p(x) term need not be added
|
||||
presol.append(ybar)
|
||||
# m is a positive integer and there are valid coefficients
|
||||
elif len(coeffs):
|
||||
# Substitute the valid coefficients to get p(x)
|
||||
psol = psol.xreplace(coeffs)
|
||||
# y(x) = ybar(x) + p'(x)/p(x)
|
||||
presol.append(ybar + psol.diff(x)/psol)
|
||||
|
||||
# Remove redundant solutions from the list of existing solutions
|
||||
remove = set()
|
||||
for i in range(len(presol)):
|
||||
for j in range(i+1, len(presol)):
|
||||
rem = remove_redundant_sols(presol[i], presol[j], x)
|
||||
if rem is not None:
|
||||
remove.add(rem)
|
||||
sols = [x for x in presol if x not in remove]
|
||||
|
||||
# Step 15 : Inverse transform the solutions of the equation in normal form
|
||||
bp = -b2.diff(x)/(2*b2**2) - b1/(2*b2)
|
||||
|
||||
# If general solution is required, compute it from the particular solutions
|
||||
if gensol:
|
||||
sols = [get_gen_sol_from_part_sol(sols, a, x)]
|
||||
|
||||
# Inverse transform the particular solutions
|
||||
presol = [Eq(fx, riccati_inverse_normal(y, x, b1, b2, bp).cancel(extension=True)) for y in sols]
|
||||
return presol
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,392 @@
|
||||
from sympy.core import S, Pow
|
||||
from sympy.core.function import (Derivative, AppliedUndef, diff)
|
||||
from sympy.core.relational import Equality, Eq
|
||||
from sympy.core.symbol import Dummy
|
||||
from sympy.core.sympify import sympify
|
||||
|
||||
from sympy.logic.boolalg import BooleanAtom
|
||||
from sympy.functions import exp
|
||||
from sympy.series import Order
|
||||
from sympy.simplify.simplify import simplify, posify, besselsimp
|
||||
from sympy.simplify.trigsimp import trigsimp
|
||||
from sympy.simplify.sqrtdenest import sqrtdenest
|
||||
from sympy.solvers import solve
|
||||
from sympy.solvers.deutils import _preprocess, ode_order
|
||||
from sympy.utilities.iterables import iterable, is_sequence
|
||||
|
||||
|
||||
def sub_func_doit(eq, func, new):
|
||||
r"""
|
||||
When replacing the func with something else, we usually want the
|
||||
derivative evaluated, so this function helps in making that happen.
|
||||
|
||||
Examples
|
||||
========
|
||||
|
||||
>>> from sympy import Derivative, symbols, Function
|
||||
>>> from sympy.solvers.ode.subscheck import sub_func_doit
|
||||
>>> x, z = symbols('x, z')
|
||||
>>> y = Function('y')
|
||||
|
||||
>>> sub_func_doit(3*Derivative(y(x), x) - 1, y(x), x)
|
||||
2
|
||||
|
||||
>>> sub_func_doit(x*Derivative(y(x), x) - y(x)**2 + y(x), y(x),
|
||||
... 1/(x*(z + 1/x)))
|
||||
x*(-1/(x**2*(z + 1/x)) + 1/(x**3*(z + 1/x)**2)) + 1/(x*(z + 1/x))
|
||||
...- 1/(x**2*(z + 1/x)**2)
|
||||
"""
|
||||
reps= {func: new}
|
||||
for d in eq.atoms(Derivative):
|
||||
if d.expr == func:
|
||||
reps[d] = new.diff(*d.variable_count)
|
||||
else:
|
||||
reps[d] = d.xreplace({func: new}).doit(deep=False)
|
||||
return eq.xreplace(reps)
|
||||
|
||||
|
||||
def checkodesol(ode, sol, func=None, order='auto', solve_for_func=True):
|
||||
r"""
|
||||
Substitutes ``sol`` into ``ode`` and checks that the result is ``0``.
|
||||
|
||||
This works when ``func`` is one function, like `f(x)` or a list of
|
||||
functions like `[f(x), g(x)]` when `ode` is a system of ODEs. ``sol`` can
|
||||
be a single solution or a list of solutions. Each solution may be an
|
||||
:py:class:`~sympy.core.relational.Equality` that the solution satisfies,
|
||||
e.g. ``Eq(f(x), C1), Eq(f(x) + C1, 0)``; or simply an
|
||||
:py:class:`~sympy.core.expr.Expr`, e.g. ``f(x) - C1``. In most cases it
|
||||
will not be necessary to explicitly identify the function, but if the
|
||||
function cannot be inferred from the original equation it can be supplied
|
||||
through the ``func`` argument.
|
||||
|
||||
If a sequence of solutions is passed, the same sort of container will be
|
||||
used to return the result for each solution.
|
||||
|
||||
It tries the following methods, in order, until it finds zero equivalence:
|
||||
|
||||
1. Substitute the solution for `f` in the original equation. This only
|
||||
works if ``ode`` is solved for `f`. It will attempt to solve it first
|
||||
unless ``solve_for_func == False``.
|
||||
2. Take `n` derivatives of the solution, where `n` is the order of
|
||||
``ode``, and check to see if that is equal to the solution. This only
|
||||
works on exact ODEs.
|
||||
3. Take the 1st, 2nd, ..., `n`\th derivatives of the solution, each time
|
||||
solving for the derivative of `f` of that order (this will always be
|
||||
possible because `f` is a linear operator). Then back substitute each
|
||||
derivative into ``ode`` in reverse order.
|
||||
|
||||
This function returns a tuple. The first item in the tuple is ``True`` if
|
||||
the substitution results in ``0``, and ``False`` otherwise. The second
|
||||
item in the tuple is what the substitution results in. It should always
|
||||
be ``0`` if the first item is ``True``. Sometimes this function will
|
||||
return ``False`` even when an expression is identically equal to ``0``.
|
||||
This happens when :py:meth:`~sympy.simplify.simplify.simplify` does not
|
||||
reduce the expression to ``0``. If an expression returned by this
|
||||
function vanishes identically, then ``sol`` really is a solution to
|
||||
the ``ode``.
|
||||
|
||||
If this function seems to hang, it is probably because of a hard
|
||||
simplification.
|
||||
|
||||
To use this function to test, test the first item of the tuple.
|
||||
|
||||
Examples
|
||||
========
|
||||
|
||||
>>> from sympy import (Eq, Function, checkodesol, symbols,
|
||||
... Derivative, exp)
|
||||
>>> x, C1, C2 = symbols('x,C1,C2')
|
||||
>>> f, g = symbols('f g', cls=Function)
|
||||
>>> checkodesol(f(x).diff(x), Eq(f(x), C1))
|
||||
(True, 0)
|
||||
>>> assert checkodesol(f(x).diff(x), C1)[0]
|
||||
>>> assert not checkodesol(f(x).diff(x), x)[0]
|
||||
>>> checkodesol(f(x).diff(x, 2), x**2)
|
||||
(False, 2)
|
||||
|
||||
>>> eqs = [Eq(Derivative(f(x), x), f(x)), Eq(Derivative(g(x), x), g(x))]
|
||||
>>> sol = [Eq(f(x), C1*exp(x)), Eq(g(x), C2*exp(x))]
|
||||
>>> checkodesol(eqs, sol)
|
||||
(True, [0, 0])
|
||||
|
||||
"""
|
||||
if iterable(ode):
|
||||
return checksysodesol(ode, sol, func=func)
|
||||
|
||||
if not isinstance(ode, Equality):
|
||||
ode = Eq(ode, 0)
|
||||
if func is None:
|
||||
try:
|
||||
_, func = _preprocess(ode.lhs)
|
||||
except ValueError:
|
||||
funcs = [s.atoms(AppliedUndef) for s in (
|
||||
sol if is_sequence(sol, set) else [sol])]
|
||||
funcs = set().union(*funcs)
|
||||
if len(funcs) != 1:
|
||||
raise ValueError(
|
||||
'must pass func arg to checkodesol for this case.')
|
||||
func = funcs.pop()
|
||||
if not isinstance(func, AppliedUndef) or len(func.args) != 1:
|
||||
raise ValueError(
|
||||
"func must be a function of one variable, not %s" % func)
|
||||
if is_sequence(sol, set):
|
||||
return type(sol)([checkodesol(ode, i, order=order, solve_for_func=solve_for_func) for i in sol])
|
||||
|
||||
if not isinstance(sol, Equality):
|
||||
sol = Eq(func, sol)
|
||||
elif sol.rhs == func:
|
||||
sol = sol.reversed
|
||||
|
||||
if order == 'auto':
|
||||
order = ode_order(ode, func)
|
||||
solved = sol.lhs == func and not sol.rhs.has(func)
|
||||
if solve_for_func and not solved:
|
||||
rhs = solve(sol, func)
|
||||
if rhs:
|
||||
eqs = [Eq(func, t) for t in rhs]
|
||||
if len(rhs) == 1:
|
||||
eqs = eqs[0]
|
||||
return checkodesol(ode, eqs, order=order,
|
||||
solve_for_func=False)
|
||||
|
||||
x = func.args[0]
|
||||
|
||||
# Handle series solutions here
|
||||
if sol.has(Order):
|
||||
assert sol.lhs == func
|
||||
Oterm = sol.rhs.getO()
|
||||
solrhs = sol.rhs.removeO()
|
||||
|
||||
Oexpr = Oterm.expr
|
||||
assert isinstance(Oexpr, Pow)
|
||||
sorder = Oexpr.exp
|
||||
assert Oterm == Order(x**sorder)
|
||||
|
||||
odesubs = (ode.lhs-ode.rhs).subs(func, solrhs).doit().expand()
|
||||
|
||||
neworder = Order(x**(sorder - order))
|
||||
odesubs = odesubs + neworder
|
||||
assert odesubs.getO() == neworder
|
||||
residual = odesubs.removeO()
|
||||
|
||||
return (residual == 0, residual)
|
||||
|
||||
s = True
|
||||
testnum = 0
|
||||
while s:
|
||||
if testnum == 0:
|
||||
# First pass, try substituting a solved solution directly into the
|
||||
# ODE. This has the highest chance of succeeding.
|
||||
ode_diff = ode.lhs - ode.rhs
|
||||
|
||||
if sol.lhs == func:
|
||||
s = sub_func_doit(ode_diff, func, sol.rhs)
|
||||
s = besselsimp(s)
|
||||
else:
|
||||
testnum += 1
|
||||
continue
|
||||
ss = simplify(s.rewrite(exp))
|
||||
if ss:
|
||||
# with the new numer_denom in power.py, if we do a simple
|
||||
# expansion then testnum == 0 verifies all solutions.
|
||||
s = ss.expand(force=True)
|
||||
else:
|
||||
s = 0
|
||||
testnum += 1
|
||||
elif testnum == 1:
|
||||
# Second pass. If we cannot substitute f, try seeing if the nth
|
||||
# derivative is equal, this will only work for odes that are exact,
|
||||
# by definition.
|
||||
s = simplify(
|
||||
trigsimp(diff(sol.lhs, x, order) - diff(sol.rhs, x, order)) -
|
||||
trigsimp(ode.lhs) + trigsimp(ode.rhs))
|
||||
# s2 = simplify(
|
||||
# diff(sol.lhs, x, order) - diff(sol.rhs, x, order) - \
|
||||
# ode.lhs + ode.rhs)
|
||||
testnum += 1
|
||||
elif testnum == 2:
|
||||
# Third pass. Try solving for df/dx and substituting that into the
|
||||
# ODE. Thanks to Chris Smith for suggesting this method. Many of
|
||||
# the comments below are his, too.
|
||||
# The method:
|
||||
# - Take each of 1..n derivatives of the solution.
|
||||
# - Solve each nth derivative for d^(n)f/dx^(n)
|
||||
# (the differential of that order)
|
||||
# - Back substitute into the ODE in decreasing order
|
||||
# (i.e., n, n-1, ...)
|
||||
# - Check the result for zero equivalence
|
||||
if sol.lhs == func and not sol.rhs.has(func):
|
||||
diffsols = {0: sol.rhs}
|
||||
elif sol.rhs == func and not sol.lhs.has(func):
|
||||
diffsols = {0: sol.lhs}
|
||||
else:
|
||||
diffsols = {}
|
||||
sol = sol.lhs - sol.rhs
|
||||
for i in range(1, order + 1):
|
||||
# Differentiation is a linear operator, so there should always
|
||||
# be 1 solution. Nonetheless, we test just to make sure.
|
||||
# We only need to solve once. After that, we automatically
|
||||
# have the solution to the differential in the order we want.
|
||||
if i == 1:
|
||||
ds = sol.diff(x)
|
||||
try:
|
||||
sdf = solve(ds, func.diff(x, i))
|
||||
if not sdf:
|
||||
raise NotImplementedError
|
||||
except NotImplementedError:
|
||||
testnum += 1
|
||||
break
|
||||
else:
|
||||
diffsols[i] = sdf[0]
|
||||
else:
|
||||
# This is what the solution says df/dx should be.
|
||||
diffsols[i] = diffsols[i - 1].diff(x)
|
||||
|
||||
# Make sure the above didn't fail.
|
||||
if testnum > 2:
|
||||
continue
|
||||
else:
|
||||
# Substitute it into ODE to check for self consistency.
|
||||
lhs, rhs = ode.lhs, ode.rhs
|
||||
for i in range(order, -1, -1):
|
||||
if i == 0 and 0 not in diffsols:
|
||||
# We can only substitute f(x) if the solution was
|
||||
# solved for f(x).
|
||||
break
|
||||
lhs = sub_func_doit(lhs, func.diff(x, i), diffsols[i])
|
||||
rhs = sub_func_doit(rhs, func.diff(x, i), diffsols[i])
|
||||
ode_or_bool = Eq(lhs, rhs)
|
||||
ode_or_bool = simplify(ode_or_bool)
|
||||
|
||||
if isinstance(ode_or_bool, (bool, BooleanAtom)):
|
||||
if ode_or_bool:
|
||||
lhs = rhs = S.Zero
|
||||
else:
|
||||
lhs = ode_or_bool.lhs
|
||||
rhs = ode_or_bool.rhs
|
||||
# No sense in overworking simplify -- just prove that the
|
||||
# numerator goes to zero
|
||||
num = trigsimp((lhs - rhs).as_numer_denom()[0])
|
||||
# since solutions are obtained using force=True we test
|
||||
# using the same level of assumptions
|
||||
## replace function with dummy so assumptions will work
|
||||
_func = Dummy('func')
|
||||
num = num.subs(func, _func)
|
||||
## posify the expression
|
||||
num, reps = posify(num)
|
||||
s = simplify(num).xreplace(reps).xreplace({_func: func})
|
||||
testnum += 1
|
||||
else:
|
||||
break
|
||||
|
||||
if not s:
|
||||
return (True, s)
|
||||
elif s is True: # The code above never was able to change s
|
||||
raise NotImplementedError("Unable to test if " + str(sol) +
|
||||
" is a solution to " + str(ode) + ".")
|
||||
else:
|
||||
return (False, s)
|
||||
|
||||
|
||||
def checksysodesol(eqs, sols, func=None):
|
||||
r"""
|
||||
Substitutes corresponding ``sols`` for each functions into each ``eqs`` and
|
||||
checks that the result of substitutions for each equation is ``0``. The
|
||||
equations and solutions passed can be any iterable.
|
||||
|
||||
This only works when each ``sols`` have one function only, like `x(t)` or `y(t)`.
|
||||
For each function, ``sols`` can have a single solution or a list of solutions.
|
||||
In most cases it will not be necessary to explicitly identify the function,
|
||||
but if the function cannot be inferred from the original equation it
|
||||
can be supplied through the ``func`` argument.
|
||||
|
||||
When a sequence of equations is passed, the same sequence is used to return
|
||||
the result for each equation with each function substituted with corresponding
|
||||
solutions.
|
||||
|
||||
It tries the following method to find zero equivalence for each equation:
|
||||
|
||||
Substitute the solutions for functions, like `x(t)` and `y(t)` into the
|
||||
original equations containing those functions.
|
||||
This function returns a tuple. The first item in the tuple is ``True`` if
|
||||
the substitution results for each equation is ``0``, and ``False`` otherwise.
|
||||
The second item in the tuple is what the substitution results in. Each element
|
||||
of the ``list`` should always be ``0`` corresponding to each equation if the
|
||||
first item is ``True``. Note that sometimes this function may return ``False``,
|
||||
but with an expression that is identically equal to ``0``, instead of returning
|
||||
``True``. This is because :py:meth:`~sympy.simplify.simplify.simplify` cannot
|
||||
reduce the expression to ``0``. If an expression returned by each function
|
||||
vanishes identically, then ``sols`` really is a solution to ``eqs``.
|
||||
|
||||
If this function seems to hang, it is probably because of a difficult simplification.
|
||||
|
||||
Examples
|
||||
========
|
||||
|
||||
>>> from sympy import Eq, diff, symbols, sin, cos, exp, sqrt, S, Function
|
||||
>>> from sympy.solvers.ode.subscheck import checksysodesol
|
||||
>>> C1, C2 = symbols('C1:3')
|
||||
>>> t = symbols('t')
|
||||
>>> x, y = symbols('x, y', cls=Function)
|
||||
>>> eq = (Eq(diff(x(t),t), x(t) + y(t) + 17), Eq(diff(y(t),t), -2*x(t) + y(t) + 12))
|
||||
>>> sol = [Eq(x(t), (C1*sin(sqrt(2)*t) + C2*cos(sqrt(2)*t))*exp(t) - S(5)/3),
|
||||
... Eq(y(t), (sqrt(2)*C1*cos(sqrt(2)*t) - sqrt(2)*C2*sin(sqrt(2)*t))*exp(t) - S(46)/3)]
|
||||
>>> checksysodesol(eq, sol)
|
||||
(True, [0, 0])
|
||||
>>> eq = (Eq(diff(x(t),t),x(t)*y(t)**4), Eq(diff(y(t),t),y(t)**3))
|
||||
>>> sol = [Eq(x(t), C1*exp(-1/(4*(C2 + t)))), Eq(y(t), -sqrt(2)*sqrt(-1/(C2 + t))/2),
|
||||
... Eq(x(t), C1*exp(-1/(4*(C2 + t)))), Eq(y(t), sqrt(2)*sqrt(-1/(C2 + t))/2)]
|
||||
>>> checksysodesol(eq, sol)
|
||||
(True, [0, 0])
|
||||
|
||||
"""
|
||||
def _sympify(eq):
|
||||
return list(map(sympify, eq if iterable(eq) else [eq]))
|
||||
eqs = _sympify(eqs)
|
||||
for i in range(len(eqs)):
|
||||
if isinstance(eqs[i], Equality):
|
||||
eqs[i] = eqs[i].lhs - eqs[i].rhs
|
||||
if func is None:
|
||||
funcs = []
|
||||
for eq in eqs:
|
||||
derivs = eq.atoms(Derivative)
|
||||
func = set().union(*[d.atoms(AppliedUndef) for d in derivs])
|
||||
funcs.extend(func)
|
||||
funcs = list(set(funcs))
|
||||
if not all(isinstance(func, AppliedUndef) and len(func.args) == 1 for func in funcs)\
|
||||
and len({func.args for func in funcs})!=1:
|
||||
raise ValueError("func must be a function of one variable, not %s" % func)
|
||||
for sol in sols:
|
||||
if len(sol.atoms(AppliedUndef)) != 1:
|
||||
raise ValueError("solutions should have one function only")
|
||||
if len(funcs) != len({sol.lhs for sol in sols}):
|
||||
raise ValueError("number of solutions provided does not match the number of equations")
|
||||
dictsol = {}
|
||||
for sol in sols:
|
||||
func = list(sol.atoms(AppliedUndef))[0]
|
||||
if sol.rhs == func:
|
||||
sol = sol.reversed
|
||||
solved = sol.lhs == func and not sol.rhs.has(func)
|
||||
if not solved:
|
||||
rhs = solve(sol, func)
|
||||
if not rhs:
|
||||
raise NotImplementedError
|
||||
else:
|
||||
rhs = sol.rhs
|
||||
dictsol[func] = rhs
|
||||
checkeq = []
|
||||
for eq in eqs:
|
||||
for func in funcs:
|
||||
eq = sub_func_doit(eq, func, dictsol[func])
|
||||
ss = simplify(eq)
|
||||
if ss != 0:
|
||||
eq = ss.expand(force=True)
|
||||
if eq != 0:
|
||||
eq = sqrtdenest(eq).simplify()
|
||||
else:
|
||||
eq = 0
|
||||
checkeq.append(eq)
|
||||
if len(set(checkeq)) == 1 and list(set(checkeq))[0] == 0:
|
||||
return (True, checkeq)
|
||||
else:
|
||||
return (False, checkeq)
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,152 @@
|
||||
from sympy.core.function import Function
|
||||
from sympy.core.numbers import Rational
|
||||
from sympy.core.relational import Eq
|
||||
from sympy.core.symbol import (Symbol, symbols)
|
||||
from sympy.functions.elementary.exponential import (exp, log)
|
||||
from sympy.functions.elementary.miscellaneous import sqrt
|
||||
from sympy.functions.elementary.trigonometric import (atan, sin, tan)
|
||||
|
||||
from sympy.solvers.ode import (classify_ode, checkinfsol, dsolve, infinitesimals)
|
||||
|
||||
from sympy.solvers.ode.subscheck import checkodesol
|
||||
|
||||
from sympy.testing.pytest import XFAIL
|
||||
|
||||
|
||||
C1 = Symbol('C1')
|
||||
x, y = symbols("x y")
|
||||
f = Function('f')
|
||||
xi = Function('xi')
|
||||
eta = Function('eta')
|
||||
|
||||
|
||||
def test_heuristic1():
|
||||
a, b, c, a4, a3, a2, a1, a0 = symbols("a b c a4 a3 a2 a1 a0")
|
||||
df = f(x).diff(x)
|
||||
eq = Eq(df, x**2*f(x))
|
||||
eq1 = f(x).diff(x) + a*f(x) - c*exp(b*x)
|
||||
eq2 = f(x).diff(x) + 2*x*f(x) - x*exp(-x**2)
|
||||
eq3 = (1 + 2*x)*df + 2 - 4*exp(-f(x))
|
||||
eq4 = f(x).diff(x) - (a4*x**4 + a3*x**3 + a2*x**2 + a1*x + a0)**Rational(-1, 2)
|
||||
eq5 = x**2*df - f(x) + x**2*exp(x - (1/x))
|
||||
eqlist = [eq, eq1, eq2, eq3, eq4, eq5]
|
||||
|
||||
i = infinitesimals(eq, hint='abaco1_simple')
|
||||
assert i == [{eta(x, f(x)): exp(x**3/3), xi(x, f(x)): 0},
|
||||
{eta(x, f(x)): f(x), xi(x, f(x)): 0},
|
||||
{eta(x, f(x)): 0, xi(x, f(x)): x**(-2)}]
|
||||
i1 = infinitesimals(eq1, hint='abaco1_simple')
|
||||
assert i1 == [{eta(x, f(x)): exp(-a*x), xi(x, f(x)): 0}]
|
||||
i2 = infinitesimals(eq2, hint='abaco1_simple')
|
||||
assert i2 == [{eta(x, f(x)): exp(-x**2), xi(x, f(x)): 0}]
|
||||
i3 = infinitesimals(eq3, hint='abaco1_simple')
|
||||
assert i3 == [{eta(x, f(x)): 0, xi(x, f(x)): 2*x + 1},
|
||||
{eta(x, f(x)): 0, xi(x, f(x)): 1/(exp(f(x)) - 2)}]
|
||||
i4 = infinitesimals(eq4, hint='abaco1_simple')
|
||||
assert i4 == [{eta(x, f(x)): 1, xi(x, f(x)): 0},
|
||||
{eta(x, f(x)): 0,
|
||||
xi(x, f(x)): sqrt(a0 + a1*x + a2*x**2 + a3*x**3 + a4*x**4)}]
|
||||
i5 = infinitesimals(eq5, hint='abaco1_simple')
|
||||
assert i5 == [{xi(x, f(x)): 0, eta(x, f(x)): exp(-1/x)}]
|
||||
|
||||
ilist = [i, i1, i2, i3, i4, i5]
|
||||
for eq, i in (zip(eqlist, ilist)):
|
||||
check = checkinfsol(eq, i)
|
||||
assert check[0]
|
||||
|
||||
# This ODE can be solved by the Lie Group method, when there are
|
||||
# better assumptions
|
||||
eq6 = df - (f(x)/x)*(x*log(x**2/f(x)) + 2)
|
||||
i = infinitesimals(eq6, hint='abaco1_product')
|
||||
assert i == [{eta(x, f(x)): f(x)*exp(-x), xi(x, f(x)): 0}]
|
||||
assert checkinfsol(eq6, i)[0]
|
||||
|
||||
eq7 = x*(f(x).diff(x)) + 1 - f(x)**2
|
||||
i = infinitesimals(eq7, hint='chi')
|
||||
assert checkinfsol(eq7, i)[0]
|
||||
|
||||
|
||||
def test_heuristic3():
|
||||
a, b = symbols("a b")
|
||||
df = f(x).diff(x)
|
||||
|
||||
eq = x**2*df + x*f(x) + f(x)**2 + x**2
|
||||
i = infinitesimals(eq, hint='bivariate')
|
||||
assert i == [{eta(x, f(x)): f(x), xi(x, f(x)): x}]
|
||||
assert checkinfsol(eq, i)[0]
|
||||
|
||||
eq = x**2*(-f(x)**2 + df)- a*x**2*f(x) + 2 - a*x
|
||||
i = infinitesimals(eq, hint='bivariate')
|
||||
assert checkinfsol(eq, i)[0]
|
||||
|
||||
|
||||
def test_heuristic_function_sum():
|
||||
eq = f(x).diff(x) - (3*(1 + x**2/f(x)**2)*atan(f(x)/x) + (1 - 2*f(x))/x +
|
||||
(1 - 3*f(x))*(x/f(x)**2))
|
||||
i = infinitesimals(eq, hint='function_sum')
|
||||
assert i == [{eta(x, f(x)): f(x)**(-2) + x**(-2), xi(x, f(x)): 0}]
|
||||
assert checkinfsol(eq, i)[0]
|
||||
|
||||
|
||||
def test_heuristic_abaco2_similar():
|
||||
a, b = symbols("a b")
|
||||
F = Function('F')
|
||||
eq = f(x).diff(x) - F(a*x + b*f(x))
|
||||
i = infinitesimals(eq, hint='abaco2_similar')
|
||||
assert i == [{eta(x, f(x)): -a/b, xi(x, f(x)): 1}]
|
||||
assert checkinfsol(eq, i)[0]
|
||||
|
||||
eq = f(x).diff(x) - (f(x)**2 / (sin(f(x) - x) - x**2 + 2*x*f(x)))
|
||||
i = infinitesimals(eq, hint='abaco2_similar')
|
||||
assert i == [{eta(x, f(x)): f(x)**2, xi(x, f(x)): f(x)**2}]
|
||||
assert checkinfsol(eq, i)[0]
|
||||
|
||||
|
||||
def test_heuristic_abaco2_unique_unknown():
|
||||
|
||||
a, b = symbols("a b")
|
||||
F = Function('F')
|
||||
eq = f(x).diff(x) - x**(a - 1)*(f(x)**(1 - b))*F(x**a/a + f(x)**b/b)
|
||||
i = infinitesimals(eq, hint='abaco2_unique_unknown')
|
||||
assert i == [{eta(x, f(x)): -f(x)*f(x)**(-b), xi(x, f(x)): x*x**(-a)}]
|
||||
assert checkinfsol(eq, i)[0]
|
||||
|
||||
eq = f(x).diff(x) + tan(F(x**2 + f(x)**2) + atan(x/f(x)))
|
||||
i = infinitesimals(eq, hint='abaco2_unique_unknown')
|
||||
assert i == [{eta(x, f(x)): x, xi(x, f(x)): -f(x)}]
|
||||
assert checkinfsol(eq, i)[0]
|
||||
|
||||
eq = (x*f(x).diff(x) + f(x) + 2*x)**2 -4*x*f(x) -4*x**2 -4*a
|
||||
i = infinitesimals(eq, hint='abaco2_unique_unknown')
|
||||
assert checkinfsol(eq, i)[0]
|
||||
|
||||
|
||||
def test_heuristic_linear():
|
||||
a, b, m, n = symbols("a b m n")
|
||||
|
||||
eq = x**(n*(m + 1) - m)*(f(x).diff(x)) - a*f(x)**n -b*x**(n*(m + 1))
|
||||
i = infinitesimals(eq, hint='linear')
|
||||
assert checkinfsol(eq, i)[0]
|
||||
|
||||
|
||||
@XFAIL
|
||||
def test_kamke():
|
||||
a, b, alpha, c = symbols("a b alpha c")
|
||||
eq = x**2*(a*f(x)**2+(f(x).diff(x))) + b*x**alpha + c
|
||||
i = infinitesimals(eq, hint='sum_function') # XFAIL
|
||||
assert checkinfsol(eq, i)[0]
|
||||
|
||||
|
||||
def test_user_infinitesimals():
|
||||
x = Symbol("x") # assuming x is real generates an error
|
||||
eq = x*(f(x).diff(x)) + 1 - f(x)**2
|
||||
sol = Eq(f(x), (C1 + x**2)/(C1 - x**2))
|
||||
infinitesimals = {'xi':sqrt(f(x) - 1)/sqrt(f(x) + 1), 'eta':0}
|
||||
assert dsolve(eq, hint='lie_group', **infinitesimals) == sol
|
||||
assert checkodesol(eq, sol) == (True, 0)
|
||||
|
||||
|
||||
@XFAIL
|
||||
def test_lie_group_issue15219():
|
||||
eqn = exp(f(x).diff(x)-f(x))
|
||||
assert 'lie_group' not in classify_ode(eqn, f(x))
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,877 @@
|
||||
from sympy.core.random import randint
|
||||
from sympy.core.function import Function
|
||||
from sympy.core.mul import Mul
|
||||
from sympy.core.numbers import (I, Rational, oo)
|
||||
from sympy.core.relational import Eq
|
||||
from sympy.core.singleton import S
|
||||
from sympy.core.symbol import (Dummy, symbols)
|
||||
from sympy.functions.elementary.exponential import (exp, log)
|
||||
from sympy.functions.elementary.hyperbolic import tanh
|
||||
from sympy.functions.elementary.miscellaneous import sqrt
|
||||
from sympy.functions.elementary.trigonometric import sin
|
||||
from sympy.polys.polytools import Poly
|
||||
from sympy.simplify.ratsimp import ratsimp
|
||||
from sympy.solvers.ode.subscheck import checkodesol
|
||||
from sympy.testing.pytest import slow
|
||||
from sympy.solvers.ode.riccati import (riccati_normal, riccati_inverse_normal,
|
||||
riccati_reduced, match_riccati, inverse_transform_poly, limit_at_inf,
|
||||
check_necessary_conds, val_at_inf, construct_c_case_1,
|
||||
construct_c_case_2, construct_c_case_3, construct_d_case_4,
|
||||
construct_d_case_5, construct_d_case_6, rational_laurent_series,
|
||||
solve_riccati)
|
||||
|
||||
f = Function('f')
|
||||
x = symbols('x')
|
||||
|
||||
# These are the functions used to generate the tests
|
||||
# SHOULD NOT BE USED DIRECTLY IN TESTS
|
||||
|
||||
def rand_rational(maxint):
|
||||
return Rational(randint(-maxint, maxint), randint(1, maxint))
|
||||
|
||||
|
||||
def rand_poly(x, degree, maxint):
|
||||
return Poly([rand_rational(maxint) for _ in range(degree+1)], x)
|
||||
|
||||
|
||||
def rand_rational_function(x, degree, maxint):
|
||||
degnum = randint(1, degree)
|
||||
degden = randint(1, degree)
|
||||
num = rand_poly(x, degnum, maxint)
|
||||
den = rand_poly(x, degden, maxint)
|
||||
while den == Poly(0, x):
|
||||
den = rand_poly(x, degden, maxint)
|
||||
return num / den
|
||||
|
||||
|
||||
def find_riccati_ode(ratfunc, x, yf):
|
||||
y = ratfunc
|
||||
yp = y.diff(x)
|
||||
q1 = rand_rational_function(x, 1, 3)
|
||||
q2 = rand_rational_function(x, 1, 3)
|
||||
while q2 == 0:
|
||||
q2 = rand_rational_function(x, 1, 3)
|
||||
q0 = ratsimp(yp - q1*y - q2*y**2)
|
||||
eq = Eq(yf.diff(), q0 + q1*yf + q2*yf**2)
|
||||
sol = Eq(yf, y)
|
||||
assert checkodesol(eq, sol) == (True, 0)
|
||||
return eq, q0, q1, q2
|
||||
|
||||
|
||||
# Testing functions start
|
||||
|
||||
def test_riccati_transformation():
|
||||
"""
|
||||
This function tests the transformation of the
|
||||
solution of a Riccati ODE to the solution of
|
||||
its corresponding normal Riccati ODE.
|
||||
|
||||
Each test case 4 values -
|
||||
|
||||
1. w - The solution to be transformed
|
||||
2. b1 - The coefficient of f(x) in the ODE.
|
||||
3. b2 - The coefficient of f(x)**2 in the ODE.
|
||||
4. y - The solution to the normal Riccati ODE.
|
||||
"""
|
||||
tests = [
|
||||
(
|
||||
x/(x - 1),
|
||||
(x**2 + 7)/3*x,
|
||||
x,
|
||||
-x**2/(x - 1) - x*(x**2/3 + S(7)/3)/2 - 1/(2*x)
|
||||
),
|
||||
(
|
||||
(2*x + 3)/(2*x + 2),
|
||||
(3 - 3*x)/(x + 1),
|
||||
5*x,
|
||||
-5*x*(2*x + 3)/(2*x + 2) - (3 - 3*x)/(Mul(2, x + 1, evaluate=False)) - 1/(2*x)
|
||||
),
|
||||
(
|
||||
-1/(2*x**2 - 1),
|
||||
0,
|
||||
(2 - x)/(4*x - 2),
|
||||
(2 - x)/((4*x - 2)*(2*x**2 - 1)) - (4*x - 2)*(Mul(-4, 2 - x, evaluate=False)/(4*x - \
|
||||
2)**2 - 1/(4*x - 2))/(Mul(2, 2 - x, evaluate=False))
|
||||
),
|
||||
(
|
||||
x,
|
||||
(8*x - 12)/(12*x + 9),
|
||||
x**3/(6*x - 9),
|
||||
-x**4/(6*x - 9) - (8*x - 12)/(Mul(2, 12*x + 9, evaluate=False)) - (6*x - 9)*(-6*x**3/(6*x \
|
||||
- 9)**2 + 3*x**2/(6*x - 9))/(2*x**3)
|
||||
)]
|
||||
for w, b1, b2, y in tests:
|
||||
assert y == riccati_normal(w, x, b1, b2)
|
||||
assert w == riccati_inverse_normal(y, x, b1, b2).cancel()
|
||||
|
||||
# Test bp parameter in riccati_inverse_normal
|
||||
tests = [
|
||||
(
|
||||
(-2*x - 1)/(2*x**2 + 2*x - 2),
|
||||
-2/x,
|
||||
(-x - 1)/(4*x),
|
||||
8*x**2*(1/(4*x) + (-x - 1)/(4*x**2))/(-x - 1)**2 + 4/(-x - 1),
|
||||
-2*x*(-1/(4*x) - (-x - 1)/(4*x**2))/(-x - 1) - (-2*x - 1)*(-x - 1)/(4*x*(2*x**2 + 2*x \
|
||||
- 2)) + 1/x
|
||||
),
|
||||
(
|
||||
3/(2*x**2),
|
||||
-2/x,
|
||||
(-x - 1)/(4*x),
|
||||
8*x**2*(1/(4*x) + (-x - 1)/(4*x**2))/(-x - 1)**2 + 4/(-x - 1),
|
||||
-2*x*(-1/(4*x) - (-x - 1)/(4*x**2))/(-x - 1) + 1/x - Mul(3, -x - 1, evaluate=False)/(8*x**3)
|
||||
)]
|
||||
for w, b1, b2, bp, y in tests:
|
||||
assert y == riccati_normal(w, x, b1, b2)
|
||||
assert w == riccati_inverse_normal(y, x, b1, b2, bp).cancel()
|
||||
|
||||
|
||||
def test_riccati_reduced():
|
||||
"""
|
||||
This function tests the transformation of a
|
||||
Riccati ODE to its normal Riccati ODE.
|
||||
|
||||
Each test case 2 values -
|
||||
|
||||
1. eq - A Riccati ODE.
|
||||
2. normal_eq - The normal Riccati ODE of eq.
|
||||
"""
|
||||
tests = [
|
||||
(
|
||||
f(x).diff(x) - x**2 - x*f(x) - x*f(x)**2,
|
||||
|
||||
f(x).diff(x) + f(x)**2 + x**3 - x**2/4 - 3/(4*x**2)
|
||||
),
|
||||
(
|
||||
6*x/(2*x + 9) + f(x).diff(x) - (x + 1)*f(x)**2/x,
|
||||
|
||||
-3*x**2*(1/x + (-x - 1)/x**2)**2/(4*(-x - 1)**2) + Mul(6, \
|
||||
-x - 1, evaluate=False)/(2*x + 9) + f(x)**2 + f(x).diff(x) \
|
||||
- (-1 + (x + 1)/x)/(x*(-x - 1))
|
||||
),
|
||||
(
|
||||
f(x)**2 + f(x).diff(x) - (x - 1)*f(x)/(-x - S(1)/2),
|
||||
|
||||
-(2*x - 2)**2/(4*(2*x + 1)**2) + (2*x - 2)/(2*x + 1)**2 + \
|
||||
f(x)**2 + f(x).diff(x) - 1/(2*x + 1)
|
||||
),
|
||||
(
|
||||
f(x).diff(x) - f(x)**2/x,
|
||||
|
||||
f(x)**2 + f(x).diff(x) + 1/(4*x**2)
|
||||
),
|
||||
(
|
||||
-3*(-x**2 - x + 1)/(x**2 + 6*x + 1) + f(x).diff(x) + f(x)**2/x,
|
||||
|
||||
f(x)**2 + f(x).diff(x) + (3*x**2/(x**2 + 6*x + 1) + 3*x/(x**2 \
|
||||
+ 6*x + 1) - 3/(x**2 + 6*x + 1))/x + 1/(4*x**2)
|
||||
),
|
||||
(
|
||||
6*x/(2*x + 9) + f(x).diff(x) - (x + 1)*f(x)/x,
|
||||
|
||||
False
|
||||
),
|
||||
(
|
||||
f(x)*f(x).diff(x) - 1/x + f(x)/3 + f(x)**2/(x**2 - 2),
|
||||
|
||||
False
|
||||
)]
|
||||
for eq, normal_eq in tests:
|
||||
assert normal_eq == riccati_reduced(eq, f, x)
|
||||
|
||||
|
||||
def test_match_riccati():
|
||||
"""
|
||||
This function tests if an ODE is Riccati or not.
|
||||
|
||||
Each test case has 5 values -
|
||||
|
||||
1. eq - The Riccati ODE.
|
||||
2. match - Boolean indicating if eq is a Riccati ODE.
|
||||
3. b0 -
|
||||
4. b1 - Coefficient of f(x) in eq.
|
||||
5. b2 - Coefficient of f(x)**2 in eq.
|
||||
"""
|
||||
tests = [
|
||||
# Test Rational Riccati ODEs
|
||||
(
|
||||
f(x).diff(x) - (405*x**3 - 882*x**2 - 78*x + 92)/(243*x**4 \
|
||||
- 945*x**3 + 846*x**2 + 180*x - 72) - 2 - f(x)**2/(3*x + 1) \
|
||||
- (S(1)/3 - x)*f(x)/(S(1)/3 - 3*x/2),
|
||||
|
||||
True,
|
||||
|
||||
45*x**3/(27*x**4 - 105*x**3 + 94*x**2 + 20*x - 8) - 98*x**2/ \
|
||||
(27*x**4 - 105*x**3 + 94*x**2 + 20*x - 8) - 26*x/(81*x**4 - \
|
||||
315*x**3 + 282*x**2 + 60*x - 24) + 2 + 92/(243*x**4 - 945*x**3 \
|
||||
+ 846*x**2 + 180*x - 72),
|
||||
|
||||
Mul(-1, 2 - 6*x, evaluate=False)/(9*x - 2),
|
||||
|
||||
1/(3*x + 1)
|
||||
),
|
||||
(
|
||||
f(x).diff(x) + 4*x/27 - (x/3 - 1)*f(x)**2 - (2*x/3 + \
|
||||
1)*f(x)/(3*x + 2) - S(10)/27 - (265*x**2 + 423*x + 162) \
|
||||
/(324*x**3 + 216*x**2),
|
||||
|
||||
True,
|
||||
|
||||
-4*x/27 + S(10)/27 + 3/(6*x**3 + 4*x**2) + 47/(36*x**2 \
|
||||
+ 24*x) + 265/(324*x + 216),
|
||||
|
||||
Mul(-1, -2*x - 3, evaluate=False)/(9*x + 6),
|
||||
|
||||
x/3 - 1
|
||||
),
|
||||
(
|
||||
f(x).diff(x) - (304*x**5 - 745*x**4 + 631*x**3 - 876*x**2 \
|
||||
+ 198*x - 108)/(36*x**6 - 216*x**5 + 477*x**4 - 567*x**3 + \
|
||||
360*x**2 - 108*x) - S(17)/9 - (x - S(3)/2)*f(x)/(x/2 - \
|
||||
S(3)/2) - (x/3 - 3)*f(x)**2/(3*x),
|
||||
|
||||
True,
|
||||
|
||||
304*x**4/(36*x**5 - 216*x**4 + 477*x**3 - 567*x**2 + 360*x - \
|
||||
108) - 745*x**3/(36*x**5 - 216*x**4 + 477*x**3 - 567*x**2 + \
|
||||
360*x - 108) + 631*x**2/(36*x**5 - 216*x**4 + 477*x**3 - 567* \
|
||||
x**2 + 360*x - 108) - 292*x/(12*x**5 - 72*x**4 + 159*x**3 - \
|
||||
189*x**2 + 120*x - 36) + S(17)/9 - 12/(4*x**6 - 24*x**5 + \
|
||||
53*x**4 - 63*x**3 + 40*x**2 - 12*x) + 22/(4*x**5 - 24*x**4 \
|
||||
+ 53*x**3 - 63*x**2 + 40*x - 12),
|
||||
|
||||
Mul(-1, 3 - 2*x, evaluate=False)/(x - 3),
|
||||
|
||||
Mul(-1, 9 - x, evaluate=False)/(9*x)
|
||||
),
|
||||
# Test Non-Rational Riccati ODEs
|
||||
(
|
||||
f(x).diff(x) - x**(S(3)/2)/(x**(S(1)/2) - 2) + x**2*f(x) + \
|
||||
x*f(x)**2/(x**(S(3)/4)),
|
||||
False, 0, 0, 0
|
||||
),
|
||||
(
|
||||
f(x).diff(x) - sin(x**2) + exp(x)*f(x) + log(x)*f(x)**2,
|
||||
False, 0, 0, 0
|
||||
),
|
||||
(
|
||||
f(x).diff(x) - tanh(x + sqrt(x)) + f(x) + x**4*f(x)**2,
|
||||
False, 0, 0, 0
|
||||
),
|
||||
# Test Non-Riccati ODEs
|
||||
(
|
||||
(1 - x**2)*f(x).diff(x, 2) - 2*x*f(x).diff(x) + 20*f(x),
|
||||
False, 0, 0, 0
|
||||
),
|
||||
(
|
||||
f(x).diff(x) - x**2 + x**3*f(x) + (x**2/(x + 1))*f(x)**3,
|
||||
False, 0, 0, 0
|
||||
),
|
||||
(
|
||||
f(x).diff(x)*f(x)**2 + (x**2 - 1)/(x**3 + 1)*f(x) + 1/(2*x \
|
||||
+ 3) + f(x)**2,
|
||||
False, 0, 0, 0
|
||||
)]
|
||||
for eq, res, b0, b1, b2 in tests:
|
||||
match, funcs = match_riccati(eq, f, x)
|
||||
assert match == res
|
||||
if res:
|
||||
assert [b0, b1, b2] == funcs
|
||||
|
||||
|
||||
def test_val_at_inf():
|
||||
"""
|
||||
This function tests the valuation of rational
|
||||
function at oo.
|
||||
|
||||
Each test case has 3 values -
|
||||
|
||||
1. num - Numerator of rational function.
|
||||
2. den - Denominator of rational function.
|
||||
3. val_inf - Valuation of rational function at oo
|
||||
"""
|
||||
tests = [
|
||||
# degree(denom) > degree(numer)
|
||||
(
|
||||
Poly(10*x**3 + 8*x**2 - 13*x + 6, x),
|
||||
Poly(-13*x**10 - x**9 + 5*x**8 + 7*x**7 + 10*x**6 + 6*x**5 - 7*x**4 + 11*x**3 - 8*x**2 + 5*x + 13, x),
|
||||
7
|
||||
),
|
||||
(
|
||||
Poly(1, x),
|
||||
Poly(-9*x**4 + 3*x**3 + 15*x**2 - 6*x - 14, x),
|
||||
4
|
||||
),
|
||||
# degree(denom) == degree(numer)
|
||||
(
|
||||
Poly(-6*x**3 - 8*x**2 + 8*x - 6, x),
|
||||
Poly(-5*x**3 + 12*x**2 - 6*x - 9, x),
|
||||
0
|
||||
),
|
||||
# degree(denom) < degree(numer)
|
||||
(
|
||||
Poly(12*x**8 - 12*x**7 - 11*x**6 + 8*x**5 + 3*x**4 - x**3 + x**2 - 11*x, x),
|
||||
Poly(-14*x**2 + x, x),
|
||||
-6
|
||||
),
|
||||
(
|
||||
Poly(5*x**6 + 9*x**5 - 11*x**4 - 9*x**3 + x**2 - 4*x + 4, x),
|
||||
Poly(15*x**4 + 3*x**3 - 8*x**2 + 15*x + 12, x),
|
||||
-2
|
||||
)]
|
||||
for num, den, val in tests:
|
||||
assert val_at_inf(num, den, x) == val
|
||||
|
||||
|
||||
def test_necessary_conds():
|
||||
"""
|
||||
This function tests the necessary conditions for
|
||||
a Riccati ODE to have a rational particular solution.
|
||||
"""
|
||||
# Valuation at Infinity is an odd negative integer
|
||||
assert check_necessary_conds(-3, [1, 2, 4]) == False
|
||||
# Valuation at Infinity is a positive integer lesser than 2
|
||||
assert check_necessary_conds(1, [1, 2, 4]) == False
|
||||
# Multiplicity of a pole is an odd integer greater than 1
|
||||
assert check_necessary_conds(2, [3, 1, 6]) == False
|
||||
# All values are correct
|
||||
assert check_necessary_conds(-10, [1, 2, 8, 12]) == True
|
||||
|
||||
|
||||
def test_inverse_transform_poly():
|
||||
"""
|
||||
This function tests the substitution x -> 1/x
|
||||
in rational functions represented using Poly.
|
||||
"""
|
||||
fns = [
|
||||
(15*x**3 - 8*x**2 - 2*x - 6)/(18*x + 6),
|
||||
|
||||
(180*x**5 + 40*x**4 + 80*x**3 + 30*x**2 - 60*x - 80)/(180*x**3 - 150*x**2 + 75*x + 12),
|
||||
|
||||
(-15*x**5 - 36*x**4 + 75*x**3 - 60*x**2 - 80*x - 60)/(80*x**4 + 60*x**3 + 60*x**2 + 60*x - 80),
|
||||
|
||||
(60*x**7 + 24*x**6 - 15*x**5 - 20*x**4 + 30*x**2 + 100*x - 60)/(240*x**2 - 20*x - 30),
|
||||
|
||||
(30*x**6 - 12*x**5 + 15*x**4 - 15*x**2 + 10*x + 60)/(3*x**10 - 45*x**9 + 15*x**5 + 15*x**4 - 5*x**3 \
|
||||
+ 15*x**2 + 45*x - 15)
|
||||
]
|
||||
for f in fns:
|
||||
num, den = [Poly(e, x) for e in f.as_numer_denom()]
|
||||
num, den = inverse_transform_poly(num, den, x)
|
||||
assert f.subs(x, 1/x).cancel() == num/den
|
||||
|
||||
|
||||
def test_limit_at_inf():
|
||||
"""
|
||||
This function tests the limit at oo of a
|
||||
rational function.
|
||||
|
||||
Each test case has 3 values -
|
||||
|
||||
1. num - Numerator of rational function.
|
||||
2. den - Denominator of rational function.
|
||||
3. limit_at_inf - Limit of rational function at oo
|
||||
"""
|
||||
tests = [
|
||||
# deg(denom) > deg(numer)
|
||||
(
|
||||
Poly(-12*x**2 + 20*x + 32, x),
|
||||
Poly(32*x**3 + 72*x**2 + 3*x - 32, x),
|
||||
0
|
||||
),
|
||||
# deg(denom) < deg(numer)
|
||||
(
|
||||
Poly(1260*x**4 - 1260*x**3 - 700*x**2 - 1260*x + 1400, x),
|
||||
Poly(6300*x**3 - 1575*x**2 + 756*x - 540, x),
|
||||
oo
|
||||
),
|
||||
# deg(denom) < deg(numer), one of the leading coefficients is negative
|
||||
(
|
||||
Poly(-735*x**8 - 1400*x**7 + 1680*x**6 - 315*x**5 - 600*x**4 + 840*x**3 - 525*x**2 \
|
||||
+ 630*x + 3780, x),
|
||||
Poly(1008*x**7 - 2940*x**6 - 84*x**5 + 2940*x**4 - 420*x**3 + 1512*x**2 + 105*x + 168, x),
|
||||
-oo
|
||||
),
|
||||
# deg(denom) == deg(numer)
|
||||
(
|
||||
Poly(105*x**7 - 960*x**6 + 60*x**5 + 60*x**4 - 80*x**3 + 45*x**2 + 120*x + 15, x),
|
||||
Poly(735*x**7 + 525*x**6 + 720*x**5 + 720*x**4 - 8400*x**3 - 2520*x**2 + 2800*x + 280, x),
|
||||
S(1)/7
|
||||
),
|
||||
(
|
||||
Poly(288*x**4 - 450*x**3 + 280*x**2 - 900*x - 90, x),
|
||||
Poly(607*x**4 + 840*x**3 - 1050*x**2 + 420*x + 420, x),
|
||||
S(288)/607
|
||||
)]
|
||||
for num, den, lim in tests:
|
||||
assert limit_at_inf(num, den, x) == lim
|
||||
|
||||
|
||||
def test_construct_c_case_1():
|
||||
"""
|
||||
This function tests the Case 1 in the step
|
||||
to calculate coefficients of c-vectors.
|
||||
|
||||
Each test case has 4 values -
|
||||
|
||||
1. num - Numerator of the rational function a(x).
|
||||
2. den - Denominator of the rational function a(x).
|
||||
3. pole - Pole of a(x) for which c-vector is being
|
||||
calculated.
|
||||
4. c - The c-vector for the pole.
|
||||
"""
|
||||
tests = [
|
||||
(
|
||||
Poly(-3*x**3 + 3*x**2 + 4*x - 5, x, extension=True),
|
||||
Poly(4*x**8 + 16*x**7 + 9*x**5 + 12*x**4 + 6*x**3 + 12*x**2, x, extension=True),
|
||||
S(0),
|
||||
[[S(1)/2 + sqrt(6)*I/6], [S(1)/2 - sqrt(6)*I/6]]
|
||||
),
|
||||
(
|
||||
Poly(1200*x**3 + 1440*x**2 + 816*x + 560, x, extension=True),
|
||||
Poly(128*x**5 - 656*x**4 + 1264*x**3 - 1125*x**2 + 385*x + 49, x, extension=True),
|
||||
S(7)/4,
|
||||
[[S(1)/2 + sqrt(16367978)/634], [S(1)/2 - sqrt(16367978)/634]]
|
||||
),
|
||||
(
|
||||
Poly(4*x + 2, x, extension=True),
|
||||
Poly(18*x**4 + (2 - 18*sqrt(3))*x**3 + (14 - 11*sqrt(3))*x**2 + (4 - 6*sqrt(3))*x \
|
||||
+ 8*sqrt(3) + 16, x, domain='QQ<sqrt(3)>'),
|
||||
(S(1) + sqrt(3))/2,
|
||||
[[S(1)/2 + sqrt(Mul(4, 2*sqrt(3) + 4, evaluate=False)/(19*sqrt(3) + 44) + 1)/2], \
|
||||
[S(1)/2 - sqrt(Mul(4, 2*sqrt(3) + 4, evaluate=False)/(19*sqrt(3) + 44) + 1)/2]]
|
||||
)]
|
||||
for num, den, pole, c in tests:
|
||||
assert construct_c_case_1(num, den, x, pole) == c
|
||||
|
||||
|
||||
def test_construct_c_case_2():
|
||||
"""
|
||||
This function tests the Case 2 in the step
|
||||
to calculate coefficients of c-vectors.
|
||||
|
||||
Each test case has 5 values -
|
||||
|
||||
1. num - Numerator of the rational function a(x).
|
||||
2. den - Denominator of the rational function a(x).
|
||||
3. pole - Pole of a(x) for which c-vector is being
|
||||
calculated.
|
||||
4. mul - The multiplicity of the pole.
|
||||
5. c - The c-vector for the pole.
|
||||
"""
|
||||
tests = [
|
||||
# Testing poles with multiplicity 2
|
||||
(
|
||||
Poly(1, x, extension=True),
|
||||
Poly((x - 1)**2*(x - 2), x, extension=True),
|
||||
1, 2,
|
||||
[[-I*(-1 - I)/2], [I*(-1 + I)/2]]
|
||||
),
|
||||
(
|
||||
Poly(3*x**5 - 12*x**4 - 7*x**3 + 1, x, extension=True),
|
||||
Poly((3*x - 1)**2*(x + 2)**2, x, extension=True),
|
||||
S(1)/3, 2,
|
||||
[[-S(89)/98], [-S(9)/98]]
|
||||
),
|
||||
# Testing poles with multiplicity 4
|
||||
(
|
||||
Poly(x**3 - x**2 + 4*x, x, extension=True),
|
||||
Poly((x - 2)**4*(x + 5)**2, x, extension=True),
|
||||
2, 4,
|
||||
[[7*sqrt(3)*(S(60)/343 - 4*sqrt(3)/7)/12, 2*sqrt(3)/7], \
|
||||
[-7*sqrt(3)*(S(60)/343 + 4*sqrt(3)/7)/12, -2*sqrt(3)/7]]
|
||||
),
|
||||
(
|
||||
Poly(3*x**5 + x**4 + 3, x, extension=True),
|
||||
Poly((4*x + 1)**4*(x + 2), x, extension=True),
|
||||
-S(1)/4, 4,
|
||||
[[128*sqrt(439)*(-sqrt(439)/128 - S(55)/14336)/439, sqrt(439)/256], \
|
||||
[-128*sqrt(439)*(sqrt(439)/128 - S(55)/14336)/439, -sqrt(439)/256]]
|
||||
),
|
||||
# Testing poles with multiplicity 6
|
||||
(
|
||||
Poly(x**3 + 2, x, extension=True),
|
||||
Poly((3*x - 1)**6*(x**2 + 1), x, extension=True),
|
||||
S(1)/3, 6,
|
||||
[[27*sqrt(66)*(-sqrt(66)/54 - S(131)/267300)/22, -2*sqrt(66)/1485, sqrt(66)/162], \
|
||||
[-27*sqrt(66)*(sqrt(66)/54 - S(131)/267300)/22, 2*sqrt(66)/1485, -sqrt(66)/162]]
|
||||
),
|
||||
(
|
||||
Poly(x**2 + 12, x, extension=True),
|
||||
Poly((x - sqrt(2))**6, x, extension=True),
|
||||
sqrt(2), 6,
|
||||
[[sqrt(14)*(S(6)/7 - 3*sqrt(14))/28, sqrt(7)/7, sqrt(14)], \
|
||||
[-sqrt(14)*(S(6)/7 + 3*sqrt(14))/28, -sqrt(7)/7, -sqrt(14)]]
|
||||
)]
|
||||
for num, den, pole, mul, c in tests:
|
||||
assert construct_c_case_2(num, den, x, pole, mul) == c
|
||||
|
||||
|
||||
def test_construct_c_case_3():
|
||||
"""
|
||||
This function tests the Case 3 in the step
|
||||
to calculate coefficients of c-vectors.
|
||||
"""
|
||||
assert construct_c_case_3() == [[1]]
|
||||
|
||||
|
||||
def test_construct_d_case_4():
|
||||
"""
|
||||
This function tests the Case 4 in the step
|
||||
to calculate coefficients of the d-vector.
|
||||
|
||||
Each test case has 4 values -
|
||||
|
||||
1. num - Numerator of the rational function a(x).
|
||||
2. den - Denominator of the rational function a(x).
|
||||
3. mul - Multiplicity of oo as a pole.
|
||||
4. d - The d-vector.
|
||||
"""
|
||||
tests = [
|
||||
# Tests with multiplicity at oo = 2
|
||||
(
|
||||
Poly(-x**5 - 2*x**4 + 4*x**3 + 2*x + 5, x, extension=True),
|
||||
Poly(9*x**3 - 2*x**2 + 10*x - 2, x, extension=True),
|
||||
2,
|
||||
[[10*I/27, I/3, -3*I*(S(158)/243 - I/3)/2], \
|
||||
[-10*I/27, -I/3, 3*I*(S(158)/243 + I/3)/2]]
|
||||
),
|
||||
(
|
||||
Poly(-x**6 + 9*x**5 + 5*x**4 + 6*x**3 + 5*x**2 + 6*x + 7, x, extension=True),
|
||||
Poly(x**4 + 3*x**3 + 12*x**2 - x + 7, x, extension=True),
|
||||
2,
|
||||
[[-6*I, I, -I*(17 - I)/2], [6*I, -I, I*(17 + I)/2]]
|
||||
),
|
||||
# Tests with multiplicity at oo = 4
|
||||
(
|
||||
Poly(-2*x**6 - x**5 - x**4 - 2*x**3 - x**2 - 3*x - 3, x, extension=True),
|
||||
Poly(3*x**2 + 10*x + 7, x, extension=True),
|
||||
4,
|
||||
[[269*sqrt(6)*I/288, -17*sqrt(6)*I/36, sqrt(6)*I/3, -sqrt(6)*I*(S(16969)/2592 \
|
||||
- 2*sqrt(6)*I/3)/4], [-269*sqrt(6)*I/288, 17*sqrt(6)*I/36, -sqrt(6)*I/3, \
|
||||
sqrt(6)*I*(S(16969)/2592 + 2*sqrt(6)*I/3)/4]]
|
||||
),
|
||||
(
|
||||
Poly(-3*x**5 - 3*x**4 - 3*x**3 - x**2 - 1, x, extension=True),
|
||||
Poly(12*x - 2, x, extension=True),
|
||||
4,
|
||||
[[41*I/192, 7*I/24, I/2, -I*(-S(59)/6912 - I)], \
|
||||
[-41*I/192, -7*I/24, -I/2, I*(-S(59)/6912 + I)]]
|
||||
),
|
||||
# Tests with multiplicity at oo = 4
|
||||
(
|
||||
Poly(-x**7 - x**5 - x**4 - x**2 - x, x, extension=True),
|
||||
Poly(x + 2, x, extension=True),
|
||||
6,
|
||||
[[-5*I/2, 2*I, -I, I, -I*(-9 - 3*I)/2], [5*I/2, -2*I, I, -I, I*(-9 + 3*I)/2]]
|
||||
),
|
||||
(
|
||||
Poly(-x**7 - x**6 - 2*x**5 - 2*x**4 - x**3 - x**2 + 2*x - 2, x, extension=True),
|
||||
Poly(2*x - 2, x, extension=True),
|
||||
6,
|
||||
[[3*sqrt(2)*I/4, 3*sqrt(2)*I/4, sqrt(2)*I/2, sqrt(2)*I/2, -sqrt(2)*I*(-S(7)/8 - \
|
||||
3*sqrt(2)*I/2)/2], [-3*sqrt(2)*I/4, -3*sqrt(2)*I/4, -sqrt(2)*I/2, -sqrt(2)*I/2, \
|
||||
sqrt(2)*I*(-S(7)/8 + 3*sqrt(2)*I/2)/2]]
|
||||
)]
|
||||
for num, den, mul, d in tests:
|
||||
ser = rational_laurent_series(num, den, x, oo, mul, 1)
|
||||
assert construct_d_case_4(ser, mul//2) == d
|
||||
|
||||
|
||||
def test_construct_d_case_5():
|
||||
"""
|
||||
This function tests the Case 5 in the step
|
||||
to calculate coefficients of the d-vector.
|
||||
|
||||
Each test case has 3 values -
|
||||
|
||||
1. num - Numerator of the rational function a(x).
|
||||
2. den - Denominator of the rational function a(x).
|
||||
3. d - The d-vector.
|
||||
"""
|
||||
tests = [
|
||||
(
|
||||
Poly(2*x**3 + x**2 + x - 2, x, extension=True),
|
||||
Poly(9*x**3 + 5*x**2 + 2*x - 1, x, extension=True),
|
||||
[[sqrt(2)/3, -sqrt(2)/108], [-sqrt(2)/3, sqrt(2)/108]]
|
||||
),
|
||||
(
|
||||
Poly(3*x**5 + x**4 - x**3 + x**2 - 2*x - 2, x, domain='ZZ'),
|
||||
Poly(9*x**5 + 7*x**4 + 3*x**3 + 2*x**2 + 5*x + 7, x, domain='ZZ'),
|
||||
[[sqrt(3)/3, -2*sqrt(3)/27], [-sqrt(3)/3, 2*sqrt(3)/27]]
|
||||
),
|
||||
(
|
||||
Poly(x**2 - x + 1, x, domain='ZZ'),
|
||||
Poly(3*x**2 + 7*x + 3, x, domain='ZZ'),
|
||||
[[sqrt(3)/3, -5*sqrt(3)/9], [-sqrt(3)/3, 5*sqrt(3)/9]]
|
||||
)]
|
||||
for num, den, d in tests:
|
||||
# Multiplicity of oo is 0
|
||||
ser = rational_laurent_series(num, den, x, oo, 0, 1)
|
||||
assert construct_d_case_5(ser) == d
|
||||
|
||||
|
||||
def test_construct_d_case_6():
|
||||
"""
|
||||
This function tests the Case 6 in the step
|
||||
to calculate coefficients of the d-vector.
|
||||
|
||||
Each test case has 3 values -
|
||||
|
||||
1. num - Numerator of the rational function a(x).
|
||||
2. den - Denominator of the rational function a(x).
|
||||
3. d - The d-vector.
|
||||
"""
|
||||
tests = [
|
||||
(
|
||||
Poly(-2*x**2 - 5, x, domain='ZZ'),
|
||||
Poly(4*x**4 + 2*x**2 + 10*x + 2, x, domain='ZZ'),
|
||||
[[S(1)/2 + I/2], [S(1)/2 - I/2]]
|
||||
),
|
||||
(
|
||||
Poly(-2*x**3 - 4*x**2 - 2*x - 5, x, domain='ZZ'),
|
||||
Poly(x**6 - x**5 + 2*x**4 - 4*x**3 - 5*x**2 - 5*x + 9, x, domain='ZZ'),
|
||||
[[1], [0]]
|
||||
),
|
||||
(
|
||||
Poly(-5*x**3 + x**2 + 11*x + 12, x, domain='ZZ'),
|
||||
Poly(6*x**8 - 26*x**7 - 27*x**6 - 10*x**5 - 44*x**4 - 46*x**3 - 34*x**2 \
|
||||
- 27*x - 42, x, domain='ZZ'),
|
||||
[[1], [0]]
|
||||
)]
|
||||
for num, den, d in tests:
|
||||
assert construct_d_case_6(num, den, x) == d
|
||||
|
||||
|
||||
def test_rational_laurent_series():
|
||||
"""
|
||||
This function tests the computation of coefficients
|
||||
of Laurent series of a rational function.
|
||||
|
||||
Each test case has 5 values -
|
||||
|
||||
1. num - Numerator of the rational function.
|
||||
2. den - Denominator of the rational function.
|
||||
3. x0 - Point about which Laurent series is to
|
||||
be calculated.
|
||||
4. mul - Multiplicity of x0 if x0 is a pole of
|
||||
the rational function (0 otherwise).
|
||||
5. n - Number of terms upto which the series
|
||||
is to be calculated.
|
||||
"""
|
||||
tests = [
|
||||
# Laurent series about simple pole (Multiplicity = 1)
|
||||
(
|
||||
Poly(x**2 - 3*x + 9, x, extension=True),
|
||||
Poly(x**2 - x, x, extension=True),
|
||||
S(1), 1, 6,
|
||||
{1: 7, 0: -8, -1: 9, -2: -9, -3: 9, -4: -9}
|
||||
),
|
||||
# Laurent series about multiple pole (Multiplicity > 1)
|
||||
(
|
||||
Poly(64*x**3 - 1728*x + 1216, x, extension=True),
|
||||
Poly(64*x**4 - 80*x**3 - 831*x**2 + 1809*x - 972, x, extension=True),
|
||||
S(9)/8, 2, 3,
|
||||
{0: S(32177152)/46521675, 2: S(1019)/984, -1: S(11947565056)/28610830125, \
|
||||
1: S(209149)/75645}
|
||||
),
|
||||
(
|
||||
Poly(1, x, extension=True),
|
||||
Poly(x**5 + (-4*sqrt(2) - 1)*x**4 + (4*sqrt(2) + 12)*x**3 + (-12 - 8*sqrt(2))*x**2 \
|
||||
+ (4 + 8*sqrt(2))*x - 4, x, extension=True),
|
||||
sqrt(2), 4, 6,
|
||||
{4: 1 + sqrt(2), 3: -3 - 2*sqrt(2), 2: Mul(-1, -3 - 2*sqrt(2), evaluate=False)/(-1 \
|
||||
+ sqrt(2)), 1: (-3 - 2*sqrt(2))/(-1 + sqrt(2))**2, 0: Mul(-1, -3 - 2*sqrt(2), evaluate=False \
|
||||
)/(-1 + sqrt(2))**3, -1: (-3 - 2*sqrt(2))/(-1 + sqrt(2))**4}
|
||||
),
|
||||
# Laurent series about oo
|
||||
(
|
||||
Poly(x**5 - 4*x**3 + 6*x**2 + 10*x - 13, x, extension=True),
|
||||
Poly(x**2 - 5, x, extension=True),
|
||||
oo, 3, 6,
|
||||
{3: 1, 2: 0, 1: 1, 0: 6, -1: 15, -2: 17}
|
||||
),
|
||||
# Laurent series at x0 where x0 is not a pole of the function
|
||||
# Using multiplicity as 0 (as x0 will not be a pole)
|
||||
(
|
||||
Poly(3*x**3 + 6*x**2 - 2*x + 5, x, extension=True),
|
||||
Poly(9*x**4 - x**3 - 3*x**2 + 4*x + 4, x, extension=True),
|
||||
S(2)/5, 0, 1,
|
||||
{0: S(3345)/3304, -1: S(399325)/2729104, -2: S(3926413375)/4508479808, \
|
||||
-3: S(-5000852751875)/1862002160704, -4: S(-6683640101653125)/6152055138966016}
|
||||
),
|
||||
(
|
||||
Poly(-7*x**2 + 2*x - 4, x, extension=True),
|
||||
Poly(7*x**5 + 9*x**4 + 8*x**3 + 3*x**2 + 6*x + 9, x, extension=True),
|
||||
oo, 0, 6,
|
||||
{0: 0, -2: 0, -5: -S(71)/49, -1: 0, -3: -1, -4: S(11)/7}
|
||||
)]
|
||||
for num, den, x0, mul, n, ser in tests:
|
||||
assert ser == rational_laurent_series(num, den, x, x0, mul, n)
|
||||
|
||||
|
||||
def check_dummy_sol(eq, solse, dummy_sym):
|
||||
"""
|
||||
Helper function to check if actual solution
|
||||
matches expected solution if actual solution
|
||||
contains dummy symbols.
|
||||
"""
|
||||
if isinstance(eq, Eq):
|
||||
eq = eq.lhs - eq.rhs
|
||||
_, funcs = match_riccati(eq, f, x)
|
||||
|
||||
sols = solve_riccati(f(x), x, *funcs)
|
||||
C1 = Dummy('C1')
|
||||
sols = [sol.subs(C1, dummy_sym) for sol in sols]
|
||||
|
||||
assert all(x[0] for x in checkodesol(eq, sols))
|
||||
assert all(s1.dummy_eq(s2, dummy_sym) for s1, s2 in zip(sols, solse))
|
||||
|
||||
|
||||
def test_solve_riccati():
|
||||
"""
|
||||
This function tests the computation of rational
|
||||
particular solutions for a Riccati ODE.
|
||||
|
||||
Each test case has 2 values -
|
||||
|
||||
1. eq - Riccati ODE to be solved.
|
||||
2. sol - Expected solution to the equation.
|
||||
|
||||
Some examples have been taken from the paper - "Statistical Investigation of
|
||||
First-Order Algebraic ODEs and their Rational General Solutions" by
|
||||
Georg Grasegger, N. Thieu Vo, Franz Winkler
|
||||
|
||||
https://www3.risc.jku.at/publications/download/risc_5197/RISCReport15-19.pdf
|
||||
"""
|
||||
C0 = Dummy('C0')
|
||||
# Type: 1st Order Rational Riccati, dy/dx = a + b*y + c*y**2,
|
||||
# a, b, c are rational functions of x
|
||||
|
||||
tests = [
|
||||
# a(x) is a constant
|
||||
(
|
||||
Eq(f(x).diff(x) + f(x)**2 - 2, 0),
|
||||
[Eq(f(x), sqrt(2)), Eq(f(x), -sqrt(2))]
|
||||
),
|
||||
# a(x) is a constant
|
||||
(
|
||||
f(x)**2 + f(x).diff(x) + 4*f(x)/x + 2/x**2,
|
||||
[Eq(f(x), (-2*C0 - x)/(C0*x + x**2))]
|
||||
),
|
||||
# a(x) is a constant
|
||||
(
|
||||
2*x**2*f(x).diff(x) - x*(4*f(x) + f(x).diff(x) - 4) + (f(x) - 1)*f(x),
|
||||
[Eq(f(x), (C0 + 2*x**2)/(C0 + x))]
|
||||
),
|
||||
# Pole with multiplicity 1
|
||||
(
|
||||
Eq(f(x).diff(x), -f(x)**2 - 2/(x**3 - x**2)),
|
||||
[Eq(f(x), 1/(x**2 - x))]
|
||||
),
|
||||
# One pole of multiplicity 2
|
||||
(
|
||||
x**2 - (2*x + 1/x)*f(x) + f(x)**2 + f(x).diff(x),
|
||||
[Eq(f(x), (C0*x + x**3 + 2*x)/(C0 + x**2)), Eq(f(x), x)]
|
||||
),
|
||||
(
|
||||
x**4*f(x).diff(x) + x**2 - x*(2*f(x)**2 + f(x).diff(x)) + f(x),
|
||||
[Eq(f(x), (C0*x**2 + x)/(C0 + x**2)), Eq(f(x), x**2)]
|
||||
),
|
||||
# Multiple poles of multiplicity 2
|
||||
(
|
||||
-f(x)**2 + f(x).diff(x) + (15*x**2 - 20*x + 7)/((x - 1)**2*(2*x \
|
||||
- 1)**2),
|
||||
[Eq(f(x), (9*C0*x - 6*C0 - 15*x**5 + 60*x**4 - 94*x**3 + 72*x**2 \
|
||||
- 30*x + 6)/(6*C0*x**2 - 9*C0*x + 3*C0 + 6*x**6 - 29*x**5 + \
|
||||
57*x**4 - 58*x**3 + 30*x**2 - 6*x)), Eq(f(x), (3*x - 2)/(2*x**2 \
|
||||
- 3*x + 1))]
|
||||
),
|
||||
# Regression: Poles with even multiplicity > 2 fixed
|
||||
(
|
||||
f(x)**2 + f(x).diff(x) - (4*x**6 - 8*x**5 + 12*x**4 + 4*x**3 + \
|
||||
7*x**2 - 20*x + 4)/(4*x**4),
|
||||
[Eq(f(x), (2*x**5 - 2*x**4 - x**3 + 4*x**2 + 3*x - 2)/(2*x**4 \
|
||||
- 2*x**2))]
|
||||
),
|
||||
# Regression: Poles with even multiplicity > 2 fixed
|
||||
(
|
||||
Eq(f(x).diff(x), (-x**6 + 15*x**4 - 40*x**3 + 45*x**2 - 24*x + 4)/\
|
||||
(x**12 - 12*x**11 + 66*x**10 - 220*x**9 + 495*x**8 - 792*x**7 + 924*x**6 - \
|
||||
792*x**5 + 495*x**4 - 220*x**3 + 66*x**2 - 12*x + 1) + f(x)**2 + f(x)),
|
||||
[Eq(f(x), 1/(x**6 - 6*x**5 + 15*x**4 - 20*x**3 + 15*x**2 - 6*x + 1))]
|
||||
),
|
||||
# More than 2 poles with multiplicity 2
|
||||
# Regression: Fixed mistake in necessary conditions
|
||||
(
|
||||
Eq(f(x).diff(x), x*f(x) + 2*x + (3*x - 2)*f(x)**2/(4*x + 2) + \
|
||||
(8*x**2 - 7*x + 26)/(16*x**3 - 24*x**2 + 8) - S(3)/2),
|
||||
[Eq(f(x), (1 - 4*x)/(2*x - 2))]
|
||||
),
|
||||
# Regression: Fixed mistake in necessary conditions
|
||||
(
|
||||
Eq(f(x).diff(x), (-12*x**2 - 48*x - 15)/(24*x**3 - 40*x**2 + 8*x + 8) \
|
||||
+ 3*f(x)**2/(6*x + 2)),
|
||||
[Eq(f(x), (2*x + 1)/(2*x - 2))]
|
||||
),
|
||||
# Imaginary poles
|
||||
(
|
||||
f(x).diff(x) + (3*x**2 + 1)*f(x)**2/x + (6*x**2 - x + 3)*f(x)/(x*(x \
|
||||
- 1)) + (3*x**2 - 2*x + 2)/(x*(x - 1)**2),
|
||||
[Eq(f(x), (-C0 - x**3 + x**2 - 2*x)/(C0*x - C0 + x**4 - x**3 + x**2 \
|
||||
- x)), Eq(f(x), -1/(x - 1))],
|
||||
),
|
||||
# Imaginary coefficients in equation
|
||||
(
|
||||
f(x).diff(x) - 2*I*(f(x)**2 + 1)/x,
|
||||
[Eq(f(x), (-I*C0 + I*x**4)/(C0 + x**4)), Eq(f(x), -I)]
|
||||
),
|
||||
# Regression: linsolve returning empty solution
|
||||
# Large value of m (> 10)
|
||||
(
|
||||
Eq(f(x).diff(x), x*f(x)/(S(3)/2 - 2*x) + (x/2 - S(1)/3)*f(x)**2/\
|
||||
(2*x/3 - S(1)/2) - S(5)/4 + (281*x**2 - 1260*x + 756)/(16*x**3 - 12*x**2)),
|
||||
[Eq(f(x), (9 - x)/x), Eq(f(x), (40*x**14 + 28*x**13 + 420*x**12 + 2940*x**11 + \
|
||||
18480*x**10 + 103950*x**9 + 519750*x**8 + 2286900*x**7 + 8731800*x**6 + 28378350*\
|
||||
x**5 + 76403250*x**4 + 163721250*x**3 + 261954000*x**2 + 278326125*x + 147349125)/\
|
||||
((24*x**14 + 140*x**13 + 840*x**12 + 4620*x**11 + 23100*x**10 + 103950*x**9 + \
|
||||
415800*x**8 + 1455300*x**7 + 4365900*x**6 + 10914750*x**5 + 21829500*x**4 + 32744250\
|
||||
*x**3 + 32744250*x**2 + 16372125*x)))]
|
||||
),
|
||||
# Regression: Fixed bug due to a typo in paper
|
||||
(
|
||||
Eq(f(x).diff(x), 18*x**3 + 18*x**2 + (-x/2 - S(1)/2)*f(x)**2 + 6),
|
||||
[Eq(f(x), 6*x)]
|
||||
),
|
||||
# Regression: Fixed bug due to a typo in paper
|
||||
(
|
||||
Eq(f(x).diff(x), -3*x**3/4 + 15*x/2 + (x/3 - S(4)/3)*f(x)**2 \
|
||||
+ 9 + (1 - x)*f(x)/x + 3/x),
|
||||
[Eq(f(x), -3*x/2 - 3)]
|
||||
)]
|
||||
for eq, sol in tests:
|
||||
check_dummy_sol(eq, sol, C0)
|
||||
|
||||
|
||||
@slow
|
||||
def test_solve_riccati_slow():
|
||||
"""
|
||||
This function tests the computation of rational
|
||||
particular solutions for a Riccati ODE.
|
||||
|
||||
Each test case has 2 values -
|
||||
|
||||
1. eq - Riccati ODE to be solved.
|
||||
2. sol - Expected solution to the equation.
|
||||
"""
|
||||
C0 = Dummy('C0')
|
||||
tests = [
|
||||
# Very large values of m (989 and 991)
|
||||
(
|
||||
Eq(f(x).diff(x), (1 - x)*f(x)/(x - 3) + (2 - 12*x)*f(x)**2/(2*x - 9) + \
|
||||
(54924*x**3 - 405264*x**2 + 1084347*x - 1087533)/(8*x**4 - 132*x**3 + 810*x**2 - \
|
||||
2187*x + 2187) + 495),
|
||||
[Eq(f(x), (18*x + 6)/(2*x - 9))]
|
||||
)]
|
||||
for eq, sol in tests:
|
||||
check_dummy_sol(eq, sol, C0)
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,203 @@
|
||||
from sympy.core.function import (Derivative, Function, diff)
|
||||
from sympy.core.numbers import (I, Rational, pi)
|
||||
from sympy.core.relational import Eq
|
||||
from sympy.core.symbol import (Symbol, symbols)
|
||||
from sympy.functions.elementary.exponential import (exp, log)
|
||||
from sympy.functions.elementary.miscellaneous import sqrt
|
||||
from sympy.functions.elementary.trigonometric import (cos, sin)
|
||||
from sympy.functions.special.error_functions import (Ei, erf, erfi)
|
||||
from sympy.integrals.integrals import Integral
|
||||
|
||||
from sympy.solvers.ode.subscheck import checkodesol, checksysodesol
|
||||
|
||||
from sympy.functions import besselj, bessely
|
||||
|
||||
from sympy.testing.pytest import raises, slow
|
||||
|
||||
|
||||
C0, C1, C2, C3, C4 = symbols('C0:5')
|
||||
u, x, y, z = symbols('u,x:z', real=True)
|
||||
f = Function('f')
|
||||
g = Function('g')
|
||||
h = Function('h')
|
||||
|
||||
|
||||
@slow
|
||||
def test_checkodesol():
|
||||
# For the most part, checkodesol is well tested in the tests below.
|
||||
# These tests only handle cases not checked below.
|
||||
raises(ValueError, lambda: checkodesol(f(x, y).diff(x), Eq(f(x, y), x)))
|
||||
raises(ValueError, lambda: checkodesol(f(x).diff(x), Eq(f(x, y),
|
||||
x), f(x, y)))
|
||||
assert checkodesol(f(x).diff(x), Eq(f(x, y), x)) == \
|
||||
(False, -f(x).diff(x) + f(x, y).diff(x) - 1)
|
||||
assert checkodesol(f(x).diff(x), Eq(f(x), x)) is not True
|
||||
assert checkodesol(f(x).diff(x), Eq(f(x), x)) == (False, 1)
|
||||
sol1 = Eq(f(x)**5 + 11*f(x) - 2*f(x) + x, 0)
|
||||
assert checkodesol(diff(sol1.lhs, x), sol1) == (True, 0)
|
||||
assert checkodesol(diff(sol1.lhs, x)*exp(f(x)), sol1) == (True, 0)
|
||||
assert checkodesol(diff(sol1.lhs, x, 2), sol1) == (True, 0)
|
||||
assert checkodesol(diff(sol1.lhs, x, 2)*exp(f(x)), sol1) == (True, 0)
|
||||
assert checkodesol(diff(sol1.lhs, x, 3), sol1) == (True, 0)
|
||||
assert checkodesol(diff(sol1.lhs, x, 3)*exp(f(x)), sol1) == (True, 0)
|
||||
assert checkodesol(diff(sol1.lhs, x, 3), Eq(f(x), x*log(x))) == \
|
||||
(False, 60*x**4*((log(x) + 1)**2 + log(x))*(
|
||||
log(x) + 1)*log(x)**2 - 5*x**4*log(x)**4 - 9)
|
||||
assert checkodesol(diff(exp(f(x)) + x, x)*x, Eq(exp(f(x)) + x, 0)) == \
|
||||
(True, 0)
|
||||
assert checkodesol(diff(exp(f(x)) + x, x)*x, Eq(exp(f(x)) + x, 0),
|
||||
solve_for_func=False) == (True, 0)
|
||||
assert checkodesol(f(x).diff(x, 2), [Eq(f(x), C1 + C2*x),
|
||||
Eq(f(x), C2 + C1*x), Eq(f(x), C1*x + C2*x**2)]) == \
|
||||
[(True, 0), (True, 0), (False, C2)]
|
||||
assert checkodesol(f(x).diff(x, 2), {Eq(f(x), C1 + C2*x),
|
||||
Eq(f(x), C2 + C1*x), Eq(f(x), C1*x + C2*x**2)}) == \
|
||||
{(True, 0), (True, 0), (False, C2)}
|
||||
assert checkodesol(f(x).diff(x) - 1/f(x)/2, Eq(f(x)**2, x)) == \
|
||||
[(True, 0), (True, 0)]
|
||||
assert checkodesol(f(x).diff(x) - f(x), Eq(C1*exp(x), f(x))) == (True, 0)
|
||||
# Based on test_1st_homogeneous_coeff_ode2_eq3sol. Make sure that
|
||||
# checkodesol tries back substituting f(x) when it can.
|
||||
eq3 = x*exp(f(x)/x) + f(x) - x*f(x).diff(x)
|
||||
sol3 = Eq(f(x), log(log(C1/x)**(-x)))
|
||||
assert not checkodesol(eq3, sol3)[1].has(f(x))
|
||||
# This case was failing intermittently depending on hash-seed:
|
||||
eqn = Eq(Derivative(x*Derivative(f(x), x), x)/x, exp(x))
|
||||
sol = Eq(f(x), C1 + C2*log(x) + exp(x) - Ei(x))
|
||||
assert checkodesol(eqn, sol, order=2, solve_for_func=False)[0]
|
||||
eq = x**2*(f(x).diff(x, 2)) + x*(f(x).diff(x)) + (2*x**2 +25)*f(x)
|
||||
sol = Eq(f(x), C1*besselj(5*I, sqrt(2)*x) + C2*bessely(5*I, sqrt(2)*x))
|
||||
assert checkodesol(eq, sol) == (True, 0)
|
||||
|
||||
eqs = [Eq(f(x).diff(x), f(x) + g(x)), Eq(g(x).diff(x), f(x) + g(x))]
|
||||
sol = [Eq(f(x), -C1 + C2*exp(2*x)), Eq(g(x), C1 + C2*exp(2*x))]
|
||||
assert checkodesol(eqs, sol) == (True, [0, 0])
|
||||
|
||||
|
||||
def test_checksysodesol():
|
||||
x, y, z = symbols('x, y, z', cls=Function)
|
||||
t = Symbol('t')
|
||||
eq = (Eq(diff(x(t),t), 9*y(t)), Eq(diff(y(t),t), 12*x(t)))
|
||||
sol = [Eq(x(t), 9*C1*exp(-6*sqrt(3)*t) + 9*C2*exp(6*sqrt(3)*t)), \
|
||||
Eq(y(t), -6*sqrt(3)*C1*exp(-6*sqrt(3)*t) + 6*sqrt(3)*C2*exp(6*sqrt(3)*t))]
|
||||
assert checksysodesol(eq, sol) == (True, [0, 0])
|
||||
|
||||
eq = (Eq(diff(x(t),t), 2*x(t) + 4*y(t)), Eq(diff(y(t),t), 12*x(t) + 41*y(t)))
|
||||
sol = [Eq(x(t), 4*C1*exp(t*(-sqrt(1713)/2 + Rational(43, 2))) + 4*C2*exp(t*(sqrt(1713)/2 + \
|
||||
Rational(43, 2)))), Eq(y(t), C1*(-sqrt(1713)/2 + Rational(39, 2))*exp(t*(-sqrt(1713)/2 + \
|
||||
Rational(43, 2))) + C2*(Rational(39, 2) + sqrt(1713)/2)*exp(t*(sqrt(1713)/2 + Rational(43, 2))))]
|
||||
assert checksysodesol(eq, sol) == (True, [0, 0])
|
||||
|
||||
eq = (Eq(diff(x(t),t), x(t) + y(t)), Eq(diff(y(t),t), -2*x(t) + 2*y(t)))
|
||||
sol = [Eq(x(t), (C1*sin(sqrt(7)*t/2) + C2*cos(sqrt(7)*t/2))*exp(t*Rational(3, 2))), \
|
||||
Eq(y(t), ((C1/2 - sqrt(7)*C2/2)*sin(sqrt(7)*t/2) + (sqrt(7)*C1/2 + \
|
||||
C2/2)*cos(sqrt(7)*t/2))*exp(t*Rational(3, 2)))]
|
||||
assert checksysodesol(eq, sol) == (True, [0, 0])
|
||||
|
||||
eq = (Eq(diff(x(t),t), x(t) + y(t) + 9), Eq(diff(y(t),t), 2*x(t) + 5*y(t) + 23))
|
||||
sol = [Eq(x(t), C1*exp(t*(-sqrt(6) + 3)) + C2*exp(t*(sqrt(6) + 3)) - \
|
||||
Rational(22, 3)), Eq(y(t), C1*(-sqrt(6) + 2)*exp(t*(-sqrt(6) + 3)) + C2*(2 + \
|
||||
sqrt(6))*exp(t*(sqrt(6) + 3)) - Rational(5, 3))]
|
||||
assert checksysodesol(eq, sol) == (True, [0, 0])
|
||||
|
||||
eq = (Eq(diff(x(t),t), x(t) + y(t) + 81), Eq(diff(y(t),t), -2*x(t) + y(t) + 23))
|
||||
sol = [Eq(x(t), (C1*sin(sqrt(2)*t) + C2*cos(sqrt(2)*t))*exp(t) - Rational(58, 3)), \
|
||||
Eq(y(t), (sqrt(2)*C1*cos(sqrt(2)*t) - sqrt(2)*C2*sin(sqrt(2)*t))*exp(t) - Rational(185, 3))]
|
||||
assert checksysodesol(eq, sol) == (True, [0, 0])
|
||||
|
||||
eq = (Eq(diff(x(t),t), 5*t*x(t) + 2*y(t)), Eq(diff(y(t),t), 2*x(t) + 5*t*y(t)))
|
||||
sol = [Eq(x(t), (C1*exp(Integral(2, t).doit()) + C2*exp(-(Integral(2, t)).doit()))*\
|
||||
exp((Integral(5*t, t)).doit())), Eq(y(t), (C1*exp((Integral(2, t)).doit()) - \
|
||||
C2*exp(-(Integral(2, t)).doit()))*exp((Integral(5*t, t)).doit()))]
|
||||
assert checksysodesol(eq, sol) == (True, [0, 0])
|
||||
|
||||
eq = (Eq(diff(x(t),t), 5*t*x(t) + t**2*y(t)), Eq(diff(y(t),t), -t**2*x(t) + 5*t*y(t)))
|
||||
sol = [Eq(x(t), (C1*cos((Integral(t**2, t)).doit()) + C2*sin((Integral(t**2, t)).doit()))*\
|
||||
exp((Integral(5*t, t)).doit())), Eq(y(t), (-C1*sin((Integral(t**2, t)).doit()) + \
|
||||
C2*cos((Integral(t**2, t)).doit()))*exp((Integral(5*t, t)).doit()))]
|
||||
assert checksysodesol(eq, sol) == (True, [0, 0])
|
||||
|
||||
eq = (Eq(diff(x(t),t), 5*t*x(t) + t**2*y(t)), Eq(diff(y(t),t), -t**2*x(t) + (5*t+9*t**2)*y(t)))
|
||||
sol = [Eq(x(t), (C1*exp((-sqrt(77)/2 + Rational(9, 2))*(Integral(t**2, t)).doit()) + \
|
||||
C2*exp((sqrt(77)/2 + Rational(9, 2))*(Integral(t**2, t)).doit()))*exp((Integral(5*t, t)).doit())), \
|
||||
Eq(y(t), (C1*(-sqrt(77)/2 + Rational(9, 2))*exp((-sqrt(77)/2 + Rational(9, 2))*(Integral(t**2, t)).doit()) + \
|
||||
C2*(sqrt(77)/2 + Rational(9, 2))*exp((sqrt(77)/2 + Rational(9, 2))*(Integral(t**2, t)).doit()))*exp((Integral(5*t, t)).doit()))]
|
||||
assert checksysodesol(eq, sol) == (True, [0, 0])
|
||||
|
||||
eq = (Eq(diff(x(t),t,t), 5*x(t) + 43*y(t)), Eq(diff(y(t),t,t), x(t) + 9*y(t)))
|
||||
root0 = -sqrt(-sqrt(47) + 7)
|
||||
root1 = sqrt(-sqrt(47) + 7)
|
||||
root2 = -sqrt(sqrt(47) + 7)
|
||||
root3 = sqrt(sqrt(47) + 7)
|
||||
sol = [Eq(x(t), 43*C1*exp(t*root0) + 43*C2*exp(t*root1) + 43*C3*exp(t*root2) + 43*C4*exp(t*root3)), \
|
||||
Eq(y(t), C1*(root0**2 - 5)*exp(t*root0) + C2*(root1**2 - 5)*exp(t*root1) + \
|
||||
C3*(root2**2 - 5)*exp(t*root2) + C4*(root3**2 - 5)*exp(t*root3))]
|
||||
assert checksysodesol(eq, sol) == (True, [0, 0])
|
||||
|
||||
eq = (Eq(diff(x(t),t,t), 8*x(t)+3*y(t)+31), Eq(diff(y(t),t,t), 9*x(t)+7*y(t)+12))
|
||||
root0 = -sqrt(-sqrt(109)/2 + Rational(15, 2))
|
||||
root1 = sqrt(-sqrt(109)/2 + Rational(15, 2))
|
||||
root2 = -sqrt(sqrt(109)/2 + Rational(15, 2))
|
||||
root3 = sqrt(sqrt(109)/2 + Rational(15, 2))
|
||||
sol = [Eq(x(t), 3*C1*exp(t*root0) + 3*C2*exp(t*root1) + 3*C3*exp(t*root2) + 3*C4*exp(t*root3) - Rational(181, 29)), \
|
||||
Eq(y(t), C1*(root0**2 - 8)*exp(t*root0) + C2*(root1**2 - 8)*exp(t*root1) + \
|
||||
C3*(root2**2 - 8)*exp(t*root2) + C4*(root3**2 - 8)*exp(t*root3) + Rational(183, 29))]
|
||||
assert checksysodesol(eq, sol) == (True, [0, 0])
|
||||
|
||||
eq = (Eq(diff(x(t),t,t) - 9*diff(y(t),t) + 7*x(t),0), Eq(diff(y(t),t,t) + 9*diff(x(t),t) + 7*y(t),0))
|
||||
sol = [Eq(x(t), C1*cos(t*(Rational(9, 2) + sqrt(109)/2)) + C2*sin(t*(Rational(9, 2) + sqrt(109)/2)) + \
|
||||
C3*cos(t*(-sqrt(109)/2 + Rational(9, 2))) + C4*sin(t*(-sqrt(109)/2 + Rational(9, 2)))), Eq(y(t), -C1*sin(t*(Rational(9, 2) + sqrt(109)/2)) \
|
||||
+ C2*cos(t*(Rational(9, 2) + sqrt(109)/2)) - C3*sin(t*(-sqrt(109)/2 + Rational(9, 2))) + C4*cos(t*(-sqrt(109)/2 + Rational(9, 2))))]
|
||||
assert checksysodesol(eq, sol) == (True, [0, 0])
|
||||
|
||||
eq = (Eq(diff(x(t),t,t), 9*t*diff(y(t),t)-9*y(t)), Eq(diff(y(t),t,t),7*t*diff(x(t),t)-7*x(t)))
|
||||
I1 = sqrt(6)*7**Rational(1, 4)*sqrt(pi)*erfi(sqrt(6)*7**Rational(1, 4)*t/2)/2 - exp(3*sqrt(7)*t**2/2)/t
|
||||
I2 = -sqrt(6)*7**Rational(1, 4)*sqrt(pi)*erf(sqrt(6)*7**Rational(1, 4)*t/2)/2 - exp(-3*sqrt(7)*t**2/2)/t
|
||||
sol = [Eq(x(t), C3*t + t*(9*C1*I1 + 9*C2*I2)), Eq(y(t), C4*t + t*(3*sqrt(7)*C1*I1 - 3*sqrt(7)*C2*I2))]
|
||||
assert checksysodesol(eq, sol) == (True, [0, 0])
|
||||
|
||||
eq = (Eq(diff(x(t),t), 21*x(t)), Eq(diff(y(t),t), 17*x(t)+3*y(t)), Eq(diff(z(t),t), 5*x(t)+7*y(t)+9*z(t)))
|
||||
sol = [Eq(x(t), C1*exp(21*t)), Eq(y(t), 17*C1*exp(21*t)/18 + C2*exp(3*t)), \
|
||||
Eq(z(t), 209*C1*exp(21*t)/216 - 7*C2*exp(3*t)/6 + C3*exp(9*t))]
|
||||
assert checksysodesol(eq, sol) == (True, [0, 0, 0])
|
||||
|
||||
eq = (Eq(diff(x(t),t),3*y(t)-11*z(t)),Eq(diff(y(t),t),7*z(t)-3*x(t)),Eq(diff(z(t),t),11*x(t)-7*y(t)))
|
||||
sol = [Eq(x(t), 7*C0 + sqrt(179)*C1*cos(sqrt(179)*t) + (77*C1/3 + 130*C2/3)*sin(sqrt(179)*t)), \
|
||||
Eq(y(t), 11*C0 + sqrt(179)*C2*cos(sqrt(179)*t) + (-58*C1/3 - 77*C2/3)*sin(sqrt(179)*t)), \
|
||||
Eq(z(t), 3*C0 + sqrt(179)*(-7*C1/3 - 11*C2/3)*cos(sqrt(179)*t) + (11*C1 - 7*C2)*sin(sqrt(179)*t))]
|
||||
assert checksysodesol(eq, sol) == (True, [0, 0, 0])
|
||||
|
||||
eq = (Eq(3*diff(x(t),t),4*5*(y(t)-z(t))),Eq(4*diff(y(t),t),3*5*(z(t)-x(t))),Eq(5*diff(z(t),t),3*4*(x(t)-y(t))))
|
||||
sol = [Eq(x(t), C0 + 5*sqrt(2)*C1*cos(5*sqrt(2)*t) + (12*C1/5 + 164*C2/15)*sin(5*sqrt(2)*t)), \
|
||||
Eq(y(t), C0 + 5*sqrt(2)*C2*cos(5*sqrt(2)*t) + (-51*C1/10 - 12*C2/5)*sin(5*sqrt(2)*t)), \
|
||||
Eq(z(t), C0 + 5*sqrt(2)*(-9*C1/25 - 16*C2/25)*cos(5*sqrt(2)*t) + (12*C1/5 - 12*C2/5)*sin(5*sqrt(2)*t))]
|
||||
assert checksysodesol(eq, sol) == (True, [0, 0, 0])
|
||||
|
||||
eq = (Eq(diff(x(t),t),4*x(t) - z(t)),Eq(diff(y(t),t),2*x(t)+2*y(t)-z(t)),Eq(diff(z(t),t),3*x(t)+y(t)))
|
||||
sol = [Eq(x(t), C1*exp(2*t) + C2*t*exp(2*t) + C2*exp(2*t) + C3*t**2*exp(2*t)/2 + C3*t*exp(2*t) + C3*exp(2*t)), \
|
||||
Eq(y(t), C1*exp(2*t) + C2*t*exp(2*t) + C2*exp(2*t) + C3*t**2*exp(2*t)/2 + C3*t*exp(2*t)), \
|
||||
Eq(z(t), 2*C1*exp(2*t) + 2*C2*t*exp(2*t) + C2*exp(2*t) + C3*t**2*exp(2*t) + C3*t*exp(2*t) + C3*exp(2*t))]
|
||||
assert checksysodesol(eq, sol) == (True, [0, 0, 0])
|
||||
|
||||
eq = (Eq(diff(x(t),t),4*x(t) - y(t) - 2*z(t)),Eq(diff(y(t),t),2*x(t) + y(t)- 2*z(t)),Eq(diff(z(t),t),5*x(t)-3*z(t)))
|
||||
sol = [Eq(x(t), C1*exp(2*t) + C2*(-sin(t) + 3*cos(t)) + C3*(3*sin(t) + cos(t))), \
|
||||
Eq(y(t), C2*(-sin(t) + 3*cos(t)) + C3*(3*sin(t) + cos(t))), Eq(z(t), C1*exp(2*t) + 5*C2*cos(t) + 5*C3*sin(t))]
|
||||
assert checksysodesol(eq, sol) == (True, [0, 0, 0])
|
||||
|
||||
eq = (Eq(diff(x(t),t),x(t)*y(t)**3), Eq(diff(y(t),t),y(t)**5))
|
||||
sol = [Eq(x(t), C1*exp((-1/(4*C2 + 4*t))**(Rational(-1, 4)))), Eq(y(t), -(-1/(4*C2 + 4*t))**Rational(1, 4)), \
|
||||
Eq(x(t), C1*exp(-1/(-1/(4*C2 + 4*t))**Rational(1, 4))), Eq(y(t), (-1/(4*C2 + 4*t))**Rational(1, 4)), \
|
||||
Eq(x(t), C1*exp(-I/(-1/(4*C2 + 4*t))**Rational(1, 4))), Eq(y(t), -I*(-1/(4*C2 + 4*t))**Rational(1, 4)), \
|
||||
Eq(x(t), C1*exp(I/(-1/(4*C2 + 4*t))**Rational(1, 4))), Eq(y(t), I*(-1/(4*C2 + 4*t))**Rational(1, 4))]
|
||||
assert checksysodesol(eq, sol) == (True, [0, 0])
|
||||
|
||||
eq = (Eq(diff(x(t),t), exp(3*x(t))*y(t)**3),Eq(diff(y(t),t), y(t)**5))
|
||||
sol = [Eq(x(t), -log(C1 - 3/(-1/(4*C2 + 4*t))**Rational(1, 4))/3), Eq(y(t), -(-1/(4*C2 + 4*t))**Rational(1, 4)), \
|
||||
Eq(x(t), -log(C1 + 3/(-1/(4*C2 + 4*t))**Rational(1, 4))/3), Eq(y(t), (-1/(4*C2 + 4*t))**Rational(1, 4)), \
|
||||
Eq(x(t), -log(C1 + 3*I/(-1/(4*C2 + 4*t))**Rational(1, 4))/3), Eq(y(t), -I*(-1/(4*C2 + 4*t))**Rational(1, 4)), \
|
||||
Eq(x(t), -log(C1 - 3*I/(-1/(4*C2 + 4*t))**Rational(1, 4))/3), Eq(y(t), I*(-1/(4*C2 + 4*t))**Rational(1, 4))]
|
||||
assert checksysodesol(eq, sol) == (True, [0, 0])
|
||||
|
||||
eq = (Eq(x(t),t*diff(x(t),t)+diff(x(t),t)*diff(y(t),t)), Eq(y(t),t*diff(y(t),t)+diff(y(t),t)**2))
|
||||
sol = {Eq(x(t), C1*C2 + C1*t), Eq(y(t), C2**2 + C2*t)}
|
||||
assert checksysodesol(eq, sol) == (True, [0, 0])
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,966 @@
|
||||
"""
|
||||
This module contains pdsolve() and different helper functions that it
|
||||
uses. It is heavily inspired by the ode module and hence the basic
|
||||
infrastructure remains the same.
|
||||
|
||||
**Functions in this module**
|
||||
|
||||
These are the user functions in this module:
|
||||
|
||||
- pdsolve() - Solves PDE's
|
||||
- classify_pde() - Classifies PDEs into possible hints for dsolve().
|
||||
- pde_separate() - Separate variables in partial differential equation either by
|
||||
additive or multiplicative separation approach.
|
||||
|
||||
These are the helper functions in this module:
|
||||
|
||||
- pde_separate_add() - Helper function for searching additive separable solutions.
|
||||
- pde_separate_mul() - Helper function for searching multiplicative
|
||||
separable solutions.
|
||||
|
||||
**Currently implemented solver methods**
|
||||
|
||||
The following methods are implemented for solving partial differential
|
||||
equations. See the docstrings of the various pde_hint() functions for
|
||||
more information on each (run help(pde)):
|
||||
|
||||
- 1st order linear homogeneous partial differential equations
|
||||
with constant coefficients.
|
||||
- 1st order linear general partial differential equations
|
||||
with constant coefficients.
|
||||
- 1st order linear partial differential equations with
|
||||
variable coefficients.
|
||||
|
||||
"""
|
||||
from functools import reduce
|
||||
|
||||
from itertools import combinations_with_replacement
|
||||
from sympy.simplify import simplify # type: ignore
|
||||
from sympy.core import Add, S
|
||||
from sympy.core.function import Function, expand, AppliedUndef, Subs
|
||||
from sympy.core.relational import Equality, Eq
|
||||
from sympy.core.symbol import Symbol, Wild, symbols
|
||||
from sympy.functions import exp
|
||||
from sympy.integrals.integrals import Integral, integrate
|
||||
from sympy.utilities.iterables import has_dups, is_sequence
|
||||
from sympy.utilities.misc import filldedent
|
||||
|
||||
from sympy.solvers.deutils import _preprocess, ode_order, _desolve
|
||||
from sympy.solvers.solvers import solve
|
||||
from sympy.simplify.radsimp import collect
|
||||
|
||||
import operator
|
||||
|
||||
|
||||
allhints = (
|
||||
"1st_linear_constant_coeff_homogeneous",
|
||||
"1st_linear_constant_coeff",
|
||||
"1st_linear_constant_coeff_Integral",
|
||||
"1st_linear_variable_coeff"
|
||||
)
|
||||
|
||||
|
||||
def pdsolve(eq, func=None, hint='default', dict=False, solvefun=None, **kwargs):
|
||||
"""
|
||||
Solves any (supported) kind of partial differential equation.
|
||||
|
||||
**Usage**
|
||||
|
||||
pdsolve(eq, f(x,y), hint) -> Solve partial differential equation
|
||||
eq for function f(x,y), using method hint.
|
||||
|
||||
**Details**
|
||||
|
||||
``eq`` can be any supported partial differential equation (see
|
||||
the pde docstring for supported methods). This can either
|
||||
be an Equality, or an expression, which is assumed to be
|
||||
equal to 0.
|
||||
|
||||
``f(x,y)`` is a function of two variables whose derivatives in that
|
||||
variable make up the partial differential equation. In many
|
||||
cases it is not necessary to provide this; it will be autodetected
|
||||
(and an error raised if it could not be detected).
|
||||
|
||||
``hint`` is the solving method that you want pdsolve to use. Use
|
||||
classify_pde(eq, f(x,y)) to get all of the possible hints for
|
||||
a PDE. The default hint, 'default', will use whatever hint
|
||||
is returned first by classify_pde(). See Hints below for
|
||||
more options that you can use for hint.
|
||||
|
||||
``solvefun`` is the convention used for arbitrary functions returned
|
||||
by the PDE solver. If not set by the user, it is set by default
|
||||
to be F.
|
||||
|
||||
**Hints**
|
||||
|
||||
Aside from the various solving methods, there are also some
|
||||
meta-hints that you can pass to pdsolve():
|
||||
|
||||
"default":
|
||||
This uses whatever hint is returned first by
|
||||
classify_pde(). This is the default argument to
|
||||
pdsolve().
|
||||
|
||||
"all":
|
||||
To make pdsolve apply all relevant classification hints,
|
||||
use pdsolve(PDE, func, hint="all"). This will return a
|
||||
dictionary of hint:solution terms. If a hint causes
|
||||
pdsolve to raise the NotImplementedError, value of that
|
||||
hint's key will be the exception object raised. The
|
||||
dictionary will also include some special keys:
|
||||
|
||||
- order: The order of the PDE. See also ode_order() in
|
||||
deutils.py
|
||||
- default: The solution that would be returned by
|
||||
default. This is the one produced by the hint that
|
||||
appears first in the tuple returned by classify_pde().
|
||||
|
||||
"all_Integral":
|
||||
This is the same as "all", except if a hint also has a
|
||||
corresponding "_Integral" hint, it only returns the
|
||||
"_Integral" hint. This is useful if "all" causes
|
||||
pdsolve() to hang because of a difficult or impossible
|
||||
integral. This meta-hint will also be much faster than
|
||||
"all", because integrate() is an expensive routine.
|
||||
|
||||
See also the classify_pde() docstring for more info on hints,
|
||||
and the pde docstring for a list of all supported hints.
|
||||
|
||||
**Tips**
|
||||
- You can declare the derivative of an unknown function this way:
|
||||
|
||||
>>> from sympy import Function, Derivative
|
||||
>>> from sympy.abc import x, y # x and y are the independent variables
|
||||
>>> f = Function("f")(x, y) # f is a function of x and y
|
||||
>>> # fx will be the partial derivative of f with respect to x
|
||||
>>> fx = Derivative(f, x)
|
||||
>>> # fy will be the partial derivative of f with respect to y
|
||||
>>> fy = Derivative(f, y)
|
||||
|
||||
- See test_pde.py for many tests, which serves also as a set of
|
||||
examples for how to use pdsolve().
|
||||
- pdsolve always returns an Equality class (except for the case
|
||||
when the hint is "all" or "all_Integral"). Note that it is not possible
|
||||
to get an explicit solution for f(x, y) as in the case of ODE's
|
||||
- Do help(pde.pde_hintname) to get help more information on a
|
||||
specific hint
|
||||
|
||||
|
||||
Examples
|
||||
========
|
||||
|
||||
>>> from sympy.solvers.pde import pdsolve
|
||||
>>> from sympy import Function, Eq
|
||||
>>> from sympy.abc import x, y
|
||||
>>> f = Function('f')
|
||||
>>> u = f(x, y)
|
||||
>>> ux = u.diff(x)
|
||||
>>> uy = u.diff(y)
|
||||
>>> eq = Eq(1 + (2*(ux/u)) + (3*(uy/u)), 0)
|
||||
>>> pdsolve(eq)
|
||||
Eq(f(x, y), F(3*x - 2*y)*exp(-2*x/13 - 3*y/13))
|
||||
|
||||
"""
|
||||
|
||||
if not solvefun:
|
||||
solvefun = Function('F')
|
||||
|
||||
# See the docstring of _desolve for more details.
|
||||
hints = _desolve(eq, func=func, hint=hint, simplify=True,
|
||||
type='pde', **kwargs)
|
||||
eq = hints.pop('eq', False)
|
||||
all_ = hints.pop('all', False)
|
||||
|
||||
if all_:
|
||||
# TODO : 'best' hint should be implemented when adequate
|
||||
# number of hints are added.
|
||||
pdedict = {}
|
||||
failed_hints = {}
|
||||
gethints = classify_pde(eq, dict=True)
|
||||
pdedict.update({'order': gethints['order'],
|
||||
'default': gethints['default']})
|
||||
for hint in hints:
|
||||
try:
|
||||
rv = _helper_simplify(eq, hint, hints[hint]['func'],
|
||||
hints[hint]['order'], hints[hint][hint], solvefun)
|
||||
except NotImplementedError as detail:
|
||||
failed_hints[hint] = detail
|
||||
else:
|
||||
pdedict[hint] = rv
|
||||
pdedict.update(failed_hints)
|
||||
return pdedict
|
||||
|
||||
else:
|
||||
return _helper_simplify(eq, hints['hint'], hints['func'],
|
||||
hints['order'], hints[hints['hint']], solvefun)
|
||||
|
||||
|
||||
def _helper_simplify(eq, hint, func, order, match, solvefun):
|
||||
"""Helper function of pdsolve that calls the respective
|
||||
pde functions to solve for the partial differential
|
||||
equations. This minimizes the computation in
|
||||
calling _desolve multiple times.
|
||||
"""
|
||||
solvefunc = globals()["pde_" + hint.removesuffix("_Integral")]
|
||||
return _handle_Integral(solvefunc(eq, func, order,
|
||||
match, solvefun), func, order, hint)
|
||||
|
||||
|
||||
def _handle_Integral(expr, func, order, hint):
|
||||
r"""
|
||||
Converts a solution with integrals in it into an actual solution.
|
||||
|
||||
Simplifies the integral mainly using doit()
|
||||
"""
|
||||
if hint.endswith("_Integral"):
|
||||
return expr
|
||||
|
||||
elif hint == "1st_linear_constant_coeff":
|
||||
return simplify(expr.doit())
|
||||
|
||||
else:
|
||||
return expr
|
||||
|
||||
|
||||
def classify_pde(eq, func=None, dict=False, *, prep=True, **kwargs):
|
||||
"""
|
||||
Returns a tuple of possible pdsolve() classifications for a PDE.
|
||||
|
||||
The tuple is ordered so that first item is the classification that
|
||||
pdsolve() uses to solve the PDE by default. In general,
|
||||
classifications near the beginning of the list will produce
|
||||
better solutions faster than those near the end, though there are
|
||||
always exceptions. To make pdsolve use a different classification,
|
||||
use pdsolve(PDE, func, hint=<classification>). See also the pdsolve()
|
||||
docstring for different meta-hints you can use.
|
||||
|
||||
If ``dict`` is true, classify_pde() will return a dictionary of
|
||||
hint:match expression terms. This is intended for internal use by
|
||||
pdsolve(). Note that because dictionaries are ordered arbitrarily,
|
||||
this will most likely not be in the same order as the tuple.
|
||||
|
||||
You can get help on different hints by doing help(pde.pde_hintname),
|
||||
where hintname is the name of the hint without "_Integral".
|
||||
|
||||
See sympy.pde.allhints or the sympy.pde docstring for a list of all
|
||||
supported hints that can be returned from classify_pde.
|
||||
|
||||
|
||||
Examples
|
||||
========
|
||||
|
||||
>>> from sympy.solvers.pde import classify_pde
|
||||
>>> from sympy import Function, Eq
|
||||
>>> from sympy.abc import x, y
|
||||
>>> f = Function('f')
|
||||
>>> u = f(x, y)
|
||||
>>> ux = u.diff(x)
|
||||
>>> uy = u.diff(y)
|
||||
>>> eq = Eq(1 + (2*(ux/u)) + (3*(uy/u)), 0)
|
||||
>>> classify_pde(eq)
|
||||
('1st_linear_constant_coeff_homogeneous',)
|
||||
"""
|
||||
|
||||
if func and len(func.args) != 2:
|
||||
raise NotImplementedError("Right now only partial "
|
||||
"differential equations of two variables are supported")
|
||||
|
||||
if prep or func is None:
|
||||
prep, func_ = _preprocess(eq, func)
|
||||
if func is None:
|
||||
func = func_
|
||||
|
||||
if isinstance(eq, Equality):
|
||||
if eq.rhs != 0:
|
||||
return classify_pde(eq.lhs - eq.rhs, func)
|
||||
eq = eq.lhs
|
||||
|
||||
f = func.func
|
||||
x = func.args[0]
|
||||
y = func.args[1]
|
||||
fx = f(x,y).diff(x)
|
||||
fy = f(x,y).diff(y)
|
||||
|
||||
# TODO : For now pde.py uses support offered by the ode_order function
|
||||
# to find the order with respect to a multi-variable function. An
|
||||
# improvement could be to classify the order of the PDE on the basis of
|
||||
# individual variables.
|
||||
order = ode_order(eq, f(x,y))
|
||||
|
||||
# hint:matchdict or hint:(tuple of matchdicts)
|
||||
# Also will contain "default":<default hint> and "order":order items.
|
||||
matching_hints = {'order': order}
|
||||
|
||||
if not order:
|
||||
if dict:
|
||||
matching_hints["default"] = None
|
||||
return matching_hints
|
||||
return ()
|
||||
|
||||
eq = expand(eq)
|
||||
|
||||
a = Wild('a', exclude = [f(x,y)])
|
||||
b = Wild('b', exclude = [f(x,y), fx, fy, x, y])
|
||||
c = Wild('c', exclude = [f(x,y), fx, fy, x, y])
|
||||
d = Wild('d', exclude = [f(x,y), fx, fy, x, y])
|
||||
e = Wild('e', exclude = [f(x,y), fx, fy])
|
||||
n = Wild('n', exclude = [x, y])
|
||||
# Try removing the smallest power of f(x,y)
|
||||
# from the highest partial derivatives of f(x,y)
|
||||
reduced_eq = eq
|
||||
if eq.is_Add:
|
||||
power = None
|
||||
for i in set(combinations_with_replacement((x,y), order)):
|
||||
coeff = eq.coeff(f(x,y).diff(*i))
|
||||
if coeff == 1:
|
||||
continue
|
||||
match = coeff.match(a*f(x,y)**n)
|
||||
if match and match[a]:
|
||||
if power is None or match[n] < power:
|
||||
power = match[n]
|
||||
if power:
|
||||
den = f(x,y)**power
|
||||
reduced_eq = Add(*[arg/den for arg in eq.args])
|
||||
|
||||
if order == 1:
|
||||
reduced_eq = collect(reduced_eq, f(x, y))
|
||||
r = reduced_eq.match(b*fx + c*fy + d*f(x,y) + e)
|
||||
if r:
|
||||
if not r[e]:
|
||||
## Linear first-order homogeneous partial-differential
|
||||
## equation with constant coefficients
|
||||
r.update({'b': b, 'c': c, 'd': d})
|
||||
matching_hints["1st_linear_constant_coeff_homogeneous"] = r
|
||||
elif r[b]**2 + r[c]**2 != 0:
|
||||
## Linear first-order general partial-differential
|
||||
## equation with constant coefficients
|
||||
r.update({'b': b, 'c': c, 'd': d, 'e': e})
|
||||
matching_hints["1st_linear_constant_coeff"] = r
|
||||
matching_hints["1st_linear_constant_coeff_Integral"] = r
|
||||
|
||||
else:
|
||||
b = Wild('b', exclude=[f(x, y), fx, fy])
|
||||
c = Wild('c', exclude=[f(x, y), fx, fy])
|
||||
d = Wild('d', exclude=[f(x, y), fx, fy])
|
||||
r = reduced_eq.match(b*fx + c*fy + d*f(x,y) + e)
|
||||
if r:
|
||||
r.update({'b': b, 'c': c, 'd': d, 'e': e})
|
||||
matching_hints["1st_linear_variable_coeff"] = r
|
||||
|
||||
# Order keys based on allhints.
|
||||
rettuple = tuple(i for i in allhints if i in matching_hints)
|
||||
|
||||
if dict:
|
||||
# Dictionaries are ordered arbitrarily, so make note of which
|
||||
# hint would come first for pdsolve(). Use an ordered dict in Py 3.
|
||||
matching_hints["default"] = None
|
||||
matching_hints["ordered_hints"] = rettuple
|
||||
for i in allhints:
|
||||
if i in matching_hints:
|
||||
matching_hints["default"] = i
|
||||
break
|
||||
return matching_hints
|
||||
return rettuple
|
||||
|
||||
|
||||
def checkpdesol(pde, sol, func=None, solve_for_func=True):
|
||||
"""
|
||||
Checks if the given solution satisfies the partial differential
|
||||
equation.
|
||||
|
||||
pde is the partial differential equation which can be given in the
|
||||
form of an equation or an expression. sol is the solution for which
|
||||
the pde is to be checked. This can also be given in an equation or
|
||||
an expression form. If the function is not provided, the helper
|
||||
function _preprocess from deutils is used to identify the function.
|
||||
|
||||
If a sequence of solutions is passed, the same sort of container will be
|
||||
used to return the result for each solution.
|
||||
|
||||
The following methods are currently being implemented to check if the
|
||||
solution satisfies the PDE:
|
||||
|
||||
1. Directly substitute the solution in the PDE and check. If the
|
||||
solution has not been solved for f, then it will solve for f
|
||||
provided solve_for_func has not been set to False.
|
||||
|
||||
If the solution satisfies the PDE, then a tuple (True, 0) is returned.
|
||||
Otherwise a tuple (False, expr) where expr is the value obtained
|
||||
after substituting the solution in the PDE. However if a known solution
|
||||
returns False, it may be due to the inability of doit() to simplify it to zero.
|
||||
|
||||
Examples
|
||||
========
|
||||
|
||||
>>> from sympy import Function, symbols
|
||||
>>> from sympy.solvers.pde import checkpdesol, pdsolve
|
||||
>>> x, y = symbols('x y')
|
||||
>>> f = Function('f')
|
||||
>>> eq = 2*f(x,y) + 3*f(x,y).diff(x) + 4*f(x,y).diff(y)
|
||||
>>> sol = pdsolve(eq)
|
||||
>>> assert checkpdesol(eq, sol)[0]
|
||||
>>> eq = x*f(x,y) + f(x,y).diff(x)
|
||||
>>> checkpdesol(eq, sol)
|
||||
(False, (x*F(4*x - 3*y) - 6*F(4*x - 3*y)/25 + 4*Subs(Derivative(F(_xi_1), _xi_1), _xi_1, 4*x - 3*y))*exp(-6*x/25 - 8*y/25))
|
||||
"""
|
||||
|
||||
# Converting the pde into an equation
|
||||
if not isinstance(pde, Equality):
|
||||
pde = Eq(pde, 0)
|
||||
|
||||
# If no function is given, try finding the function present.
|
||||
if func is None:
|
||||
try:
|
||||
_, func = _preprocess(pde.lhs)
|
||||
except ValueError:
|
||||
funcs = [s.atoms(AppliedUndef) for s in (
|
||||
sol if is_sequence(sol, set) else [sol])]
|
||||
funcs = set().union(funcs)
|
||||
if len(funcs) != 1:
|
||||
raise ValueError(
|
||||
'must pass func arg to checkpdesol for this case.')
|
||||
func = funcs.pop()
|
||||
|
||||
# If the given solution is in the form of a list or a set
|
||||
# then return a list or set of tuples.
|
||||
if is_sequence(sol, set):
|
||||
return type(sol)([checkpdesol(
|
||||
pde, i, func=func,
|
||||
solve_for_func=solve_for_func) for i in sol])
|
||||
|
||||
# Convert solution into an equation
|
||||
if not isinstance(sol, Equality):
|
||||
sol = Eq(func, sol)
|
||||
elif sol.rhs == func:
|
||||
sol = sol.reversed
|
||||
|
||||
# Try solving for the function
|
||||
solved = sol.lhs == func and not sol.rhs.has(func)
|
||||
if solve_for_func and not solved:
|
||||
solved = solve(sol, func)
|
||||
if solved:
|
||||
if len(solved) == 1:
|
||||
return checkpdesol(pde, Eq(func, solved[0]),
|
||||
func=func, solve_for_func=False)
|
||||
else:
|
||||
return checkpdesol(pde, [Eq(func, t) for t in solved],
|
||||
func=func, solve_for_func=False)
|
||||
|
||||
# try direct substitution of the solution into the PDE and simplify
|
||||
if sol.lhs == func:
|
||||
pde = pde.lhs - pde.rhs
|
||||
s = simplify(pde.subs(func, sol.rhs).doit())
|
||||
return s is S.Zero, s
|
||||
|
||||
raise NotImplementedError(filldedent('''
|
||||
Unable to test if %s is a solution to %s.''' % (sol, pde)))
|
||||
|
||||
|
||||
|
||||
def pde_1st_linear_constant_coeff_homogeneous(eq, func, order, match, solvefun):
|
||||
r"""
|
||||
Solves a first order linear homogeneous
|
||||
partial differential equation with constant coefficients.
|
||||
|
||||
The general form of this partial differential equation is
|
||||
|
||||
.. math:: a \frac{\partial f(x,y)}{\partial x}
|
||||
+ b \frac{\partial f(x,y)}{\partial y} + c f(x,y) = 0
|
||||
|
||||
where `a`, `b` and `c` are constants.
|
||||
|
||||
The general solution is of the form:
|
||||
|
||||
.. math::
|
||||
f(x, y) = F(- a y + b x ) e^{- \frac{c (a x + b y)}{a^2 + b^2}}
|
||||
|
||||
and can be found in SymPy with ``pdsolve``::
|
||||
|
||||
>>> from sympy.solvers import pdsolve
|
||||
>>> from sympy.abc import x, y, a, b, c
|
||||
>>> from sympy import Function, pprint
|
||||
>>> f = Function('f')
|
||||
>>> u = f(x,y)
|
||||
>>> ux = u.diff(x)
|
||||
>>> uy = u.diff(y)
|
||||
>>> genform = a*ux + b*uy + c*u
|
||||
>>> pprint(genform)
|
||||
d d
|
||||
a*--(f(x, y)) + b*--(f(x, y)) + c*f(x, y)
|
||||
dx dy
|
||||
|
||||
>>> pprint(pdsolve(genform))
|
||||
-c*(a*x + b*y)
|
||||
---------------
|
||||
2 2
|
||||
a + b
|
||||
f(x, y) = F(-a*y + b*x)*e
|
||||
|
||||
Examples
|
||||
========
|
||||
|
||||
>>> from sympy import pdsolve
|
||||
>>> from sympy import Function, pprint
|
||||
>>> from sympy.abc import x,y
|
||||
>>> f = Function('f')
|
||||
>>> pdsolve(f(x,y) + f(x,y).diff(x) + f(x,y).diff(y))
|
||||
Eq(f(x, y), F(x - y)*exp(-x/2 - y/2))
|
||||
>>> pprint(pdsolve(f(x,y) + f(x,y).diff(x) + f(x,y).diff(y)))
|
||||
x y
|
||||
- - - -
|
||||
2 2
|
||||
f(x, y) = F(x - y)*e
|
||||
|
||||
References
|
||||
==========
|
||||
|
||||
- Viktor Grigoryan, "Partial Differential Equations"
|
||||
Math 124A - Fall 2010, pp.7
|
||||
|
||||
"""
|
||||
# TODO : For now homogeneous first order linear PDE's having
|
||||
# two variables are implemented. Once there is support for
|
||||
# solving systems of ODE's, this can be extended to n variables.
|
||||
|
||||
f = func.func
|
||||
x = func.args[0]
|
||||
y = func.args[1]
|
||||
b = match[match['b']]
|
||||
c = match[match['c']]
|
||||
d = match[match['d']]
|
||||
return Eq(f(x,y), exp(-S(d)/(b**2 + c**2)*(b*x + c*y))*solvefun(c*x - b*y))
|
||||
|
||||
|
||||
def pde_1st_linear_constant_coeff(eq, func, order, match, solvefun):
|
||||
r"""
|
||||
Solves a first order linear partial differential equation
|
||||
with constant coefficients.
|
||||
|
||||
The general form of this partial differential equation is
|
||||
|
||||
.. math:: a \frac{\partial f(x,y)}{\partial x}
|
||||
+ b \frac{\partial f(x,y)}{\partial y}
|
||||
+ c f(x,y) = G(x,y)
|
||||
|
||||
where `a`, `b` and `c` are constants and `G(x, y)` can be an arbitrary
|
||||
function in `x` and `y`.
|
||||
|
||||
The general solution of the PDE is:
|
||||
|
||||
.. math::
|
||||
f(x, y) = \left. \left[F(\eta) + \frac{1}{a^2 + b^2}
|
||||
\int\limits^{a x + b y} G\left(\frac{a \xi + b \eta}{a^2 + b^2},
|
||||
\frac{- a \eta + b \xi}{a^2 + b^2} \right)
|
||||
e^{\frac{c \xi}{a^2 + b^2}}\, d\xi\right]
|
||||
e^{- \frac{c \xi}{a^2 + b^2}}
|
||||
\right|_{\substack{\eta=- a y + b x\\ \xi=a x + b y }}\, ,
|
||||
|
||||
where `F(\eta)` is an arbitrary single-valued function. The solution
|
||||
can be found in SymPy with ``pdsolve``::
|
||||
|
||||
>>> from sympy.solvers import pdsolve
|
||||
>>> from sympy.abc import x, y, a, b, c
|
||||
>>> from sympy import Function, pprint
|
||||
>>> f = Function('f')
|
||||
>>> G = Function('G')
|
||||
>>> u = f(x, y)
|
||||
>>> ux = u.diff(x)
|
||||
>>> uy = u.diff(y)
|
||||
>>> genform = a*ux + b*uy + c*u - G(x,y)
|
||||
>>> pprint(genform)
|
||||
d d
|
||||
a*--(f(x, y)) + b*--(f(x, y)) + c*f(x, y) - G(x, y)
|
||||
dx dy
|
||||
>>> pprint(pdsolve(genform, hint='1st_linear_constant_coeff_Integral'))
|
||||
// a*x + b*y \ \|
|
||||
|| / | ||
|
||||
|| | | ||
|
||||
|| | c*xi | ||
|
||||
|| | ------- | ||
|
||||
|| | 2 2 | ||
|
||||
|| | /a*xi + b*eta -a*eta + b*xi\ a + b | ||
|
||||
|| | G|------------, -------------|*e d(xi)| ||
|
||||
|| | | 2 2 2 2 | | ||
|
||||
|| | \ a + b a + b / | -c*xi ||
|
||||
|| | | -------||
|
||||
|| / | 2 2||
|
||||
|| | a + b ||
|
||||
f(x, y) = ||F(eta) + -------------------------------------------------------|*e ||
|
||||
|| 2 2 | ||
|
||||
\\ a + b / /|eta=-a*y + b*x, xi=a*x + b*y
|
||||
|
||||
Examples
|
||||
========
|
||||
|
||||
>>> from sympy.solvers.pde import pdsolve
|
||||
>>> from sympy import Function, pprint, exp
|
||||
>>> from sympy.abc import x,y
|
||||
>>> f = Function('f')
|
||||
>>> eq = -2*f(x,y).diff(x) + 4*f(x,y).diff(y) + 5*f(x,y) - exp(x + 3*y)
|
||||
>>> pdsolve(eq)
|
||||
Eq(f(x, y), (F(4*x + 2*y)*exp(x/2) + exp(x + 4*y)/15)*exp(-y))
|
||||
|
||||
References
|
||||
==========
|
||||
|
||||
- Viktor Grigoryan, "Partial Differential Equations"
|
||||
Math 124A - Fall 2010, pp.7
|
||||
|
||||
"""
|
||||
|
||||
# TODO : For now homogeneous first order linear PDE's having
|
||||
# two variables are implemented. Once there is support for
|
||||
# solving systems of ODE's, this can be extended to n variables.
|
||||
xi, eta = symbols("xi eta")
|
||||
f = func.func
|
||||
x = func.args[0]
|
||||
y = func.args[1]
|
||||
b = match[match['b']]
|
||||
c = match[match['c']]
|
||||
d = match[match['d']]
|
||||
e = -match[match['e']]
|
||||
expterm = exp(-S(d)/(b**2 + c**2)*xi)
|
||||
functerm = solvefun(eta)
|
||||
solvedict = solve((b*x + c*y - xi, c*x - b*y - eta), x, y)
|
||||
# Integral should remain as it is in terms of xi,
|
||||
# doit() should be done in _handle_Integral.
|
||||
genterm = (1/S(b**2 + c**2))*Integral(
|
||||
(1/expterm*e).subs(solvedict), (xi, b*x + c*y))
|
||||
return Eq(f(x,y), Subs(expterm*(functerm + genterm),
|
||||
(eta, xi), (c*x - b*y, b*x + c*y)))
|
||||
|
||||
|
||||
def pde_1st_linear_variable_coeff(eq, func, order, match, solvefun):
|
||||
r"""
|
||||
Solves a first order linear partial differential equation
|
||||
with variable coefficients. The general form of this partial
|
||||
differential equation is
|
||||
|
||||
.. math:: a(x, y) \frac{\partial f(x, y)}{\partial x}
|
||||
+ b(x, y) \frac{\partial f(x, y)}{\partial y}
|
||||
+ c(x, y) f(x, y) = G(x, y)
|
||||
|
||||
where `a(x, y)`, `b(x, y)`, `c(x, y)` and `G(x, y)` are arbitrary
|
||||
functions in `x` and `y`. This PDE is converted into an ODE by
|
||||
making the following transformation:
|
||||
|
||||
1. `\xi` as `x`
|
||||
|
||||
2. `\eta` as the constant in the solution to the differential
|
||||
equation `\frac{dy}{dx} = -\frac{b}{a}`
|
||||
|
||||
Making the previous substitutions reduces it to the linear ODE
|
||||
|
||||
.. math:: a(\xi, \eta)\frac{du}{d\xi} + c(\xi, \eta)u - G(\xi, \eta) = 0
|
||||
|
||||
which can be solved using ``dsolve``.
|
||||
|
||||
>>> from sympy.abc import x, y
|
||||
>>> from sympy import Function, pprint
|
||||
>>> a, b, c, G, f= [Function(i) for i in ['a', 'b', 'c', 'G', 'f']]
|
||||
>>> u = f(x,y)
|
||||
>>> ux = u.diff(x)
|
||||
>>> uy = u.diff(y)
|
||||
>>> genform = a(x, y)*u + b(x, y)*ux + c(x, y)*uy - G(x,y)
|
||||
>>> pprint(genform)
|
||||
d d
|
||||
-G(x, y) + a(x, y)*f(x, y) + b(x, y)*--(f(x, y)) + c(x, y)*--(f(x, y))
|
||||
dx dy
|
||||
|
||||
|
||||
Examples
|
||||
========
|
||||
|
||||
>>> from sympy.solvers.pde import pdsolve
|
||||
>>> from sympy import Function, pprint
|
||||
>>> from sympy.abc import x,y
|
||||
>>> f = Function('f')
|
||||
>>> eq = x*(u.diff(x)) - y*(u.diff(y)) + y**2*u - y**2
|
||||
>>> pdsolve(eq)
|
||||
Eq(f(x, y), F(x*y)*exp(y**2/2) + 1)
|
||||
|
||||
References
|
||||
==========
|
||||
|
||||
- Viktor Grigoryan, "Partial Differential Equations"
|
||||
Math 124A - Fall 2010, pp.7
|
||||
|
||||
"""
|
||||
from sympy.solvers.ode import dsolve
|
||||
|
||||
eta = symbols("eta")
|
||||
f = func.func
|
||||
x = func.args[0]
|
||||
y = func.args[1]
|
||||
b = match[match['b']]
|
||||
c = match[match['c']]
|
||||
d = match[match['d']]
|
||||
e = -match[match['e']]
|
||||
|
||||
|
||||
if not d:
|
||||
# To deal with cases like b*ux = e or c*uy = e
|
||||
if not (b and c):
|
||||
if c:
|
||||
try:
|
||||
tsol = integrate(e/c, y)
|
||||
except NotImplementedError:
|
||||
raise NotImplementedError("Unable to find a solution"
|
||||
" due to inability of integrate")
|
||||
else:
|
||||
return Eq(f(x,y), solvefun(x) + tsol)
|
||||
if b:
|
||||
try:
|
||||
tsol = integrate(e/b, x)
|
||||
except NotImplementedError:
|
||||
raise NotImplementedError("Unable to find a solution"
|
||||
" due to inability of integrate")
|
||||
else:
|
||||
return Eq(f(x,y), solvefun(y) + tsol)
|
||||
|
||||
if not c:
|
||||
# To deal with cases when c is 0, a simpler method is used.
|
||||
# The PDE reduces to b*(u.diff(x)) + d*u = e, which is a linear ODE in x
|
||||
plode = f(x).diff(x)*b + d*f(x) - e
|
||||
sol = dsolve(plode, f(x))
|
||||
syms = sol.free_symbols - plode.free_symbols - {x, y}
|
||||
rhs = _simplify_variable_coeff(sol.rhs, syms, solvefun, y)
|
||||
return Eq(f(x, y), rhs)
|
||||
|
||||
if not b:
|
||||
# To deal with cases when b is 0, a simpler method is used.
|
||||
# The PDE reduces to c*(u.diff(y)) + d*u = e, which is a linear ODE in y
|
||||
plode = f(y).diff(y)*c + d*f(y) - e
|
||||
sol = dsolve(plode, f(y))
|
||||
syms = sol.free_symbols - plode.free_symbols - {x, y}
|
||||
rhs = _simplify_variable_coeff(sol.rhs, syms, solvefun, x)
|
||||
return Eq(f(x, y), rhs)
|
||||
|
||||
dummy = Function('d')
|
||||
h = (c/b).subs(y, dummy(x))
|
||||
sol = dsolve(dummy(x).diff(x) - h, dummy(x))
|
||||
if isinstance(sol, list):
|
||||
sol = sol[0]
|
||||
solsym = sol.free_symbols - h.free_symbols - {x, y}
|
||||
if len(solsym) == 1:
|
||||
solsym = solsym.pop()
|
||||
etat = (solve(sol, solsym)[0]).subs(dummy(x), y)
|
||||
ysub = solve(eta - etat, y)[0]
|
||||
deq = (b*(f(x).diff(x)) + d*f(x) - e).subs(y, ysub)
|
||||
final = (dsolve(deq, f(x), hint='1st_linear')).rhs
|
||||
if isinstance(final, list):
|
||||
final = final[0]
|
||||
finsyms = final.free_symbols - deq.free_symbols - {x, y}
|
||||
rhs = _simplify_variable_coeff(final, finsyms, solvefun, etat)
|
||||
return Eq(f(x, y), rhs)
|
||||
|
||||
else:
|
||||
raise NotImplementedError("Cannot solve the partial differential equation due"
|
||||
" to inability of constantsimp")
|
||||
|
||||
|
||||
def _simplify_variable_coeff(sol, syms, func, funcarg):
|
||||
r"""
|
||||
Helper function to replace constants by functions in 1st_linear_variable_coeff
|
||||
"""
|
||||
eta = Symbol("eta")
|
||||
if len(syms) == 1:
|
||||
sym = syms.pop()
|
||||
final = sol.subs(sym, func(funcarg))
|
||||
|
||||
else:
|
||||
for sym in syms:
|
||||
final = sol.subs(sym, func(funcarg))
|
||||
|
||||
return simplify(final.subs(eta, funcarg))
|
||||
|
||||
|
||||
def pde_separate(eq, fun, sep, strategy='mul'):
|
||||
"""Separate variables in partial differential equation either by additive
|
||||
or multiplicative separation approach. It tries to rewrite an equation so
|
||||
that one of the specified variables occurs on a different side of the
|
||||
equation than the others.
|
||||
|
||||
:param eq: Partial differential equation
|
||||
|
||||
:param fun: Original function F(x, y, z)
|
||||
|
||||
:param sep: List of separated functions [X(x), u(y, z)]
|
||||
|
||||
:param strategy: Separation strategy. You can choose between additive
|
||||
separation ('add') and multiplicative separation ('mul') which is
|
||||
default.
|
||||
|
||||
Examples
|
||||
========
|
||||
|
||||
>>> from sympy import E, Eq, Function, pde_separate, Derivative as D
|
||||
>>> from sympy.abc import x, t
|
||||
>>> u, X, T = map(Function, 'uXT')
|
||||
|
||||
>>> eq = Eq(D(u(x, t), x), E**(u(x, t))*D(u(x, t), t))
|
||||
>>> pde_separate(eq, u(x, t), [X(x), T(t)], strategy='add')
|
||||
[exp(-X(x))*Derivative(X(x), x), exp(T(t))*Derivative(T(t), t)]
|
||||
|
||||
>>> eq = Eq(D(u(x, t), x, 2), D(u(x, t), t, 2))
|
||||
>>> pde_separate(eq, u(x, t), [X(x), T(t)], strategy='mul')
|
||||
[Derivative(X(x), (x, 2))/X(x), Derivative(T(t), (t, 2))/T(t)]
|
||||
|
||||
See Also
|
||||
========
|
||||
pde_separate_add, pde_separate_mul
|
||||
"""
|
||||
|
||||
do_add = False
|
||||
if strategy == 'add':
|
||||
do_add = True
|
||||
elif strategy == 'mul':
|
||||
do_add = False
|
||||
else:
|
||||
raise ValueError('Unknown strategy: %s' % strategy)
|
||||
|
||||
if isinstance(eq, Equality):
|
||||
if eq.rhs != 0:
|
||||
return pde_separate(Eq(eq.lhs - eq.rhs, 0), fun, sep, strategy)
|
||||
else:
|
||||
return pde_separate(Eq(eq, 0), fun, sep, strategy)
|
||||
|
||||
if eq.rhs != 0:
|
||||
raise ValueError("Value should be 0")
|
||||
|
||||
# Handle arguments
|
||||
orig_args = list(fun.args)
|
||||
subs_args = [arg for s in sep for arg in s.args]
|
||||
|
||||
if do_add:
|
||||
functions = reduce(operator.add, sep)
|
||||
else:
|
||||
functions = reduce(operator.mul, sep)
|
||||
|
||||
# Check whether variables match
|
||||
if len(subs_args) != len(orig_args):
|
||||
raise ValueError("Variable counts do not match")
|
||||
# Check for duplicate arguments like [X(x), u(x, y)]
|
||||
if has_dups(subs_args):
|
||||
raise ValueError("Duplicate substitution arguments detected")
|
||||
# Check whether the variables match
|
||||
if set(orig_args) != set(subs_args):
|
||||
raise ValueError("Arguments do not match")
|
||||
|
||||
# Substitute original function with separated...
|
||||
result = eq.lhs.subs(fun, functions).doit()
|
||||
|
||||
# Divide by terms when doing multiplicative separation
|
||||
if not do_add:
|
||||
eq = 0
|
||||
for i in result.args:
|
||||
eq += i/functions
|
||||
result = eq
|
||||
|
||||
svar = subs_args[0]
|
||||
dvar = subs_args[1:]
|
||||
return _separate(result, svar, dvar)
|
||||
|
||||
|
||||
def pde_separate_add(eq, fun, sep):
|
||||
"""
|
||||
Helper function for searching additive separable solutions.
|
||||
|
||||
Consider an equation of two independent variables x, y and a dependent
|
||||
variable w, we look for the product of two functions depending on different
|
||||
arguments:
|
||||
|
||||
`w(x, y, z) = X(x) + y(y, z)`
|
||||
|
||||
Examples
|
||||
========
|
||||
|
||||
>>> from sympy import E, Eq, Function, pde_separate_add, Derivative as D
|
||||
>>> from sympy.abc import x, t
|
||||
>>> u, X, T = map(Function, 'uXT')
|
||||
|
||||
>>> eq = Eq(D(u(x, t), x), E**(u(x, t))*D(u(x, t), t))
|
||||
>>> pde_separate_add(eq, u(x, t), [X(x), T(t)])
|
||||
[exp(-X(x))*Derivative(X(x), x), exp(T(t))*Derivative(T(t), t)]
|
||||
|
||||
"""
|
||||
return pde_separate(eq, fun, sep, strategy='add')
|
||||
|
||||
|
||||
def pde_separate_mul(eq, fun, sep):
|
||||
"""
|
||||
Helper function for searching multiplicative separable solutions.
|
||||
|
||||
Consider an equation of two independent variables x, y and a dependent
|
||||
variable w, we look for the product of two functions depending on different
|
||||
arguments:
|
||||
|
||||
`w(x, y, z) = X(x)*u(y, z)`
|
||||
|
||||
Examples
|
||||
========
|
||||
|
||||
>>> from sympy import Function, Eq, pde_separate_mul, Derivative as D
|
||||
>>> from sympy.abc import x, y
|
||||
>>> u, X, Y = map(Function, 'uXY')
|
||||
|
||||
>>> eq = Eq(D(u(x, y), x, 2), D(u(x, y), y, 2))
|
||||
>>> pde_separate_mul(eq, u(x, y), [X(x), Y(y)])
|
||||
[Derivative(X(x), (x, 2))/X(x), Derivative(Y(y), (y, 2))/Y(y)]
|
||||
|
||||
"""
|
||||
return pde_separate(eq, fun, sep, strategy='mul')
|
||||
|
||||
|
||||
def _separate(eq, dep, others):
|
||||
"""Separate expression into two parts based on dependencies of variables."""
|
||||
|
||||
# FIRST PASS
|
||||
# Extract derivatives depending our separable variable...
|
||||
terms = set()
|
||||
for term in eq.args:
|
||||
if term.is_Mul:
|
||||
for i in term.args:
|
||||
if i.is_Derivative and not i.has(*others):
|
||||
terms.add(term)
|
||||
continue
|
||||
elif term.is_Derivative and not term.has(*others):
|
||||
terms.add(term)
|
||||
# Find the factor that we need to divide by
|
||||
div = set()
|
||||
for term in terms:
|
||||
ext, sep = term.expand().as_independent(dep)
|
||||
# Failed?
|
||||
if sep.has(*others):
|
||||
return None
|
||||
div.add(ext)
|
||||
# FIXME: Find lcm() of all the divisors and divide with it, instead of
|
||||
# current hack :(
|
||||
# https://github.com/sympy/sympy/issues/4597
|
||||
if len(div) > 0:
|
||||
# double sum required or some tests will fail
|
||||
eq = Add(*[simplify(Add(*[term/i for i in div])) for term in eq.args])
|
||||
# SECOND PASS - separate the derivatives
|
||||
div = set()
|
||||
lhs = rhs = 0
|
||||
for term in eq.args:
|
||||
# Check, whether we have already term with independent variable...
|
||||
if not term.has(*others):
|
||||
lhs += term
|
||||
continue
|
||||
# ...otherwise, try to separate
|
||||
temp, sep = term.expand().as_independent(dep)
|
||||
# Failed?
|
||||
if sep.has(*others):
|
||||
return None
|
||||
# Extract the divisors
|
||||
div.add(sep)
|
||||
rhs -= term.expand()
|
||||
# Do the division
|
||||
fulldiv = reduce(operator.add, div)
|
||||
lhs = simplify(lhs/fulldiv).expand()
|
||||
rhs = simplify(rhs/fulldiv).expand()
|
||||
# ...and check whether we were successful :)
|
||||
if lhs.has(*others) or rhs.has(dep):
|
||||
return None
|
||||
return [lhs, rhs]
|
||||
@@ -0,0 +1,872 @@
|
||||
"""Solvers of systems of polynomial equations. """
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
from collections.abc import Sequence, Iterable
|
||||
|
||||
import itertools
|
||||
|
||||
from sympy import Dummy
|
||||
from sympy.core import S
|
||||
from sympy.core.expr import Expr
|
||||
from sympy.core.exprtools import factor_terms
|
||||
from sympy.core.sorting import default_sort_key
|
||||
from sympy.logic.boolalg import Boolean
|
||||
from sympy.polys import Poly, groebner, roots
|
||||
from sympy.polys.domains import ZZ
|
||||
from sympy.polys.polyoptions import build_options
|
||||
from sympy.polys.polytools import parallel_poly_from_expr, sqf_part
|
||||
from sympy.polys.polyerrors import (
|
||||
ComputationFailed,
|
||||
PolificationFailed,
|
||||
CoercionFailed,
|
||||
GeneratorsNeeded,
|
||||
DomainError
|
||||
)
|
||||
from sympy.simplify import rcollect
|
||||
from sympy.utilities import postfixes
|
||||
from sympy.utilities.iterables import cartes
|
||||
from sympy.utilities.misc import filldedent
|
||||
from sympy.logic.boolalg import Or, And
|
||||
from sympy.core.relational import Eq
|
||||
|
||||
|
||||
class SolveFailed(Exception):
|
||||
"""Raised when solver's conditions were not met. """
|
||||
|
||||
|
||||
def solve_poly_system(seq, *gens, strict=False, **args):
|
||||
"""
|
||||
Return a list of solutions for the system of polynomial equations
|
||||
or else None.
|
||||
|
||||
Parameters
|
||||
==========
|
||||
|
||||
seq: a list/tuple/set
|
||||
Listing all the equations that are needed to be solved
|
||||
gens: generators
|
||||
generators of the equations in seq for which we want the
|
||||
solutions
|
||||
strict: a boolean (default is False)
|
||||
if strict is True, NotImplementedError will be raised if
|
||||
the solution is known to be incomplete (which can occur if
|
||||
not all solutions are expressible in radicals)
|
||||
args: Keyword arguments
|
||||
Special options for solving the equations.
|
||||
|
||||
|
||||
Returns
|
||||
=======
|
||||
|
||||
List[Tuple]
|
||||
a list of tuples with elements being solutions for the
|
||||
symbols in the order they were passed as gens
|
||||
None
|
||||
None is returned when the computed basis contains only the ground.
|
||||
|
||||
Examples
|
||||
========
|
||||
|
||||
>>> from sympy import solve_poly_system
|
||||
>>> from sympy.abc import x, y
|
||||
|
||||
>>> solve_poly_system([x*y - 2*y, 2*y**2 - x**2], x, y)
|
||||
[(0, 0), (2, -sqrt(2)), (2, sqrt(2))]
|
||||
|
||||
>>> solve_poly_system([x**5 - x + y**3, y**2 - 1], x, y, strict=True)
|
||||
Traceback (most recent call last):
|
||||
...
|
||||
UnsolvableFactorError
|
||||
|
||||
"""
|
||||
try:
|
||||
polys, opt = parallel_poly_from_expr(seq, *gens, **args)
|
||||
except PolificationFailed as exc:
|
||||
raise ComputationFailed('solve_poly_system', len(seq), exc)
|
||||
|
||||
if len(polys) == len(opt.gens) == 2:
|
||||
f, g = polys
|
||||
|
||||
if all(i <= 2 for i in f.degree_list() + g.degree_list()):
|
||||
try:
|
||||
return solve_biquadratic(f, g, opt)
|
||||
except SolveFailed:
|
||||
pass
|
||||
|
||||
return solve_generic(polys, opt, strict=strict)
|
||||
|
||||
|
||||
def solve_biquadratic(f, g, opt):
|
||||
"""Solve a system of two bivariate quadratic polynomial equations.
|
||||
|
||||
Parameters
|
||||
==========
|
||||
|
||||
f: a single Expr or Poly
|
||||
First equation
|
||||
g: a single Expr or Poly
|
||||
Second Equation
|
||||
opt: an Options object
|
||||
For specifying keyword arguments and generators
|
||||
|
||||
Returns
|
||||
=======
|
||||
|
||||
List[Tuple]
|
||||
a list of tuples with elements being solutions for the
|
||||
symbols in the order they were passed as gens
|
||||
None
|
||||
None is returned when the computed basis contains only the ground.
|
||||
|
||||
Examples
|
||||
========
|
||||
|
||||
>>> from sympy import Options, Poly
|
||||
>>> from sympy.abc import x, y
|
||||
>>> from sympy.solvers.polysys import solve_biquadratic
|
||||
>>> NewOption = Options((x, y), {'domain': 'ZZ'})
|
||||
|
||||
>>> a = Poly(y**2 - 4 + x, y, x, domain='ZZ')
|
||||
>>> b = Poly(y*2 + 3*x - 7, y, x, domain='ZZ')
|
||||
>>> solve_biquadratic(a, b, NewOption)
|
||||
[(1/3, 3), (41/27, 11/9)]
|
||||
|
||||
>>> a = Poly(y + x**2 - 3, y, x, domain='ZZ')
|
||||
>>> b = Poly(-y + x - 4, y, x, domain='ZZ')
|
||||
>>> solve_biquadratic(a, b, NewOption)
|
||||
[(7/2 - sqrt(29)/2, -sqrt(29)/2 - 1/2), (sqrt(29)/2 + 7/2, -1/2 + \
|
||||
sqrt(29)/2)]
|
||||
"""
|
||||
G = groebner([f, g])
|
||||
|
||||
if len(G) == 1 and G[0].is_ground:
|
||||
return None
|
||||
|
||||
if len(G) != 2:
|
||||
raise SolveFailed
|
||||
|
||||
x, y = opt.gens
|
||||
p, q = G
|
||||
if not p.gcd(q).is_ground:
|
||||
# not 0-dimensional
|
||||
raise SolveFailed
|
||||
|
||||
p = Poly(p, x, expand=False)
|
||||
p_roots = [rcollect(expr, y) for expr in roots(p).keys()]
|
||||
|
||||
q = q.ltrim(-1)
|
||||
q_roots = list(roots(q).keys())
|
||||
|
||||
solutions = [(p_root.subs(y, q_root), q_root) for q_root, p_root in
|
||||
itertools.product(q_roots, p_roots)]
|
||||
|
||||
return sorted(solutions, key=default_sort_key)
|
||||
|
||||
|
||||
def solve_generic(polys, opt, strict=False):
|
||||
"""
|
||||
Solve a generic system of polynomial equations.
|
||||
|
||||
Returns all possible solutions over C[x_1, x_2, ..., x_m] of a
|
||||
set F = { f_1, f_2, ..., f_n } of polynomial equations, using
|
||||
Groebner basis approach. For now only zero-dimensional systems
|
||||
are supported, which means F can have at most a finite number
|
||||
of solutions. If the basis contains only the ground, None is
|
||||
returned.
|
||||
|
||||
The algorithm works by the fact that, supposing G is the basis
|
||||
of F with respect to an elimination order (here lexicographic
|
||||
order is used), G and F generate the same ideal, they have the
|
||||
same set of solutions. By the elimination property, if G is a
|
||||
reduced, zero-dimensional Groebner basis, then there exists an
|
||||
univariate polynomial in G (in its last variable). This can be
|
||||
solved by computing its roots. Substituting all computed roots
|
||||
for the last (eliminated) variable in other elements of G, new
|
||||
polynomial system is generated. Applying the above procedure
|
||||
recursively, a finite number of solutions can be found.
|
||||
|
||||
The ability of finding all solutions by this procedure depends
|
||||
on the root finding algorithms. If no solutions were found, it
|
||||
means only that roots() failed, but the system is solvable. To
|
||||
overcome this difficulty use numerical algorithms instead.
|
||||
|
||||
Parameters
|
||||
==========
|
||||
|
||||
polys: a list/tuple/set
|
||||
Listing all the polynomial equations that are needed to be solved
|
||||
opt: an Options object
|
||||
For specifying keyword arguments and generators
|
||||
strict: a boolean
|
||||
If strict is True, NotImplementedError will be raised if the solution
|
||||
is known to be incomplete
|
||||
|
||||
Returns
|
||||
=======
|
||||
|
||||
List[Tuple]
|
||||
a list of tuples with elements being solutions for the
|
||||
symbols in the order they were passed as gens
|
||||
None
|
||||
None is returned when the computed basis contains only the ground.
|
||||
|
||||
References
|
||||
==========
|
||||
|
||||
.. [Buchberger01] B. Buchberger, Groebner Bases: A Short
|
||||
Introduction for Systems Theorists, In: R. Moreno-Diaz,
|
||||
B. Buchberger, J.L. Freire, Proceedings of EUROCAST'01,
|
||||
February, 2001
|
||||
|
||||
.. [Cox97] D. Cox, J. Little, D. O'Shea, Ideals, Varieties
|
||||
and Algorithms, Springer, Second Edition, 1997, pp. 112
|
||||
|
||||
Raises
|
||||
========
|
||||
|
||||
NotImplementedError
|
||||
If the system is not zero-dimensional (does not have a finite
|
||||
number of solutions)
|
||||
|
||||
UnsolvableFactorError
|
||||
If ``strict`` is True and not all solution components are
|
||||
expressible in radicals
|
||||
|
||||
Examples
|
||||
========
|
||||
|
||||
>>> from sympy import Poly, Options
|
||||
>>> from sympy.solvers.polysys import solve_generic
|
||||
>>> from sympy.abc import x, y
|
||||
>>> NewOption = Options((x, y), {'domain': 'ZZ'})
|
||||
|
||||
>>> a = Poly(x - y + 5, x, y, domain='ZZ')
|
||||
>>> b = Poly(x + y - 3, x, y, domain='ZZ')
|
||||
>>> solve_generic([a, b], NewOption)
|
||||
[(-1, 4)]
|
||||
|
||||
>>> a = Poly(x - 2*y + 5, x, y, domain='ZZ')
|
||||
>>> b = Poly(2*x - y - 3, x, y, domain='ZZ')
|
||||
>>> solve_generic([a, b], NewOption)
|
||||
[(11/3, 13/3)]
|
||||
|
||||
>>> a = Poly(x**2 + y, x, y, domain='ZZ')
|
||||
>>> b = Poly(x + y*4, x, y, domain='ZZ')
|
||||
>>> solve_generic([a, b], NewOption)
|
||||
[(0, 0), (1/4, -1/16)]
|
||||
|
||||
>>> a = Poly(x**5 - x + y**3, x, y, domain='ZZ')
|
||||
>>> b = Poly(y**2 - 1, x, y, domain='ZZ')
|
||||
>>> solve_generic([a, b], NewOption, strict=True)
|
||||
Traceback (most recent call last):
|
||||
...
|
||||
UnsolvableFactorError
|
||||
|
||||
"""
|
||||
def _is_univariate(f):
|
||||
"""Returns True if 'f' is univariate in its last variable. """
|
||||
for monom in f.monoms():
|
||||
if any(monom[:-1]):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def _subs_root(f, gen, zero):
|
||||
"""Replace generator with a root so that the result is nice. """
|
||||
p = f.as_expr({gen: zero})
|
||||
|
||||
if f.degree(gen) >= 2:
|
||||
p = p.expand(deep=False)
|
||||
|
||||
return p
|
||||
|
||||
def _solve_reduced_system(system, gens, entry=False):
|
||||
"""Recursively solves reduced polynomial systems. """
|
||||
if len(system) == len(gens) == 1:
|
||||
# the below line will produce UnsolvableFactorError if
|
||||
# strict=True and the solution from `roots` is incomplete
|
||||
zeros = list(roots(system[0], gens[-1], strict=strict).keys())
|
||||
return [(zero,) for zero in zeros]
|
||||
|
||||
basis = groebner(system, gens, polys=True)
|
||||
|
||||
if len(basis) == 1 and basis[0].is_ground:
|
||||
if not entry:
|
||||
return []
|
||||
else:
|
||||
return None
|
||||
|
||||
univariate = list(filter(_is_univariate, basis))
|
||||
|
||||
if len(basis) < len(gens):
|
||||
raise NotImplementedError(filldedent('''
|
||||
only zero-dimensional systems supported
|
||||
(finite number of solutions)
|
||||
'''))
|
||||
|
||||
if len(univariate) == 1:
|
||||
f = univariate.pop()
|
||||
else:
|
||||
raise NotImplementedError(filldedent('''
|
||||
only zero-dimensional systems supported
|
||||
(finite number of solutions)
|
||||
'''))
|
||||
|
||||
gens = f.gens
|
||||
gen = gens[-1]
|
||||
|
||||
# the below line will produce UnsolvableFactorError if
|
||||
# strict=True and the solution from `roots` is incomplete
|
||||
zeros = list(roots(f.ltrim(gen), strict=strict).keys())
|
||||
|
||||
if not zeros:
|
||||
return []
|
||||
|
||||
if len(basis) == 1:
|
||||
return [(zero,) for zero in zeros]
|
||||
|
||||
solutions = []
|
||||
|
||||
for zero in zeros:
|
||||
new_system = []
|
||||
new_gens = gens[:-1]
|
||||
|
||||
for b in basis[:-1]:
|
||||
eq = _subs_root(b, gen, zero)
|
||||
|
||||
if eq is not S.Zero:
|
||||
new_system.append(eq)
|
||||
|
||||
for solution in _solve_reduced_system(new_system, new_gens):
|
||||
solutions.append(solution + (zero,))
|
||||
|
||||
if solutions and len(solutions[0]) != len(gens):
|
||||
raise NotImplementedError(filldedent('''
|
||||
only zero-dimensional systems supported
|
||||
(finite number of solutions)
|
||||
'''))
|
||||
return solutions
|
||||
|
||||
try:
|
||||
result = _solve_reduced_system(polys, opt.gens, entry=True)
|
||||
except CoercionFailed:
|
||||
raise NotImplementedError
|
||||
|
||||
if result is not None:
|
||||
return sorted(result, key=default_sort_key)
|
||||
|
||||
|
||||
def solve_triangulated(polys, *gens, **args):
|
||||
"""
|
||||
Solve a polynomial system using Gianni-Kalkbrenner algorithm.
|
||||
|
||||
The algorithm proceeds by computing one Groebner basis in the ground
|
||||
domain and then by iteratively computing polynomial factorizations in
|
||||
appropriately constructed algebraic extensions of the ground domain.
|
||||
|
||||
Parameters
|
||||
==========
|
||||
|
||||
polys: a list/tuple/set
|
||||
Listing all the equations that are needed to be solved
|
||||
gens: generators
|
||||
generators of the equations in polys for which we want the
|
||||
solutions
|
||||
args: Keyword arguments
|
||||
Special options for solving the equations
|
||||
|
||||
Returns
|
||||
=======
|
||||
|
||||
List[Tuple]
|
||||
A List of tuples. Solutions for symbols that satisfy the
|
||||
equations listed in polys
|
||||
|
||||
Examples
|
||||
========
|
||||
|
||||
>>> from sympy import solve_triangulated
|
||||
>>> from sympy.abc import x, y, z
|
||||
|
||||
>>> F = [x**2 + y + z - 1, x + y**2 + z - 1, x + y + z**2 - 1]
|
||||
|
||||
>>> solve_triangulated(F, x, y, z)
|
||||
[(0, 0, 1), (0, 1, 0), (1, 0, 0)]
|
||||
|
||||
Using extension for algebraic solutions.
|
||||
|
||||
>>> solve_triangulated(F, x, y, z, extension=True) #doctest: +NORMALIZE_WHITESPACE
|
||||
[(0, 0, 1), (0, 1, 0), (1, 0, 0),
|
||||
(CRootOf(x**2 + 2*x - 1, 0), CRootOf(x**2 + 2*x - 1, 0), CRootOf(x**2 + 2*x - 1, 0)),
|
||||
(CRootOf(x**2 + 2*x - 1, 1), CRootOf(x**2 + 2*x - 1, 1), CRootOf(x**2 + 2*x - 1, 1))]
|
||||
|
||||
References
|
||||
==========
|
||||
|
||||
1. Patrizia Gianni, Teo Mora, Algebraic Solution of System of
|
||||
Polynomial Equations using Groebner Bases, AAECC-5 on Applied Algebra,
|
||||
Algebraic Algorithms and Error-Correcting Codes, LNCS 356 247--257, 1989
|
||||
|
||||
"""
|
||||
opt = build_options(gens, args)
|
||||
|
||||
G = groebner(polys, gens, polys=True)
|
||||
G = list(reversed(G))
|
||||
|
||||
extension = opt.get('extension', False)
|
||||
if extension:
|
||||
def _solve_univariate(f):
|
||||
return [r for r, _ in f.all_roots(multiple=False, radicals=False)]
|
||||
else:
|
||||
domain = opt.get('domain')
|
||||
|
||||
if domain is not None:
|
||||
for i, g in enumerate(G):
|
||||
G[i] = g.set_domain(domain)
|
||||
|
||||
def _solve_univariate(f):
|
||||
return list(f.ground_roots().keys())
|
||||
|
||||
f, G = G[0].ltrim(-1), G[1:]
|
||||
dom = f.get_domain()
|
||||
|
||||
zeros = _solve_univariate(f)
|
||||
|
||||
if extension:
|
||||
solutions = {((zero,), dom.algebraic_field(zero)) for zero in zeros}
|
||||
else:
|
||||
solutions = {((zero,), dom) for zero in zeros}
|
||||
|
||||
var_seq = reversed(gens[:-1])
|
||||
vars_seq = postfixes(gens[1:])
|
||||
|
||||
for var, vars in zip(var_seq, vars_seq):
|
||||
_solutions = set()
|
||||
|
||||
for values, dom in solutions:
|
||||
H, mapping = [], list(zip(vars, values))
|
||||
|
||||
for g in G:
|
||||
_vars = (var,) + vars
|
||||
|
||||
if g.has_only_gens(*_vars) and g.degree(var) != 0:
|
||||
if extension:
|
||||
g = g.set_domain(g.domain.unify(dom))
|
||||
h = g.ltrim(var).eval(dict(mapping))
|
||||
|
||||
if g.degree(var) == h.degree():
|
||||
H.append(h)
|
||||
|
||||
p = min(H, key=lambda h: h.degree())
|
||||
zeros = _solve_univariate(p)
|
||||
|
||||
for zero in zeros:
|
||||
if not (zero in dom):
|
||||
dom_zero = dom.algebraic_field(zero)
|
||||
else:
|
||||
dom_zero = dom
|
||||
|
||||
_solutions.add(((zero,) + values, dom_zero))
|
||||
|
||||
solutions = _solutions
|
||||
return sorted((s for s, _ in solutions), key=default_sort_key)
|
||||
|
||||
|
||||
def factor_system(eqs: Sequence[Expr | complex], gens: Sequence[Expr] = (), **kwargs: Any) -> list[list[Expr]]:
|
||||
"""
|
||||
Factorizes a system of polynomial equations into
|
||||
irreducible subsystems.
|
||||
|
||||
Parameters
|
||||
==========
|
||||
|
||||
eqs : list
|
||||
List of expressions to be factored.
|
||||
Each expression is assumed to be equal to zero.
|
||||
|
||||
gens : list, optional
|
||||
Generator(s) of the polynomial ring.
|
||||
If not provided, all free symbols will be used.
|
||||
|
||||
**kwargs : dict, optional
|
||||
Same optional arguments taken by ``factor``
|
||||
|
||||
Returns
|
||||
=======
|
||||
|
||||
list[list[Expr]]
|
||||
A list of lists of expressions, where each sublist represents
|
||||
an irreducible subsystem. When solved, each subsystem gives
|
||||
one component of the solution. Only generic solutions are
|
||||
returned (cases not requiring parameters to be zero).
|
||||
|
||||
Examples
|
||||
========
|
||||
|
||||
>>> from sympy.solvers.polysys import factor_system, factor_system_cond
|
||||
>>> from sympy.abc import x, y, a, b, c
|
||||
|
||||
A simple system with multiple solutions:
|
||||
|
||||
>>> factor_system([x**2 - 1, y - 1])
|
||||
[[x + 1, y - 1], [x - 1, y - 1]]
|
||||
|
||||
A system with no solution:
|
||||
|
||||
>>> factor_system([x, 1])
|
||||
[]
|
||||
|
||||
A system where any value of the symbol(s) is a solution:
|
||||
|
||||
>>> factor_system([x - x, (x + 1)**2 - (x**2 + 2*x + 1)])
|
||||
[[]]
|
||||
|
||||
A system with no generic solution:
|
||||
|
||||
>>> factor_system([a*x*(x-1), b*y, c], [x, y])
|
||||
[]
|
||||
|
||||
If c is added to the unknowns then the system has a generic solution:
|
||||
|
||||
>>> factor_system([a*x*(x-1), b*y, c], [x, y, c])
|
||||
[[x - 1, y, c], [x, y, c]]
|
||||
|
||||
Alternatively :func:`factor_system_cond` can be used to get degenerate
|
||||
cases as well:
|
||||
|
||||
>>> factor_system_cond([a*x*(x-1), b*y, c], [x, y])
|
||||
[[x - 1, y, c], [x, y, c], [x - 1, b, c], [x, b, c], [y, a, c], [a, b, c]]
|
||||
|
||||
Each of the above cases is only satisfiable in the degenerate case `c = 0`.
|
||||
|
||||
The solution set of the original system represented
|
||||
by eqs is the union of the solution sets of the
|
||||
factorized systems.
|
||||
|
||||
An empty list [] means no generic solution exists.
|
||||
A list containing an empty list [[]] means any value of
|
||||
the symbol(s) is a solution.
|
||||
|
||||
See Also
|
||||
========
|
||||
|
||||
factor_system_cond : Returns both generic and degenerate solutions
|
||||
factor_system_bool : Returns a Boolean combination representing all solutions
|
||||
sympy.polys.polytools.factor : Factors a polynomial into irreducible factors
|
||||
over the rational numbers
|
||||
"""
|
||||
|
||||
systems = _factor_system_poly_from_expr(eqs, gens, **kwargs)
|
||||
systems_generic = [sys for sys in systems if not _is_degenerate(sys)]
|
||||
systems_expr = [[p.as_expr() for p in system] for system in systems_generic]
|
||||
return systems_expr
|
||||
|
||||
|
||||
def _is_degenerate(system: list[Poly]) -> bool:
|
||||
"""Helper function to check if a system is degenerate"""
|
||||
return any(p.is_ground for p in system)
|
||||
|
||||
|
||||
def factor_system_bool(eqs: Sequence[Expr | complex], gens: Sequence[Expr] = (), **kwargs: Any) -> Boolean:
|
||||
"""
|
||||
Factorizes a system of polynomial equations into irreducible DNF.
|
||||
|
||||
The system of expressions(eqs) is taken and a Boolean combination
|
||||
of equations is returned that represents the same solution set.
|
||||
The result is in disjunctive normal form (OR of ANDs).
|
||||
|
||||
Parameters
|
||||
==========
|
||||
|
||||
eqs : list
|
||||
List of expressions to be factored.
|
||||
Each expression is assumed to be equal to zero.
|
||||
|
||||
gens : list, optional
|
||||
Generator(s) of the polynomial ring.
|
||||
If not provided, all free symbols will be used.
|
||||
|
||||
**kwargs : dict, optional
|
||||
Optional keyword arguments
|
||||
|
||||
|
||||
Returns
|
||||
=======
|
||||
|
||||
Boolean:
|
||||
A Boolean combination of equations. The result is typically in
|
||||
the form of a conjunction (AND) of a disjunctive normal form
|
||||
with additional conditions.
|
||||
|
||||
Examples
|
||||
========
|
||||
|
||||
>>> from sympy.solvers.polysys import factor_system_bool
|
||||
>>> from sympy.abc import x, y, a, b, c
|
||||
>>> factor_system_bool([x**2 - 1])
|
||||
Eq(x - 1, 0) | Eq(x + 1, 0)
|
||||
|
||||
>>> factor_system_bool([x**2 - 1, y - 1])
|
||||
(Eq(x - 1, 0) & Eq(y - 1, 0)) | (Eq(x + 1, 0) & Eq(y - 1, 0))
|
||||
|
||||
>>> eqs = [a * (x - 1), b]
|
||||
>>> factor_system_bool([a*(x - 1), b])
|
||||
(Eq(a, 0) & Eq(b, 0)) | (Eq(b, 0) & Eq(x - 1, 0))
|
||||
|
||||
>>> factor_system_bool([a*x**2 - a, b*(x + 1), c], [x])
|
||||
(Eq(c, 0) & Eq(x + 1, 0)) | (Eq(a, 0) & Eq(b, 0) & Eq(c, 0)) | (Eq(b, 0) & Eq(c, 0) & Eq(x - 1, 0))
|
||||
|
||||
>>> factor_system_bool([x**2 + 2*x + 1 - (x + 1)**2])
|
||||
True
|
||||
|
||||
The result is logically equivalent to the system of equations
|
||||
i.e. eqs. The function returns ``True`` when all values of
|
||||
the symbol(s) is a solution and ``False`` when the system
|
||||
cannot be solved.
|
||||
|
||||
See Also
|
||||
========
|
||||
|
||||
factor_system : Returns factors and solvability condition separately
|
||||
factor_system_cond : Returns both factors and conditions
|
||||
|
||||
"""
|
||||
|
||||
systems = factor_system_cond(eqs, gens, **kwargs)
|
||||
return Or(*[And(*[Eq(eq, 0) for eq in sys]) for sys in systems])
|
||||
|
||||
|
||||
def factor_system_cond(eqs: Sequence[Expr | complex], gens: Sequence[Expr] = (), **kwargs: Any) -> list[list[Expr]]:
|
||||
"""
|
||||
Factorizes a polynomial system into irreducible components and returns
|
||||
both generic and degenerate solutions.
|
||||
|
||||
Parameters
|
||||
==========
|
||||
|
||||
eqs : list
|
||||
List of expressions to be factored.
|
||||
Each expression is assumed to be equal to zero.
|
||||
|
||||
gens : list, optional
|
||||
Generator(s) of the polynomial ring.
|
||||
If not provided, all free symbols will be used.
|
||||
|
||||
**kwargs : dict, optional
|
||||
Optional keyword arguments.
|
||||
|
||||
Returns
|
||||
=======
|
||||
|
||||
list[list[Expr]]
|
||||
A list of lists of expressions, where each sublist represents
|
||||
an irreducible subsystem. Includes both generic solutions and
|
||||
degenerate cases requiring equality conditions on parameters.
|
||||
|
||||
Examples
|
||||
========
|
||||
|
||||
>>> from sympy.solvers.polysys import factor_system_cond
|
||||
>>> from sympy.abc import x, y, a, b, c
|
||||
|
||||
>>> factor_system_cond([x**2 - 4, a*y, b], [x, y])
|
||||
[[x + 2, y, b], [x - 2, y, b], [x + 2, a, b], [x - 2, a, b]]
|
||||
|
||||
>>> factor_system_cond([a*x*(x-1), b*y, c], [x, y])
|
||||
[[x - 1, y, c], [x, y, c], [x - 1, b, c], [x, b, c], [y, a, c], [a, b, c]]
|
||||
|
||||
An empty list [] means no solution exists.
|
||||
A list containing an empty list [[]] means any value of
|
||||
the symbol(s) is a solution.
|
||||
|
||||
See Also
|
||||
========
|
||||
|
||||
factor_system : Returns only generic solutions
|
||||
factor_system_bool : Returns a Boolean combination representing all solutions
|
||||
sympy.polys.polytools.factor : Factors a polynomial into irreducible factors
|
||||
over the rational numbers
|
||||
"""
|
||||
systems_poly = _factor_system_poly_from_expr(eqs, gens, **kwargs)
|
||||
systems = [[p.as_expr() for p in system] for system in systems_poly]
|
||||
return systems
|
||||
|
||||
|
||||
def _factor_system_poly_from_expr(
|
||||
eqs: Sequence[Expr | complex], gens: Sequence[Expr], **kwargs: Any
|
||||
) -> list[list[Poly]]:
|
||||
"""
|
||||
Convert expressions to polynomials and factor the system.
|
||||
|
||||
Takes a sequence of expressions, converts them to
|
||||
polynomials, and factors the resulting system. Handles both regular
|
||||
polynomial systems and purely numerical cases.
|
||||
"""
|
||||
try:
|
||||
polys, opts = parallel_poly_from_expr(eqs, *gens, **kwargs)
|
||||
only_numbers = False
|
||||
except (GeneratorsNeeded, PolificationFailed):
|
||||
_u = Dummy('u')
|
||||
polys, opts = parallel_poly_from_expr(eqs, [_u], **kwargs)
|
||||
assert opts['domain'].is_Numerical
|
||||
only_numbers = True
|
||||
|
||||
if only_numbers:
|
||||
return [[]] if all(p == 0 for p in polys) else []
|
||||
|
||||
return factor_system_poly(polys)
|
||||
|
||||
|
||||
def factor_system_poly(polys: list[Poly]) -> list[list[Poly]]:
|
||||
"""
|
||||
Factors a system of polynomial equations into irreducible subsystems
|
||||
|
||||
Core implementation that works directly with Poly instances.
|
||||
|
||||
Parameters
|
||||
==========
|
||||
|
||||
polys : list[Poly]
|
||||
A list of Poly instances to be factored.
|
||||
|
||||
Returns
|
||||
=======
|
||||
|
||||
list[list[Poly]]
|
||||
A list of lists of polynomials, where each sublist represents
|
||||
an irreducible component of the solution. Includes both
|
||||
generic and degenerate cases.
|
||||
|
||||
Examples
|
||||
========
|
||||
|
||||
>>> from sympy import symbols, Poly, ZZ
|
||||
>>> from sympy.solvers.polysys import factor_system_poly
|
||||
>>> a, b, c, x = symbols('a b c x')
|
||||
>>> p1 = Poly((a - 1)*(x - 2), x, domain=ZZ[a,b,c])
|
||||
>>> p2 = Poly((b - 3)*(x - 2), x, domain=ZZ[a,b,c])
|
||||
>>> p3 = Poly(c, x, domain=ZZ[a,b,c])
|
||||
|
||||
The equation to be solved for x is ``x - 2 = 0`` provided either
|
||||
of the two conditions on the parameters ``a`` and ``b`` is nonzero
|
||||
and the constant parameter ``c`` should be zero.
|
||||
|
||||
>>> sys1, sys2 = factor_system_poly([p1, p2, p3])
|
||||
>>> sys1
|
||||
[Poly(x - 2, x, domain='ZZ[a,b,c]'),
|
||||
Poly(c, x, domain='ZZ[a,b,c]')]
|
||||
>>> sys2
|
||||
[Poly(a - 1, x, domain='ZZ[a,b,c]'),
|
||||
Poly(b - 3, x, domain='ZZ[a,b,c]'),
|
||||
Poly(c, x, domain='ZZ[a,b,c]')]
|
||||
|
||||
An empty list [] when returned means no solution exists.
|
||||
Whereas a list containing an empty list [[]] means any value is a solution.
|
||||
|
||||
See Also
|
||||
========
|
||||
|
||||
factor_system : Returns only generic solutions
|
||||
factor_system_bool : Returns a Boolean combination representing the solutions
|
||||
factor_system_cond : Returns both generic and degenerate solutions
|
||||
sympy.polys.polytools.factor : Factors a polynomial into irreducible factors
|
||||
over the rational numbers
|
||||
"""
|
||||
if not all(isinstance(poly, Poly) for poly in polys):
|
||||
raise TypeError("polys should be a list of Poly instances")
|
||||
if not polys:
|
||||
return [[]]
|
||||
|
||||
base_domain = polys[0].domain
|
||||
base_gens = polys[0].gens
|
||||
if not all(poly.domain == base_domain and poly.gens == base_gens for poly in polys[1:]):
|
||||
raise DomainError("All polynomials must have the same domain and generators")
|
||||
|
||||
factor_sets = []
|
||||
for poly in polys:
|
||||
constant, factors_mult = poly.factor_list()
|
||||
|
||||
if constant.is_zero is True:
|
||||
continue
|
||||
elif constant.is_zero is False:
|
||||
if not factors_mult:
|
||||
return []
|
||||
factor_sets.append([f for f, _ in factors_mult])
|
||||
else:
|
||||
constant = sqf_part(factor_terms(constant).as_coeff_Mul()[1])
|
||||
constp = Poly(constant, base_gens, domain=base_domain)
|
||||
factors = [f for f, _ in factors_mult]
|
||||
factors.append(constp)
|
||||
factor_sets.append(factors)
|
||||
|
||||
if not factor_sets:
|
||||
return [[]]
|
||||
|
||||
result = _factor_sets(factor_sets)
|
||||
return _sort_systems(result)
|
||||
|
||||
|
||||
def _factor_sets_slow(eqs: list[list]) -> set[frozenset]:
|
||||
"""
|
||||
Helper to find the minimal set of factorised subsystems that is
|
||||
equivalent to the original system.
|
||||
|
||||
The result is in DNF.
|
||||
"""
|
||||
if not eqs:
|
||||
return {frozenset()}
|
||||
systems_set = {frozenset(sys) for sys in cartes(*eqs)}
|
||||
return {s1 for s1 in systems_set if not any(s1 > s2 for s2 in systems_set)}
|
||||
|
||||
|
||||
def _factor_sets(eqs: list[list]) -> set[frozenset]:
|
||||
"""
|
||||
Helper that builds factor combinations.
|
||||
"""
|
||||
if not eqs:
|
||||
return {frozenset()}
|
||||
|
||||
current_set = min(eqs, key=len)
|
||||
other_sets = [s for s in eqs if s is not current_set]
|
||||
|
||||
stack = [(factor, [s for s in other_sets if factor not in s], {factor})
|
||||
for factor in current_set]
|
||||
|
||||
result = set()
|
||||
|
||||
while stack:
|
||||
factor, remaining_sets, current_solution = stack.pop()
|
||||
|
||||
if not remaining_sets:
|
||||
result.add(frozenset(current_solution))
|
||||
continue
|
||||
|
||||
next_set = min(remaining_sets, key=len)
|
||||
next_remaining = [s for s in remaining_sets if s is not next_set]
|
||||
|
||||
for next_factor in next_set:
|
||||
valid_remaining = [s for s in next_remaining if next_factor not in s]
|
||||
new_solution = current_solution | {next_factor}
|
||||
stack.append((next_factor, valid_remaining, new_solution))
|
||||
|
||||
return {s1 for s1 in result if not any(s1 > s2 for s2 in result)}
|
||||
|
||||
|
||||
def _sort_systems(systems: Iterable[Iterable[Poly]]) -> list[list[Poly]]:
|
||||
"""Sorts a list of lists of polynomials"""
|
||||
systems_list = [sorted(s, key=_poly_sort_key, reverse=True) for s in systems]
|
||||
return sorted(systems_list, key=_sys_sort_key, reverse=True)
|
||||
|
||||
|
||||
def _poly_sort_key(poly):
|
||||
"""Sort key for polynomials"""
|
||||
if poly.domain.is_FF:
|
||||
poly = poly.set_domain(ZZ)
|
||||
return poly.degree_list(), poly.rep.to_list()
|
||||
|
||||
|
||||
def _sys_sort_key(sys):
|
||||
"""Sort key for lists of polynomials"""
|
||||
return list(zip(*map(_poly_sort_key, sys)))
|
||||
@@ -0,0 +1,843 @@
|
||||
r"""
|
||||
This module is intended for solving recurrences or, in other words,
|
||||
difference equations. Currently supported are linear, inhomogeneous
|
||||
equations with polynomial or rational coefficients.
|
||||
|
||||
The solutions are obtained among polynomials, rational functions,
|
||||
hypergeometric terms, or combinations of hypergeometric term which
|
||||
are pairwise dissimilar.
|
||||
|
||||
``rsolve_X`` functions were meant as a low level interface
|
||||
for ``rsolve`` which would use Mathematica's syntax.
|
||||
|
||||
Given a recurrence relation:
|
||||
|
||||
.. math:: a_{k}(n) y(n+k) + a_{k-1}(n) y(n+k-1) +
|
||||
... + a_{0}(n) y(n) = f(n)
|
||||
|
||||
where `k > 0` and `a_{i}(n)` are polynomials in `n`. To use
|
||||
``rsolve_X`` we need to put all coefficients in to a list ``L`` of
|
||||
`k+1` elements the following way:
|
||||
|
||||
``L = [a_{0}(n), ..., a_{k-1}(n), a_{k}(n)]``
|
||||
|
||||
where ``L[i]``, for `i=0, \ldots, k`, maps to
|
||||
`a_{i}(n) y(n+i)` (`y(n+i)` is implicit).
|
||||
|
||||
For example if we would like to compute `m`-th Bernoulli polynomial
|
||||
up to a constant (example was taken from rsolve_poly docstring),
|
||||
then we would use `b(n+1) - b(n) = m n^{m-1}` recurrence, which
|
||||
has solution `b(n) = B_m + C`.
|
||||
|
||||
Then ``L = [-1, 1]`` and `f(n) = m n^(m-1)` and finally for `m=4`:
|
||||
|
||||
>>> from sympy import Symbol, bernoulli, rsolve_poly
|
||||
>>> n = Symbol('n', integer=True)
|
||||
|
||||
>>> rsolve_poly([-1, 1], 4*n**3, n)
|
||||
C0 + n**4 - 2*n**3 + n**2
|
||||
|
||||
>>> bernoulli(4, n)
|
||||
n**4 - 2*n**3 + n**2 - 1/30
|
||||
|
||||
For the sake of completeness, `f(n)` can be:
|
||||
|
||||
[1] a polynomial -> rsolve_poly
|
||||
[2] a rational function -> rsolve_ratio
|
||||
[3] a hypergeometric function -> rsolve_hyper
|
||||
"""
|
||||
from collections import defaultdict
|
||||
|
||||
from sympy.concrete import product
|
||||
from sympy.core.singleton import S
|
||||
from sympy.core.numbers import Rational, I
|
||||
from sympy.core.symbol import Symbol, Wild, Dummy
|
||||
from sympy.core.relational import Equality
|
||||
from sympy.core.add import Add
|
||||
from sympy.core.mul import Mul
|
||||
from sympy.core.sorting import default_sort_key
|
||||
from sympy.core.sympify import sympify
|
||||
|
||||
from sympy.simplify import simplify, hypersimp, hypersimilar # type: ignore
|
||||
from sympy.solvers import solve, solve_undetermined_coeffs
|
||||
from sympy.polys import Poly, quo, gcd, lcm, roots, resultant
|
||||
from sympy.functions import binomial, factorial, FallingFactorial, RisingFactorial
|
||||
from sympy.matrices import Matrix, casoratian
|
||||
from sympy.utilities.iterables import numbered_symbols
|
||||
|
||||
|
||||
def rsolve_poly(coeffs, f, n, shift=0, **hints):
|
||||
r"""
|
||||
Given linear recurrence operator `\operatorname{L}` of order
|
||||
`k` with polynomial coefficients and inhomogeneous equation
|
||||
`\operatorname{L} y = f`, where `f` is a polynomial, we seek for
|
||||
all polynomial solutions over field `K` of characteristic zero.
|
||||
|
||||
The algorithm performs two basic steps:
|
||||
|
||||
(1) Compute degree `N` of the general polynomial solution.
|
||||
(2) Find all polynomials of degree `N` or less
|
||||
of `\operatorname{L} y = f`.
|
||||
|
||||
There are two methods for computing the polynomial solutions.
|
||||
If the degree bound is relatively small, i.e. it's smaller than
|
||||
or equal to the order of the recurrence, then naive method of
|
||||
undetermined coefficients is being used. This gives a system
|
||||
of algebraic equations with `N+1` unknowns.
|
||||
|
||||
In the other case, the algorithm performs transformation of the
|
||||
initial equation to an equivalent one for which the system of
|
||||
algebraic equations has only `r` indeterminates. This method is
|
||||
quite sophisticated (in comparison with the naive one) and was
|
||||
invented together by Abramov, Bronstein and Petkovsek.
|
||||
|
||||
It is possible to generalize the algorithm implemented here to
|
||||
the case of linear q-difference and differential equations.
|
||||
|
||||
Lets say that we would like to compute `m`-th Bernoulli polynomial
|
||||
up to a constant. For this we can use `b(n+1) - b(n) = m n^{m-1}`
|
||||
recurrence, which has solution `b(n) = B_m + C`. For example:
|
||||
|
||||
>>> from sympy import Symbol, rsolve_poly
|
||||
>>> n = Symbol('n', integer=True)
|
||||
|
||||
>>> rsolve_poly([-1, 1], 4*n**3, n)
|
||||
C0 + n**4 - 2*n**3 + n**2
|
||||
|
||||
References
|
||||
==========
|
||||
|
||||
.. [1] S. A. Abramov, M. Bronstein and M. Petkovsek, On polynomial
|
||||
solutions of linear operator equations, in: T. Levelt, ed.,
|
||||
Proc. ISSAC '95, ACM Press, New York, 1995, 290-296.
|
||||
|
||||
.. [2] M. Petkovsek, Hypergeometric solutions of linear recurrences
|
||||
with polynomial coefficients, J. Symbolic Computation,
|
||||
14 (1992), 243-264.
|
||||
|
||||
.. [3] M. Petkovsek, H. S. Wilf, D. Zeilberger, A = B, 1996.
|
||||
|
||||
"""
|
||||
f = sympify(f)
|
||||
|
||||
if not f.is_polynomial(n):
|
||||
return None
|
||||
|
||||
homogeneous = f.is_zero
|
||||
|
||||
r = len(coeffs) - 1
|
||||
|
||||
coeffs = [Poly(coeff, n) for coeff in coeffs]
|
||||
|
||||
polys = [Poly(0, n)]*(r + 1)
|
||||
terms = [(S.Zero, S.NegativeInfinity)]*(r + 1)
|
||||
|
||||
for i in range(r + 1):
|
||||
for j in range(i, r + 1):
|
||||
polys[i] += coeffs[j]*(binomial(j, i).as_poly(n))
|
||||
|
||||
if not polys[i].is_zero:
|
||||
(exp,), coeff = polys[i].LT()
|
||||
terms[i] = (coeff, exp)
|
||||
|
||||
d = b = terms[0][1]
|
||||
|
||||
for i in range(1, r + 1):
|
||||
if terms[i][1] > d:
|
||||
d = terms[i][1]
|
||||
|
||||
if terms[i][1] - i > b:
|
||||
b = terms[i][1] - i
|
||||
|
||||
d, b = int(d), int(b)
|
||||
|
||||
x = Dummy('x')
|
||||
|
||||
degree_poly = S.Zero
|
||||
|
||||
for i in range(r + 1):
|
||||
if terms[i][1] - i == b:
|
||||
degree_poly += terms[i][0]*FallingFactorial(x, i)
|
||||
|
||||
nni_roots = list(roots(degree_poly, x, filter='Z',
|
||||
predicate=lambda r: r >= 0).keys())
|
||||
|
||||
if nni_roots:
|
||||
N = [max(nni_roots)]
|
||||
else:
|
||||
N = []
|
||||
|
||||
if homogeneous:
|
||||
N += [-b - 1]
|
||||
else:
|
||||
N += [f.as_poly(n).degree() - b, -b - 1]
|
||||
|
||||
N = int(max(N))
|
||||
|
||||
if N < 0:
|
||||
if homogeneous:
|
||||
if hints.get('symbols', False):
|
||||
return (S.Zero, [])
|
||||
else:
|
||||
return S.Zero
|
||||
else:
|
||||
return None
|
||||
|
||||
if N <= r:
|
||||
C = []
|
||||
y = E = S.Zero
|
||||
|
||||
for i in range(N + 1):
|
||||
C.append(Symbol('C' + str(i + shift)))
|
||||
y += C[i] * n**i
|
||||
|
||||
for i in range(r + 1):
|
||||
E += coeffs[i].as_expr()*y.subs(n, n + i)
|
||||
|
||||
solutions = solve_undetermined_coeffs(E - f, C, n)
|
||||
|
||||
if solutions is not None:
|
||||
_C = C
|
||||
C = [c for c in C if (c not in solutions)]
|
||||
result = y.subs(solutions)
|
||||
else:
|
||||
return None # TBD
|
||||
else:
|
||||
A = r
|
||||
U = N + A + b + 1
|
||||
|
||||
nni_roots = list(roots(polys[r], filter='Z',
|
||||
predicate=lambda r: r >= 0).keys())
|
||||
|
||||
if nni_roots != []:
|
||||
a = max(nni_roots) + 1
|
||||
else:
|
||||
a = S.Zero
|
||||
|
||||
def _zero_vector(k):
|
||||
return [S.Zero] * k
|
||||
|
||||
def _one_vector(k):
|
||||
return [S.One] * k
|
||||
|
||||
def _delta(p, k):
|
||||
B = S.One
|
||||
D = p.subs(n, a + k)
|
||||
|
||||
for i in range(1, k + 1):
|
||||
B *= Rational(i - k - 1, i)
|
||||
D += B * p.subs(n, a + k - i)
|
||||
|
||||
return D
|
||||
|
||||
alpha = {}
|
||||
|
||||
for i in range(-A, d + 1):
|
||||
I = _one_vector(d + 1)
|
||||
|
||||
for k in range(1, d + 1):
|
||||
I[k] = I[k - 1] * (x + i - k + 1)/k
|
||||
|
||||
alpha[i] = S.Zero
|
||||
|
||||
for j in range(A + 1):
|
||||
for k in range(d + 1):
|
||||
B = binomial(k, i + j)
|
||||
D = _delta(polys[j].as_expr(), k)
|
||||
|
||||
alpha[i] += I[k]*B*D
|
||||
|
||||
V = Matrix(U, A, lambda i, j: int(i == j))
|
||||
|
||||
if homogeneous:
|
||||
for i in range(A, U):
|
||||
v = _zero_vector(A)
|
||||
|
||||
for k in range(1, A + b + 1):
|
||||
if i - k < 0:
|
||||
break
|
||||
|
||||
B = alpha[k - A].subs(x, i - k)
|
||||
|
||||
for j in range(A):
|
||||
v[j] += B * V[i - k, j]
|
||||
|
||||
denom = alpha[-A].subs(x, i)
|
||||
|
||||
for j in range(A):
|
||||
V[i, j] = -v[j] / denom
|
||||
else:
|
||||
G = _zero_vector(U)
|
||||
|
||||
for i in range(A, U):
|
||||
v = _zero_vector(A)
|
||||
g = S.Zero
|
||||
|
||||
for k in range(1, A + b + 1):
|
||||
if i - k < 0:
|
||||
break
|
||||
|
||||
B = alpha[k - A].subs(x, i - k)
|
||||
|
||||
for j in range(A):
|
||||
v[j] += B * V[i - k, j]
|
||||
|
||||
g += B * G[i - k]
|
||||
|
||||
denom = alpha[-A].subs(x, i)
|
||||
|
||||
for j in range(A):
|
||||
V[i, j] = -v[j] / denom
|
||||
|
||||
G[i] = (_delta(f, i - A) - g) / denom
|
||||
|
||||
P, Q = _one_vector(U), _zero_vector(A)
|
||||
|
||||
for i in range(1, U):
|
||||
P[i] = (P[i - 1] * (n - a - i + 1)/i).expand()
|
||||
|
||||
for i in range(A):
|
||||
Q[i] = Add(*[(v*p).expand() for v, p in zip(V[:, i], P)])
|
||||
|
||||
if not homogeneous:
|
||||
h = Add(*[(g*p).expand() for g, p in zip(G, P)])
|
||||
|
||||
C = [Symbol('C' + str(i + shift)) for i in range(A)]
|
||||
|
||||
g = lambda i: Add(*[c*_delta(q, i) for c, q in zip(C, Q)])
|
||||
|
||||
if homogeneous:
|
||||
E = [g(i) for i in range(N + 1, U)]
|
||||
else:
|
||||
E = [g(i) + _delta(h, i) for i in range(N + 1, U)]
|
||||
|
||||
if E != []:
|
||||
solutions = solve(E, *C)
|
||||
|
||||
if not solutions:
|
||||
if homogeneous:
|
||||
if hints.get('symbols', False):
|
||||
return (S.Zero, [])
|
||||
else:
|
||||
return S.Zero
|
||||
else:
|
||||
return None
|
||||
else:
|
||||
solutions = {}
|
||||
|
||||
if homogeneous:
|
||||
result = S.Zero
|
||||
else:
|
||||
result = h
|
||||
|
||||
_C = C[:]
|
||||
for c, q in list(zip(C, Q)):
|
||||
if c in solutions:
|
||||
s = solutions[c]*q
|
||||
C.remove(c)
|
||||
else:
|
||||
s = c*q
|
||||
|
||||
result += s.expand()
|
||||
|
||||
if C != _C:
|
||||
# renumber so they are contiguous
|
||||
result = result.xreplace(dict(zip(C, _C)))
|
||||
C = _C[:len(C)]
|
||||
|
||||
if hints.get('symbols', False):
|
||||
return (result, C)
|
||||
else:
|
||||
return result
|
||||
|
||||
|
||||
def rsolve_ratio(coeffs, f, n, **hints):
|
||||
r"""
|
||||
Given linear recurrence operator `\operatorname{L}` of order `k`
|
||||
with polynomial coefficients and inhomogeneous equation
|
||||
`\operatorname{L} y = f`, where `f` is a polynomial, we seek
|
||||
for all rational solutions over field `K` of characteristic zero.
|
||||
|
||||
This procedure accepts only polynomials, however if you are
|
||||
interested in solving recurrence with rational coefficients
|
||||
then use ``rsolve`` which will pre-process the given equation
|
||||
and run this procedure with polynomial arguments.
|
||||
|
||||
The algorithm performs two basic steps:
|
||||
|
||||
(1) Compute polynomial `v(n)` which can be used as universal
|
||||
denominator of any rational solution of equation
|
||||
`\operatorname{L} y = f`.
|
||||
|
||||
(2) Construct new linear difference equation by substitution
|
||||
`y(n) = u(n)/v(n)` and solve it for `u(n)` finding all its
|
||||
polynomial solutions. Return ``None`` if none were found.
|
||||
|
||||
The algorithm implemented here is a revised version of the original
|
||||
Abramov's algorithm, developed in 1989. The new approach is much
|
||||
simpler to implement and has better overall efficiency. This
|
||||
method can be easily adapted to the q-difference equations case.
|
||||
|
||||
Besides finding rational solutions alone, this functions is
|
||||
an important part of Hyper algorithm where it is used to find
|
||||
a particular solution for the inhomogeneous part of a recurrence.
|
||||
|
||||
Examples
|
||||
========
|
||||
|
||||
>>> from sympy.abc import x
|
||||
>>> from sympy.solvers.recurr import rsolve_ratio
|
||||
>>> rsolve_ratio([-2*x**3 + x**2 + 2*x - 1, 2*x**3 + x**2 - 6*x,
|
||||
... - 2*x**3 - 11*x**2 - 18*x - 9, 2*x**3 + 13*x**2 + 22*x + 8], 0, x)
|
||||
C0*(2*x - 3)/(2*(x**2 - 1))
|
||||
|
||||
References
|
||||
==========
|
||||
|
||||
.. [1] S. A. Abramov, Rational solutions of linear difference
|
||||
and q-difference equations with polynomial coefficients,
|
||||
in: T. Levelt, ed., Proc. ISSAC '95, ACM Press, New York,
|
||||
1995, 285-289
|
||||
|
||||
See Also
|
||||
========
|
||||
|
||||
rsolve_hyper
|
||||
"""
|
||||
f = sympify(f)
|
||||
|
||||
if not f.is_polynomial(n):
|
||||
return None
|
||||
|
||||
coeffs = list(map(sympify, coeffs))
|
||||
|
||||
r = len(coeffs) - 1
|
||||
|
||||
A, B = coeffs[r], coeffs[0]
|
||||
A = A.subs(n, n - r).expand()
|
||||
|
||||
h = Dummy('h')
|
||||
|
||||
res = resultant(A, B.subs(n, n + h), n)
|
||||
|
||||
if not res.is_polynomial(h):
|
||||
p, q = res.as_numer_denom()
|
||||
res = quo(p, q, h)
|
||||
|
||||
nni_roots = list(roots(res, h, filter='Z',
|
||||
predicate=lambda r: r >= 0).keys())
|
||||
|
||||
if not nni_roots:
|
||||
return rsolve_poly(coeffs, f, n, **hints)
|
||||
else:
|
||||
C, numers = S.One, [S.Zero]*(r + 1)
|
||||
|
||||
for i in range(int(max(nni_roots)), -1, -1):
|
||||
d = gcd(A, B.subs(n, n + i), n)
|
||||
|
||||
A = quo(A, d, n)
|
||||
B = quo(B, d.subs(n, n - i), n)
|
||||
|
||||
C *= Mul(*[d.subs(n, n - j) for j in range(i + 1)])
|
||||
|
||||
denoms = [C.subs(n, n + i) for i in range(r + 1)]
|
||||
|
||||
for i in range(r + 1):
|
||||
g = gcd(coeffs[i], denoms[i], n)
|
||||
|
||||
numers[i] = quo(coeffs[i], g, n)
|
||||
denoms[i] = quo(denoms[i], g, n)
|
||||
|
||||
for i in range(r + 1):
|
||||
numers[i] *= Mul(*(denoms[:i] + denoms[i + 1:]))
|
||||
|
||||
result = rsolve_poly(numers, f * Mul(*denoms), n, **hints)
|
||||
|
||||
if result is not None:
|
||||
if hints.get('symbols', False):
|
||||
return (simplify(result[0] / C), result[1])
|
||||
else:
|
||||
return simplify(result / C)
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
def rsolve_hyper(coeffs, f, n, **hints):
|
||||
r"""
|
||||
Given linear recurrence operator `\operatorname{L}` of order `k`
|
||||
with polynomial coefficients and inhomogeneous equation
|
||||
`\operatorname{L} y = f` we seek for all hypergeometric solutions
|
||||
over field `K` of characteristic zero.
|
||||
|
||||
The inhomogeneous part can be either hypergeometric or a sum
|
||||
of a fixed number of pairwise dissimilar hypergeometric terms.
|
||||
|
||||
The algorithm performs three basic steps:
|
||||
|
||||
(1) Group together similar hypergeometric terms in the
|
||||
inhomogeneous part of `\operatorname{L} y = f`, and find
|
||||
particular solution using Abramov's algorithm.
|
||||
|
||||
(2) Compute generating set of `\operatorname{L}` and find basis
|
||||
in it, so that all solutions are linearly independent.
|
||||
|
||||
(3) Form final solution with the number of arbitrary
|
||||
constants equal to dimension of basis of `\operatorname{L}`.
|
||||
|
||||
Term `a(n)` is hypergeometric if it is annihilated by first order
|
||||
linear difference equations with polynomial coefficients or, in
|
||||
simpler words, if consecutive term ratio is a rational function.
|
||||
|
||||
The output of this procedure is a linear combination of fixed
|
||||
number of hypergeometric terms. However the underlying method
|
||||
can generate larger class of solutions - D'Alembertian terms.
|
||||
|
||||
Note also that this method not only computes the kernel of the
|
||||
inhomogeneous equation, but also reduces in to a basis so that
|
||||
solutions generated by this procedure are linearly independent
|
||||
|
||||
Examples
|
||||
========
|
||||
|
||||
>>> from sympy.solvers import rsolve_hyper
|
||||
>>> from sympy.abc import x
|
||||
|
||||
>>> rsolve_hyper([-1, -1, 1], 0, x)
|
||||
C0*(1/2 - sqrt(5)/2)**x + C1*(1/2 + sqrt(5)/2)**x
|
||||
|
||||
>>> rsolve_hyper([-1, 1], 1 + x, x)
|
||||
C0 + x*(x + 1)/2
|
||||
|
||||
References
|
||||
==========
|
||||
|
||||
.. [1] M. Petkovsek, Hypergeometric solutions of linear recurrences
|
||||
with polynomial coefficients, J. Symbolic Computation,
|
||||
14 (1992), 243-264.
|
||||
|
||||
.. [2] M. Petkovsek, H. S. Wilf, D. Zeilberger, A = B, 1996.
|
||||
"""
|
||||
coeffs = list(map(sympify, coeffs))
|
||||
|
||||
f = sympify(f)
|
||||
|
||||
r, kernel, symbols = len(coeffs) - 1, [], set()
|
||||
|
||||
if not f.is_zero:
|
||||
if f.is_Add:
|
||||
similar = {}
|
||||
|
||||
for g in f.expand().args:
|
||||
if not g.is_hypergeometric(n):
|
||||
return None
|
||||
|
||||
for h in similar.keys():
|
||||
if hypersimilar(g, h, n):
|
||||
similar[h] += g
|
||||
break
|
||||
else:
|
||||
similar[g] = S.Zero
|
||||
|
||||
inhomogeneous = [g + h for g, h in similar.items()]
|
||||
elif f.is_hypergeometric(n):
|
||||
inhomogeneous = [f]
|
||||
else:
|
||||
return None
|
||||
|
||||
for i, g in enumerate(inhomogeneous):
|
||||
coeff, polys = S.One, coeffs[:]
|
||||
denoms = [S.One]*(r + 1)
|
||||
|
||||
s = hypersimp(g, n)
|
||||
|
||||
for j in range(1, r + 1):
|
||||
coeff *= s.subs(n, n + j - 1)
|
||||
|
||||
p, q = coeff.as_numer_denom()
|
||||
|
||||
polys[j] *= p
|
||||
denoms[j] = q
|
||||
|
||||
for j in range(r + 1):
|
||||
polys[j] *= Mul(*(denoms[:j] + denoms[j + 1:]))
|
||||
|
||||
# FIXME: The call to rsolve_ratio below should suffice (rsolve_poly
|
||||
# call can be removed) but the XFAIL test_rsolve_ratio_missed must
|
||||
# be fixed first.
|
||||
R = rsolve_ratio(polys, Mul(*denoms), n, symbols=True)
|
||||
if R is not None:
|
||||
R, syms = R
|
||||
if syms:
|
||||
R = R.subs(zip(syms, [0]*len(syms)))
|
||||
else:
|
||||
R = rsolve_poly(polys, Mul(*denoms), n)
|
||||
|
||||
if R:
|
||||
inhomogeneous[i] *= R
|
||||
else:
|
||||
return None
|
||||
|
||||
result = Add(*inhomogeneous)
|
||||
result = simplify(result)
|
||||
else:
|
||||
result = S.Zero
|
||||
|
||||
Z = Dummy('Z')
|
||||
|
||||
p, q = coeffs[0], coeffs[r].subs(n, n - r + 1)
|
||||
|
||||
p_factors = list(roots(p, n).keys())
|
||||
q_factors = list(roots(q, n).keys())
|
||||
|
||||
factors = [(S.One, S.One)]
|
||||
|
||||
for p in p_factors:
|
||||
for q in q_factors:
|
||||
if p.is_integer and q.is_integer and p <= q:
|
||||
continue
|
||||
else:
|
||||
factors += [(n - p, n - q)]
|
||||
|
||||
p = [(n - p, S.One) for p in p_factors]
|
||||
q = [(S.One, n - q) for q in q_factors]
|
||||
|
||||
factors = p + factors + q
|
||||
|
||||
for A, B in factors:
|
||||
polys, degrees = [], []
|
||||
D = A*B.subs(n, n + r - 1)
|
||||
|
||||
for i in range(r + 1):
|
||||
a = Mul(*[A.subs(n, n + j) for j in range(i)])
|
||||
b = Mul(*[B.subs(n, n + j) for j in range(i, r)])
|
||||
|
||||
poly = quo(coeffs[i]*a*b, D, n)
|
||||
polys.append(poly.as_poly(n))
|
||||
|
||||
if not poly.is_zero:
|
||||
degrees.append(polys[i].degree())
|
||||
|
||||
if degrees:
|
||||
d, poly = max(degrees), S.Zero
|
||||
else:
|
||||
return None
|
||||
|
||||
for i in range(r + 1):
|
||||
coeff = polys[i].nth(d)
|
||||
|
||||
if coeff is not S.Zero:
|
||||
poly += coeff * Z**i
|
||||
|
||||
for z in roots(poly, Z).keys():
|
||||
if z.is_zero:
|
||||
continue
|
||||
|
||||
recurr_coeffs = [polys[i].as_expr()*z**i for i in range(r + 1)]
|
||||
if d == 0 and 0 != Add(*[recurr_coeffs[j]*j for j in range(1, r + 1)]):
|
||||
# faster inline check (than calling rsolve_poly) for a
|
||||
# constant solution to a constant coefficient recurrence.
|
||||
sol = [Symbol("C" + str(len(symbols)))]
|
||||
else:
|
||||
sol, syms = rsolve_poly(recurr_coeffs, 0, n, len(symbols), symbols=True)
|
||||
sol = sol.collect(syms)
|
||||
sol = [sol.coeff(s) for s in syms]
|
||||
|
||||
for C in sol:
|
||||
ratio = z * A * C.subs(n, n + 1) / B / C
|
||||
ratio = simplify(ratio)
|
||||
# If there is a nonnegative root in the denominator of the ratio,
|
||||
# this indicates that the term y(n_root) is zero, and one should
|
||||
# start the product with the term y(n_root + 1).
|
||||
n0 = 0
|
||||
for n_root in roots(ratio.as_numer_denom()[1], n).keys():
|
||||
if n_root.has(I):
|
||||
return None
|
||||
elif (n0 < (n_root + 1)) == True:
|
||||
n0 = n_root + 1
|
||||
K = product(ratio, (n, n0, n - 1))
|
||||
if K.has(factorial, FallingFactorial, RisingFactorial):
|
||||
K = simplify(K)
|
||||
|
||||
if casoratian(kernel + [K], n, zero=False) != 0:
|
||||
kernel.append(K)
|
||||
|
||||
kernel.sort(key=default_sort_key)
|
||||
sk = list(zip(numbered_symbols('C'), kernel))
|
||||
|
||||
for C, ker in sk:
|
||||
result += C * ker
|
||||
|
||||
if hints.get('symbols', False):
|
||||
# XXX: This returns the symbols in a non-deterministic order
|
||||
symbols |= {s for s, k in sk}
|
||||
return (result, list(symbols))
|
||||
else:
|
||||
return result
|
||||
|
||||
|
||||
def rsolve(f, y, init=None):
|
||||
r"""
|
||||
Solve univariate recurrence with rational coefficients.
|
||||
|
||||
Given `k`-th order linear recurrence `\operatorname{L} y = f`,
|
||||
or equivalently:
|
||||
|
||||
.. math:: a_{k}(n) y(n+k) + a_{k-1}(n) y(n+k-1) +
|
||||
\cdots + a_{0}(n) y(n) = f(n)
|
||||
|
||||
where `a_{i}(n)`, for `i=0, \ldots, k`, are polynomials or rational
|
||||
functions in `n`, and `f` is a hypergeometric function or a sum
|
||||
of a fixed number of pairwise dissimilar hypergeometric terms in
|
||||
`n`, finds all solutions or returns ``None``, if none were found.
|
||||
|
||||
Initial conditions can be given as a dictionary in two forms:
|
||||
|
||||
(1) ``{ n_0 : v_0, n_1 : v_1, ..., n_m : v_m}``
|
||||
(2) ``{y(n_0) : v_0, y(n_1) : v_1, ..., y(n_m) : v_m}``
|
||||
|
||||
or as a list ``L`` of values:
|
||||
|
||||
``L = [v_0, v_1, ..., v_m]``
|
||||
|
||||
where ``L[i] = v_i``, for `i=0, \ldots, m`, maps to `y(n_i)`.
|
||||
|
||||
Examples
|
||||
========
|
||||
|
||||
Lets consider the following recurrence:
|
||||
|
||||
.. math:: (n - 1) y(n + 2) - (n^2 + 3 n - 2) y(n + 1) +
|
||||
2 n (n + 1) y(n) = 0
|
||||
|
||||
>>> from sympy import Function, rsolve
|
||||
>>> from sympy.abc import n
|
||||
>>> y = Function('y')
|
||||
|
||||
>>> f = (n - 1)*y(n + 2) - (n**2 + 3*n - 2)*y(n + 1) + 2*n*(n + 1)*y(n)
|
||||
|
||||
>>> rsolve(f, y(n))
|
||||
2**n*C0 + C1*factorial(n)
|
||||
|
||||
>>> rsolve(f, y(n), {y(0):0, y(1):3})
|
||||
3*2**n - 3*factorial(n)
|
||||
|
||||
See Also
|
||||
========
|
||||
|
||||
rsolve_poly, rsolve_ratio, rsolve_hyper
|
||||
|
||||
"""
|
||||
if isinstance(f, Equality):
|
||||
f = f.lhs - f.rhs
|
||||
|
||||
n = y.args[0]
|
||||
k = Wild('k', exclude=(n,))
|
||||
|
||||
# Preprocess user input to allow things like
|
||||
# y(n) + a*(y(n + 1) + y(n - 1))/2
|
||||
f = f.expand().collect(y.func(Wild('m', integer=True)))
|
||||
|
||||
h_part = defaultdict(list)
|
||||
i_part = []
|
||||
for g in Add.make_args(f):
|
||||
coeff, dep = g.as_coeff_mul(y.func)
|
||||
if not dep:
|
||||
i_part.append(coeff)
|
||||
continue
|
||||
for h in dep:
|
||||
if h.is_Function and h.func == y.func:
|
||||
result = h.args[0].match(n + k)
|
||||
if result is not None:
|
||||
h_part[int(result[k])].append(coeff)
|
||||
continue
|
||||
raise ValueError(
|
||||
"'%s(%s + k)' expected, got '%s'" % (y.func, n, h))
|
||||
for k in h_part:
|
||||
h_part[k] = Add(*h_part[k])
|
||||
h_part.default_factory = lambda: 0
|
||||
i_part = Add(*i_part)
|
||||
|
||||
for k, coeff in h_part.items():
|
||||
h_part[k] = simplify(coeff)
|
||||
|
||||
common = S.One
|
||||
|
||||
if not i_part.is_zero and not i_part.is_hypergeometric(n) and \
|
||||
not (i_part.is_Add and all((x.is_hypergeometric(n) for x in i_part.expand().args))):
|
||||
raise ValueError("The independent term should be a sum of hypergeometric functions, got '%s'" % i_part)
|
||||
|
||||
for coeff in h_part.values():
|
||||
if coeff.is_rational_function(n):
|
||||
if not coeff.is_polynomial(n):
|
||||
common = lcm(common, coeff.as_numer_denom()[1], n)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Polynomial or rational function expected, got '%s'" % coeff)
|
||||
|
||||
i_numer, i_denom = i_part.as_numer_denom()
|
||||
|
||||
if i_denom.is_polynomial(n):
|
||||
common = lcm(common, i_denom, n)
|
||||
|
||||
if common is not S.One:
|
||||
for k, coeff in h_part.items():
|
||||
numer, denom = coeff.as_numer_denom()
|
||||
h_part[k] = numer*quo(common, denom, n)
|
||||
|
||||
i_part = i_numer*quo(common, i_denom, n)
|
||||
|
||||
K_min = min(h_part.keys())
|
||||
|
||||
if K_min < 0:
|
||||
K = abs(K_min)
|
||||
|
||||
H_part = defaultdict(lambda: S.Zero)
|
||||
i_part = i_part.subs(n, n + K).expand()
|
||||
common = common.subs(n, n + K).expand()
|
||||
|
||||
for k, coeff in h_part.items():
|
||||
H_part[k + K] = coeff.subs(n, n + K).expand()
|
||||
else:
|
||||
H_part = h_part
|
||||
|
||||
K_max = max(H_part.keys())
|
||||
coeffs = [H_part[i] for i in range(K_max + 1)]
|
||||
|
||||
result = rsolve_hyper(coeffs, -i_part, n, symbols=True)
|
||||
|
||||
if result is None:
|
||||
return None
|
||||
|
||||
solution, symbols = result
|
||||
|
||||
if init in ({}, []):
|
||||
init = None
|
||||
|
||||
if symbols and init is not None:
|
||||
if isinstance(init, list):
|
||||
init = {i: init[i] for i in range(len(init))}
|
||||
|
||||
equations = []
|
||||
|
||||
for k, v in init.items():
|
||||
try:
|
||||
i = int(k)
|
||||
except TypeError:
|
||||
if k.is_Function and k.func == y.func:
|
||||
i = int(k.args[0])
|
||||
else:
|
||||
raise ValueError("Integer or term expected, got '%s'" % k)
|
||||
|
||||
eq = solution.subs(n, i) - v
|
||||
if eq.has(S.NaN):
|
||||
eq = solution.limit(n, i) - v
|
||||
equations.append(eq)
|
||||
|
||||
result = solve(equations, *symbols)
|
||||
|
||||
if not result:
|
||||
return None
|
||||
else:
|
||||
solution = solution.subs(result)
|
||||
|
||||
return solution
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,179 @@
|
||||
"""
|
||||
If the arbitrary constant class from issue 4435 is ever implemented, this
|
||||
should serve as a set of test cases.
|
||||
"""
|
||||
|
||||
from sympy.core.function import Function
|
||||
from sympy.core.numbers import I
|
||||
from sympy.core.power import Pow
|
||||
from sympy.core.relational import Eq
|
||||
from sympy.core.singleton import S
|
||||
from sympy.core.symbol import Symbol
|
||||
from sympy.functions.elementary.exponential import (exp, log)
|
||||
from sympy.functions.elementary.hyperbolic import (cosh, sinh)
|
||||
from sympy.functions.elementary.miscellaneous import sqrt
|
||||
from sympy.functions.elementary.trigonometric import (acos, cos, sin)
|
||||
from sympy.integrals.integrals import Integral
|
||||
from sympy.solvers.ode.ode import constantsimp, constant_renumber
|
||||
from sympy.testing.pytest import XFAIL
|
||||
|
||||
|
||||
x = Symbol('x')
|
||||
y = Symbol('y')
|
||||
z = Symbol('z')
|
||||
u2 = Symbol('u2')
|
||||
_a = Symbol('_a')
|
||||
C1 = Symbol('C1')
|
||||
C2 = Symbol('C2')
|
||||
C3 = Symbol('C3')
|
||||
f = Function('f')
|
||||
|
||||
|
||||
def test_constant_mul():
|
||||
# We want C1 (Constant) below to absorb the y's, but not the x's
|
||||
assert constant_renumber(constantsimp(y*C1, [C1])) == C1*y
|
||||
assert constant_renumber(constantsimp(C1*y, [C1])) == C1*y
|
||||
assert constant_renumber(constantsimp(x*C1, [C1])) == x*C1
|
||||
assert constant_renumber(constantsimp(C1*x, [C1])) == x*C1
|
||||
assert constant_renumber(constantsimp(2*C1, [C1])) == C1
|
||||
assert constant_renumber(constantsimp(C1*2, [C1])) == C1
|
||||
assert constant_renumber(constantsimp(y*C1*x, [C1, y])) == C1*x
|
||||
assert constant_renumber(constantsimp(x*y*C1, [C1, y])) == x*C1
|
||||
assert constant_renumber(constantsimp(y*x*C1, [C1, y])) == x*C1
|
||||
assert constant_renumber(constantsimp(C1*x*y, [C1, y])) == C1*x
|
||||
assert constant_renumber(constantsimp(x*C1*y, [C1, y])) == x*C1
|
||||
assert constant_renumber(constantsimp(C1*y*(y + 1), [C1])) == C1*y*(y+1)
|
||||
assert constant_renumber(constantsimp(y*C1*(y + 1), [C1])) == C1*y*(y+1)
|
||||
assert constant_renumber(constantsimp(x*(y*C1), [C1])) == x*y*C1
|
||||
assert constant_renumber(constantsimp(x*(C1*y), [C1])) == x*y*C1
|
||||
assert constant_renumber(constantsimp(C1*(x*y), [C1, y])) == C1*x
|
||||
assert constant_renumber(constantsimp((x*y)*C1, [C1, y])) == x*C1
|
||||
assert constant_renumber(constantsimp((y*x)*C1, [C1, y])) == x*C1
|
||||
assert constant_renumber(constantsimp(y*(y + 1)*C1, [C1, y])) == C1
|
||||
assert constant_renumber(constantsimp((C1*x)*y, [C1, y])) == C1*x
|
||||
assert constant_renumber(constantsimp(y*(x*C1), [C1, y])) == x*C1
|
||||
assert constant_renumber(constantsimp((x*C1)*y, [C1, y])) == x*C1
|
||||
assert constant_renumber(constantsimp(C1*x*y*x*y*2, [C1, y])) == C1*x**2
|
||||
assert constant_renumber(constantsimp(C1*x*y*z, [C1, y, z])) == C1*x
|
||||
assert constant_renumber(constantsimp(C1*x*y**2*sin(z), [C1, y, z])) == C1*x
|
||||
assert constant_renumber(constantsimp(C1*C1, [C1])) == C1
|
||||
assert constant_renumber(constantsimp(C1*C2, [C1, C2])) == C1
|
||||
assert constant_renumber(constantsimp(C2*C2, [C1, C2])) == C1
|
||||
assert constant_renumber(constantsimp(C1*C1*C2, [C1, C2])) == C1
|
||||
assert constant_renumber(constantsimp(C1*x*2**x, [C1])) == C1*x*2**x
|
||||
|
||||
def test_constant_add():
|
||||
assert constant_renumber(constantsimp(C1 + C1, [C1])) == C1
|
||||
assert constant_renumber(constantsimp(C1 + 2, [C1])) == C1
|
||||
assert constant_renumber(constantsimp(2 + C1, [C1])) == C1
|
||||
assert constant_renumber(constantsimp(C1 + y, [C1, y])) == C1
|
||||
assert constant_renumber(constantsimp(C1 + x, [C1])) == C1 + x
|
||||
assert constant_renumber(constantsimp(C1 + C1, [C1])) == C1
|
||||
assert constant_renumber(constantsimp(C1 + C2, [C1, C2])) == C1
|
||||
assert constant_renumber(constantsimp(C2 + C1, [C1, C2])) == C1
|
||||
assert constant_renumber(constantsimp(C1 + C2 + C1, [C1, C2])) == C1
|
||||
|
||||
|
||||
def test_constant_power_as_base():
|
||||
assert constant_renumber(constantsimp(C1**C1, [C1])) == C1
|
||||
assert constant_renumber(constantsimp(Pow(C1, C1), [C1])) == C1
|
||||
assert constant_renumber(constantsimp(C1**C1, [C1])) == C1
|
||||
assert constant_renumber(constantsimp(C1**C2, [C1, C2])) == C1
|
||||
assert constant_renumber(constantsimp(C2**C1, [C1, C2])) == C1
|
||||
assert constant_renumber(constantsimp(C2**C2, [C1, C2])) == C1
|
||||
assert constant_renumber(constantsimp(C1**y, [C1, y])) == C1
|
||||
assert constant_renumber(constantsimp(C1**x, [C1])) == C1**x
|
||||
assert constant_renumber(constantsimp(C1**2, [C1])) == C1
|
||||
assert constant_renumber(
|
||||
constantsimp(C1**(x*y), [C1])) == C1**(x*y)
|
||||
|
||||
|
||||
def test_constant_power_as_exp():
|
||||
assert constant_renumber(constantsimp(x**C1, [C1])) == x**C1
|
||||
assert constant_renumber(constantsimp(y**C1, [C1, y])) == C1
|
||||
assert constant_renumber(constantsimp(x**y**C1, [C1, y])) == x**C1
|
||||
assert constant_renumber(
|
||||
constantsimp((x**y)**C1, [C1])) == (x**y)**C1
|
||||
assert constant_renumber(
|
||||
constantsimp(x**(y**C1), [C1, y])) == x**C1
|
||||
assert constant_renumber(constantsimp(x**C1**y, [C1, y])) == x**C1
|
||||
assert constant_renumber(
|
||||
constantsimp(x**(C1**y), [C1, y])) == x**C1
|
||||
assert constant_renumber(
|
||||
constantsimp((x**C1)**y, [C1])) == (x**C1)**y
|
||||
assert constant_renumber(constantsimp(2**C1, [C1])) == C1
|
||||
assert constant_renumber(constantsimp(S(2)**C1, [C1])) == C1
|
||||
assert constant_renumber(constantsimp(exp(C1), [C1])) == C1
|
||||
assert constant_renumber(
|
||||
constantsimp(exp(C1 + x), [C1])) == C1*exp(x)
|
||||
assert constant_renumber(constantsimp(Pow(2, C1), [C1])) == C1
|
||||
|
||||
|
||||
def test_constant_function():
|
||||
assert constant_renumber(constantsimp(sin(C1), [C1])) == C1
|
||||
assert constant_renumber(constantsimp(f(C1), [C1])) == C1
|
||||
assert constant_renumber(constantsimp(f(C1, C1), [C1])) == C1
|
||||
assert constant_renumber(constantsimp(f(C1, C2), [C1, C2])) == C1
|
||||
assert constant_renumber(constantsimp(f(C2, C1), [C1, C2])) == C1
|
||||
assert constant_renumber(constantsimp(f(C2, C2), [C1, C2])) == C1
|
||||
assert constant_renumber(
|
||||
constantsimp(f(C1, x), [C1])) == f(C1, x)
|
||||
assert constant_renumber(constantsimp(f(C1, y), [C1, y])) == C1
|
||||
assert constant_renumber(constantsimp(f(y, C1), [C1, y])) == C1
|
||||
assert constant_renumber(constantsimp(f(C1, y, C2), [C1, C2, y])) == C1
|
||||
|
||||
|
||||
def test_constant_function_multiple():
|
||||
# The rules to not renumber in this case would be too complicated, and
|
||||
# dsolve is not likely to ever encounter anything remotely like this.
|
||||
assert constant_renumber(
|
||||
constantsimp(f(C1, C1, x), [C1])) == f(C1, C1, x)
|
||||
|
||||
|
||||
def test_constant_multiple():
|
||||
assert constant_renumber(constantsimp(C1*2 + 2, [C1])) == C1
|
||||
assert constant_renumber(constantsimp(x*2/C1, [C1])) == C1*x
|
||||
assert constant_renumber(constantsimp(C1**2*2 + 2, [C1])) == C1
|
||||
assert constant_renumber(
|
||||
constantsimp(sin(2*C1) + x + sqrt(2), [C1])) == C1 + x
|
||||
assert constant_renumber(constantsimp(2*C1 + C2, [C1, C2])) == C1
|
||||
|
||||
def test_constant_repeated():
|
||||
assert C1 + C1*x == constant_renumber( C1 + C1*x)
|
||||
|
||||
def test_ode_solutions():
|
||||
# only a few examples here, the rest will be tested in the actual dsolve tests
|
||||
assert constant_renumber(constantsimp(C1*exp(2*x) + exp(x)*(C2 + C3), [C1, C2, C3])) == \
|
||||
constant_renumber(C1*exp(x) + C2*exp(2*x))
|
||||
assert constant_renumber(
|
||||
constantsimp(Eq(f(x), I*C1*sinh(x/3) + C2*cosh(x/3)), [C1, C2])
|
||||
) == constant_renumber(Eq(f(x), C1*sinh(x/3) + C2*cosh(x/3)))
|
||||
assert constant_renumber(constantsimp(Eq(f(x), acos((-C1)/cos(x))), [C1])) == \
|
||||
Eq(f(x), acos(C1/cos(x)))
|
||||
assert constant_renumber(
|
||||
constantsimp(Eq(log(f(x)/C1) + 2*exp(x/f(x)), 0), [C1])
|
||||
) == Eq(log(C1*f(x)) + 2*exp(x/f(x)), 0)
|
||||
assert constant_renumber(constantsimp(Eq(log(x*sqrt(2)*sqrt(1/x)*sqrt(f(x))
|
||||
/C1) + x**2/(2*f(x)**2), 0), [C1])) == \
|
||||
Eq(log(C1*sqrt(x)*sqrt(f(x))) + x**2/(2*f(x)**2), 0)
|
||||
assert constant_renumber(constantsimp(Eq(-exp(-f(x)/x)*sin(f(x)/x)/2 + log(x/C1) -
|
||||
cos(f(x)/x)*exp(-f(x)/x)/2, 0), [C1])) == \
|
||||
Eq(-exp(-f(x)/x)*sin(f(x)/x)/2 + log(C1*x) - cos(f(x)/x)*
|
||||
exp(-f(x)/x)/2, 0)
|
||||
assert constant_renumber(constantsimp(Eq(-Integral(-1/(sqrt(1 - u2**2)*u2),
|
||||
(u2, _a, x/f(x))) + log(f(x)/C1), 0), [C1])) == \
|
||||
Eq(-Integral(-1/(u2*sqrt(1 - u2**2)), (u2, _a, x/f(x))) +
|
||||
log(C1*f(x)), 0)
|
||||
assert [constantsimp(i, [C1]) for i in [Eq(f(x), sqrt(-C1*x + x**2)), Eq(f(x), -sqrt(-C1*x + x**2))]] == \
|
||||
[Eq(f(x), sqrt(x*(C1 + x))), Eq(f(x), -sqrt(x*(C1 + x)))]
|
||||
|
||||
|
||||
@XFAIL
|
||||
def test_nonlocal_simplification():
|
||||
assert constantsimp(C1 + C2+x*C2, [C1, C2]) == C1 + C2*x
|
||||
|
||||
|
||||
def test_constant_Eq():
|
||||
# C1 on the rhs is well-tested, but the lhs is only tested here
|
||||
assert constantsimp(Eq(C1, 3 + f(x)*x), [C1]) == Eq(x*f(x), C1)
|
||||
assert constantsimp(Eq(C1, 3 * f(x)*x), [C1]) == Eq(f(x)*x, C1)
|
||||
@@ -0,0 +1,59 @@
|
||||
from sympy.solvers.decompogen import decompogen, compogen
|
||||
from sympy.core.symbol import symbols
|
||||
from sympy.functions.elementary.complexes import Abs
|
||||
from sympy.functions.elementary.exponential import exp
|
||||
from sympy.functions.elementary.miscellaneous import sqrt, Max
|
||||
from sympy.functions.elementary.trigonometric import (cos, sin)
|
||||
from sympy.testing.pytest import XFAIL, raises
|
||||
|
||||
x, y = symbols('x y')
|
||||
|
||||
|
||||
def test_decompogen():
|
||||
assert decompogen(sin(cos(x)), x) == [sin(x), cos(x)]
|
||||
assert decompogen(sin(x)**2 + sin(x) + 1, x) == [x**2 + x + 1, sin(x)]
|
||||
assert decompogen(sqrt(6*x**2 - 5), x) == [sqrt(x), 6*x**2 - 5]
|
||||
assert decompogen(sin(sqrt(cos(x**2 + 1))), x) == [sin(x), sqrt(x), cos(x), x**2 + 1]
|
||||
assert decompogen(Abs(cos(x)**2 + 3*cos(x) - 4), x) == [Abs(x), x**2 + 3*x - 4, cos(x)]
|
||||
assert decompogen(sin(x)**2 + sin(x) - sqrt(3)/2, x) == [x**2 + x - sqrt(3)/2, sin(x)]
|
||||
assert decompogen(Abs(cos(y)**2 + 3*cos(x) - 4), x) == [Abs(x), 3*x + cos(y)**2 - 4, cos(x)]
|
||||
assert decompogen(x, y) == [x]
|
||||
assert decompogen(1, x) == [1]
|
||||
assert decompogen(Max(3, x), x) == [Max(3, x)]
|
||||
raises(TypeError, lambda: decompogen(x < 5, x))
|
||||
u = 2*x + 3
|
||||
assert decompogen(Max(sqrt(u),(u)**2), x) == [Max(sqrt(x), x**2), u]
|
||||
assert decompogen(Max(u, u**2, y), x) == [Max(x, x**2, y), u]
|
||||
assert decompogen(Max(sin(x), u), x) == [Max(2*x + 3, sin(x))]
|
||||
|
||||
|
||||
def test_decompogen_poly():
|
||||
assert decompogen(x**4 + 2*x**2 + 1, x) == [x**2 + 2*x + 1, x**2]
|
||||
assert decompogen(x**4 + 2*x**3 - x - 1, x) == [x**2 - x - 1, x**2 + x]
|
||||
|
||||
|
||||
@XFAIL
|
||||
def test_decompogen_fails():
|
||||
A = lambda x: x**2 + 2*x + 3
|
||||
B = lambda x: 4*x**2 + 5*x + 6
|
||||
assert decompogen(A(x*exp(x)), x) == [x**2 + 2*x + 3, x*exp(x)]
|
||||
assert decompogen(A(B(x)), x) == [x**2 + 2*x + 3, 4*x**2 + 5*x + 6]
|
||||
assert decompogen(A(1/x + 1/x**2), x) == [x**2 + 2*x + 3, 1/x + 1/x**2]
|
||||
assert decompogen(A(1/x + 2/(x + 1)), x) == [x**2 + 2*x + 3, 1/x + 2/(x + 1)]
|
||||
|
||||
|
||||
def test_compogen():
|
||||
assert compogen([sin(x), cos(x)], x) == sin(cos(x))
|
||||
assert compogen([x**2 + x + 1, sin(x)], x) == sin(x)**2 + sin(x) + 1
|
||||
assert compogen([sqrt(x), 6*x**2 - 5], x) == sqrt(6*x**2 - 5)
|
||||
assert compogen([sin(x), sqrt(x), cos(x), x**2 + 1], x) == sin(sqrt(
|
||||
cos(x**2 + 1)))
|
||||
assert compogen([Abs(x), x**2 + 3*x - 4, cos(x)], x) == Abs(cos(x)**2 +
|
||||
3*cos(x) - 4)
|
||||
assert compogen([x**2 + x - sqrt(3)/2, sin(x)], x) == (sin(x)**2 + sin(x) -
|
||||
sqrt(3)/2)
|
||||
assert compogen([Abs(x), 3*x + cos(y)**2 - 4, cos(x)], x) == \
|
||||
Abs(3*cos(x) + cos(y)**2 - 4)
|
||||
assert compogen([x**2 + 2*x + 1, x**2], x) == x**4 + 2*x**2 + 1
|
||||
# the result is in unsimplified form
|
||||
assert compogen([x**2 - x - 1, x**2 + x], x) == -x**2 - x + (x**2 + x)**2 - 1
|
||||
@@ -0,0 +1,500 @@
|
||||
"""Tests for tools for solving inequalities and systems of inequalities. """
|
||||
|
||||
from sympy.concrete.summations import Sum
|
||||
from sympy.core.function import Function
|
||||
from sympy.core.numbers import I, Rational, oo, pi
|
||||
from sympy.core.relational import Eq, Ge, Gt, Le, Lt, Ne
|
||||
from sympy.core.singleton import S
|
||||
from sympy.core.symbol import (Dummy, Symbol)
|
||||
from sympy.functions.elementary.complexes import Abs
|
||||
from sympy.functions.elementary.exponential import exp, log
|
||||
from sympy.functions.elementary.miscellaneous import root, sqrt
|
||||
from sympy.functions.elementary.piecewise import Piecewise
|
||||
from sympy.functions.elementary.trigonometric import cos, sin, tan
|
||||
from sympy.integrals.integrals import Integral
|
||||
from sympy.logic.boolalg import And, Or
|
||||
from sympy.polys.polytools import Poly, PurePoly
|
||||
from sympy.sets.sets import FiniteSet, Interval, Union
|
||||
from sympy.solvers.inequalities import (reduce_inequalities,
|
||||
solve_poly_inequality as psolve,
|
||||
reduce_rational_inequalities,
|
||||
solve_univariate_inequality as isolve,
|
||||
reduce_abs_inequality,
|
||||
_solve_inequality)
|
||||
from sympy.polys.rootoftools import rootof
|
||||
from sympy.solvers.solvers import solve
|
||||
from sympy.solvers.solveset import solveset
|
||||
from sympy.core.mod import Mod
|
||||
from sympy.abc import x, y
|
||||
|
||||
from sympy.testing.pytest import raises, XFAIL
|
||||
|
||||
|
||||
inf = oo.evalf()
|
||||
|
||||
|
||||
def test_solve_poly_inequality():
|
||||
assert psolve(Poly(0, x), '==') == [S.Reals]
|
||||
assert psolve(Poly(1, x), '==') == [S.EmptySet]
|
||||
assert psolve(PurePoly(x + 1, x), ">") == [Interval(-1, oo, True, False)]
|
||||
|
||||
|
||||
def test_reduce_poly_inequalities_real_interval():
|
||||
assert reduce_rational_inequalities(
|
||||
[[Eq(x**2, 0)]], x, relational=False) == FiniteSet(0)
|
||||
assert reduce_rational_inequalities(
|
||||
[[Le(x**2, 0)]], x, relational=False) == FiniteSet(0)
|
||||
assert reduce_rational_inequalities(
|
||||
[[Lt(x**2, 0)]], x, relational=False) == S.EmptySet
|
||||
assert reduce_rational_inequalities(
|
||||
[[Ge(x**2, 0)]], x, relational=False) == \
|
||||
S.Reals if x.is_real else Interval(-oo, oo)
|
||||
assert reduce_rational_inequalities(
|
||||
[[Gt(x**2, 0)]], x, relational=False) == \
|
||||
FiniteSet(0).complement(S.Reals)
|
||||
assert reduce_rational_inequalities(
|
||||
[[Ne(x**2, 0)]], x, relational=False) == \
|
||||
FiniteSet(0).complement(S.Reals)
|
||||
|
||||
assert reduce_rational_inequalities(
|
||||
[[Eq(x**2, 1)]], x, relational=False) == FiniteSet(-1, 1)
|
||||
assert reduce_rational_inequalities(
|
||||
[[Le(x**2, 1)]], x, relational=False) == Interval(-1, 1)
|
||||
assert reduce_rational_inequalities(
|
||||
[[Lt(x**2, 1)]], x, relational=False) == Interval(-1, 1, True, True)
|
||||
assert reduce_rational_inequalities(
|
||||
[[Ge(x**2, 1)]], x, relational=False) == \
|
||||
Union(Interval(-oo, -1), Interval(1, oo))
|
||||
assert reduce_rational_inequalities(
|
||||
[[Gt(x**2, 1)]], x, relational=False) == \
|
||||
Interval(-1, 1).complement(S.Reals)
|
||||
assert reduce_rational_inequalities(
|
||||
[[Ne(x**2, 1)]], x, relational=False) == \
|
||||
FiniteSet(-1, 1).complement(S.Reals)
|
||||
assert reduce_rational_inequalities([[Eq(
|
||||
x**2, 1.0)]], x, relational=False) == FiniteSet(-1.0, 1.0).evalf()
|
||||
assert reduce_rational_inequalities(
|
||||
[[Le(x**2, 1.0)]], x, relational=False) == Interval(-1.0, 1.0)
|
||||
assert reduce_rational_inequalities([[Lt(
|
||||
x**2, 1.0)]], x, relational=False) == Interval(-1.0, 1.0, True, True)
|
||||
assert reduce_rational_inequalities(
|
||||
[[Ge(x**2, 1.0)]], x, relational=False) == \
|
||||
Union(Interval(-inf, -1.0), Interval(1.0, inf))
|
||||
assert reduce_rational_inequalities(
|
||||
[[Gt(x**2, 1.0)]], x, relational=False) == \
|
||||
Union(Interval(-inf, -1.0, right_open=True),
|
||||
Interval(1.0, inf, left_open=True))
|
||||
assert reduce_rational_inequalities([[Ne(
|
||||
x**2, 1.0)]], x, relational=False) == \
|
||||
FiniteSet(-1.0, 1.0).complement(S.Reals)
|
||||
|
||||
s = sqrt(2)
|
||||
|
||||
assert reduce_rational_inequalities([[Lt(
|
||||
x**2 - 1, 0), Gt(x**2 - 1, 0)]], x, relational=False) == S.EmptySet
|
||||
assert reduce_rational_inequalities([[Le(x**2 - 1, 0), Ge(
|
||||
x**2 - 1, 0)]], x, relational=False) == FiniteSet(-1, 1)
|
||||
assert reduce_rational_inequalities(
|
||||
[[Le(x**2 - 2, 0), Ge(x**2 - 1, 0)]], x, relational=False
|
||||
) == Union(Interval(-s, -1, False, False), Interval(1, s, False, False))
|
||||
assert reduce_rational_inequalities(
|
||||
[[Le(x**2 - 2, 0), Gt(x**2 - 1, 0)]], x, relational=False
|
||||
) == Union(Interval(-s, -1, False, True), Interval(1, s, True, False))
|
||||
assert reduce_rational_inequalities(
|
||||
[[Lt(x**2 - 2, 0), Ge(x**2 - 1, 0)]], x, relational=False
|
||||
) == Union(Interval(-s, -1, True, False), Interval(1, s, False, True))
|
||||
assert reduce_rational_inequalities(
|
||||
[[Lt(x**2 - 2, 0), Gt(x**2 - 1, 0)]], x, relational=False
|
||||
) == Union(Interval(-s, -1, True, True), Interval(1, s, True, True))
|
||||
assert reduce_rational_inequalities(
|
||||
[[Lt(x**2 - 2, 0), Ne(x**2 - 1, 0)]], x, relational=False
|
||||
) == Union(Interval(-s, -1, True, True), Interval(-1, 1, True, True),
|
||||
Interval(1, s, True, True))
|
||||
|
||||
assert reduce_rational_inequalities([[Lt(x**2, -1.)]], x) is S.false
|
||||
|
||||
|
||||
def test_reduce_poly_inequalities_complex_relational():
|
||||
assert reduce_rational_inequalities(
|
||||
[[Eq(x**2, 0)]], x, relational=True) == Eq(x, 0)
|
||||
assert reduce_rational_inequalities(
|
||||
[[Le(x**2, 0)]], x, relational=True) == Eq(x, 0)
|
||||
assert reduce_rational_inequalities(
|
||||
[[Lt(x**2, 0)]], x, relational=True) == False
|
||||
assert reduce_rational_inequalities(
|
||||
[[Ge(x**2, 0)]], x, relational=True) == And(Lt(-oo, x), Lt(x, oo))
|
||||
assert reduce_rational_inequalities(
|
||||
[[Gt(x**2, 0)]], x, relational=True) == \
|
||||
And(Gt(x, -oo), Lt(x, oo), Ne(x, 0))
|
||||
assert reduce_rational_inequalities(
|
||||
[[Ne(x**2, 0)]], x, relational=True) == \
|
||||
And(Gt(x, -oo), Lt(x, oo), Ne(x, 0))
|
||||
|
||||
for one in (S.One, S(1.0)):
|
||||
inf = one*oo
|
||||
assert reduce_rational_inequalities(
|
||||
[[Eq(x**2, one)]], x, relational=True) == \
|
||||
Or(Eq(x, -one), Eq(x, one))
|
||||
assert reduce_rational_inequalities(
|
||||
[[Le(x**2, one)]], x, relational=True) == \
|
||||
And(And(Le(-one, x), Le(x, one)))
|
||||
assert reduce_rational_inequalities(
|
||||
[[Lt(x**2, one)]], x, relational=True) == \
|
||||
And(And(Lt(-one, x), Lt(x, one)))
|
||||
assert reduce_rational_inequalities(
|
||||
[[Ge(x**2, one)]], x, relational=True) == \
|
||||
And(Or(And(Le(one, x), Lt(x, inf)), And(Le(x, -one), Lt(-inf, x))))
|
||||
assert reduce_rational_inequalities(
|
||||
[[Gt(x**2, one)]], x, relational=True) == \
|
||||
And(Or(And(Lt(-inf, x), Lt(x, -one)), And(Lt(one, x), Lt(x, inf))))
|
||||
assert reduce_rational_inequalities(
|
||||
[[Ne(x**2, one)]], x, relational=True) == \
|
||||
Or(And(Lt(-inf, x), Lt(x, -one)),
|
||||
And(Lt(-one, x), Lt(x, one)),
|
||||
And(Lt(one, x), Lt(x, inf)))
|
||||
|
||||
|
||||
def test_reduce_rational_inequalities_real_relational():
|
||||
assert reduce_rational_inequalities([], x) == False
|
||||
assert reduce_rational_inequalities(
|
||||
[[(x**2 + 3*x + 2)/(x**2 - 16) >= 0]], x, relational=False) == \
|
||||
Union(Interval.open(-oo, -4), Interval(-2, -1), Interval.open(4, oo))
|
||||
|
||||
assert reduce_rational_inequalities(
|
||||
[[((-2*x - 10)*(3 - x))/((x**2 + 5)*(x - 2)**2) < 0]], x,
|
||||
relational=False) == \
|
||||
Union(Interval.open(-5, 2), Interval.open(2, 3))
|
||||
|
||||
assert reduce_rational_inequalities([[(x + 1)/(x - 5) <= 0]], x,
|
||||
relational=False) == \
|
||||
Interval.Ropen(-1, 5)
|
||||
|
||||
assert reduce_rational_inequalities([[(x**2 + 4*x + 3)/(x - 1) > 0]], x,
|
||||
relational=False) == \
|
||||
Union(Interval.open(-3, -1), Interval.open(1, oo))
|
||||
|
||||
assert reduce_rational_inequalities([[(x**2 - 16)/(x - 1)**2 < 0]], x,
|
||||
relational=False) == \
|
||||
Union(Interval.open(-4, 1), Interval.open(1, 4))
|
||||
|
||||
assert reduce_rational_inequalities([[(3*x + 1)/(x + 4) >= 1]], x,
|
||||
relational=False) == \
|
||||
Union(Interval.open(-oo, -4), Interval.Ropen(Rational(3, 2), oo))
|
||||
|
||||
assert reduce_rational_inequalities([[(x - 8)/x <= 3 - x]], x,
|
||||
relational=False) == \
|
||||
Union(Interval.Lopen(-oo, -2), Interval.Lopen(0, 4))
|
||||
|
||||
# issue sympy/sympy#10237
|
||||
assert reduce_rational_inequalities(
|
||||
[[x < oo, x >= 0, -oo < x]], x, relational=False) == Interval(0, oo)
|
||||
|
||||
|
||||
def test_reduce_abs_inequalities():
|
||||
e = abs(x - 5) < 3
|
||||
ans = And(Lt(2, x), Lt(x, 8))
|
||||
assert reduce_inequalities(e) == ans
|
||||
assert reduce_inequalities(e, x) == ans
|
||||
assert reduce_inequalities(abs(x - 5)) == Eq(x, 5)
|
||||
assert reduce_inequalities(
|
||||
abs(2*x + 3) >= 8) == Or(And(Le(Rational(5, 2), x), Lt(x, oo)),
|
||||
And(Le(x, Rational(-11, 2)), Lt(-oo, x)))
|
||||
assert reduce_inequalities(abs(x - 4) + abs(
|
||||
3*x - 5) < 7) == And(Lt(S.Half, x), Lt(x, 4))
|
||||
assert reduce_inequalities(abs(x - 4) + abs(3*abs(x) - 5) < 7) == \
|
||||
Or(And(S(-2) < x, x < -1), And(S.Half < x, x < 4))
|
||||
|
||||
nr = Symbol('nr', extended_real=False)
|
||||
raises(TypeError, lambda: reduce_inequalities(abs(nr - 5) < 3))
|
||||
assert reduce_inequalities(x < 3, symbols=[x, nr]) == And(-oo < x, x < 3)
|
||||
|
||||
|
||||
def test_reduce_inequalities_general():
|
||||
assert reduce_inequalities(Ge(sqrt(2)*x, 1)) == And(sqrt(2)/2 <= x, x < oo)
|
||||
assert reduce_inequalities(x + 1 > 0) == And(S.NegativeOne < x, x < oo)
|
||||
|
||||
|
||||
def test_reduce_inequalities_boolean():
|
||||
assert reduce_inequalities(
|
||||
[Eq(x**2, 0), True]) == Eq(x, 0)
|
||||
assert reduce_inequalities([Eq(x**2, 0), False]) == False
|
||||
assert reduce_inequalities(x**2 >= 0) is S.true # issue 10196
|
||||
|
||||
|
||||
def test_reduce_inequalities_multivariate():
|
||||
assert reduce_inequalities([Ge(x**2, 1), Ge(y**2, 1)]) == And(
|
||||
Or(And(Le(S.One, x), Lt(x, oo)), And(Le(x, -1), Lt(-oo, x))),
|
||||
Or(And(Le(S.One, y), Lt(y, oo)), And(Le(y, -1), Lt(-oo, y))))
|
||||
|
||||
|
||||
def test_reduce_inequalities_errors():
|
||||
raises(NotImplementedError, lambda: reduce_inequalities(Ge(sin(x) + x, 1)))
|
||||
raises(NotImplementedError, lambda: reduce_inequalities(Ge(x**2*y + y, 1)))
|
||||
|
||||
|
||||
def test__solve_inequalities():
|
||||
assert reduce_inequalities(x + y < 1, symbols=[x]) == (x < 1 - y)
|
||||
assert reduce_inequalities(x + y >= 1, symbols=[x]) == (x < oo) & (x >= -y + 1)
|
||||
assert reduce_inequalities(Eq(0, x - y), symbols=[x]) == Eq(x, y)
|
||||
assert reduce_inequalities(Ne(0, x - y), symbols=[x]) == Ne(x, y)
|
||||
|
||||
|
||||
def test_issue_6343():
|
||||
eq = -3*x**2/2 - x*Rational(45, 4) + Rational(33, 2) > 0
|
||||
assert reduce_inequalities(eq) == \
|
||||
And(x < Rational(-15, 4) + sqrt(401)/4, -sqrt(401)/4 - Rational(15, 4) < x)
|
||||
|
||||
|
||||
def test_issue_8235():
|
||||
assert reduce_inequalities(x**2 - 1 < 0) == \
|
||||
And(S.NegativeOne < x, x < 1)
|
||||
assert reduce_inequalities(x**2 - 1 <= 0) == \
|
||||
And(S.NegativeOne <= x, x <= 1)
|
||||
assert reduce_inequalities(x**2 - 1 > 0) == \
|
||||
Or(And(-oo < x, x < -1), And(x < oo, S.One < x))
|
||||
assert reduce_inequalities(x**2 - 1 >= 0) == \
|
||||
Or(And(-oo < x, x <= -1), And(S.One <= x, x < oo))
|
||||
|
||||
eq = x**8 + x - 9 # we want CRootOf solns here
|
||||
sol = solve(eq >= 0)
|
||||
tru = Or(And(rootof(eq, 1) <= x, x < oo), And(-oo < x, x <= rootof(eq, 0)))
|
||||
assert sol == tru
|
||||
|
||||
# recast vanilla as real
|
||||
assert solve(sqrt((-x + 1)**2) < 1) == And(S.Zero < x, x < 2)
|
||||
|
||||
|
||||
def test_issue_5526():
|
||||
assert reduce_inequalities(0 <=
|
||||
x + Integral(y**2, (y, 1, 3)) - 1, [x]) == \
|
||||
(x >= -Integral(y**2, (y, 1, 3)) + 1)
|
||||
f = Function('f')
|
||||
e = Sum(f(x), (x, 1, 3))
|
||||
assert reduce_inequalities(0 <= x + e + y**2, [x]) == \
|
||||
(x >= -y**2 - Sum(f(x), (x, 1, 3)))
|
||||
|
||||
|
||||
def test_solve_univariate_inequality():
|
||||
assert isolve(x**2 >= 4, x, relational=False) == Union(Interval(-oo, -2),
|
||||
Interval(2, oo))
|
||||
assert isolve(x**2 >= 4, x) == Or(And(Le(2, x), Lt(x, oo)), And(Le(x, -2),
|
||||
Lt(-oo, x)))
|
||||
assert isolve((x - 1)*(x - 2)*(x - 3) >= 0, x, relational=False) == \
|
||||
Union(Interval(1, 2), Interval(3, oo))
|
||||
assert isolve((x - 1)*(x - 2)*(x - 3) >= 0, x) == \
|
||||
Or(And(Le(1, x), Le(x, 2)), And(Le(3, x), Lt(x, oo)))
|
||||
assert isolve((x - 1)*(x - 2)*(x - 4) < 0, x, domain = FiniteSet(0, 3)) == \
|
||||
Or(Eq(x, 0), Eq(x, 3))
|
||||
# issue 2785:
|
||||
assert isolve(x**3 - 2*x - 1 > 0, x, relational=False) == \
|
||||
Union(Interval(-1, -sqrt(5)/2 + S.Half, True, True),
|
||||
Interval(S.Half + sqrt(5)/2, oo, True, True))
|
||||
# issue 2794:
|
||||
assert isolve(x**3 - x**2 + x - 1 > 0, x, relational=False) == \
|
||||
Interval(1, oo, True)
|
||||
#issue 13105
|
||||
assert isolve((x + I)*(x + 2*I) < 0, x) == Eq(x, 0)
|
||||
assert isolve(((x - 1)*(x - 2) + I)*((x - 1)*(x - 2) + 2*I) < 0, x) == Or(Eq(x, 1), Eq(x, 2))
|
||||
assert isolve((((x - 1)*(x - 2) + I)*((x - 1)*(x - 2) + 2*I))/(x - 2) > 0, x) == Eq(x, 1)
|
||||
raises (ValueError, lambda: isolve((x**2 - 3*x*I + 2)/x < 0, x))
|
||||
|
||||
# numerical testing in valid() is needed
|
||||
assert isolve(x**7 - x - 2 > 0, x) == \
|
||||
And(rootof(x**7 - x - 2, 0) < x, x < oo)
|
||||
|
||||
# handle numerator and denominator; although these would be handled as
|
||||
# rational inequalities, these test confirm that the right thing is done
|
||||
# when the domain is EX (e.g. when 2 is replaced with sqrt(2))
|
||||
assert isolve(1/(x - 2) > 0, x) == And(S(2) < x, x < oo)
|
||||
den = ((x - 1)*(x - 2)).expand()
|
||||
assert isolve((x - 1)/den <= 0, x) == \
|
||||
(x > -oo) & (x < 2) & Ne(x, 1)
|
||||
|
||||
n = Dummy('n')
|
||||
raises(NotImplementedError, lambda: isolve(Abs(x) <= n, x, relational=False))
|
||||
c1 = Dummy("c1", positive=True)
|
||||
raises(NotImplementedError, lambda: isolve(n/c1 < 0, c1))
|
||||
n = Dummy('n', negative=True)
|
||||
assert isolve(n/c1 > -2, c1) == (-n/2 < c1)
|
||||
assert isolve(n/c1 < 0, c1) == True
|
||||
assert isolve(n/c1 > 0, c1) == False
|
||||
|
||||
zero = cos(1)**2 + sin(1)**2 - 1
|
||||
raises(NotImplementedError, lambda: isolve(x**2 < zero, x))
|
||||
raises(NotImplementedError, lambda: isolve(
|
||||
x**2 < zero*I, x))
|
||||
raises(NotImplementedError, lambda: isolve(1/(x - y) < 2, x))
|
||||
raises(NotImplementedError, lambda: isolve(1/(x - y) < 0, x))
|
||||
raises(TypeError, lambda: isolve(x - I < 0, x))
|
||||
|
||||
zero = x**2 + x - x*(x + 1)
|
||||
assert isolve(zero < 0, x, relational=False) is S.EmptySet
|
||||
assert isolve(zero <= 0, x, relational=False) is S.Reals
|
||||
|
||||
# make sure iter_solutions gets a default value
|
||||
raises(NotImplementedError, lambda: isolve(
|
||||
Eq(cos(x)**2 + sin(x)**2, 1), x))
|
||||
|
||||
|
||||
def test_trig_inequalities():
|
||||
# all the inequalities are solved in a periodic interval.
|
||||
assert isolve(sin(x) < S.Half, x, relational=False) == \
|
||||
Union(Interval(0, pi/6, False, True), Interval.open(pi*Rational(5, 6), 2*pi))
|
||||
assert isolve(sin(x) > S.Half, x, relational=False) == \
|
||||
Interval(pi/6, pi*Rational(5, 6), True, True)
|
||||
assert isolve(cos(x) < S.Zero, x, relational=False) == \
|
||||
Interval(pi/2, pi*Rational(3, 2), True, True)
|
||||
assert isolve(cos(x) >= S.Zero, x, relational=False) == \
|
||||
Union(Interval(0, pi/2), Interval.Ropen(pi*Rational(3, 2), 2*pi))
|
||||
|
||||
assert isolve(tan(x) < S.One, x, relational=False) == \
|
||||
Union(Interval.Ropen(0, pi/4), Interval.open(pi/2, pi))
|
||||
|
||||
assert isolve(sin(x) <= S.Zero, x, relational=False) == \
|
||||
Union(FiniteSet(S.Zero), Interval.Ropen(pi, 2*pi))
|
||||
|
||||
assert isolve(sin(x) <= S.One, x, relational=False) == S.Reals
|
||||
assert isolve(cos(x) < S(-2), x, relational=False) == S.EmptySet
|
||||
assert isolve(sin(x) >= S.NegativeOne, x, relational=False) == S.Reals
|
||||
assert isolve(cos(x) > S.One, x, relational=False) == S.EmptySet
|
||||
|
||||
|
||||
def test_issue_9954():
|
||||
assert isolve(x**2 >= 0, x, relational=False) == S.Reals
|
||||
assert isolve(x**2 >= 0, x, relational=True) == S.Reals.as_relational(x)
|
||||
assert isolve(x**2 < 0, x, relational=False) == S.EmptySet
|
||||
assert isolve(x**2 < 0, x, relational=True) == S.EmptySet.as_relational(x)
|
||||
|
||||
|
||||
@XFAIL
|
||||
def test_slow_general_univariate():
|
||||
r = rootof(x**5 - x**2 + 1, 0)
|
||||
assert solve(sqrt(x) + 1/root(x, 3) > 1) == \
|
||||
Or(And(0 < x, x < r**6), And(r**6 < x, x < oo))
|
||||
|
||||
|
||||
def test_issue_8545():
|
||||
eq = 1 - x - abs(1 - x)
|
||||
ans = And(Lt(1, x), Lt(x, oo))
|
||||
assert reduce_abs_inequality(eq, '<', x) == ans
|
||||
eq = 1 - x - sqrt((1 - x)**2)
|
||||
assert reduce_inequalities(eq < 0) == ans
|
||||
|
||||
|
||||
def test_issue_8974():
|
||||
assert isolve(-oo < x, x) == And(-oo < x, x < oo)
|
||||
assert isolve(oo > x, x) == And(-oo < x, x < oo)
|
||||
|
||||
|
||||
def test_issue_10198():
|
||||
assert reduce_inequalities(
|
||||
-1 + 1/abs(1/x - 1) < 0) == (x > -oo) & (x < S(1)/2) & Ne(x, 0)
|
||||
|
||||
assert reduce_inequalities(abs(1/sqrt(x)) - 1, x) == Eq(x, 1)
|
||||
assert reduce_abs_inequality(-3 + 1/abs(1 - 1/x), '<', x) == \
|
||||
Or(And(-oo < x, x < 0),
|
||||
And(S.Zero < x, x < Rational(3, 4)), And(Rational(3, 2) < x, x < oo))
|
||||
raises(ValueError,lambda: reduce_abs_inequality(-3 + 1/abs(
|
||||
1 - 1/sqrt(x)), '<', x))
|
||||
|
||||
|
||||
def test_issue_10047():
|
||||
# issue 10047: this must remain an inequality, not True, since if x
|
||||
# is not real the inequality is invalid
|
||||
# assert solve(sin(x) < 2) == (x <= oo)
|
||||
|
||||
# with PR 16956, (x <= oo) autoevaluates when x is extended_real
|
||||
# which is assumed in the current implementation of inequality solvers
|
||||
assert solve(sin(x) < 2) == True
|
||||
assert solveset(sin(x) < 2, domain=S.Reals) == S.Reals
|
||||
|
||||
|
||||
def test_issue_10268():
|
||||
assert solve(log(x) < 1000) == And(S.Zero < x, x < exp(1000))
|
||||
|
||||
|
||||
@XFAIL
|
||||
def test_isolve_Sets():
|
||||
n = Dummy('n')
|
||||
assert isolve(Abs(x) <= n, x, relational=False) == \
|
||||
Piecewise((S.EmptySet, n < 0), (Interval(-n, n), True))
|
||||
|
||||
|
||||
def test_integer_domain_relational_isolve():
|
||||
|
||||
dom = FiniteSet(0, 3)
|
||||
x = Symbol('x',zero=False)
|
||||
assert isolve((x - 1)*(x - 2)*(x - 4) < 0, x, domain=dom) == Eq(x, 3)
|
||||
|
||||
x = Symbol('x')
|
||||
assert isolve(x + 2 < 0, x, domain=S.Integers) == \
|
||||
(x <= -3) & (x > -oo) & Eq(Mod(x, 1), 0)
|
||||
assert isolve(2 * x + 3 > 0, x, domain=S.Integers) == \
|
||||
(x >= -1) & (x < oo) & Eq(Mod(x, 1), 0)
|
||||
assert isolve((x ** 2 + 3 * x - 2) < 0, x, domain=S.Integers) == \
|
||||
(x >= -3) & (x <= 0) & Eq(Mod(x, 1), 0)
|
||||
assert isolve((x ** 2 + 3 * x - 2) > 0, x, domain=S.Integers) == \
|
||||
((x >= 1) & (x < oo) & Eq(Mod(x, 1), 0)) | (
|
||||
(x <= -4) & (x > -oo) & Eq(Mod(x, 1), 0))
|
||||
|
||||
|
||||
def test_issue_10671_12466():
|
||||
assert solveset(sin(y), y, Interval(0, pi)) == FiniteSet(0, pi)
|
||||
i = Interval(1, 10)
|
||||
assert solveset((1/x).diff(x) < 0, x, i) == i
|
||||
assert solveset((log(x - 6)/x) <= 0, x, S.Reals) == \
|
||||
Interval.Lopen(6, 7)
|
||||
|
||||
|
||||
def test__solve_inequality():
|
||||
for op in (Gt, Lt, Le, Ge, Eq, Ne):
|
||||
assert _solve_inequality(op(x, 1), x).lhs == x
|
||||
assert _solve_inequality(op(S.One, x), x).lhs == x
|
||||
# don't get tricked by symbol on right: solve it
|
||||
assert _solve_inequality(Eq(2*x - 1, x), x) == Eq(x, 1)
|
||||
ie = Eq(S.One, y)
|
||||
assert _solve_inequality(ie, x) == ie
|
||||
for fx in (x**2, exp(x), sin(x) + cos(x), x*(1 + x)):
|
||||
for c in (0, 1):
|
||||
e = 2*fx - c > 0
|
||||
assert _solve_inequality(e, x, linear=True) == (
|
||||
fx > c/S(2))
|
||||
assert _solve_inequality(2*x**2 + 2*x - 1 < 0, x, linear=True) == (
|
||||
x*(x + 1) < S.Half)
|
||||
assert _solve_inequality(Eq(x*y, 1), x) == Eq(x*y, 1)
|
||||
nz = Symbol('nz', nonzero=True)
|
||||
assert _solve_inequality(Eq(x*nz, 1), x) == Eq(x, 1/nz)
|
||||
assert _solve_inequality(x*nz < 1, x) == (x*nz < 1)
|
||||
a = Symbol('a', positive=True)
|
||||
assert _solve_inequality(a/x > 1, x) == (S.Zero < x) & (x < a)
|
||||
assert _solve_inequality(a/x > 1, x, linear=True) == (1/x > 1/a)
|
||||
# make sure to include conditions under which solution is valid
|
||||
e = Eq(1 - x, x*(1/x - 1))
|
||||
assert _solve_inequality(e, x) == Ne(x, 0)
|
||||
assert _solve_inequality(x < x*(1/x - 1), x) == (x < S.Half) & Ne(x, 0)
|
||||
|
||||
|
||||
def test__pt():
|
||||
from sympy.solvers.inequalities import _pt
|
||||
assert _pt(-oo, oo) == 0
|
||||
assert _pt(S.One, S(3)) == 2
|
||||
assert _pt(S.One, oo) == _pt(oo, S.One) == 2
|
||||
assert _pt(S.One, -oo) == _pt(-oo, S.One) == S.Half
|
||||
assert _pt(S.NegativeOne, oo) == _pt(oo, S.NegativeOne) == Rational(-1, 2)
|
||||
assert _pt(S.NegativeOne, -oo) == _pt(-oo, S.NegativeOne) == -2
|
||||
assert _pt(x, oo) == _pt(oo, x) == x + 1
|
||||
assert _pt(x, -oo) == _pt(-oo, x) == x - 1
|
||||
raises(ValueError, lambda: _pt(Dummy('i', infinite=True), S.One))
|
||||
|
||||
|
||||
def test_issue_25697():
|
||||
assert _solve_inequality(log(x, 3) <= 2, x) == (x <= 9) & (S.Zero < x)
|
||||
|
||||
|
||||
def test_issue_25738():
|
||||
assert reduce_inequalities(3 < abs(x)
|
||||
) == reduce_inequalities(pi < abs(x)).subs(pi, 3)
|
||||
|
||||
|
||||
def test_issue_25983():
|
||||
assert(reduce_inequalities(pi/Abs(x) <= 1) == ((pi <= x) & (x < oo)) | ((-oo < x) & (x <= -pi)))
|
||||
@@ -0,0 +1,139 @@
|
||||
from sympy.core.function import nfloat
|
||||
from sympy.core.numbers import (Float, I, Rational, pi)
|
||||
from sympy.core.relational import Eq
|
||||
from sympy.core.symbol import (Symbol, symbols)
|
||||
from sympy.functions.elementary.miscellaneous import sqrt
|
||||
from sympy.functions.elementary.piecewise import Piecewise
|
||||
from sympy.functions.elementary.trigonometric import sin
|
||||
from sympy.integrals.integrals import Integral
|
||||
from sympy.matrices.dense import Matrix
|
||||
from mpmath import mnorm, mpf
|
||||
from sympy.solvers import nsolve
|
||||
from sympy.utilities.lambdify import lambdify
|
||||
from sympy.testing.pytest import raises, XFAIL
|
||||
from sympy.utilities.decorator import conserve_mpmath_dps
|
||||
|
||||
@XFAIL
|
||||
def test_nsolve_fail():
|
||||
x = symbols('x')
|
||||
# Sometimes it is better to use the numerator (issue 4829)
|
||||
# but sometimes it is not (issue 11768) so leave this to
|
||||
# the discretion of the user
|
||||
ans = nsolve(x**2/(1 - x)/(1 - 2*x)**2 - 100, x, 0)
|
||||
assert ans > 0.46 and ans < 0.47
|
||||
|
||||
|
||||
def test_nsolve_denominator():
|
||||
x = symbols('x')
|
||||
# Test that nsolve uses the full expression (numerator and denominator).
|
||||
ans = nsolve((x**2 + 3*x + 2)/(x + 2), -2.1)
|
||||
# The root -2 was divided out, so make sure we don't find it.
|
||||
assert ans == -1.0
|
||||
|
||||
def test_nsolve():
|
||||
# onedimensional
|
||||
x = Symbol('x')
|
||||
assert nsolve(sin(x), 2) - pi.evalf() < 1e-15
|
||||
assert nsolve(Eq(2*x, 2), x, -10) == nsolve(2*x - 2, -10)
|
||||
# Testing checks on number of inputs
|
||||
raises(TypeError, lambda: nsolve(Eq(2*x, 2)))
|
||||
raises(TypeError, lambda: nsolve(Eq(2*x, 2), x, 1, 2))
|
||||
# multidimensional
|
||||
x1 = Symbol('x1')
|
||||
x2 = Symbol('x2')
|
||||
f1 = 3 * x1**2 - 2 * x2**2 - 1
|
||||
f2 = x1**2 - 2 * x1 + x2**2 + 2 * x2 - 8
|
||||
f = Matrix((f1, f2)).T
|
||||
F = lambdify((x1, x2), f.T, modules='mpmath')
|
||||
for x0 in [(-1, 1), (1, -2), (4, 4), (-4, -4)]:
|
||||
x = nsolve(f, (x1, x2), x0, tol=1.e-8)
|
||||
assert mnorm(F(*x), 1) <= 1.e-10
|
||||
# The Chinese mathematician Zhu Shijie was the very first to solve this
|
||||
# nonlinear system 700 years ago (z was added to make it 3-dimensional)
|
||||
x = Symbol('x')
|
||||
y = Symbol('y')
|
||||
z = Symbol('z')
|
||||
f1 = -x + 2*y
|
||||
f2 = (x**2 + x*(y**2 - 2) - 4*y) / (x + 4)
|
||||
f3 = sqrt(x**2 + y**2)*z
|
||||
f = Matrix((f1, f2, f3)).T
|
||||
F = lambdify((x, y, z), f.T, modules='mpmath')
|
||||
|
||||
def getroot(x0):
|
||||
root = nsolve(f, (x, y, z), x0)
|
||||
assert mnorm(F(*root), 1) <= 1.e-8
|
||||
return root
|
||||
assert list(map(round, getroot((1, 1, 1)))) == [2, 1, 0]
|
||||
assert nsolve([Eq(
|
||||
f1, 0), Eq(f2, 0), Eq(f3, 0)], [x, y, z], (1, 1, 1)) # just see that it works
|
||||
a = Symbol('a')
|
||||
assert abs(nsolve(1/(0.001 + a)**3 - 6/(0.9 - a)**3, a, 0.3) -
|
||||
mpf('0.31883011387318591')) < 1e-15
|
||||
|
||||
|
||||
def test_issue_6408():
|
||||
x = Symbol('x')
|
||||
assert nsolve(Piecewise((x, x < 1), (x**2, True)), x, 2) == 0
|
||||
|
||||
|
||||
def test_issue_6408_integral():
|
||||
x, y = symbols('x y')
|
||||
assert nsolve(Integral(x*y, (x, 0, 5)), y, 2) == 0
|
||||
|
||||
|
||||
@conserve_mpmath_dps
|
||||
def test_increased_dps():
|
||||
# Issue 8564
|
||||
import mpmath
|
||||
mpmath.mp.dps = 128
|
||||
x = Symbol('x')
|
||||
e1 = x**2 - pi
|
||||
q = nsolve(e1, x, 3.0)
|
||||
|
||||
assert abs(sqrt(pi).evalf(128) - q) < 1e-128
|
||||
|
||||
def test_nsolve_precision():
|
||||
x, y = symbols('x y')
|
||||
sol = nsolve(x**2 - pi, x, 3, prec=128)
|
||||
assert abs(sqrt(pi).evalf(128) - sol) < 1e-128
|
||||
assert isinstance(sol, Float)
|
||||
|
||||
sols = nsolve((y**2 - x, x**2 - pi), (x, y), (3, 3), prec=128)
|
||||
assert isinstance(sols, Matrix)
|
||||
assert sols.shape == (2, 1)
|
||||
assert abs(sqrt(pi).evalf(128) - sols[0]) < 1e-128
|
||||
assert abs(sqrt(sqrt(pi)).evalf(128) - sols[1]) < 1e-128
|
||||
assert all(isinstance(i, Float) for i in sols)
|
||||
|
||||
def test_nsolve_complex():
|
||||
x, y = symbols('x y')
|
||||
|
||||
assert nsolve(x**2 + 2, 1j) == sqrt(2.)*I
|
||||
assert nsolve(x**2 + 2, I) == sqrt(2.)*I
|
||||
|
||||
assert nsolve([x**2 + 2, y**2 + 2], [x, y], [I, I]) == Matrix([sqrt(2.)*I, sqrt(2.)*I])
|
||||
assert nsolve([x**2 + 2, y**2 + 2], [x, y], [I, I]) == Matrix([sqrt(2.)*I, sqrt(2.)*I])
|
||||
|
||||
def test_nsolve_dict_kwarg():
|
||||
x, y = symbols('x y')
|
||||
# one variable
|
||||
assert nsolve(x**2 - 2, 1, dict = True) == \
|
||||
[{x: sqrt(2.)}]
|
||||
# one variable with complex solution
|
||||
assert nsolve(x**2 + 2, I, dict = True) == \
|
||||
[{x: sqrt(2.)*I}]
|
||||
# two variables
|
||||
assert nsolve([x**2 + y**2 - 5, x**2 - y**2 + 1], [x, y], [1, 1], dict = True) == \
|
||||
[{x: sqrt(2.), y: sqrt(3.)}]
|
||||
|
||||
def test_nsolve_rational():
|
||||
x = symbols('x')
|
||||
assert nsolve(x - Rational(1, 3), 0, prec=100) == Rational(1, 3).evalf(100)
|
||||
|
||||
|
||||
def test_issue_14950():
|
||||
x = Matrix(symbols('t s'))
|
||||
x0 = Matrix([17, 23])
|
||||
eqn = x + x0
|
||||
assert nsolve(eqn, x, x0) == nfloat(-x0)
|
||||
assert nsolve(eqn.T, x.T, x0.T) == nfloat(-x0)
|
||||
@@ -0,0 +1,239 @@
|
||||
from sympy.core.function import (Derivative as D, Function)
|
||||
from sympy.core.relational import Eq
|
||||
from sympy.core.symbol import (Symbol, symbols)
|
||||
from sympy.functions.elementary.exponential import (exp, log)
|
||||
from sympy.functions.elementary.trigonometric import (cos, sin)
|
||||
from sympy.core import S
|
||||
from sympy.solvers.pde import (pde_separate, pde_separate_add, pde_separate_mul,
|
||||
pdsolve, classify_pde, checkpdesol)
|
||||
from sympy.testing.pytest import raises
|
||||
|
||||
|
||||
a, b, c, x, y = symbols('a b c x y')
|
||||
|
||||
def test_pde_separate_add():
|
||||
x, y, z, t = symbols("x,y,z,t")
|
||||
F, T, X, Y, Z, u = map(Function, 'FTXYZu')
|
||||
|
||||
eq = Eq(D(u(x, t), x), D(u(x, t), t)*exp(u(x, t)))
|
||||
res = pde_separate_add(eq, u(x, t), [X(x), T(t)])
|
||||
assert res == [D(X(x), x)*exp(-X(x)), D(T(t), t)*exp(T(t))]
|
||||
|
||||
|
||||
def test_pde_separate():
|
||||
x, y, z, t = symbols("x,y,z,t")
|
||||
F, T, X, Y, Z, u = map(Function, 'FTXYZu')
|
||||
|
||||
eq = Eq(D(u(x, t), x), D(u(x, t), t)*exp(u(x, t)))
|
||||
raises(ValueError, lambda: pde_separate(eq, u(x, t), [X(x), T(t)], 'div'))
|
||||
|
||||
|
||||
def test_pde_separate_mul():
|
||||
x, y, z, t = symbols("x,y,z,t")
|
||||
c = Symbol("C", real=True)
|
||||
Phi = Function('Phi')
|
||||
F, R, T, X, Y, Z, u = map(Function, 'FRTXYZu')
|
||||
r, theta, z = symbols('r,theta,z')
|
||||
|
||||
# Something simple :)
|
||||
eq = Eq(D(F(x, y, z), x) + D(F(x, y, z), y) + D(F(x, y, z), z), 0)
|
||||
|
||||
# Duplicate arguments in functions
|
||||
raises(
|
||||
ValueError, lambda: pde_separate_mul(eq, F(x, y, z), [X(x), u(z, z)]))
|
||||
# Wrong number of arguments
|
||||
raises(ValueError, lambda: pde_separate_mul(eq, F(x, y, z), [X(x), Y(y)]))
|
||||
# Wrong variables: [x, y] -> [x, z]
|
||||
raises(
|
||||
ValueError, lambda: pde_separate_mul(eq, F(x, y, z), [X(t), Y(x, y)]))
|
||||
|
||||
assert pde_separate_mul(eq, F(x, y, z), [Y(y), u(x, z)]) == \
|
||||
[D(Y(y), y)/Y(y), -D(u(x, z), x)/u(x, z) - D(u(x, z), z)/u(x, z)]
|
||||
assert pde_separate_mul(eq, F(x, y, z), [X(x), Y(y), Z(z)]) == \
|
||||
[D(X(x), x)/X(x), -D(Z(z), z)/Z(z) - D(Y(y), y)/Y(y)]
|
||||
|
||||
# wave equation
|
||||
wave = Eq(D(u(x, t), t, t), c**2*D(u(x, t), x, x))
|
||||
res = pde_separate_mul(wave, u(x, t), [X(x), T(t)])
|
||||
assert res == [D(X(x), x, x)/X(x), D(T(t), t, t)/(c**2*T(t))]
|
||||
|
||||
# Laplace equation in cylindrical coords
|
||||
eq = Eq(1/r * D(Phi(r, theta, z), r) + D(Phi(r, theta, z), r, 2) +
|
||||
1/r**2 * D(Phi(r, theta, z), theta, 2) + D(Phi(r, theta, z), z, 2), 0)
|
||||
# Separate z
|
||||
res = pde_separate_mul(eq, Phi(r, theta, z), [Z(z), u(theta, r)])
|
||||
assert res == [D(Z(z), z, z)/Z(z),
|
||||
-D(u(theta, r), r, r)/u(theta, r) -
|
||||
D(u(theta, r), r)/(r*u(theta, r)) -
|
||||
D(u(theta, r), theta, theta)/(r**2*u(theta, r))]
|
||||
# Lets use the result to create a new equation...
|
||||
eq = Eq(res[1], c)
|
||||
# ...and separate theta...
|
||||
res = pde_separate_mul(eq, u(theta, r), [T(theta), R(r)])
|
||||
assert res == [D(T(theta), theta, theta)/T(theta),
|
||||
-r*D(R(r), r)/R(r) - r**2*D(R(r), r, r)/R(r) - c*r**2]
|
||||
# ...or r...
|
||||
res = pde_separate_mul(eq, u(theta, r), [R(r), T(theta)])
|
||||
assert res == [r*D(R(r), r)/R(r) + r**2*D(R(r), r, r)/R(r) + c*r**2,
|
||||
-D(T(theta), theta, theta)/T(theta)]
|
||||
|
||||
|
||||
def test_issue_11726():
|
||||
x, t = symbols("x t")
|
||||
f = symbols("f", cls=Function)
|
||||
X, T = symbols("X T", cls=Function)
|
||||
|
||||
u = f(x, t)
|
||||
eq = u.diff(x, 2) - u.diff(t, 2)
|
||||
res = pde_separate(eq, u, [T(x), X(t)])
|
||||
assert res == [D(T(x), x, x)/T(x),D(X(t), t, t)/X(t)]
|
||||
|
||||
|
||||
def test_pde_classify():
|
||||
# When more number of hints are added, add tests for classifying here.
|
||||
f = Function('f')
|
||||
eq1 = a*f(x,y) + b*f(x,y).diff(x) + c*f(x,y).diff(y)
|
||||
eq2 = 3*f(x,y) + 2*f(x,y).diff(x) + f(x,y).diff(y)
|
||||
eq3 = a*f(x,y) + b*f(x,y).diff(x) + 2*f(x,y).diff(y)
|
||||
eq4 = x*f(x,y) + f(x,y).diff(x) + 3*f(x,y).diff(y)
|
||||
eq5 = x**2*f(x,y) + x*f(x,y).diff(x) + x*y*f(x,y).diff(y)
|
||||
eq6 = y*x**2*f(x,y) + y*f(x,y).diff(x) + f(x,y).diff(y)
|
||||
for eq in [eq1, eq2, eq3]:
|
||||
assert classify_pde(eq) == ('1st_linear_constant_coeff_homogeneous',)
|
||||
for eq in [eq4, eq5, eq6]:
|
||||
assert classify_pde(eq) == ('1st_linear_variable_coeff',)
|
||||
|
||||
|
||||
def test_checkpdesol():
|
||||
f, F = map(Function, ['f', 'F'])
|
||||
eq1 = a*f(x,y) + b*f(x,y).diff(x) + c*f(x,y).diff(y)
|
||||
eq2 = 3*f(x,y) + 2*f(x,y).diff(x) + f(x,y).diff(y)
|
||||
eq3 = a*f(x,y) + b*f(x,y).diff(x) + 2*f(x,y).diff(y)
|
||||
for eq in [eq1, eq2, eq3]:
|
||||
assert checkpdesol(eq, pdsolve(eq))[0]
|
||||
eq4 = x*f(x,y) + f(x,y).diff(x) + 3*f(x,y).diff(y)
|
||||
eq5 = 2*f(x,y) + 1*f(x,y).diff(x) + 3*f(x,y).diff(y)
|
||||
eq6 = f(x,y) + 1*f(x,y).diff(x) + 3*f(x,y).diff(y)
|
||||
assert checkpdesol(eq4, [pdsolve(eq5), pdsolve(eq6)]) == [
|
||||
(False, (x - 2)*F(3*x - y)*exp(-x/S(5) - 3*y/S(5))),
|
||||
(False, (x - 1)*F(3*x - y)*exp(-x/S(10) - 3*y/S(10)))]
|
||||
for eq in [eq4, eq5, eq6]:
|
||||
assert checkpdesol(eq, pdsolve(eq))[0]
|
||||
sol = pdsolve(eq4)
|
||||
sol4 = Eq(sol.lhs - sol.rhs, 0)
|
||||
raises(NotImplementedError, lambda:
|
||||
checkpdesol(eq4, sol4, solve_for_func=False))
|
||||
|
||||
|
||||
def test_solvefun():
|
||||
f, F, G, H = map(Function, ['f', 'F', 'G', 'H'])
|
||||
eq1 = f(x,y) + f(x,y).diff(x) + f(x,y).diff(y)
|
||||
assert pdsolve(eq1) == Eq(f(x, y), F(x - y)*exp(-x/2 - y/2))
|
||||
assert pdsolve(eq1, solvefun=G) == Eq(f(x, y), G(x - y)*exp(-x/2 - y/2))
|
||||
assert pdsolve(eq1, solvefun=H) == Eq(f(x, y), H(x - y)*exp(-x/2 - y/2))
|
||||
|
||||
|
||||
def test_pde_1st_linear_constant_coeff_homogeneous():
|
||||
f, F = map(Function, ['f', 'F'])
|
||||
u = f(x, y)
|
||||
eq = 2*u + u.diff(x) + u.diff(y)
|
||||
assert classify_pde(eq) == ('1st_linear_constant_coeff_homogeneous',)
|
||||
sol = pdsolve(eq)
|
||||
assert sol == Eq(u, F(x - y)*exp(-x - y))
|
||||
assert checkpdesol(eq, sol)[0]
|
||||
|
||||
eq = 4 + (3*u.diff(x)/u) + (2*u.diff(y)/u)
|
||||
assert classify_pde(eq) == ('1st_linear_constant_coeff_homogeneous',)
|
||||
sol = pdsolve(eq)
|
||||
assert sol == Eq(u, F(2*x - 3*y)*exp(-S(12)*x/13 - S(8)*y/13))
|
||||
assert checkpdesol(eq, sol)[0]
|
||||
|
||||
eq = u + (6*u.diff(x)) + (7*u.diff(y))
|
||||
assert classify_pde(eq) == ('1st_linear_constant_coeff_homogeneous',)
|
||||
sol = pdsolve(eq)
|
||||
assert sol == Eq(u, F(7*x - 6*y)*exp(-6*x/S(85) - 7*y/S(85)))
|
||||
assert checkpdesol(eq, sol)[0]
|
||||
|
||||
eq = a*u + b*u.diff(x) + c*u.diff(y)
|
||||
sol = pdsolve(eq)
|
||||
assert checkpdesol(eq, sol)[0]
|
||||
|
||||
|
||||
def test_pde_1st_linear_constant_coeff():
|
||||
f, F = map(Function, ['f', 'F'])
|
||||
u = f(x,y)
|
||||
eq = -2*u.diff(x) + 4*u.diff(y) + 5*u - exp(x + 3*y)
|
||||
sol = pdsolve(eq)
|
||||
assert sol == Eq(f(x,y),
|
||||
(F(4*x + 2*y)*exp(x/2) + exp(x + 4*y)/15)*exp(-y))
|
||||
assert classify_pde(eq) == ('1st_linear_constant_coeff',
|
||||
'1st_linear_constant_coeff_Integral')
|
||||
assert checkpdesol(eq, sol)[0]
|
||||
|
||||
eq = (u.diff(x)/u) + (u.diff(y)/u) + 1 - (exp(x + y)/u)
|
||||
sol = pdsolve(eq)
|
||||
assert sol == Eq(f(x, y), F(x - y)*exp(-x/2 - y/2) + exp(x + y)/3)
|
||||
assert classify_pde(eq) == ('1st_linear_constant_coeff',
|
||||
'1st_linear_constant_coeff_Integral')
|
||||
assert checkpdesol(eq, sol)[0]
|
||||
|
||||
eq = 2*u + -u.diff(x) + 3*u.diff(y) + sin(x)
|
||||
sol = pdsolve(eq)
|
||||
assert sol == Eq(f(x, y),
|
||||
F(3*x + y)*exp(x/5 - 3*y/5) - 2*sin(x)/5 - cos(x)/5)
|
||||
assert classify_pde(eq) == ('1st_linear_constant_coeff',
|
||||
'1st_linear_constant_coeff_Integral')
|
||||
assert checkpdesol(eq, sol)[0]
|
||||
|
||||
eq = u + u.diff(x) + u.diff(y) + x*y
|
||||
sol = pdsolve(eq)
|
||||
assert sol.expand() == Eq(f(x, y),
|
||||
x + y + (x - y)**2/4 - (x + y)**2/4 + F(x - y)*exp(-x/2 - y/2) - 2).expand()
|
||||
assert classify_pde(eq) == ('1st_linear_constant_coeff',
|
||||
'1st_linear_constant_coeff_Integral')
|
||||
assert checkpdesol(eq, sol)[0]
|
||||
eq = u + u.diff(x) + u.diff(y) + log(x)
|
||||
assert classify_pde(eq) == ('1st_linear_constant_coeff',
|
||||
'1st_linear_constant_coeff_Integral')
|
||||
|
||||
|
||||
def test_pdsolve_all():
|
||||
f, F = map(Function, ['f', 'F'])
|
||||
u = f(x,y)
|
||||
eq = u + u.diff(x) + u.diff(y) + x**2*y
|
||||
sol = pdsolve(eq, hint = 'all')
|
||||
keys = ['1st_linear_constant_coeff',
|
||||
'1st_linear_constant_coeff_Integral', 'default', 'order']
|
||||
assert sorted(sol.keys()) == keys
|
||||
assert sol['order'] == 1
|
||||
assert sol['default'] == '1st_linear_constant_coeff'
|
||||
assert sol['1st_linear_constant_coeff'].expand() == Eq(f(x, y),
|
||||
-x**2*y + x**2 + 2*x*y - 4*x - 2*y + F(x - y)*exp(-x/2 - y/2) + 6).expand()
|
||||
|
||||
|
||||
def test_pdsolve_variable_coeff():
|
||||
f, F = map(Function, ['f', 'F'])
|
||||
u = f(x, y)
|
||||
eq = x*(u.diff(x)) - y*(u.diff(y)) + y**2*u - y**2
|
||||
sol = pdsolve(eq, hint="1st_linear_variable_coeff")
|
||||
assert sol == Eq(u, F(x*y)*exp(y**2/2) + 1)
|
||||
assert checkpdesol(eq, sol)[0]
|
||||
|
||||
eq = x**2*u + x*u.diff(x) + x*y*u.diff(y)
|
||||
sol = pdsolve(eq, hint='1st_linear_variable_coeff')
|
||||
assert sol == Eq(u, F(y*exp(-x))*exp(-x**2/2))
|
||||
assert checkpdesol(eq, sol)[0]
|
||||
|
||||
eq = y*x**2*u + y*u.diff(x) + u.diff(y)
|
||||
sol = pdsolve(eq, hint='1st_linear_variable_coeff')
|
||||
assert sol == Eq(u, F(-2*x + y**2)*exp(-x**3/3))
|
||||
assert checkpdesol(eq, sol)[0]
|
||||
|
||||
eq = exp(x)**2*(u.diff(x)) + y
|
||||
sol = pdsolve(eq, hint='1st_linear_variable_coeff')
|
||||
assert sol == Eq(u, y*exp(-2*x)/2 + F(y))
|
||||
assert checkpdesol(eq, sol)[0]
|
||||
|
||||
eq = exp(2*x)*(u.diff(y)) + y*u - u
|
||||
sol = pdsolve(eq, hint='1st_linear_variable_coeff')
|
||||
assert sol == Eq(u, F(x)*exp(-y*(y - 2)*exp(-2*x)/2))
|
||||
@@ -0,0 +1,462 @@
|
||||
"""Tests for solvers of systems of polynomial equations. """
|
||||
from sympy.polys.domains import ZZ, QQ_I
|
||||
from sympy.core.numbers import (I, Integer, Rational)
|
||||
from sympy.core.singleton import S
|
||||
from sympy.core.symbol import symbols
|
||||
from sympy.functions.elementary.miscellaneous import sqrt
|
||||
from sympy.polys.domains.rationalfield import QQ
|
||||
from sympy.polys.polyerrors import UnsolvableFactorError
|
||||
from sympy.polys.polyoptions import Options
|
||||
from sympy.polys.polytools import Poly
|
||||
from sympy.polys.rootoftools import CRootOf
|
||||
from sympy.solvers.solvers import solve
|
||||
from sympy.utilities.iterables import flatten
|
||||
from sympy.abc import a, b, c, x, y, z
|
||||
from sympy.polys import PolynomialError
|
||||
from sympy.solvers.polysys import (solve_poly_system,
|
||||
solve_triangulated,
|
||||
solve_biquadratic, SolveFailed,
|
||||
solve_generic, factor_system_bool,
|
||||
factor_system_cond, factor_system_poly,
|
||||
factor_system, _factor_sets, _factor_sets_slow)
|
||||
from sympy.polys.polytools import parallel_poly_from_expr
|
||||
from sympy.testing.pytest import raises
|
||||
from sympy.core.relational import Eq
|
||||
from sympy.functions.elementary.trigonometric import sin, cos
|
||||
|
||||
from sympy.functions.elementary.exponential import exp
|
||||
|
||||
|
||||
def test_solve_poly_system():
|
||||
assert solve_poly_system([x - 1], x) == [(S.One,)]
|
||||
|
||||
assert solve_poly_system([y - x, y - x - 1], x, y) is None
|
||||
|
||||
assert solve_poly_system([y - x**2, y + x**2], x, y) == [(S.Zero, S.Zero)]
|
||||
|
||||
assert solve_poly_system([2*x - 3, y*Rational(3, 2) - 2*x, z - 5*y], x, y, z) == \
|
||||
[(Rational(3, 2), Integer(2), Integer(10))]
|
||||
|
||||
assert solve_poly_system([x*y - 2*y, 2*y**2 - x**2], x, y) == \
|
||||
[(0, 0), (2, -sqrt(2)), (2, sqrt(2))]
|
||||
|
||||
assert solve_poly_system([y - x**2, y + x**2 + 1], x, y) == \
|
||||
[(-I*sqrt(S.Half), Rational(-1, 2)), (I*sqrt(S.Half), Rational(-1, 2))]
|
||||
|
||||
f_1 = x**2 + y + z - 1
|
||||
f_2 = x + y**2 + z - 1
|
||||
f_3 = x + y + z**2 - 1
|
||||
|
||||
a, b = sqrt(2) - 1, -sqrt(2) - 1
|
||||
|
||||
assert solve_poly_system([f_1, f_2, f_3], x, y, z) == \
|
||||
[(0, 0, 1), (0, 1, 0), (1, 0, 0), (a, a, a), (b, b, b)]
|
||||
|
||||
solution = [(1, -1), (1, 1)]
|
||||
|
||||
assert solve_poly_system([Poly(x**2 - y**2), Poly(x - 1)]) == solution
|
||||
assert solve_poly_system([x**2 - y**2, x - 1], x, y) == solution
|
||||
assert solve_poly_system([x**2 - y**2, x - 1]) == solution
|
||||
|
||||
assert solve_poly_system(
|
||||
[x + x*y - 3, y + x*y - 4], x, y) == [(-3, -2), (1, 2)]
|
||||
|
||||
raises(NotImplementedError, lambda: solve_poly_system([x**3 - y**3], x, y))
|
||||
raises(NotImplementedError, lambda: solve_poly_system(
|
||||
[z, -2*x*y**2 + x + y**2*z, y**2*(-z - 4) + 2]))
|
||||
raises(PolynomialError, lambda: solve_poly_system([1/x], x))
|
||||
|
||||
raises(NotImplementedError, lambda: solve_poly_system(
|
||||
[x-1,], (x, y)))
|
||||
raises(NotImplementedError, lambda: solve_poly_system(
|
||||
[y-1,], (x, y)))
|
||||
|
||||
# solve_poly_system should ideally construct solutions using
|
||||
# CRootOf for the following four tests
|
||||
assert solve_poly_system([x**5 - x + 1], [x], strict=False) == []
|
||||
raises(UnsolvableFactorError, lambda: solve_poly_system(
|
||||
[x**5 - x + 1], [x], strict=True))
|
||||
|
||||
assert solve_poly_system([(x - 1)*(x**5 - x + 1), y**2 - 1], [x, y],
|
||||
strict=False) == [(1, -1), (1, 1)]
|
||||
raises(UnsolvableFactorError,
|
||||
lambda: solve_poly_system([(x - 1)*(x**5 - x + 1), y**2-1],
|
||||
[x, y], strict=True))
|
||||
|
||||
|
||||
def test_solve_generic():
|
||||
NewOption = Options((x, y), {'domain': 'ZZ'})
|
||||
assert solve_generic([x**2 - 2*y**2, y**2 - y + 1], NewOption) == \
|
||||
[(-sqrt(-1 - sqrt(3)*I), Rational(1, 2) - sqrt(3)*I/2),
|
||||
(sqrt(-1 - sqrt(3)*I), Rational(1, 2) - sqrt(3)*I/2),
|
||||
(-sqrt(-1 + sqrt(3)*I), Rational(1, 2) + sqrt(3)*I/2),
|
||||
(sqrt(-1 + sqrt(3)*I), Rational(1, 2) + sqrt(3)*I/2)]
|
||||
|
||||
# solve_generic should ideally construct solutions using
|
||||
# CRootOf for the following two tests
|
||||
assert solve_generic(
|
||||
[2*x - y, (y - 1)*(y**5 - y + 1)], NewOption, strict=False) == \
|
||||
[(Rational(1, 2), 1)]
|
||||
raises(UnsolvableFactorError, lambda: solve_generic(
|
||||
[2*x - y, (y - 1)*(y**5 - y + 1)], NewOption, strict=True))
|
||||
|
||||
|
||||
def test_solve_biquadratic():
|
||||
x0, y0, x1, y1, r = symbols('x0 y0 x1 y1 r')
|
||||
|
||||
f_1 = (x - 1)**2 + (y - 1)**2 - r**2
|
||||
f_2 = (x - 2)**2 + (y - 2)**2 - r**2
|
||||
s = sqrt(2*r**2 - 1)
|
||||
a = (3 - s)/2
|
||||
b = (3 + s)/2
|
||||
assert solve_poly_system([f_1, f_2], x, y) == [(a, b), (b, a)]
|
||||
|
||||
f_1 = (x - 1)**2 + (y - 2)**2 - r**2
|
||||
f_2 = (x - 1)**2 + (y - 1)**2 - r**2
|
||||
|
||||
assert solve_poly_system([f_1, f_2], x, y) == \
|
||||
[(1 - sqrt((2*r - 1)*(2*r + 1))/2, Rational(3, 2)),
|
||||
(1 + sqrt((2*r - 1)*(2*r + 1))/2, Rational(3, 2))]
|
||||
|
||||
query = lambda expr: expr.is_Pow and expr.exp is S.Half
|
||||
|
||||
f_1 = (x - 1 )**2 + (y - 2)**2 - r**2
|
||||
f_2 = (x - x1)**2 + (y - 1)**2 - r**2
|
||||
|
||||
result = solve_poly_system([f_1, f_2], x, y)
|
||||
|
||||
assert len(result) == 2 and all(len(r) == 2 for r in result)
|
||||
assert all(r.count(query) == 1 for r in flatten(result))
|
||||
|
||||
f_1 = (x - x0)**2 + (y - y0)**2 - r**2
|
||||
f_2 = (x - x1)**2 + (y - y1)**2 - r**2
|
||||
|
||||
result = solve_poly_system([f_1, f_2], x, y)
|
||||
|
||||
assert len(result) == 2 and all(len(r) == 2 for r in result)
|
||||
assert all(len(r.find(query)) == 1 for r in flatten(result))
|
||||
|
||||
s1 = (x*y - y, x**2 - x)
|
||||
assert solve(s1) == [{x: 1}, {x: 0, y: 0}]
|
||||
s2 = (x*y - x, y**2 - y)
|
||||
assert solve(s2) == [{y: 1}, {x: 0, y: 0}]
|
||||
gens = (x, y)
|
||||
for seq in (s1, s2):
|
||||
(f, g), opt = parallel_poly_from_expr(seq, *gens)
|
||||
raises(SolveFailed, lambda: solve_biquadratic(f, g, opt))
|
||||
seq = (x**2 + y**2 - 2, y**2 - 1)
|
||||
(f, g), opt = parallel_poly_from_expr(seq, *gens)
|
||||
assert solve_biquadratic(f, g, opt) == [
|
||||
(-1, -1), (-1, 1), (1, -1), (1, 1)]
|
||||
ans = [(0, -1), (0, 1)]
|
||||
seq = (x**2 + y**2 - 1, y**2 - 1)
|
||||
(f, g), opt = parallel_poly_from_expr(seq, *gens)
|
||||
assert solve_biquadratic(f, g, opt) == ans
|
||||
seq = (x**2 + y**2 - 1, x**2 - x + y**2 - 1)
|
||||
(f, g), opt = parallel_poly_from_expr(seq, *gens)
|
||||
assert solve_biquadratic(f, g, opt) == ans
|
||||
|
||||
|
||||
def test_solve_triangulated():
|
||||
f_1 = x**2 + y + z - 1
|
||||
f_2 = x + y**2 + z - 1
|
||||
f_3 = x + y + z**2 - 1
|
||||
|
||||
a, b = sqrt(2) - 1, -sqrt(2) - 1
|
||||
|
||||
assert solve_triangulated([f_1, f_2, f_3], x, y, z) == \
|
||||
[(0, 0, 1), (0, 1, 0), (1, 0, 0)]
|
||||
|
||||
dom = QQ.algebraic_field(sqrt(2))
|
||||
|
||||
assert solve_triangulated([f_1, f_2, f_3], x, y, z, domain=dom) == \
|
||||
[(0, 0, 1), (0, 1, 0), (1, 0, 0), (a, a, a), (b, b, b)]
|
||||
|
||||
a, b = CRootOf(z**2 + 2*z - 1, 0), CRootOf(z**2 + 2*z - 1, 1)
|
||||
assert solve_triangulated([f_1, f_2, f_3], x, y, z, extension=True) == \
|
||||
[(0, 0, 1), (0, 1, 0), (1, 0, 0), (a, a, a), (b, b, b)]
|
||||
|
||||
|
||||
def test_solve_issue_3686():
|
||||
roots = solve_poly_system([((x - 5)**2/250000 + (y - Rational(5, 10))**2/250000) - 1, x], x, y)
|
||||
assert roots == [(0, S.Half - 15*sqrt(1111)), (0, S.Half + 15*sqrt(1111))]
|
||||
|
||||
roots = solve_poly_system([((x - 5)**2/250000 + (y - 5.0/10)**2/250000) - 1, x], x, y)
|
||||
# TODO: does this really have to be so complicated?!
|
||||
assert len(roots) == 2
|
||||
assert roots[0][0] == 0
|
||||
assert roots[0][1].epsilon_eq(-499.474999374969, 1e12)
|
||||
assert roots[1][0] == 0
|
||||
assert roots[1][1].epsilon_eq(500.474999374969, 1e12)
|
||||
|
||||
|
||||
def test_factor_system():
|
||||
|
||||
assert factor_system([x**2 + 2*x + 1]) == [[x + 1]]
|
||||
assert factor_system([x**2 + 2*x + 1, y**2 + 2*y + 1]) == [[x + 1, y + 1]]
|
||||
assert factor_system([x**2 + 1]) == [[x**2 + 1]]
|
||||
assert factor_system([]) == [[]]
|
||||
|
||||
assert factor_system([x**2 + y**2 + 2*x*y, x**2 - 2], extension=sqrt(2)) == [
|
||||
[x + y, x + sqrt(2)],
|
||||
[x + y, x - sqrt(2)],
|
||||
]
|
||||
|
||||
assert factor_system([x**2 + 1, y**2 + 1], gaussian=True) == [
|
||||
[x + I, y + I],
|
||||
[x + I, y - I],
|
||||
[x - I, y + I],
|
||||
[x - I, y - I],
|
||||
]
|
||||
|
||||
assert factor_system([x**2 + 1, y**2 + 1], domain=QQ_I) == [
|
||||
[x + I, y + I],
|
||||
[x + I, y - I],
|
||||
[x - I, y + I],
|
||||
[x - I, y - I],
|
||||
]
|
||||
|
||||
assert factor_system([0]) == [[]]
|
||||
assert factor_system([1]) == []
|
||||
assert factor_system([0 , x]) == [[x]]
|
||||
assert factor_system([1, 0, x]) == []
|
||||
|
||||
assert factor_system([x**4 - 1, y**6 - 1]) == [
|
||||
[x**2 + 1, y**2 + y + 1],
|
||||
[x**2 + 1, y**2 - y + 1],
|
||||
[x**2 + 1, y + 1],
|
||||
[x**2 + 1, y - 1],
|
||||
[x + 1, y**2 + y + 1],
|
||||
[x + 1, y**2 - y + 1],
|
||||
[x - 1, y**2 + y + 1],
|
||||
[x - 1, y**2 - y + 1],
|
||||
[x + 1, y + 1],
|
||||
[x + 1, y - 1],
|
||||
[x - 1, y + 1],
|
||||
[x - 1, y - 1],
|
||||
]
|
||||
|
||||
assert factor_system([(x - 1)*(y - 2), (y - 2)*(z - 3)]) == [
|
||||
[x - 1, z - 3],
|
||||
[y - 2]
|
||||
]
|
||||
|
||||
assert factor_system([sin(x)**2 + cos(x)**2 - 1, x]) == [
|
||||
[x, sin(x)**2 + cos(x)**2 - 1],
|
||||
]
|
||||
|
||||
assert factor_system([sin(x)**2 + cos(x)**2 - 1]) == [
|
||||
[sin(x)**2 + cos(x)**2 - 1]
|
||||
]
|
||||
|
||||
assert factor_system([sin(x)**2 + cos(x)**2]) == [
|
||||
[sin(x)**2 + cos(x)**2]
|
||||
]
|
||||
|
||||
assert factor_system([a*x, y, a]) == [[y, a]]
|
||||
|
||||
assert factor_system([a*x, y, a], [x, y]) == []
|
||||
|
||||
assert factor_system([a ** 2 * x, y], [x, y]) == [[x, y]]
|
||||
|
||||
assert factor_system([a*x*(x - 1), b*y, c], [x, y]) == []
|
||||
|
||||
assert factor_system([a*x*(x - 1), b*y, c], [x, y, c]) == [
|
||||
[x - 1, y, c],
|
||||
[x, y, c],
|
||||
]
|
||||
|
||||
assert factor_system([a*x*(x - 1), b*y, c]) == [
|
||||
[x - 1, y, c],
|
||||
[x, y, c],
|
||||
[x - 1, b, c],
|
||||
[x, b, c],
|
||||
[y, a, c],
|
||||
[a, b, c],
|
||||
]
|
||||
|
||||
assert factor_system([x**2 - 2], [y]) == []
|
||||
|
||||
assert factor_system([x**2 - 2], [x]) == [[x**2 - 2]]
|
||||
|
||||
assert factor_system([cos(x)**2 - sin(x)**2, cos(x)**2 + sin(x)**2 - 1]) == [
|
||||
[sin(x)**2 + cos(x)**2 - 1, sin(x) + cos(x)],
|
||||
[sin(x)**2 + cos(x)**2 - 1, -sin(x) + cos(x)],
|
||||
]
|
||||
|
||||
assert factor_system([(cos(x) + sin(x))**2 - 1, cos(x)**2 - sin(x)**2 - cos(2*x)]) == [
|
||||
[sin(x)**2 - cos(x)**2 + cos(2*x), sin(x) + cos(x) + 1],
|
||||
[sin(x)**2 - cos(x)**2 + cos(2*x), sin(x) + cos(x) - 1],
|
||||
]
|
||||
|
||||
assert factor_system([(cos(x) + sin(x))*exp(y) - 1, (cos(x) - sin(x))*exp(y) - 1]) == [
|
||||
[exp(y)*sin(x) + exp(y)*cos(x) - 1, -exp(y)*sin(x) + exp(y)*cos(x) - 1]
|
||||
]
|
||||
|
||||
|
||||
def test_factor_system_poly():
|
||||
|
||||
px = lambda e: Poly(e, x)
|
||||
pxab = lambda e: Poly(e, x, domain=ZZ[a, b])
|
||||
pxI = lambda e: Poly(e, x, domain=QQ_I)
|
||||
pxyz = lambda e: Poly(e, (x, y, z))
|
||||
|
||||
assert factor_system_poly([px(x**2 - 1), px(x**2 - 4)]) == [
|
||||
[px(x + 2), px(x + 1)],
|
||||
[px(x + 2), px(x - 1)],
|
||||
[px(x + 1), px(x - 2)],
|
||||
[px(x - 1), px(x - 2)],
|
||||
]
|
||||
|
||||
assert factor_system_poly([px(x**2 - 1)]) == [[px(x + 1)], [px(x - 1)]]
|
||||
|
||||
assert factor_system_poly([pxyz(x**2*y - y), pxyz(x**2*z - z)]) == [
|
||||
[pxyz(x + 1)],
|
||||
[pxyz(x - 1)],
|
||||
[pxyz(y), pxyz(z)],
|
||||
]
|
||||
|
||||
assert factor_system_poly([px(x**2*(x - 1)**2), px(x*(x - 1))]) == [
|
||||
[px(x)],
|
||||
[px(x - 1)],
|
||||
]
|
||||
|
||||
assert factor_system_poly([pxyz(x**2 + y*x), pxyz(x**2 + z*x)]) == [
|
||||
[pxyz(x + y), pxyz(x + z)],
|
||||
[pxyz(x)],
|
||||
]
|
||||
|
||||
assert factor_system_poly([pxab((a - 1)*(x - 2)), pxab((b - 3)*(x - 2))]) == [
|
||||
[pxab(x - 2)],
|
||||
[pxab(a - 1), pxab(b - 3)],
|
||||
]
|
||||
|
||||
assert factor_system_poly([pxI(x**2 + 1)]) == [[pxI(x + I)], [pxI(x - I)]]
|
||||
|
||||
assert factor_system_poly([]) == [[]]
|
||||
|
||||
assert factor_system_poly([px(1)]) == []
|
||||
assert factor_system_poly([px(0), px(x)]) == [[px(x)]]
|
||||
|
||||
|
||||
def test_factor_system_cond():
|
||||
|
||||
assert factor_system_cond([x ** 2 - 1, x ** 2 - 4]) == [
|
||||
[x + 2, x + 1],
|
||||
[x + 2, x - 1],
|
||||
[x + 1, x - 2],
|
||||
[x - 1, x - 2],
|
||||
]
|
||||
|
||||
assert factor_system_cond([1]) == []
|
||||
assert factor_system_cond([0]) == [[]]
|
||||
assert factor_system_cond([1, x]) == []
|
||||
assert factor_system_cond([0, x]) == [[x]]
|
||||
assert factor_system_cond([]) == [[]]
|
||||
|
||||
assert factor_system_cond([x**2 + y*x]) == [[x + y], [x]]
|
||||
|
||||
assert factor_system_cond([(a - 1)*(x - 2), (b - 3)*(x - 2)], [x]) == [
|
||||
[x - 2],
|
||||
[a - 1, b - 3],
|
||||
]
|
||||
|
||||
assert factor_system_cond([a * (x - 1), b], [x]) == [[x - 1, b], [a, b]]
|
||||
|
||||
assert factor_system_cond([a*x*(x-1), b*y, c], [x, y]) == [
|
||||
[x - 1, y, c],
|
||||
[x, y, c],
|
||||
[x - 1, b, c],
|
||||
[x, b, c],
|
||||
[y, a, c],
|
||||
[a, b, c],
|
||||
]
|
||||
|
||||
assert factor_system_cond([x*(x-1), y], [x, y]) == [[x - 1, y], [x, y]]
|
||||
|
||||
assert factor_system_cond([a*x, y, a], [x, y]) == [[y, a]]
|
||||
|
||||
assert factor_system_cond([a*x, b*x], [x, y]) == [[x], [a, b]]
|
||||
|
||||
assert factor_system_cond([a*b*x, y], [x, y]) == [[x, y], [y, a*b]]
|
||||
|
||||
assert factor_system_cond([a*b*x, y]) == [[x, y], [y, a], [y, b]]
|
||||
|
||||
assert factor_system_cond([a**2*x, y], [x, y]) == [[x, y], [y, a]]
|
||||
|
||||
def test_factor_system_bool():
|
||||
|
||||
eqs = [a*(x - 1)*(y - 1), b*(x - 2)*(y - 1)*(y - 2)]
|
||||
assert factor_system_bool(eqs, [x, y]) == (
|
||||
Eq(y - 1, 0)
|
||||
| (Eq(a, 0) & Eq(b, 0))
|
||||
| (Eq(a, 0) & Eq(x - 2, 0))
|
||||
| (Eq(a, 0) & Eq(y - 2, 0))
|
||||
| (Eq(b, 0) & Eq(x - 1, 0))
|
||||
| (Eq(x - 2, 0) & Eq(x - 1, 0))
|
||||
| (Eq(x - 1, 0) & Eq(y - 2, 0))
|
||||
)
|
||||
|
||||
assert factor_system_bool([x - 1], [x]) == Eq(x - 1, 0)
|
||||
|
||||
assert factor_system_bool([(x - 1)*(x - 2)], [x]) == Eq(x - 2, 0) | Eq(x - 1, 0)
|
||||
|
||||
assert factor_system_bool([], [x]) == True
|
||||
assert factor_system_bool([0], [x]) == True
|
||||
assert factor_system_bool([1], [x]) == False
|
||||
assert factor_system_bool([a], [x]) == Eq(a, 0)
|
||||
|
||||
assert factor_system_bool([a * x, y, a], [x, y]) == Eq(a, 0) & Eq(y, 0)
|
||||
|
||||
assert (factor_system_bool([a*x, b*y*x, a], [x, y]) == (
|
||||
Eq(a, 0) & Eq(b, 0))
|
||||
| (Eq(a, 0) & Eq(x, 0))
|
||||
| (Eq(a, 0) & Eq(y, 0)))
|
||||
|
||||
assert (factor_system_bool([a*x, b*x], [x, y]) == Eq(x, 0) |
|
||||
(Eq(a, 0) & Eq(b, 0)))
|
||||
|
||||
assert (factor_system_bool([a*b*x, y], [x, y]) == (
|
||||
Eq(x, 0) & Eq(y, 0)) |
|
||||
(Eq(y, 0) & Eq(a*b, 0)))
|
||||
|
||||
assert (factor_system_bool([a**2*x, y], [x, y]) == (
|
||||
Eq(a, 0) & Eq(y, 0)) |
|
||||
(Eq(x, 0) & Eq(y, 0)))
|
||||
|
||||
assert factor_system_bool([a*x*y, b*y*z], [x, y, z]) == (
|
||||
Eq(y, 0)
|
||||
| (Eq(a, 0) & Eq(b, 0))
|
||||
| (Eq(a, 0) & Eq(z, 0))
|
||||
| (Eq(b, 0) & Eq(x, 0))
|
||||
| (Eq(x, 0) & Eq(z, 0))
|
||||
)
|
||||
|
||||
assert factor_system_bool([a*(x - 1), b], [x]) == (
|
||||
(Eq(a, 0) & Eq(b, 0))
|
||||
| (Eq(x - 1, 0) & Eq(b, 0))
|
||||
)
|
||||
|
||||
|
||||
def test_factor_sets():
|
||||
#
|
||||
from random import randint
|
||||
|
||||
def generate_random_system(n_eqs=3, n_factors=2, max_val=10):
|
||||
return [
|
||||
[randint(0, max_val) for _ in range(randint(1, n_factors))]
|
||||
for _ in range(n_eqs)
|
||||
]
|
||||
|
||||
test_cases = [
|
||||
[[1, 2], [1, 3]],
|
||||
[[1, 2], [3, 4]],
|
||||
[[1], [1, 2], [2]],
|
||||
]
|
||||
|
||||
for case in test_cases:
|
||||
assert _factor_sets(case) == _factor_sets_slow(case)
|
||||
|
||||
for _ in range(100):
|
||||
system = generate_random_system()
|
||||
assert _factor_sets(system) == _factor_sets_slow(system)
|
||||
@@ -0,0 +1,295 @@
|
||||
from sympy.core.function import (Function, Lambda, expand)
|
||||
from sympy.core.numbers import (I, Rational)
|
||||
from sympy.core.relational import Eq
|
||||
from sympy.core.singleton import S
|
||||
from sympy.core.symbol import (Symbol, symbols)
|
||||
from sympy.functions.combinatorial.factorials import (rf, binomial, factorial)
|
||||
from sympy.functions.elementary.complexes import Abs
|
||||
from sympy.functions.elementary.miscellaneous import sqrt
|
||||
from sympy.functions.elementary.trigonometric import (cos, sin)
|
||||
from sympy.polys.polytools import factor
|
||||
from sympy.solvers.recurr import rsolve, rsolve_hyper, rsolve_poly, rsolve_ratio
|
||||
from sympy.testing.pytest import raises, slow, XFAIL
|
||||
from sympy.abc import a, b
|
||||
|
||||
y = Function('y')
|
||||
n, k = symbols('n,k', integer=True)
|
||||
C0, C1, C2 = symbols('C0,C1,C2')
|
||||
|
||||
|
||||
def test_rsolve_poly():
|
||||
assert rsolve_poly([-1, -1, 1], 0, n) == 0
|
||||
assert rsolve_poly([-1, -1, 1], 1, n) == -1
|
||||
|
||||
assert rsolve_poly([-1, n + 1], n, n) == 1
|
||||
assert rsolve_poly([-1, 1], n, n) == C0 + (n**2 - n)/2
|
||||
assert rsolve_poly([-n - 1, n], 1, n) == C0*n - 1
|
||||
assert rsolve_poly([-4*n - 2, 1], 4*n + 1, n) == -1
|
||||
|
||||
assert rsolve_poly([-1, 1], n**5 + n**3, n) == \
|
||||
C0 - n**3 / 2 - n**5 / 2 + n**2 / 6 + n**6 / 6 + 2*n**4 / 3
|
||||
|
||||
|
||||
def test_rsolve_ratio():
|
||||
solution = rsolve_ratio([-2*n**3 + n**2 + 2*n - 1, 2*n**3 + n**2 - 6*n,
|
||||
-2*n**3 - 11*n**2 - 18*n - 9, 2*n**3 + 13*n**2 + 22*n + 8], 0, n)
|
||||
assert solution == C0*(2*n - 3)/(n**2 - 1)/2
|
||||
|
||||
|
||||
def test_rsolve_hyper():
|
||||
assert rsolve_hyper([-1, -1, 1], 0, n) in [
|
||||
C0*(S.Half - S.Half*sqrt(5))**n + C1*(S.Half + S.Half*sqrt(5))**n,
|
||||
C1*(S.Half - S.Half*sqrt(5))**n + C0*(S.Half + S.Half*sqrt(5))**n,
|
||||
]
|
||||
|
||||
assert rsolve_hyper([n**2 - 2, -2*n - 1, 1], 0, n) in [
|
||||
C0*rf(sqrt(2), n) + C1*rf(-sqrt(2), n),
|
||||
C1*rf(sqrt(2), n) + C0*rf(-sqrt(2), n),
|
||||
]
|
||||
|
||||
assert rsolve_hyper([n**2 - k, -2*n - 1, 1], 0, n) in [
|
||||
C0*rf(sqrt(k), n) + C1*rf(-sqrt(k), n),
|
||||
C1*rf(sqrt(k), n) + C0*rf(-sqrt(k), n),
|
||||
]
|
||||
|
||||
assert rsolve_hyper(
|
||||
[2*n*(n + 1), -n**2 - 3*n + 2, n - 1], 0, n) == C1*factorial(n) + C0*2**n
|
||||
|
||||
assert rsolve_hyper(
|
||||
[n + 2, -(2*n + 3)*(17*n**2 + 51*n + 39), n + 1], 0, n) == 0
|
||||
|
||||
assert rsolve_hyper([-n - 1, -1, 1], 0, n) == 0
|
||||
|
||||
assert rsolve_hyper([-1, 1], n, n).expand() == C0 + n**2/2 - n/2
|
||||
|
||||
assert rsolve_hyper([-1, 1], 1 + n, n).expand() == C0 + n**2/2 + n/2
|
||||
|
||||
assert rsolve_hyper([-1, 1], 3*(n + n**2), n).expand() == C0 + n**3 - n
|
||||
|
||||
assert rsolve_hyper([-a, 1],0,n).expand() == C0*a**n
|
||||
|
||||
assert rsolve_hyper([-a, 0, 1], 0, n).expand() == (-1)**n*C1*a**(n/2) + C0*a**(n/2)
|
||||
|
||||
assert rsolve_hyper([1, 1, 1], 0, n).expand() == \
|
||||
C0*(Rational(-1, 2) - sqrt(3)*I/2)**n + C1*(Rational(-1, 2) + sqrt(3)*I/2)**n
|
||||
|
||||
assert rsolve_hyper([1, -2*n/a - 2/a, 1], 0, n) == 0
|
||||
|
||||
|
||||
@XFAIL
|
||||
def test_rsolve_ratio_missed():
|
||||
# this arises during computation
|
||||
# assert rsolve_hyper([-1, 1], 3*(n + n**2), n).expand() == C0 + n**3 - n
|
||||
assert rsolve_ratio([-n, n + 2], n, n) is not None
|
||||
|
||||
|
||||
def recurrence_term(c, f):
|
||||
"""Compute RHS of recurrence in f(n) with coefficients in c."""
|
||||
return sum(c[i]*f.subs(n, n + i) for i in range(len(c)))
|
||||
|
||||
|
||||
def test_rsolve_bulk():
|
||||
"""Some bulk-generated tests."""
|
||||
funcs = [ n, n + 1, n**2, n**3, n**4, n + n**2, 27*n + 52*n**2 - 3*
|
||||
n**3 + 12*n**4 - 52*n**5 ]
|
||||
coeffs = [ [-2, 1], [-2, -1, 1], [-1, 1, 1, -1, 1], [-n, 1], [n**2 -
|
||||
n + 12, 1] ]
|
||||
for p in funcs:
|
||||
# compute difference
|
||||
for c in coeffs:
|
||||
q = recurrence_term(c, p)
|
||||
if p.is_polynomial(n):
|
||||
assert rsolve_poly(c, q, n) == p
|
||||
# See issue 3956:
|
||||
if p.is_hypergeometric(n) and len(c) <= 3:
|
||||
assert rsolve_hyper(c, q, n).subs(zip(symbols('C:3'), [0, 0, 0])).expand() == p
|
||||
|
||||
|
||||
def test_rsolve_0_sol_homogeneous():
|
||||
# fixed by cherry-pick from
|
||||
# https://github.com/diofant/diofant/commit/e1d2e52125199eb3df59f12e8944f8a5f24b00a5
|
||||
assert rsolve_hyper([n**2 - n + 12, 1], n*(n**2 - n + 12) + n + 1, n) == n
|
||||
|
||||
|
||||
def test_rsolve():
|
||||
f = y(n + 2) - y(n + 1) - y(n)
|
||||
h = sqrt(5)*(S.Half + S.Half*sqrt(5))**n \
|
||||
- sqrt(5)*(S.Half - S.Half*sqrt(5))**n
|
||||
|
||||
assert rsolve(f, y(n)) in [
|
||||
C0*(S.Half - S.Half*sqrt(5))**n + C1*(S.Half + S.Half*sqrt(5))**n,
|
||||
C1*(S.Half - S.Half*sqrt(5))**n + C0*(S.Half + S.Half*sqrt(5))**n,
|
||||
]
|
||||
|
||||
assert rsolve(f, y(n), [0, 5]) == h
|
||||
assert rsolve(f, y(n), {0: 0, 1: 5}) == h
|
||||
assert rsolve(f, y(n), {y(0): 0, y(1): 5}) == h
|
||||
assert rsolve(y(n) - y(n - 1) - y(n - 2), y(n), [0, 5]) == h
|
||||
assert rsolve(Eq(y(n), y(n - 1) + y(n - 2)), y(n), [0, 5]) == h
|
||||
|
||||
assert f.subs(y, Lambda(k, rsolve(f, y(n)).subs(n, k))).simplify() == 0
|
||||
|
||||
f = (n - 1)*y(n + 2) - (n**2 + 3*n - 2)*y(n + 1) + 2*n*(n + 1)*y(n)
|
||||
g = C1*factorial(n) + C0*2**n
|
||||
h = -3*factorial(n) + 3*2**n
|
||||
|
||||
assert rsolve(f, y(n)) == g
|
||||
assert rsolve(f, y(n), []) == g
|
||||
assert rsolve(f, y(n), {}) == g
|
||||
|
||||
assert rsolve(f, y(n), [0, 3]) == h
|
||||
assert rsolve(f, y(n), {0: 0, 1: 3}) == h
|
||||
assert rsolve(f, y(n), {y(0): 0, y(1): 3}) == h
|
||||
|
||||
assert f.subs(y, Lambda(k, rsolve(f, y(n)).subs(n, k))).simplify() == 0
|
||||
|
||||
f = y(n) - y(n - 1) - 2
|
||||
|
||||
assert rsolve(f, y(n), {y(0): 0}) == 2*n
|
||||
assert rsolve(f, y(n), {y(0): 1}) == 2*n + 1
|
||||
assert rsolve(f, y(n), {y(0): 0, y(1): 1}) is None
|
||||
|
||||
assert f.subs(y, Lambda(k, rsolve(f, y(n)).subs(n, k))).simplify() == 0
|
||||
|
||||
f = 3*y(n - 1) - y(n) - 1
|
||||
|
||||
assert rsolve(f, y(n), {y(0): 0}) == -3**n/2 + S.Half
|
||||
assert rsolve(f, y(n), {y(0): 1}) == 3**n/2 + S.Half
|
||||
assert rsolve(f, y(n), {y(0): 2}) == 3*3**n/2 + S.Half
|
||||
|
||||
assert f.subs(y, Lambda(k, rsolve(f, y(n)).subs(n, k))).simplify() == 0
|
||||
|
||||
f = y(n) - 1/n*y(n - 1)
|
||||
assert rsolve(f, y(n)) == C0/factorial(n)
|
||||
assert f.subs(y, Lambda(k, rsolve(f, y(n)).subs(n, k))).simplify() == 0
|
||||
|
||||
f = y(n) - 1/n*y(n - 1) - 1
|
||||
assert rsolve(f, y(n)) is None
|
||||
|
||||
f = 2*y(n - 1) + (1 - n)*y(n)/n
|
||||
|
||||
assert rsolve(f, y(n), {y(1): 1}) == 2**(n - 1)*n
|
||||
assert rsolve(f, y(n), {y(1): 2}) == 2**(n - 1)*n*2
|
||||
assert rsolve(f, y(n), {y(1): 3}) == 2**(n - 1)*n*3
|
||||
|
||||
assert f.subs(y, Lambda(k, rsolve(f, y(n)).subs(n, k))).simplify() == 0
|
||||
|
||||
f = (n - 1)*(n - 2)*y(n + 2) - (n + 1)*(n + 2)*y(n)
|
||||
|
||||
assert rsolve(f, y(n), {y(3): 6, y(4): 24}) == n*(n - 1)*(n - 2)
|
||||
assert rsolve(
|
||||
f, y(n), {y(3): 6, y(4): -24}) == -n*(n - 1)*(n - 2)*(-1)**(n)
|
||||
|
||||
assert f.subs(y, Lambda(k, rsolve(f, y(n)).subs(n, k))).simplify() == 0
|
||||
|
||||
assert rsolve(Eq(y(n + 1), a*y(n)), y(n), {y(1): a}).simplify() == a**n
|
||||
|
||||
assert rsolve(y(n) - a*y(n-2),y(n), \
|
||||
{y(1): sqrt(a)*(a + b), y(2): a*(a - b)}).simplify() == \
|
||||
a**(n/2 + 1) - b*(-sqrt(a))**n
|
||||
|
||||
f = (-16*n**2 + 32*n - 12)*y(n - 1) + (4*n**2 - 12*n + 9)*y(n)
|
||||
|
||||
yn = rsolve(f, y(n), {y(1): binomial(2*n + 1, 3)})
|
||||
sol = 2**(2*n)*n*(2*n - 1)**2*(2*n + 1)/12
|
||||
assert factor(expand(yn, func=True)) == sol
|
||||
|
||||
sol = rsolve(y(n) + a*(y(n + 1) + y(n - 1))/2, y(n))
|
||||
assert str(sol) == 'C0*((-sqrt(1 - a**2) - 1)/a)**n + C1*((sqrt(1 - a**2) - 1)/a)**n'
|
||||
|
||||
assert rsolve((k + 1)*y(k), y(k)) is None
|
||||
assert (rsolve((k + 1)*y(k) + (k + 3)*y(k + 1) + (k + 5)*y(k + 2), y(k))
|
||||
is None)
|
||||
|
||||
assert rsolve(y(n) + y(n + 1) + 2**n + 3**n, y(n)) == (-1)**n*C0 - 2**n/3 - 3**n/4
|
||||
|
||||
|
||||
def test_rsolve_raises():
|
||||
x = Function('x')
|
||||
raises(ValueError, lambda: rsolve(y(n) - y(k + 1), y(n)))
|
||||
raises(ValueError, lambda: rsolve(y(n) - y(n + 1), x(n)))
|
||||
raises(ValueError, lambda: rsolve(y(n) - x(n + 1), y(n)))
|
||||
raises(ValueError, lambda: rsolve(y(n) - sqrt(n)*y(n + 1), y(n)))
|
||||
raises(ValueError, lambda: rsolve(y(n) - y(n + 1), y(n), {x(0): 0}))
|
||||
raises(ValueError, lambda: rsolve(y(n) + y(n + 1) + 2**n + cos(n), y(n)))
|
||||
|
||||
|
||||
def test_issue_6844():
|
||||
f = y(n + 2) - y(n + 1) + y(n)/4
|
||||
assert rsolve(f, y(n)) == 2**(-n + 1)*C1*n + 2**(-n)*C0
|
||||
assert rsolve(f, y(n), {y(0): 0, y(1): 1}) == 2**(1 - n)*n
|
||||
|
||||
|
||||
def test_issue_18751():
|
||||
r = Symbol('r', positive=True)
|
||||
theta = Symbol('theta', real=True)
|
||||
f = y(n) - 2 * r * cos(theta) * y(n - 1) + r**2 * y(n - 2)
|
||||
assert rsolve(f, y(n)) == \
|
||||
C0*(r*(cos(theta) - I*Abs(sin(theta))))**n + C1*(r*(cos(theta) + I*Abs(sin(theta))))**n
|
||||
|
||||
|
||||
def test_constant_naming():
|
||||
#issue 8697
|
||||
assert rsolve(y(n+3) - y(n+2) - y(n+1) + y(n), y(n)) == (-1)**n*C1 + C0 + C2*n
|
||||
assert rsolve(y(n+3)+3*y(n+2)+3*y(n+1)+y(n), y(n)).expand() == (-1)**n*C0 - (-1)**n*C1*n - (-1)**n*C2*n**2
|
||||
assert rsolve(y(n) - 2*y(n - 3) + 5*y(n - 2) - 4*y(n - 1),y(n),[1,3,8]) == 3*2**n - n - 2
|
||||
|
||||
#issue 19630
|
||||
assert rsolve(y(n+3) - 3*y(n+1) + 2*y(n), y(n), {y(1):0, y(2):8, y(3):-2}) == (-2)**n + 2*n
|
||||
|
||||
|
||||
@slow
|
||||
def test_issue_15751():
|
||||
f = y(n) + 21*y(n + 1) - 273*y(n + 2) - 1092*y(n + 3) + 1820*y(n + 4) + 1092*y(n + 5) - 273*y(n + 6) - 21*y(n + 7) + y(n + 8)
|
||||
assert rsolve(f, y(n)) is not None
|
||||
|
||||
|
||||
def test_issue_17990():
|
||||
f = -10*y(n) + 4*y(n + 1) + 6*y(n + 2) + 46*y(n + 3)
|
||||
sol = rsolve(f, y(n))
|
||||
expected = C0*((86*18**(S(1)/3)/69 + (-12 + (-1 + sqrt(3)*I)*(290412 +
|
||||
3036*sqrt(9165))**(S(1)/3))*(1 - sqrt(3)*I)*(24201 + 253*sqrt(9165))**
|
||||
(S(1)/3)/276)/((1 - sqrt(3)*I)*(24201 + 253*sqrt(9165))**(S(1)/3))
|
||||
)**n + C1*((86*18**(S(1)/3)/69 + (-12 + (-1 - sqrt(3)*I)*(290412 + 3036
|
||||
*sqrt(9165))**(S(1)/3))*(1 + sqrt(3)*I)*(24201 + 253*sqrt(9165))**
|
||||
(S(1)/3)/276)/((1 + sqrt(3)*I)*(24201 + 253*sqrt(9165))**(S(1)/3))
|
||||
)**n + C2*(-43*18**(S(1)/3)/(69*(24201 + 253*sqrt(9165))**(S(1)/3)) -
|
||||
S(1)/23 + (290412 + 3036*sqrt(9165))**(S(1)/3)/138)**n
|
||||
assert sol == expected
|
||||
e = sol.subs({C0: 1, C1: 1, C2: 1, n: 1}).evalf()
|
||||
assert abs(e + 0.130434782608696) < 1e-13
|
||||
|
||||
|
||||
def test_issue_8697():
|
||||
a = Function('a')
|
||||
eq = a(n + 3) - a(n + 2) - a(n + 1) + a(n)
|
||||
assert rsolve(eq, a(n)) == (-1)**n*C1 + C0 + C2*n
|
||||
eq2 = a(n + 3) + 3*a(n + 2) + 3*a(n + 1) + a(n)
|
||||
assert (rsolve(eq2, a(n)) ==
|
||||
(-1)**n*C0 + (-1)**(n + 1)*C1*n + (-1)**(n + 1)*C2*n**2)
|
||||
|
||||
assert rsolve(a(n) - 2*a(n - 3) + 5*a(n - 2) - 4*a(n - 1),
|
||||
a(n), {a(0): 1, a(1): 3, a(2): 8}) == 3*2**n - n - 2
|
||||
|
||||
# From issue thread (but fixed by https://github.com/diofant/diofant/commit/da9789c6cd7d0c2ceeea19fbf59645987125b289):
|
||||
assert rsolve(a(n) - 2*a(n - 1) - n, a(n), {a(0): 1}) == 3*2**n - n - 2
|
||||
|
||||
|
||||
def test_diofantissue_294():
|
||||
f = y(n) - y(n - 1) - 2*y(n - 2) - 2*n
|
||||
assert rsolve(f, y(n)) == (-1)**n*C0 + 2**n*C1 - n - Rational(5, 2)
|
||||
# issue sympy/sympy#11261
|
||||
assert rsolve(f, y(n), {y(0): -1, y(1): 1}) == (-(-1)**n/2 + 2*2**n -
|
||||
n - Rational(5, 2))
|
||||
# issue sympy/sympy#7055
|
||||
assert rsolve(-2*y(n) + y(n + 1) + n - 1, y(n)) == 2**n*C0 + n
|
||||
|
||||
|
||||
def test_issue_15553():
|
||||
f = Function("f")
|
||||
assert rsolve(Eq(f(n), 2*f(n - 1) + n), f(n)) == 2**n*C0 - n - 2
|
||||
assert rsolve(Eq(f(n + 1), 2*f(n) + n**2 + 1), f(n)) == 2**n*C0 - n**2 - 2*n - 4
|
||||
assert rsolve(Eq(f(n + 1), 2*f(n) + n**2 + 1), f(n), {f(1): 0}) == 7*2**n/2 - n**2 - 2*n - 4
|
||||
assert rsolve(Eq(f(n), 2*f(n - 1) + 3*n**2), f(n)) == 2**n*C0 - 3*n**2 - 12*n - 18
|
||||
assert rsolve(Eq(f(n), 2*f(n - 1) + n**2), f(n)) == 2**n*C0 - n**2 - 4*n - 6
|
||||
assert rsolve(Eq(f(n), 2*f(n - 1) + n), f(n), {f(0): 1}) == 3*2**n - n - 2
|
||||
@@ -0,0 +1,254 @@
|
||||
from sympy.core.numbers import Rational
|
||||
from sympy.core.relational import Eq, Ne
|
||||
from sympy.core.symbol import symbols
|
||||
from sympy.core.sympify import sympify
|
||||
from sympy.core.singleton import S
|
||||
from sympy.core.random import random, choice
|
||||
from sympy.functions.elementary.miscellaneous import sqrt
|
||||
from sympy.ntheory.generate import randprime
|
||||
from sympy.matrices.dense import Matrix
|
||||
from sympy.solvers.solveset import linear_eq_to_matrix
|
||||
from sympy.solvers.simplex import (_lp as lp, _primal_dual,
|
||||
UnboundedLPError, InfeasibleLPError, lpmin, lpmax,
|
||||
_m, _abcd, _simplex, linprog)
|
||||
|
||||
from sympy.external.importtools import import_module
|
||||
|
||||
from sympy.testing.pytest import raises
|
||||
|
||||
from sympy.abc import x, y, z
|
||||
|
||||
|
||||
np = import_module("numpy")
|
||||
scipy = import_module("scipy")
|
||||
|
||||
|
||||
def test_lp():
|
||||
r1 = y + 2*z <= 3
|
||||
r2 = -x - 3*z <= -2
|
||||
r3 = 2*x + y + 7*z <= 5
|
||||
constraints = [r1, r2, r3, x >= 0, y >= 0, z >= 0]
|
||||
objective = -x - y - 5 * z
|
||||
ans = optimum, argmax = lp(max, objective, constraints)
|
||||
assert ans == lpmax(objective, constraints)
|
||||
assert objective.subs(argmax) == optimum
|
||||
for constr in constraints:
|
||||
assert constr.subs(argmax) == True
|
||||
|
||||
r1 = x - y + 2*z <= 3
|
||||
r2 = -x + 2*y - 3*z <= -2
|
||||
r3 = 2*x + y - 7*z <= -5
|
||||
constraints = [r1, r2, r3, x >= 0, y >= 0, z >= 0]
|
||||
objective = -x - y - 5*z
|
||||
ans = optimum, argmax = lp(max, objective, constraints)
|
||||
assert ans == lpmax(objective, constraints)
|
||||
assert objective.subs(argmax) == optimum
|
||||
for constr in constraints:
|
||||
assert constr.subs(argmax) == True
|
||||
|
||||
r1 = x - y + 2*z <= -4
|
||||
r2 = -x + 2*y - 3*z <= 8
|
||||
r3 = 2*x + y - 7*z <= 10
|
||||
constraints = [r1, r2, r3, x >= 0, y >= 0, z >= 0]
|
||||
const = 2
|
||||
objective = -x-y-5*z+const # has constant term
|
||||
ans = optimum, argmax = lp(max, objective, constraints)
|
||||
assert ans == lpmax(objective, constraints)
|
||||
assert objective.subs(argmax) == optimum
|
||||
for constr in constraints:
|
||||
assert constr.subs(argmax) == True
|
||||
|
||||
# Section 4 Problem 1 from
|
||||
# http://web.tecnico.ulisboa.pt/mcasquilho/acad/or/ftp/FergusonUCLA_LP.pdf
|
||||
# answer on page 55
|
||||
v = x1, x2, x3, x4 = symbols('x1 x2 x3 x4')
|
||||
r1 = x1 - x2 - 2*x3 - x4 <= 4
|
||||
r2 = 2*x1 + x3 -4*x4 <= 2
|
||||
r3 = -2*x1 + x2 + x4 <= 1
|
||||
objective, constraints = x1 - 2*x2 - 3*x3 - x4, [r1, r2, r3] + [
|
||||
i >= 0 for i in v]
|
||||
ans = optimum, argmax = lp(max, objective, constraints)
|
||||
assert ans == lpmax(objective, constraints)
|
||||
assert ans == (4, {x1: 7, x2: 0, x3: 0, x4: 3})
|
||||
|
||||
# input contains Floats
|
||||
r1 = x - y + 2.0*z <= -4
|
||||
r2 = -x + 2*y - 3.0*z <= 8
|
||||
r3 = 2*x + y - 7*z <= 10
|
||||
constraints = [r1, r2, r3] + [i >= 0 for i in (x, y, z)]
|
||||
objective = -x-y-5*z
|
||||
optimum, argmax = lp(max, objective, constraints)
|
||||
assert objective.subs(argmax) == optimum
|
||||
for constr in constraints:
|
||||
assert constr.subs(argmax) == True
|
||||
|
||||
# input contains non-float or non-Rational
|
||||
r1 = x - y + sqrt(2) * z <= -4
|
||||
r2 = -x + 2*y - 3*z <= 8
|
||||
r3 = 2*x + y - 7*z <= 10
|
||||
raises(TypeError, lambda: lp(max, -x-y-5*z, [r1, r2, r3]))
|
||||
|
||||
r1 = x >= 0
|
||||
raises(UnboundedLPError, lambda: lp(max, x, [r1]))
|
||||
r2 = x <= -1
|
||||
raises(InfeasibleLPError, lambda: lp(max, x, [r1, r2]))
|
||||
|
||||
# strict inequalities are not allowed
|
||||
r1 = x > 0
|
||||
raises(TypeError, lambda: lp(max, x, [r1]))
|
||||
|
||||
# not equals not allowed
|
||||
r1 = Ne(x, 0)
|
||||
raises(TypeError, lambda: lp(max, x, [r1]))
|
||||
|
||||
def make_random_problem(nvar=2, num_constraints=2, sparsity=.1):
|
||||
def rand():
|
||||
if random() < sparsity:
|
||||
return sympify(0)
|
||||
int1, int2 = [randprime(0, 200) for _ in range(2)]
|
||||
return Rational(int1, int2)*choice([-1, 1])
|
||||
variables = symbols('x1:%s' % (nvar + 1))
|
||||
constraints = [(sum(rand()*x for x in variables) <= rand())
|
||||
for _ in range(num_constraints)]
|
||||
objective = sum(rand() * x for x in variables)
|
||||
return objective, constraints, variables
|
||||
|
||||
# equality
|
||||
r1 = Eq(x, y)
|
||||
r2 = Eq(y, z)
|
||||
r3 = z <= 3
|
||||
constraints = [r1, r2, r3]
|
||||
objective = x
|
||||
ans = optimum, argmax = lp(max, objective, constraints)
|
||||
assert ans == lpmax(objective, constraints)
|
||||
assert objective.subs(argmax) == optimum
|
||||
for constr in constraints:
|
||||
assert constr.subs(argmax) == True
|
||||
|
||||
|
||||
def test_simplex():
|
||||
L = [
|
||||
[[1, 1], [-1, 1], [0, 1], [-1, 0]],
|
||||
[5, 1, 2, -1],
|
||||
[[1, 1]],
|
||||
[-1]]
|
||||
A, B, C, D = _abcd(_m(*L), list=False)
|
||||
assert _simplex(A, B, -C, -D) == (-6, [3, 2], [1, 0, 0, 0])
|
||||
assert _simplex(A, B, -C, -D, dual=True) == (-6,
|
||||
[1, 0, 0, 0], [5, 0])
|
||||
|
||||
assert _simplex([[]],[],[[1]],[0]) == (0, [0], [])
|
||||
|
||||
# handling of Eq (or Eq-like x<=y, x>=y conditions)
|
||||
assert lpmax(x - y, [x <= y + 2, x >= y + 2, x >= 0, y >= 0]
|
||||
) == (2, {x: 2, y: 0})
|
||||
assert lpmax(x - y, [x <= y + 2, Eq(x, y + 2), x >= 0, y >= 0]
|
||||
) == (2, {x: 2, y: 0})
|
||||
assert lpmax(x - y, [x <= y + 2, Eq(x, 2)]) == (2, {x: 2, y: 0})
|
||||
assert lpmax(y, [Eq(y, 2)]) == (2, {y: 2})
|
||||
|
||||
# the conditions are equivalent to Eq(x, y + 2)
|
||||
assert lpmin(y, [x <= y + 2, x >= y + 2, y >= 0]
|
||||
) == (0, {x: 2, y: 0})
|
||||
# equivalent to Eq(y, -2)
|
||||
assert lpmax(y, [0 <= y + 2, 0 >= y + 2]) == (-2, {y: -2})
|
||||
assert lpmax(y, [0 <= y + 2, 0 >= y + 2, y <= 0]
|
||||
) == (-2, {y: -2})
|
||||
|
||||
# extra symbols symbols
|
||||
assert lpmin(x, [y >= 1, x >= y]) == (1, {x: 1, y: 1})
|
||||
assert lpmin(x, [y >= 1, x >= y + z, x >= 0, z >= 0]
|
||||
) == (1, {x: 1, y: 1, z: 0})
|
||||
|
||||
# detect oscillation
|
||||
# o1
|
||||
v = x1, x2, x3, x4 = symbols('x1 x2 x3 x4')
|
||||
raises(InfeasibleLPError, lambda: lpmin(
|
||||
9*x2 - 8*x3 + 3*x4 + 6,
|
||||
[5*x2 - 2*x3 <= 0,
|
||||
-x1 - 8*x2 + 9*x3 <= -3,
|
||||
10*x1 - x2+ 9*x4 <= -4] + [i >= 0 for i in v]))
|
||||
# o2 - equations fed to lpmin are changed into a matrix
|
||||
# system that doesn't oscillate and has the same solution
|
||||
# as below
|
||||
M = linear_eq_to_matrix
|
||||
f = 5*x2 + x3 + 4*x4 - x1
|
||||
L = 5*x2 + 2*x3 + 5*x4 - (x1 + 5)
|
||||
cond = [L <= 0] + [Eq(3*x2 + x4, 2), Eq(-x1 + x3 + 2*x4, 1)]
|
||||
c, d = M(f, v)
|
||||
a, b = M(L, v)
|
||||
aeq, beq = M(cond[1:], v)
|
||||
ans = (S(9)/2, [0, S(1)/2, 0, S(1)/2])
|
||||
assert linprog(c, a, b, aeq, beq, bounds=(0, 1)) == ans
|
||||
lpans = lpmin(f, cond + [x1 >= 0, x1 <= 1,
|
||||
x2 >= 0, x2 <= 1, x3 >= 0, x3 <= 1, x4 >= 0, x4 <= 1])
|
||||
assert (lpans[0], list(lpans[1].values())) == ans
|
||||
|
||||
|
||||
def test_lpmin_lpmax():
|
||||
v = x1, x2, y1, y2 = symbols('x1 x2 y1 y2')
|
||||
L = [[1, -1]], [1], [[1, 1]], [2]
|
||||
a, b, c, d = [Matrix(i) for i in L]
|
||||
m = Matrix([[a, b], [c, d]])
|
||||
f, constr = _primal_dual(m)[0]
|
||||
ans = lpmin(f, constr + [i >= 0 for i in v[:2]])
|
||||
assert ans == (-1, {x1: 1, x2: 0}),ans
|
||||
|
||||
L = [[1, -1], [1, 1]], [1, 1], [[1, 1]], [2]
|
||||
a, b, c, d = [Matrix(i) for i in L]
|
||||
m = Matrix([[a, b], [c, d]])
|
||||
f, constr = _primal_dual(m)[1]
|
||||
ans = lpmax(f, constr + [i >= 0 for i in v[-2:]])
|
||||
assert ans == (-1, {y1: 1, y2: 0})
|
||||
|
||||
|
||||
def test_linprog():
|
||||
for do in range(2):
|
||||
if not do:
|
||||
M = lambda a, b: linear_eq_to_matrix(a, b)
|
||||
else:
|
||||
# check matrices as list
|
||||
M = lambda a, b: tuple([
|
||||
i.tolist() for i in linear_eq_to_matrix(a, b)])
|
||||
|
||||
v = x, y, z = symbols('x1:4')
|
||||
f = x + y - 2*z
|
||||
c = M(f, v)[0]
|
||||
ineq = [7*x + 4*y - 7*z <= 3,
|
||||
3*x - y + 10*z <= 6,
|
||||
x >= 0, y >= 0, z >= 0]
|
||||
ab = M([i.lts - i.gts for i in ineq], v)
|
||||
ans = (-S(6)/5, [0, 0, S(3)/5])
|
||||
assert lpmin(f, ineq) == (ans[0], dict(zip(v, ans[1])))
|
||||
assert linprog(c, *ab) == ans
|
||||
|
||||
f += 1
|
||||
c = M(f, v)[0]
|
||||
eq = [Eq(y - 9*x, 1)]
|
||||
abeq = M([i.lhs - i.rhs for i in eq], v)
|
||||
ans = (1 - S(2)/5, [0, 1, S(7)/10])
|
||||
assert lpmin(f, ineq + eq) == (ans[0], dict(zip(v, ans[1])))
|
||||
assert linprog(c, *ab, *abeq) == (ans[0] - 1, ans[1])
|
||||
|
||||
eq = [z - y <= S.Half]
|
||||
abeq = M([i.lhs - i.rhs for i in eq], v)
|
||||
ans = (1 - S(10)/9, [0, S(1)/9, S(11)/18])
|
||||
assert lpmin(f, ineq + eq) == (ans[0], dict(zip(v, ans[1])))
|
||||
assert linprog(c, *ab, *abeq) == (ans[0] - 1, ans[1])
|
||||
|
||||
bounds = [(0, None), (0, None), (None, S.Half)]
|
||||
ans = (0, [0, 0, S.Half])
|
||||
assert lpmin(f, ineq + [z <= S.Half]) == (
|
||||
ans[0], dict(zip(v, ans[1])))
|
||||
assert linprog(c, *ab, bounds=bounds) == (ans[0] - 1, ans[1])
|
||||
assert linprog(c, *ab, bounds={v.index(z): bounds[-1]}
|
||||
) == (ans[0] - 1, ans[1])
|
||||
eq = [z - y <= S.Half]
|
||||
|
||||
assert linprog([[1]], [], [], bounds=(2, 3)) == (2, [2])
|
||||
assert linprog([1], [], [], bounds=(2, 3)) == (2, [2])
|
||||
assert linprog([1], bounds=(2, 3)) == (2, [2])
|
||||
assert linprog([1, -1], [[1, 1]], [2], bounds={1:(None, None)}
|
||||
) == (-2, [0, 2])
|
||||
assert linprog([1, -1], [[1, 1]], [5], bounds={1:(3, None)}
|
||||
) == (-5, [0, 5])
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user