switching to high quality piper tts and added label translations

This commit is contained in:
Matthias Hinrichs
2026-01-29 23:48:19 +01:00
commit d80c619df9
3934 changed files with 1451600 additions and 0 deletions
@@ -0,0 +1,75 @@
"""A module for solving all kinds of equations.
Examples
========
>>> from sympy.solvers import solve
>>> from sympy.abc import x
>>> solve(((x + 1)**5).expand(), x)
[-1]
"""
from sympy.core.assumptions import check_assumptions, failing_assumptions
from .solvers import solve, solve_linear_system, solve_linear_system_LU, \
solve_undetermined_coeffs, nsolve, solve_linear, checksol, \
det_quick, inv_quick
from sympy.solvers.diophantine.diophantine import diophantine
from .recurr import rsolve, rsolve_poly, rsolve_ratio, rsolve_hyper
from .ode import checkodesol, classify_ode, dsolve, \
homogeneous_order
from .polysys import solve_poly_system, solve_triangulated, factor_system
from .pde import pde_separate, pde_separate_add, pde_separate_mul, \
pdsolve, classify_pde, checkpdesol
from .deutils import ode_order
from .inequalities import reduce_inequalities, reduce_abs_inequality, \
reduce_abs_inequalities, solve_poly_inequality, solve_rational_inequalities, solve_univariate_inequality
from .decompogen import decompogen
from .solveset import solveset, linsolve, linear_eq_to_matrix, nonlinsolve, substitution
from .simplex import lpmin, lpmax, linprog
# This is here instead of sympy/sets/__init__.py to avoid circular import issues
from ..core.singleton import S
Complexes = S.Complexes
__all__ = [
'solve', 'solve_linear_system', 'solve_linear_system_LU',
'solve_undetermined_coeffs', 'nsolve', 'solve_linear', 'checksol',
'det_quick', 'inv_quick', 'check_assumptions', 'failing_assumptions',
'diophantine',
'rsolve', 'rsolve_poly', 'rsolve_ratio', 'rsolve_hyper',
'checkodesol', 'classify_ode', 'dsolve', 'homogeneous_order',
'solve_poly_system', 'solve_triangulated', 'factor_system',
'pde_separate', 'pde_separate_add', 'pde_separate_mul', 'pdsolve',
'classify_pde', 'checkpdesol',
'ode_order',
'reduce_inequalities', 'reduce_abs_inequality', 'reduce_abs_inequalities',
'solve_poly_inequality', 'solve_rational_inequalities',
'solve_univariate_inequality',
'decompogen',
'solveset', 'linsolve', 'linear_eq_to_matrix', 'nonlinsolve',
'substitution',
# This is here instead of sympy/sets/__init__.py to avoid circular import issues
'Complexes',
'lpmin', 'lpmax', 'linprog'
]
@@ -0,0 +1,12 @@
from sympy.core.symbol import Symbol
from sympy.matrices.dense import (eye, zeros)
from sympy.solvers.solvers import solve_linear_system
N = 8
M = zeros(N, N + 1)
M[:, :N] = eye(N)
S = [Symbol('A%i' % i) for i in range(N)]
def timeit_linsolve_trivial():
solve_linear_system(M, *S)
@@ -0,0 +1,509 @@
from sympy.core.add import Add
from sympy.core.exprtools import factor_terms
from sympy.core.function import expand_log, _mexpand
from sympy.core.power import Pow
from sympy.core.singleton import S
from sympy.core.sorting import ordered
from sympy.core.symbol import Dummy
from sympy.functions.elementary.exponential import (LambertW, exp, log)
from sympy.functions.elementary.miscellaneous import root
from sympy.polys.polyroots import roots
from sympy.polys.polytools import Poly, factor
from sympy.simplify.simplify import separatevars
from sympy.simplify.radsimp import collect
from sympy.simplify.simplify import powsimp
from sympy.solvers.solvers import solve, _invert
from sympy.utilities.iterables import uniq
def _filtered_gens(poly, symbol):
"""process the generators of ``poly``, returning the set of generators that
have ``symbol``. If there are two generators that are inverses of each other,
prefer the one that has no denominator.
Examples
========
>>> from sympy.solvers.bivariate import _filtered_gens
>>> from sympy import Poly, exp
>>> from sympy.abc import x
>>> _filtered_gens(Poly(x + 1/x + exp(x)), x)
{x, exp(x)}
"""
# TODO it would be good to pick the smallest divisible power
# instead of the base for something like x**4 + x**2 -->
# return x**2 not x
gens = {g for g in poly.gens if symbol in g.free_symbols}
for g in list(gens):
ag = 1/g
if g in gens and ag in gens:
if ag.as_numer_denom()[1] is not S.One:
g = ag
gens.remove(g)
return gens
def _mostfunc(lhs, func, X=None):
"""Returns the term in lhs which contains the most of the
func-type things e.g. log(log(x)) wins over log(x) if both terms appear.
``func`` can be a function (exp, log, etc...) or any other SymPy object,
like Pow.
If ``X`` is not ``None``, then the function returns the term composed with the
most ``func`` having the specified variable.
Examples
========
>>> from sympy.solvers.bivariate import _mostfunc
>>> from sympy import exp
>>> from sympy.abc import x, y
>>> _mostfunc(exp(x) + exp(exp(x) + 2), exp)
exp(exp(x) + 2)
>>> _mostfunc(exp(x) + exp(exp(y) + 2), exp)
exp(exp(y) + 2)
>>> _mostfunc(exp(x) + exp(exp(y) + 2), exp, x)
exp(x)
>>> _mostfunc(x, exp, x) is None
True
>>> _mostfunc(exp(x) + exp(x*y), exp, x)
exp(x)
"""
fterms = [tmp for tmp in lhs.atoms(func) if (not X or
X.is_Symbol and X in tmp.free_symbols or
not X.is_Symbol and tmp.has(X))]
if len(fterms) == 1:
return fterms[0]
elif fterms:
return max(list(ordered(fterms)), key=lambda x: x.count(func))
return None
def _linab(arg, symbol):
"""Return ``a, b, X`` assuming ``arg`` can be written as ``a*X + b``
where ``X`` is a symbol-dependent factor and ``a`` and ``b`` are
independent of ``symbol``.
Examples
========
>>> from sympy.solvers.bivariate import _linab
>>> from sympy.abc import x, y
>>> from sympy import exp, S
>>> _linab(S(2), x)
(2, 0, 1)
>>> _linab(2*x, x)
(2, 0, x)
>>> _linab(y + y*x + 2*x, x)
(y + 2, y, x)
>>> _linab(3 + 2*exp(x), x)
(2, 3, exp(x))
"""
arg = factor_terms(arg.expand())
ind, dep = arg.as_independent(symbol)
if arg.is_Mul and dep.is_Add:
a, b, x = _linab(dep, symbol)
return ind*a, ind*b, x
if not arg.is_Add:
b = 0
a, x = ind, dep
else:
b = ind
a, x = separatevars(dep).as_independent(symbol, as_Add=False)
if x.could_extract_minus_sign():
a = -a
x = -x
return a, b, x
def _lambert(eq, x):
"""
Given an expression assumed to be in the form
``F(X, a..f) = a*log(b*X + c) + d*X + f = 0``
where X = g(x) and x = g^-1(X), return the Lambert solution,
``x = g^-1(-c/b + (a/d)*W(d/(a*b)*exp(c*d/a/b)*exp(-f/a)))``.
"""
eq = _mexpand(expand_log(eq))
mainlog = _mostfunc(eq, log, x)
if not mainlog:
return [] # violated assumptions
other = eq.subs(mainlog, 0)
if isinstance(-other, log):
eq = (eq - other).subs(mainlog, mainlog.args[0])
mainlog = mainlog.args[0]
if not isinstance(mainlog, log):
return [] # violated assumptions
other = -(-other).args[0]
eq += other
if x not in other.free_symbols:
return [] # violated assumptions
d, f, X2 = _linab(other, x)
logterm = collect(eq - other, mainlog)
a = logterm.as_coefficient(mainlog)
if a is None or x in a.free_symbols:
return [] # violated assumptions
logarg = mainlog.args[0]
b, c, X1 = _linab(logarg, x)
if X1 != X2:
return [] # violated assumptions
# invert the generator X1 so we have x(u)
u = Dummy('rhs')
xusolns = solve(X1 - u, x)
# There are infinitely many branches for LambertW
# but only branches for k = -1 and 0 might be real. The k = 0
# branch is real and the k = -1 branch is real if the LambertW argument
# in in range [-1/e, 0]. Since `solve` does not return infinite
# solutions we will only include the -1 branch if it tests as real.
# Otherwise, inclusion of any LambertW in the solution indicates to
# the user that there are imaginary solutions corresponding to
# different k values.
lambert_real_branches = [-1, 0]
sol = []
# solution of the given Lambert equation is like
# sol = -c/b + (a/d)*LambertW(arg, k),
# where arg = d/(a*b)*exp((c*d-b*f)/a/b) and k in lambert_real_branches.
# Instead of considering the single arg, `d/(a*b)*exp((c*d-b*f)/a/b)`,
# the individual `p` roots obtained when writing `exp((c*d-b*f)/a/b)`
# as `exp(A/p) = exp(A)**(1/p)`, where `p` is an Integer, are used.
# calculating args for LambertW
num, den = ((c*d-b*f)/a/b).as_numer_denom()
p, den = den.as_coeff_Mul()
e = exp(num/den)
t = Dummy('t')
args = [d/(a*b)*t for t in roots(t**p - e, t).keys()]
# calculating solutions from args
for arg in args:
for k in lambert_real_branches:
w = LambertW(arg, k)
if k and not w.is_real:
continue
rhs = -c/b + (a/d)*w
sol.extend(xu.subs(u, rhs) for xu in xusolns)
return sol
def _solve_lambert(f, symbol, gens):
"""Return solution to ``f`` if it is a Lambert-type expression
else raise NotImplementedError.
For ``f(X, a..f) = a*log(b*X + c) + d*X - f = 0`` the solution
for ``X`` is ``X = -c/b + (a/d)*W(d/(a*b)*exp(c*d/a/b)*exp(f/a))``.
There are a variety of forms for `f(X, a..f)` as enumerated below:
1a1)
if B**B = R for R not in [0, 1] (since those cases would already
be solved before getting here) then log of both sides gives
log(B) + log(log(B)) = log(log(R)) and
X = log(B), a = 1, b = 1, c = 0, d = 1, f = log(log(R))
1a2)
if B*(b*log(B) + c)**a = R then log of both sides gives
log(B) + a*log(b*log(B) + c) = log(R) and
X = log(B), d=1, f=log(R)
1b)
if a*log(b*B + c) + d*B = R and
X = B, f = R
2a)
if (b*B + c)*exp(d*B + g) = R then log of both sides gives
log(b*B + c) + d*B + g = log(R) and
X = B, a = 1, f = log(R) - g
2b)
if g*exp(d*B + h) - b*B = c then the log form is
log(g) + d*B + h - log(b*B + c) = 0 and
X = B, a = -1, f = -h - log(g)
3)
if d*p**(a*B + g) - b*B = c then the log form is
log(d) + (a*B + g)*log(p) - log(b*B + c) = 0 and
X = B, a = -1, d = a*log(p), f = -log(d) - g*log(p)
"""
def _solve_even_degree_expr(expr, t, symbol):
"""Return the unique solutions of equations derived from
``expr`` by replacing ``t`` with ``+/- symbol``.
Parameters
==========
expr : Expr
The expression which includes a dummy variable t to be
replaced with +symbol and -symbol.
symbol : Symbol
The symbol for which a solution is being sought.
Returns
=======
List of unique solution of the two equations generated by
replacing ``t`` with positive and negative ``symbol``.
Notes
=====
If ``expr = 2*log(t) + x/2` then solutions for
``2*log(x) + x/2 = 0`` and ``2*log(-x) + x/2 = 0`` are
returned by this function. Though this may seem
counter-intuitive, one must note that the ``expr`` being
solved here has been derived from a different expression. For
an expression like ``eq = x**2*g(x) = 1``, if we take the
log of both sides we obtain ``log(x**2) + log(g(x)) = 0``. If
x is positive then this simplifies to
``2*log(x) + log(g(x)) = 0``; the Lambert-solving routines will
return solutions for this, but we must also consider the
solutions for ``2*log(-x) + log(g(x))`` since those must also
be a solution of ``eq`` which has the same value when the ``x``
in ``x**2`` is negated. If `g(x)` does not have even powers of
symbol then we do not want to replace the ``x`` there with
``-x``. So the role of the ``t`` in the expression received by
this function is to mark where ``+/-x`` should be inserted
before obtaining the Lambert solutions.
"""
nlhs, plhs = [
expr.xreplace({t: sgn*symbol}) for sgn in (-1, 1)]
sols = _solve_lambert(nlhs, symbol, gens)
if plhs != nlhs:
sols.extend(_solve_lambert(plhs, symbol, gens))
# uniq is needed for a case like
# 2*log(t) - log(-z**2) + log(z + log(x) + log(z))
# where substituting t with +/-x gives all the same solution;
# uniq, rather than list(set()), is used to maintain canonical
# order
return list(uniq(sols))
nrhs, lhs = f.as_independent(symbol, as_Add=True)
rhs = -nrhs
lamcheck = [tmp for tmp in gens
if (tmp.func in [exp, log] or
(tmp.is_Pow and symbol in tmp.exp.free_symbols))]
if not lamcheck:
raise NotImplementedError()
if lhs.is_Add or lhs.is_Mul:
# replacing all even_degrees of symbol with dummy variable t
# since these will need special handling; non-Add/Mul do not
# need this handling
t = Dummy('t', **symbol.assumptions0)
lhs = lhs.replace(
lambda i: # find symbol**even
i.is_Pow and i.base == symbol and i.exp.is_even,
lambda i: # replace t**even
t**i.exp)
if lhs.is_Add and lhs.has(t):
t_indep = lhs.subs(t, 0)
t_term = lhs - t_indep
_rhs = rhs - t_indep
if not t_term.is_Add and _rhs and not (
t_term.has(S.ComplexInfinity, S.NaN)):
eq = expand_log(log(t_term) - log(_rhs))
return _solve_even_degree_expr(eq, t, symbol)
elif lhs.is_Mul and rhs:
# this needs to happen whether t is present or not
lhs = expand_log(log(lhs), force=True)
rhs = log(rhs)
if lhs.has(t) and lhs.is_Add:
# it expanded from Mul to Add
eq = lhs - rhs
return _solve_even_degree_expr(eq, t, symbol)
# restore symbol in lhs
lhs = lhs.xreplace({t: symbol})
lhs = powsimp(factor(lhs, deep=True))
# make sure we have inverted as completely as possible
r = Dummy()
i, lhs = _invert(lhs - r, symbol)
rhs = i.xreplace({r: rhs})
# For the first forms:
#
# 1a1) B**B = R will arrive here as B*log(B) = log(R)
# lhs is Mul so take log of both sides:
# log(B) + log(log(B)) = log(log(R))
# 1a2) B*(b*log(B) + c)**a = R will arrive unchanged so
# lhs is Mul, so take log of both sides:
# log(B) + a*log(b*log(B) + c) = log(R)
# 1b) d*log(a*B + b) + c*B = R will arrive unchanged so
# lhs is Add, so isolate c*B and expand log of both sides:
# log(c) + log(B) = log(R - d*log(a*B + b))
soln = []
if not soln:
mainlog = _mostfunc(lhs, log, symbol)
if mainlog:
if lhs.is_Mul and rhs != 0:
soln = _lambert(log(lhs) - log(rhs), symbol)
elif lhs.is_Add:
other = lhs.subs(mainlog, 0)
if other and not other.is_Add and [
tmp for tmp in other.atoms(Pow)
if symbol in tmp.free_symbols]:
if not rhs:
diff = log(other) - log(other - lhs)
else:
diff = log(lhs - other) - log(rhs - other)
soln = _lambert(expand_log(diff), symbol)
else:
#it's ready to go
soln = _lambert(lhs - rhs, symbol)
# For the next forms,
#
# collect on main exp
# 2a) (b*B + c)*exp(d*B + g) = R
# lhs is mul, so take log of both sides:
# log(b*B + c) + d*B = log(R) - g
# 2b) g*exp(d*B + h) - b*B = R
# lhs is add, so add b*B to both sides,
# take the log of both sides and rearrange to give
# log(R + b*B) - d*B = log(g) + h
if not soln:
mainexp = _mostfunc(lhs, exp, symbol)
if mainexp:
lhs = collect(lhs, mainexp)
if lhs.is_Mul and rhs != 0:
soln = _lambert(expand_log(log(lhs) - log(rhs)), symbol)
elif lhs.is_Add:
# move all but mainexp-containing term to rhs
other = lhs.subs(mainexp, 0)
mainterm = lhs - other
rhs = rhs - other
if (mainterm.could_extract_minus_sign() and
rhs.could_extract_minus_sign()):
mainterm *= -1
rhs *= -1
diff = log(mainterm) - log(rhs)
soln = _lambert(expand_log(diff), symbol)
# For the last form:
#
# 3) d*p**(a*B + g) - b*B = c
# collect on main pow, add b*B to both sides,
# take log of both sides and rearrange to give
# a*B*log(p) - log(b*B + c) = -log(d) - g*log(p)
if not soln:
mainpow = _mostfunc(lhs, Pow, symbol)
if mainpow and symbol in mainpow.exp.free_symbols:
lhs = collect(lhs, mainpow)
if lhs.is_Mul and rhs != 0:
# b*B = 0
soln = _lambert(expand_log(log(lhs) - log(rhs)), symbol)
elif lhs.is_Add:
# move all but mainpow-containing term to rhs
other = lhs.subs(mainpow, 0)
mainterm = lhs - other
rhs = rhs - other
diff = log(mainterm) - log(rhs)
soln = _lambert(expand_log(diff), symbol)
if not soln:
raise NotImplementedError('%s does not appear to have a solution in '
'terms of LambertW' % f)
return list(ordered(soln))
def bivariate_type(f, x, y, *, first=True):
"""Given an expression, f, 3 tests will be done to see what type
of composite bivariate it might be, options for u(x, y) are::
x*y
x+y
x*y+x
x*y+y
If it matches one of these types, ``u(x, y)``, ``P(u)`` and dummy
variable ``u`` will be returned. Solving ``P(u)`` for ``u`` and
equating the solutions to ``u(x, y)`` and then solving for ``x`` or
``y`` is equivalent to solving the original expression for ``x`` or
``y``. If ``x`` and ``y`` represent two functions in the same
variable, e.g. ``x = g(t)`` and ``y = h(t)``, then if ``u(x, y) - p``
can be solved for ``t`` then these represent the solutions to
``P(u) = 0`` when ``p`` are the solutions of ``P(u) = 0``.
Only positive values of ``u`` are considered.
Examples
========
>>> from sympy import solve
>>> from sympy.solvers.bivariate import bivariate_type
>>> from sympy.abc import x, y
>>> eq = (x**2 - 3).subs(x, x + y)
>>> bivariate_type(eq, x, y)
(x + y, _u**2 - 3, _u)
>>> uxy, pu, u = _
>>> usol = solve(pu, u); usol
[sqrt(3)]
>>> [solve(uxy - s) for s in solve(pu, u)]
[[{x: -y + sqrt(3)}]]
>>> all(eq.subs(s).equals(0) for sol in _ for s in sol)
True
"""
u = Dummy('u', positive=True)
if first:
p = Poly(f, x, y)
f = p.as_expr()
_x = Dummy()
_y = Dummy()
rv = bivariate_type(Poly(f.subs({x: _x, y: _y}), _x, _y), _x, _y, first=False)
if rv:
reps = {_x: x, _y: y}
return rv[0].xreplace(reps), rv[1].xreplace(reps), rv[2]
return
p = f
f = p.as_expr()
# f(x*y)
args = Add.make_args(p.as_expr())
new = []
for a in args:
a = _mexpand(a.subs(x, u/y))
free = a.free_symbols
if x in free or y in free:
break
new.append(a)
else:
return x*y, Add(*new), u
def ok(f, v, c):
new = _mexpand(f.subs(v, c))
free = new.free_symbols
return None if (x in free or y in free) else new
# f(a*x + b*y)
new = []
d = p.degree(x)
if p.degree(y) == d:
a = root(p.coeff_monomial(x**d), d)
b = root(p.coeff_monomial(y**d), d)
new = ok(f, x, (u - b*y)/a)
if new is not None:
return a*x + b*y, new, u
# f(a*x*y + b*y)
new = []
d = p.degree(x)
if p.degree(y) == d:
for itry in range(2):
a = root(p.coeff_monomial(x**d*y**d), d)
b = root(p.coeff_monomial(y**d), d)
new = ok(f, x, (u - b*y)/a/y)
if new is not None:
return a*x*y + b*y, new, u
x, y = y, x
@@ -0,0 +1,126 @@
from sympy.core import (Function, Pow, sympify, Expr)
from sympy.core.relational import Relational
from sympy.core.singleton import S
from sympy.polys import Poly, decompose
from sympy.utilities.misc import func_name
from sympy.functions.elementary.miscellaneous import Min, Max
def decompogen(f, symbol):
"""
Computes General functional decomposition of ``f``.
Given an expression ``f``, returns a list ``[f_1, f_2, ..., f_n]``,
where::
f = f_1 o f_2 o ... f_n = f_1(f_2(... f_n))
Note: This is a General decomposition function. It also decomposes
Polynomials. For only Polynomial decomposition see ``decompose`` in polys.
Examples
========
>>> from sympy.abc import x
>>> from sympy import decompogen, sqrt, sin, cos
>>> decompogen(sin(cos(x)), x)
[sin(x), cos(x)]
>>> decompogen(sin(x)**2 + sin(x) + 1, x)
[x**2 + x + 1, sin(x)]
>>> decompogen(sqrt(6*x**2 - 5), x)
[sqrt(x), 6*x**2 - 5]
>>> decompogen(sin(sqrt(cos(x**2 + 1))), x)
[sin(x), sqrt(x), cos(x), x**2 + 1]
>>> decompogen(x**4 + 2*x**3 - x - 1, x)
[x**2 - x - 1, x**2 + x]
"""
f = sympify(f)
if not isinstance(f, Expr) or isinstance(f, Relational):
raise TypeError('expecting Expr but got: `%s`' % func_name(f))
if symbol not in f.free_symbols:
return [f]
# ===== Simple Functions ===== #
if isinstance(f, (Function, Pow)):
if f.is_Pow and f.base == S.Exp1:
arg = f.exp
else:
arg = f.args[0]
if arg == symbol:
return [f]
return [f.subs(arg, symbol)] + decompogen(arg, symbol)
# ===== Min/Max Functions ===== #
if isinstance(f, (Min, Max)):
args = list(f.args)
d0 = None
for i, a in enumerate(args):
if not a.has_free(symbol):
continue
d = decompogen(a, symbol)
if len(d) == 1:
d = [symbol] + d
if d0 is None:
d0 = d[1:]
elif d[1:] != d0:
# decomposition is not the same for each arg:
# mark as having no decomposition
d = [symbol]
break
args[i] = d[0]
if d[0] == symbol:
return [f]
return [f.func(*args)] + d0
# ===== Convert to Polynomial ===== #
fp = Poly(f)
gens = list(filter(lambda x: symbol in x.free_symbols, fp.gens))
if len(gens) == 1 and gens[0] != symbol:
f1 = f.subs(gens[0], symbol)
f2 = gens[0]
return [f1] + decompogen(f2, symbol)
# ===== Polynomial decompose() ====== #
try:
return decompose(f)
except ValueError:
return [f]
def compogen(g_s, symbol):
"""
Returns the composition of functions.
Given a list of functions ``g_s``, returns their composition ``f``,
where:
f = g_1 o g_2 o .. o g_n
Note: This is a General composition function. It also composes Polynomials.
For only Polynomial composition see ``compose`` in polys.
Examples
========
>>> from sympy.solvers.decompogen import compogen
>>> from sympy.abc import x
>>> from sympy import sqrt, sin, cos
>>> compogen([sin(x), cos(x)], x)
sin(cos(x))
>>> compogen([x**2 + x + 1, sin(x)], x)
sin(x)**2 + sin(x) + 1
>>> compogen([sqrt(x), 6*x**2 - 5], x)
sqrt(6*x**2 - 5)
>>> compogen([sin(x), sqrt(x), cos(x), x**2 + 1], x)
sin(sqrt(cos(x**2 + 1)))
>>> compogen([x**2 - x - 1, x**2 + x], x)
-x**2 - x + (x**2 + x)**2 - 1
"""
if len(g_s) == 1:
return g_s[0]
foo = g_s[0].subs(symbol, g_s[1])
if len(g_s) == 2:
return foo
return compogen([foo] + g_s[2:], symbol)
@@ -0,0 +1,273 @@
"""Utility functions for classifying and solving
ordinary and partial differential equations.
Contains
========
_preprocess
ode_order
_desolve
"""
from sympy.core import Pow
from sympy.core.function import Derivative, AppliedUndef
from sympy.core.relational import Equality
from sympy.core.symbol import Wild
def _preprocess(expr, func=None, hint='_Integral'):
"""Prepare expr for solving by making sure that differentiation
is done so that only func remains in unevaluated derivatives and
(if hint does not end with _Integral) that doit is applied to all
other derivatives. If hint is None, do not do any differentiation.
(Currently this may cause some simple differential equations to
fail.)
In case func is None, an attempt will be made to autodetect the
function to be solved for.
>>> from sympy.solvers.deutils import _preprocess
>>> from sympy import Derivative, Function
>>> from sympy.abc import x, y, z
>>> f, g = map(Function, 'fg')
If f(x)**p == 0 and p>0 then we can solve for f(x)=0
>>> _preprocess((f(x).diff(x)-4)**5, f(x))
(Derivative(f(x), x) - 4, f(x))
Apply doit to derivatives that contain more than the function
of interest:
>>> _preprocess(Derivative(f(x) + x, x))
(Derivative(f(x), x) + 1, f(x))
Do others if the differentiation variable(s) intersect with those
of the function of interest or contain the function of interest:
>>> _preprocess(Derivative(g(x), y, z), f(y))
(0, f(y))
>>> _preprocess(Derivative(f(y), z), f(y))
(0, f(y))
Do others if the hint does not end in '_Integral' (the default
assumes that it does):
>>> _preprocess(Derivative(g(x), y), f(x))
(Derivative(g(x), y), f(x))
>>> _preprocess(Derivative(f(x), y), f(x), hint='')
(0, f(x))
Do not do any derivatives if hint is None:
>>> eq = Derivative(f(x) + 1, x) + Derivative(f(x), y)
>>> _preprocess(eq, f(x), hint=None)
(Derivative(f(x) + 1, x) + Derivative(f(x), y), f(x))
If it's not clear what the function of interest is, it must be given:
>>> eq = Derivative(f(x) + g(x), x)
>>> _preprocess(eq, g(x))
(Derivative(f(x), x) + Derivative(g(x), x), g(x))
>>> try: _preprocess(eq)
... except ValueError: print("A ValueError was raised.")
A ValueError was raised.
"""
if isinstance(expr, Pow):
# if f(x)**p=0 then f(x)=0 (p>0)
if (expr.exp).is_positive:
expr = expr.base
derivs = expr.atoms(Derivative)
if not func:
funcs = set().union(*[d.atoms(AppliedUndef) for d in derivs])
if len(funcs) != 1:
raise ValueError('The function cannot be '
'automatically detected for %s.' % expr)
func = funcs.pop()
fvars = set(func.args)
if hint is None:
return expr, func
reps = [(d, d.doit()) for d in derivs if not hint.endswith('_Integral') or
d.has(func) or set(d.variables) & fvars]
eq = expr.subs(reps)
return eq, func
def ode_order(expr, func):
"""
Returns the order of a given differential
equation with respect to func.
This function is implemented recursively.
Examples
========
>>> from sympy import Function
>>> from sympy.solvers.deutils import ode_order
>>> from sympy.abc import x
>>> f, g = map(Function, ['f', 'g'])
>>> ode_order(f(x).diff(x, 2) + f(x).diff(x)**2 +
... f(x).diff(x), f(x))
2
>>> ode_order(f(x).diff(x, 2) + g(x).diff(x, 3), f(x))
2
>>> ode_order(f(x).diff(x, 2) + g(x).diff(x, 3), g(x))
3
"""
a = Wild('a', exclude=[func])
if expr.match(a):
return 0
if isinstance(expr, Derivative):
if expr.args[0] == func:
return len(expr.variables)
else:
args = expr.args[0].args
rv = len(expr.variables)
if args:
rv += max(ode_order(_, func) for _ in args)
return rv
else:
return max(ode_order(_, func) for _ in expr.args) if expr.args else 0
def _desolve(eq, func=None, hint="default", ics=None, simplify=True, *, prep=True, **kwargs):
"""This is a helper function to dsolve and pdsolve in the ode
and pde modules.
If the hint provided to the function is "default", then a dict with
the following keys are returned
'func' - It provides the function for which the differential equation
has to be solved. This is useful when the expression has
more than one function in it.
'default' - The default key as returned by classifier functions in ode
and pde.py
'hint' - The hint given by the user for which the differential equation
is to be solved. If the hint given by the user is 'default',
then the value of 'hint' and 'default' is the same.
'order' - The order of the function as returned by ode_order
'match' - It returns the match as given by the classifier functions, for
the default hint.
If the hint provided to the function is not "default" and is not in
('all', 'all_Integral', 'best'), then a dict with the above mentioned keys
is returned along with the keys which are returned when dict in
classify_ode or classify_pde is set True
If the hint given is in ('all', 'all_Integral', 'best'), then this function
returns a nested dict, with the keys, being the set of classified hints
returned by classifier functions, and the values being the dict of form
as mentioned above.
Key 'eq' is a common key to all the above mentioned hints which returns an
expression if eq given by user is an Equality.
See Also
========
classify_ode(ode.py)
classify_pde(pde.py)
"""
if isinstance(eq, Equality):
eq = eq.lhs - eq.rhs
# preprocess the equation and find func if not given
if prep or func is None:
eq, func = _preprocess(eq, func)
prep = False
# type is an argument passed by the solve functions in ode and pde.py
# that identifies whether the function caller is an ordinary
# or partial differential equation. Accordingly corresponding
# changes are made in the function.
type = kwargs.get('type', None)
xi = kwargs.get('xi')
eta = kwargs.get('eta')
x0 = kwargs.get('x0', 0)
terms = kwargs.get('n')
if type == 'ode':
from sympy.solvers.ode import classify_ode, allhints
classifier = classify_ode
string = 'ODE '
dummy = ''
elif type == 'pde':
from sympy.solvers.pde import classify_pde, allhints
classifier = classify_pde
string = 'PDE '
dummy = 'p'
# Magic that should only be used internally. Prevents classify_ode from
# being called more than it needs to be by passing its results through
# recursive calls.
if kwargs.get('classify', True):
hints = classifier(eq, func, dict=True, ics=ics, xi=xi, eta=eta,
n=terms, x0=x0, hint=hint, prep=prep)
else:
# Here is what all this means:
#
# hint: The hint method given to _desolve() by the user.
# hints: The dictionary of hints that match the DE, along with other
# information (including the internal pass-through magic).
# default: The default hint to return, the first hint from allhints
# that matches the hint; obtained from classify_ode().
# match: Dictionary containing the match dictionary for each hint
# (the parts of the DE for solving). When going through the
# hints in "all", this holds the match string for the current
# hint.
# order: The order of the DE, as determined by ode_order().
hints = kwargs.get('hint',
{'default': hint,
hint: kwargs['match'],
'order': kwargs['order']})
if not hints['default']:
# classify_ode will set hints['default'] to None if no hints match.
if hint not in allhints and hint != 'default':
raise ValueError("Hint not recognized: " + hint)
elif hint not in hints['ordered_hints'] and hint != 'default':
raise ValueError(string + str(eq) + " does not match hint " + hint)
# If dsolve can't solve the purely algebraic equation then dsolve will raise
# ValueError
elif hints['order'] == 0:
raise ValueError(
str(eq) + " is not a solvable differential equation in " + str(func))
else:
raise NotImplementedError(dummy + "solve" + ": Cannot solve " + str(eq))
if hint == 'default':
return _desolve(eq, func, ics=ics, hint=hints['default'], simplify=simplify,
prep=prep, x0=x0, classify=False, order=hints['order'],
match=hints[hints['default']], xi=xi, eta=eta, n=terms, type=type)
elif hint in ('all', 'all_Integral', 'best'):
retdict = {}
gethints = set(hints) - {'order', 'default', 'ordered_hints'}
if hint == 'all_Integral':
for i in hints:
if i.endswith('_Integral'):
gethints.remove(i.removesuffix('_Integral'))
# special cases
for k in ["1st_homogeneous_coeff_best", "1st_power_series",
"lie_group", "2nd_power_series_ordinary", "2nd_power_series_regular"]:
if k in gethints:
gethints.remove(k)
for i in gethints:
sol = _desolve(eq, func, ics=ics, hint=i, x0=x0, simplify=simplify, prep=prep,
classify=False, n=terms, order=hints['order'], match=hints[i], type=type)
retdict[i] = sol
retdict['all'] = True
retdict['eq'] = eq
return retdict
elif hint not in allhints: # and hint not in ('default', 'ordered_hints'):
raise ValueError("Hint not recognized: " + hint)
elif hint not in hints:
raise ValueError(string + str(eq) + " does not match hint " + hint)
else:
# Key added to identify the hint needed to solve the equation
hints['hint'] = hint
hints.update({'func': func, 'eq': eq})
return hints
@@ -0,0 +1,5 @@
from .diophantine import diophantine, classify_diop, diop_solve
__all__ = [
'diophantine', 'classify_diop', 'diop_solve'
]
File diff suppressed because it is too large Load Diff
@@ -0,0 +1,986 @@
"""Tools for solving inequalities and systems of inequalities. """
import itertools
from sympy.calculus.util import (continuous_domain, periodicity,
function_range)
from sympy.core import sympify
from sympy.core.exprtools import factor_terms
from sympy.core.relational import Relational, Lt, Ge, Eq
from sympy.core.symbol import Symbol, Dummy
from sympy.sets.sets import Interval, FiniteSet, Union, Intersection
from sympy.core.singleton import S
from sympy.core.function import expand_mul
from sympy.functions.elementary.complexes import Abs
from sympy.logic import And
from sympy.polys import Poly, PolynomialError, parallel_poly_from_expr
from sympy.polys.polyutils import _nsort
from sympy.solvers.solveset import solvify, solveset
from sympy.utilities.iterables import sift, iterable
from sympy.utilities.misc import filldedent
def solve_poly_inequality(poly, rel):
"""Solve a polynomial inequality with rational coefficients.
Examples
========
>>> from sympy import solve_poly_inequality, Poly
>>> from sympy.abc import x
>>> solve_poly_inequality(Poly(x, x, domain='ZZ'), '==')
[{0}]
>>> solve_poly_inequality(Poly(x**2 - 1, x, domain='ZZ'), '!=')
[Interval.open(-oo, -1), Interval.open(-1, 1), Interval.open(1, oo)]
>>> solve_poly_inequality(Poly(x**2 - 1, x, domain='ZZ'), '==')
[{-1}, {1}]
See Also
========
solve_poly_inequalities
"""
if not isinstance(poly, Poly):
raise ValueError(
'For efficiency reasons, `poly` should be a Poly instance')
if poly.as_expr().is_number:
t = Relational(poly.as_expr(), 0, rel)
if t is S.true:
return [S.Reals]
elif t is S.false:
return [S.EmptySet]
else:
raise NotImplementedError(
"could not determine truth value of %s" % t)
reals, intervals = poly.real_roots(multiple=False), []
if rel == '==':
for root, _ in reals:
interval = Interval(root, root)
intervals.append(interval)
elif rel == '!=':
left = S.NegativeInfinity
for right, _ in reals + [(S.Infinity, 1)]:
interval = Interval(left, right, True, True)
intervals.append(interval)
left = right
else:
if poly.LC() > 0:
sign = +1
else:
sign = -1
eq_sign, equal = None, False
if rel == '>':
eq_sign = +1
elif rel == '<':
eq_sign = -1
elif rel == '>=':
eq_sign, equal = +1, True
elif rel == '<=':
eq_sign, equal = -1, True
else:
raise ValueError("'%s' is not a valid relation" % rel)
right, right_open = S.Infinity, True
for left, multiplicity in reversed(reals):
if multiplicity % 2:
if sign == eq_sign:
intervals.insert(
0, Interval(left, right, not equal, right_open))
sign, right, right_open = -sign, left, not equal
else:
if sign == eq_sign and not equal:
intervals.insert(
0, Interval(left, right, True, right_open))
right, right_open = left, True
elif sign != eq_sign and equal:
intervals.insert(0, Interval(left, left))
if sign == eq_sign:
intervals.insert(
0, Interval(S.NegativeInfinity, right, True, right_open))
return intervals
def solve_poly_inequalities(polys):
"""Solve polynomial inequalities with rational coefficients.
Examples
========
>>> from sympy import Poly
>>> from sympy.solvers.inequalities import solve_poly_inequalities
>>> from sympy.abc import x
>>> solve_poly_inequalities(((
... Poly(x**2 - 3), ">"), (
... Poly(-x**2 + 1), ">")))
Union(Interval.open(-oo, -sqrt(3)), Interval.open(-1, 1), Interval.open(sqrt(3), oo))
"""
return Union(*[s for p in polys for s in solve_poly_inequality(*p)])
def solve_rational_inequalities(eqs):
"""Solve a system of rational inequalities with rational coefficients.
Examples
========
>>> from sympy.abc import x
>>> from sympy import solve_rational_inequalities, Poly
>>> solve_rational_inequalities([[
... ((Poly(-x + 1), Poly(1, x)), '>='),
... ((Poly(-x + 1), Poly(1, x)), '<=')]])
{1}
>>> solve_rational_inequalities([[
... ((Poly(x), Poly(1, x)), '!='),
... ((Poly(-x + 1), Poly(1, x)), '>=')]])
Union(Interval.open(-oo, 0), Interval.Lopen(0, 1))
See Also
========
solve_poly_inequality
"""
result = S.EmptySet
for _eqs in eqs:
if not _eqs:
continue
global_intervals = [Interval(S.NegativeInfinity, S.Infinity)]
for (numer, denom), rel in _eqs:
numer_intervals = solve_poly_inequality(numer*denom, rel)
denom_intervals = solve_poly_inequality(denom, '==')
intervals = []
for numer_interval, global_interval in itertools.product(
numer_intervals, global_intervals):
interval = numer_interval.intersect(global_interval)
if interval is not S.EmptySet:
intervals.append(interval)
global_intervals = intervals
intervals = []
for global_interval in global_intervals:
for denom_interval in denom_intervals:
global_interval -= denom_interval
if global_interval is not S.EmptySet:
intervals.append(global_interval)
global_intervals = intervals
if not global_intervals:
break
for interval in global_intervals:
result = result.union(interval)
return result
def reduce_rational_inequalities(exprs, gen, relational=True):
"""Reduce a system of rational inequalities with rational coefficients.
Examples
========
>>> from sympy import Symbol
>>> from sympy.solvers.inequalities import reduce_rational_inequalities
>>> x = Symbol('x', real=True)
>>> reduce_rational_inequalities([[x**2 <= 0]], x)
Eq(x, 0)
>>> reduce_rational_inequalities([[x + 2 > 0]], x)
-2 < x
>>> reduce_rational_inequalities([[(x + 2, ">")]], x)
-2 < x
>>> reduce_rational_inequalities([[x + 2]], x)
Eq(x, -2)
This function find the non-infinite solution set so if the unknown symbol
is declared as extended real rather than real then the result may include
finiteness conditions:
>>> y = Symbol('y', extended_real=True)
>>> reduce_rational_inequalities([[y + 2 > 0]], y)
(-2 < y) & (y < oo)
"""
exact = True
eqs = []
solution = S.EmptySet # add pieces for each group
for _exprs in exprs:
if not _exprs:
continue
_eqs = []
_sol = S.Reals
for expr in _exprs:
if isinstance(expr, tuple):
expr, rel = expr
else:
if expr.is_Relational:
expr, rel = expr.lhs - expr.rhs, expr.rel_op
else:
rel = '=='
if expr is S.true:
numer, denom, rel = S.Zero, S.One, '=='
elif expr is S.false:
numer, denom, rel = S.One, S.One, '=='
else:
numer, denom = expr.together().as_numer_denom()
try:
(numer, denom), opt = parallel_poly_from_expr(
(numer, denom), gen)
except PolynomialError:
raise PolynomialError(filldedent('''
only polynomials and rational functions are
supported in this context.
'''))
if not opt.domain.is_Exact:
numer, denom, exact = numer.to_exact(), denom.to_exact(), False
domain = opt.domain.get_exact()
if not (domain.is_ZZ or domain.is_QQ):
expr = numer/denom
expr = Relational(expr, 0, rel)
_sol &= solve_univariate_inequality(expr, gen, relational=False)
else:
_eqs.append(((numer, denom), rel))
if _eqs:
_sol &= solve_rational_inequalities([_eqs])
exclude = solve_rational_inequalities([[((d, d.one), '==')
for i in eqs for ((n, d), _) in i if d.has(gen)]])
_sol -= exclude
solution |= _sol
if not exact and solution:
solution = solution.evalf()
if relational:
solution = solution.as_relational(gen)
return solution
def reduce_abs_inequality(expr, rel, gen):
"""Reduce an inequality with nested absolute values.
Examples
========
>>> from sympy import reduce_abs_inequality, Abs, Symbol
>>> x = Symbol('x', real=True)
>>> reduce_abs_inequality(Abs(x - 5) - 3, '<', x)
(2 < x) & (x < 8)
>>> reduce_abs_inequality(Abs(x + 2)*3 - 13, '<', x)
(-19/3 < x) & (x < 7/3)
See Also
========
reduce_abs_inequalities
"""
if gen.is_extended_real is False:
raise TypeError(filldedent('''
Cannot solve inequalities with absolute values containing
non-real variables.
'''))
def _bottom_up_scan(expr):
exprs = []
if expr.is_Add or expr.is_Mul:
op = expr.func
for arg in expr.args:
_exprs = _bottom_up_scan(arg)
if not exprs:
exprs = _exprs
else:
exprs = [(op(expr, _expr), conds + _conds) for (expr, conds), (_expr, _conds) in
itertools.product(exprs, _exprs)]
elif expr.is_Pow:
n = expr.exp
if not n.is_Integer:
raise ValueError("Only Integer Powers are allowed on Abs.")
exprs.extend((expr**n, conds) for expr, conds in _bottom_up_scan(expr.base))
elif isinstance(expr, Abs):
_exprs = _bottom_up_scan(expr.args[0])
for expr, conds in _exprs:
exprs.append(( expr, conds + [Ge(expr, 0)]))
exprs.append((-expr, conds + [Lt(expr, 0)]))
else:
exprs = [(expr, [])]
return exprs
mapping = {'<': '>', '<=': '>='}
inequalities = []
for expr, conds in _bottom_up_scan(expr):
if rel not in mapping.keys():
expr = Relational( expr, 0, rel)
else:
expr = Relational(-expr, 0, mapping[rel])
inequalities.append([expr] + conds)
return reduce_rational_inequalities(inequalities, gen)
def reduce_abs_inequalities(exprs, gen):
"""Reduce a system of inequalities with nested absolute values.
Examples
========
>>> from sympy import reduce_abs_inequalities, Abs, Symbol
>>> x = Symbol('x', extended_real=True)
>>> reduce_abs_inequalities([(Abs(3*x - 5) - 7, '<'),
... (Abs(x + 25) - 13, '>')], x)
(-2/3 < x) & (x < 4) & (((-oo < x) & (x < -38)) | ((-12 < x) & (x < oo)))
>>> reduce_abs_inequalities([(Abs(x - 4) + Abs(3*x - 5) - 7, '<')], x)
(1/2 < x) & (x < 4)
See Also
========
reduce_abs_inequality
"""
return And(*[ reduce_abs_inequality(expr, rel, gen)
for expr, rel in exprs ])
def solve_univariate_inequality(expr, gen, relational=True, domain=S.Reals, continuous=False):
"""Solves a real univariate inequality.
Parameters
==========
expr : Relational
The target inequality
gen : Symbol
The variable for which the inequality is solved
relational : bool
A Relational type output is expected or not
domain : Set
The domain over which the equation is solved
continuous: bool
True if expr is known to be continuous over the given domain
(and so continuous_domain() does not need to be called on it)
Raises
======
NotImplementedError
The solution of the inequality cannot be determined due to limitation
in :func:`sympy.solvers.solveset.solvify`.
Notes
=====
Currently, we cannot solve all the inequalities due to limitations in
:func:`sympy.solvers.solveset.solvify`. Also, the solution returned for trigonometric inequalities
are restricted in its periodic interval.
See Also
========
sympy.solvers.solveset.solvify: solver returning solveset solutions with solve's output API
Examples
========
>>> from sympy import solve_univariate_inequality, Symbol, sin, Interval, S
>>> x = Symbol('x')
>>> solve_univariate_inequality(x**2 >= 4, x)
((2 <= x) & (x < oo)) | ((-oo < x) & (x <= -2))
>>> solve_univariate_inequality(x**2 >= 4, x, relational=False)
Union(Interval(-oo, -2), Interval(2, oo))
>>> domain = Interval(0, S.Infinity)
>>> solve_univariate_inequality(x**2 >= 4, x, False, domain)
Interval(2, oo)
>>> solve_univariate_inequality(sin(x) > 0, x, relational=False)
Interval.open(0, pi)
"""
from sympy.solvers.solvers import denoms
if domain.is_subset(S.Reals) is False:
raise NotImplementedError(filldedent('''
Inequalities in the complex domain are
not supported. Try the real domain by
setting domain=S.Reals'''))
elif domain is not S.Reals:
rv = solve_univariate_inequality(
expr, gen, relational=False, continuous=continuous).intersection(domain)
if relational:
rv = rv.as_relational(gen)
return rv
else:
pass # continue with attempt to solve in Real domain
# This keeps the function independent of the assumptions about `gen`.
# `solveset` makes sure this function is called only when the domain is
# real.
_gen = gen
_domain = domain
if gen.is_extended_real is False:
rv = S.EmptySet
return rv if not relational else rv.as_relational(_gen)
elif gen.is_extended_real is None:
gen = Dummy('gen', extended_real=True)
try:
expr = expr.xreplace({_gen: gen})
except TypeError:
raise TypeError(filldedent('''
When gen is real, the relational has a complex part
which leads to an invalid comparison like I < 0.
'''))
rv = None
if expr is S.true:
rv = domain
elif expr is S.false:
rv = S.EmptySet
else:
e = expr.lhs - expr.rhs
period = periodicity(e, gen)
if period == S.Zero:
e = expand_mul(e)
const = expr.func(e, 0)
if const is S.true:
rv = domain
elif const is S.false:
rv = S.EmptySet
elif period is not None:
frange = function_range(e, gen, domain)
rel = expr.rel_op
if rel in ('<', '<='):
if expr.func(frange.sup, 0):
rv = domain
elif not expr.func(frange.inf, 0):
rv = S.EmptySet
elif rel in ('>', '>='):
if expr.func(frange.inf, 0):
rv = domain
elif not expr.func(frange.sup, 0):
rv = S.EmptySet
inf, sup = domain.inf, domain.sup
if sup - inf is S.Infinity:
domain = Interval(0, period, False, True).intersect(_domain)
_domain = domain
if rv is None:
n, d = e.as_numer_denom()
try:
if gen not in n.free_symbols and len(e.free_symbols) > 1:
raise ValueError
# this might raise ValueError on its own
# or it might give None...
solns = solvify(e, gen, domain)
if solns is None:
# in which case we raise ValueError
raise ValueError
except (ValueError, NotImplementedError):
# replace gen with generic x since it's
# univariate anyway
raise NotImplementedError(filldedent('''
The inequality, %s, cannot be solved using
solve_univariate_inequality.
''' % expr.subs(gen, Symbol('x'))))
expanded_e = expand_mul(e)
def valid(x):
# this is used to see if gen=x satisfies the
# relational by substituting it into the
# expanded form and testing against 0, e.g.
# if expr = x*(x + 1) < 2 then e = x*(x + 1) - 2
# and expanded_e = x**2 + x - 2; the test is
# whether a given value of x satisfies
# x**2 + x - 2 < 0
#
# expanded_e, expr and gen used from enclosing scope
v = expanded_e.subs(gen, expand_mul(x))
try:
r = expr.func(v, 0)
except TypeError:
r = S.false
if r in (S.true, S.false):
return r
if v.is_extended_real is False:
return S.false
else:
v = v.n(2)
if v.is_comparable:
return expr.func(v, 0)
# not comparable or couldn't be evaluated
raise NotImplementedError(
'relationship did not evaluate: %s' % r)
singularities = []
for d in denoms(expr, gen):
singularities.extend(solvify(d, gen, domain))
if not continuous:
domain = continuous_domain(expanded_e, gen, domain)
include_x = '=' in expr.rel_op and expr.rel_op != '!='
try:
discontinuities = set(domain.boundary -
FiniteSet(domain.inf, domain.sup))
# remove points that are not between inf and sup of domain
critical_points = FiniteSet(*(solns + singularities + list(
discontinuities))).intersection(
Interval(domain.inf, domain.sup,
domain.inf not in domain, domain.sup not in domain))
if all(r.is_number for r in critical_points):
reals = _nsort(critical_points, separated=True)[0]
else:
sifted = sift(critical_points, lambda x: x.is_extended_real)
if sifted[None]:
# there were some roots that weren't known
# to be real
raise NotImplementedError
try:
reals = sifted[True]
if len(reals) > 1:
reals = sorted(reals)
except TypeError:
raise NotImplementedError
except NotImplementedError:
raise NotImplementedError('sorting of these roots is not supported')
# If expr contains imaginary coefficients, only take real
# values of x for which the imaginary part is 0
make_real = S.Reals
if (coeffI := expanded_e.coeff(S.ImaginaryUnit)) != S.Zero:
check = True
im_sol = FiniteSet()
try:
a = solveset(coeffI, gen, domain)
if not isinstance(a, Interval):
for z in a:
if z not in singularities and valid(z) and z.is_extended_real:
im_sol += FiniteSet(z)
else:
start, end = a.inf, a.sup
for z in _nsort(critical_points + FiniteSet(end)):
valid_start = valid(start)
if start != end:
valid_z = valid(z)
pt = _pt(start, z)
if pt not in singularities and pt.is_extended_real and valid(pt):
if valid_start and valid_z:
im_sol += Interval(start, z)
elif valid_start:
im_sol += Interval.Ropen(start, z)
elif valid_z:
im_sol += Interval.Lopen(start, z)
else:
im_sol += Interval.open(start, z)
start = z
for s in singularities:
im_sol -= FiniteSet(s)
except (TypeError):
im_sol = S.Reals
check = False
if im_sol is S.EmptySet:
raise ValueError(filldedent('''
%s contains imaginary parts which cannot be
made 0 for any value of %s satisfying the
inequality, leading to relations like I < 0.
''' % (expr.subs(gen, _gen), _gen)))
make_real = make_real.intersect(im_sol)
sol_sets = [S.EmptySet]
start = domain.inf
if start in domain and valid(start) and start.is_finite:
sol_sets.append(FiniteSet(start))
for x in reals:
end = x
if valid(_pt(start, end)):
sol_sets.append(Interval(start, end, True, True))
if x in singularities:
singularities.remove(x)
else:
if x in discontinuities:
discontinuities.remove(x)
_valid = valid(x)
else: # it's a solution
_valid = include_x
if _valid:
sol_sets.append(FiniteSet(x))
start = end
end = domain.sup
if end in domain and valid(end) and end.is_finite:
sol_sets.append(FiniteSet(end))
if valid(_pt(start, end)):
sol_sets.append(Interval.open(start, end))
if coeffI != S.Zero and check:
rv = (make_real).intersect(_domain)
else:
rv = Intersection(
(Union(*sol_sets)), make_real, _domain).subs(gen, _gen)
return rv if not relational else rv.as_relational(_gen)
def _pt(start, end):
"""Return a point between start and end"""
if not start.is_infinite and not end.is_infinite:
pt = (start + end)/2
elif start.is_infinite and end.is_infinite:
pt = S.Zero
else:
if (start.is_infinite and start.is_extended_positive is None or
end.is_infinite and end.is_extended_positive is None):
raise ValueError('cannot proceed with unsigned infinite values')
if (end.is_infinite and end.is_extended_negative or
start.is_infinite and start.is_extended_positive):
start, end = end, start
# if possible, use a multiple of self which has
# better behavior when checking assumptions than
# an expression obtained by adding or subtracting 1
if end.is_infinite:
if start.is_extended_positive:
pt = start*2
elif start.is_extended_negative:
pt = start*S.Half
else:
pt = start + 1
elif start.is_infinite:
if end.is_extended_positive:
pt = end*S.Half
elif end.is_extended_negative:
pt = end*2
else:
pt = end - 1
return pt
def _solve_inequality(ie, s, linear=False):
"""Return the inequality with s isolated on the left, if possible.
If the relationship is non-linear, a solution involving And or Or
may be returned. False or True are returned if the relationship
is never True or always True, respectively.
If `linear` is True (default is False) an `s`-dependent expression
will be isolated on the left, if possible
but it will not be solved for `s` unless the expression is linear
in `s`. Furthermore, only "safe" operations which do not change the
sense of the relationship are applied: no division by an unsigned
value is attempted unless the relationship involves Eq or Ne and
no division by a value not known to be nonzero is ever attempted.
Examples
========
>>> from sympy import Eq, Symbol
>>> from sympy.solvers.inequalities import _solve_inequality as f
>>> from sympy.abc import x, y
For linear expressions, the symbol can be isolated:
>>> f(x - 2 < 0, x)
x < 2
>>> f(-x - 6 < x, x)
x > -3
Sometimes nonlinear relationships will be False
>>> f(x**2 + 4 < 0, x)
False
Or they may involve more than one region of values:
>>> f(x**2 - 4 < 0, x)
(-2 < x) & (x < 2)
To restrict the solution to a relational, set linear=True
and only the x-dependent portion will be isolated on the left:
>>> f(x**2 - 4 < 0, x, linear=True)
x**2 < 4
Division of only nonzero quantities is allowed, so x cannot
be isolated by dividing by y:
>>> y.is_nonzero is None # it is unknown whether it is 0 or not
True
>>> f(x*y < 1, x)
x*y < 1
And while an equality (or inequality) still holds after dividing by a
non-zero quantity
>>> nz = Symbol('nz', nonzero=True)
>>> f(Eq(x*nz, 1), x)
Eq(x, 1/nz)
the sign must be known for other inequalities involving > or <:
>>> f(x*nz <= 1, x)
nz*x <= 1
>>> p = Symbol('p', positive=True)
>>> f(x*p <= 1, x)
x <= 1/p
When there are denominators in the original expression that
are removed by expansion, conditions for them will be returned
as part of the result:
>>> f(x < x*(2/x - 1), x)
(x < 1) & Ne(x, 0)
"""
from sympy.solvers.solvers import denoms
if s not in ie.free_symbols:
return ie
if ie.rhs == s:
ie = ie.reversed
if ie.lhs == s and s not in ie.rhs.free_symbols:
return ie
def classify(ie, s, i):
# return True or False if ie evaluates when substituting s with
# i else None (if unevaluated) or NaN (when there is an error
# in evaluating)
try:
v = ie.subs(s, i)
if v is S.NaN:
return v
elif v not in (True, False):
return
return v
except TypeError:
return S.NaN
rv = None
oo = S.Infinity
expr = ie.lhs - ie.rhs
try:
p = Poly(expr, s)
if p.degree() == 0:
rv = ie.func(p.as_expr(), 0)
elif not linear and p.degree() > 1:
# handle in except clause
raise NotImplementedError
except (PolynomialError, NotImplementedError):
if not linear:
try:
rv = reduce_rational_inequalities([[ie]], s)
except PolynomialError:
rv = solve_univariate_inequality(ie, s)
# remove restrictions wrt +/-oo that may have been
# applied when using sets to simplify the relationship
okoo = classify(ie, s, oo)
if okoo is S.true and classify(rv, s, oo) is S.false:
rv = rv.subs(s < oo, True)
oknoo = classify(ie, s, -oo)
if (oknoo is S.true and
classify(rv, s, -oo) is S.false):
rv = rv.subs(-oo < s, True)
rv = rv.subs(s > -oo, True)
if rv is S.true:
rv = (s <= oo) if okoo is S.true else (s < oo)
if oknoo is not S.true:
rv = And(-oo < s, rv)
else:
p = Poly(expr)
conds = []
if rv is None:
e = p.as_expr() # this is in expanded form
# Do a safe inversion of e, moving non-s terms
# to the rhs and dividing by a nonzero factor if
# the relational is Eq/Ne; for other relationals
# the sign must also be positive or negative
rhs = 0
b, ax = e.as_independent(s, as_Add=True)
e -= b
rhs -= b
ef = factor_terms(e)
a, e = ef.as_independent(s, as_Add=False)
if (a.is_zero != False or # don't divide by potential 0
a.is_negative ==
a.is_positive is None and # if sign is not known then
ie.rel_op not in ('!=', '==')): # reject if not Eq/Ne
e = ef
a = S.One
rhs /= a
if a.is_positive:
rv = ie.func(e, rhs)
else:
rv = ie.reversed.func(e, rhs)
# return conditions under which the value is
# valid, too.
beginning_denoms = denoms(ie.lhs) | denoms(ie.rhs)
current_denoms = denoms(rv)
for d in beginning_denoms - current_denoms:
c = _solve_inequality(Eq(d, 0), s, linear=linear)
if isinstance(c, Eq) and c.lhs == s:
if classify(rv, s, c.rhs) is S.true:
# rv is permitting this value but it shouldn't
conds.append(~c)
for i in (-oo, oo):
if (classify(rv, s, i) is S.true and
classify(ie, s, i) is not S.true):
conds.append(s < i if i is oo else i < s)
conds.append(rv)
return And(*conds)
def _reduce_inequalities(inequalities, symbols):
# helper for reduce_inequalities
poly_part, abs_part = {}, {}
other = []
for inequality in inequalities:
expr, rel = inequality.lhs, inequality.rel_op # rhs is 0
# check for gens using atoms which is more strict than free_symbols to
# guard against EX domain which won't be handled by
# reduce_rational_inequalities
gens = expr.atoms(Symbol)
if len(gens) == 1:
gen = gens.pop()
else:
common = expr.free_symbols & symbols
if len(common) == 1:
gen = common.pop()
other.append(_solve_inequality(Relational(expr, 0, rel), gen))
continue
else:
raise NotImplementedError(filldedent('''
inequality has more than one symbol of interest.
'''))
if expr.is_polynomial(gen):
poly_part.setdefault(gen, []).append((expr, rel))
else:
components = expr.find(lambda u:
u.has(gen) and (
u.is_Function or u.is_Pow and not u.exp.is_Integer))
if components and all(isinstance(i, Abs) for i in components):
abs_part.setdefault(gen, []).append((expr, rel))
else:
other.append(_solve_inequality(Relational(expr, 0, rel), gen))
poly_reduced = [reduce_rational_inequalities([exprs], gen) for gen, exprs in poly_part.items()]
abs_reduced = [reduce_abs_inequalities(exprs, gen) for gen, exprs in abs_part.items()]
return And(*(poly_reduced + abs_reduced + other))
def reduce_inequalities(inequalities, symbols=[]):
"""Reduce a system of inequalities with rational coefficients.
Examples
========
>>> from sympy.abc import x, y
>>> from sympy import reduce_inequalities
>>> reduce_inequalities(0 <= x + 3, [])
(-3 <= x) & (x < oo)
>>> reduce_inequalities(0 <= x + y*2 - 1, [x])
(x < oo) & (x >= 1 - 2*y)
"""
if not iterable(inequalities):
inequalities = [inequalities]
inequalities = [sympify(i) for i in inequalities]
gens = set().union(*[i.free_symbols for i in inequalities])
if not iterable(symbols):
symbols = [symbols]
symbols = (set(symbols) or gens) & gens
if any(i.is_extended_real is False for i in symbols):
raise TypeError(filldedent('''
inequalities cannot contain symbols that are not real.
'''))
# make vanilla symbol real
recast = {i: Dummy(i.name, extended_real=True)
for i in gens if i.is_extended_real is None}
inequalities = [i.xreplace(recast) for i in inequalities]
symbols = {i.xreplace(recast) for i in symbols}
# prefilter
keep = []
for i in inequalities:
if isinstance(i, Relational):
i = i.func(i.lhs.as_expr() - i.rhs.as_expr(), 0)
elif i not in (True, False):
i = Eq(i, 0)
if i == True:
continue
elif i == False:
return S.false
if i.lhs.is_number:
raise NotImplementedError(
"could not determine truth value of %s" % i)
keep.append(i)
inequalities = keep
del keep
# solve system
rv = _reduce_inequalities(inequalities, symbols)
# restore original symbols and return
return rv.xreplace({v: k for k, v in recast.items()})
@@ -0,0 +1,16 @@
from .ode import (allhints, checkinfsol, classify_ode,
constantsimp, dsolve, homogeneous_order)
from .lie_group import infinitesimals
from .subscheck import checkodesol
from .systems import (canonical_odes, linear_ode_to_matrix,
linodesolve)
__all__ = [
'allhints', 'checkinfsol', 'checkodesol', 'classify_ode', 'constantsimp',
'dsolve', 'homogeneous_order', 'infinitesimals', 'canonical_odes', 'linear_ode_to_matrix',
'linodesolve'
]
@@ -0,0 +1,272 @@
r'''
This module contains the implementation of the 2nd_hypergeometric hint for
dsolve. This is an incomplete implementation of the algorithm described in [1].
The algorithm solves 2nd order linear ODEs of the form
.. math:: y'' + A(x) y' + B(x) y = 0\text{,}
where `A` and `B` are rational functions. The algorithm should find any
solution of the form
.. math:: y = P(x) _pF_q(..; ..;\frac{\alpha x^k + \beta}{\gamma x^k + \delta})\text{,}
where pFq is any of 2F1, 1F1 or 0F1 and `P` is an "arbitrary function".
Currently only the 2F1 case is implemented in SymPy but the other cases are
described in the paper and could be implemented in future (contributions
welcome!).
References
==========
.. [1] L. Chan, E.S. Cheb-Terrab, Non-Liouvillian solutions for second order
linear ODEs, (2004).
https://arxiv.org/abs/math-ph/0402063
'''
from sympy.core import S, Pow
from sympy.core.function import expand
from sympy.core.relational import Eq
from sympy.core.symbol import Symbol, Wild
from sympy.functions import exp, sqrt, hyper
from sympy.integrals import Integral
from sympy.polys import roots, gcd
from sympy.polys.polytools import cancel, factor
from sympy.simplify import collect, simplify, logcombine # type: ignore
from sympy.simplify.powsimp import powdenest
from sympy.solvers.ode.ode import get_numbered_constants
def match_2nd_hypergeometric(eq, func):
x = func.args[0]
df = func.diff(x)
a3 = Wild('a3', exclude=[func, func.diff(x), func.diff(x, 2)])
b3 = Wild('b3', exclude=[func, func.diff(x), func.diff(x, 2)])
c3 = Wild('c3', exclude=[func, func.diff(x), func.diff(x, 2)])
deq = a3*(func.diff(x, 2)) + b3*df + c3*func
r = collect(eq,
[func.diff(x, 2), func.diff(x), func]).match(deq)
if r:
if not all(val.is_polynomial() for val in r.values()):
n, d = eq.as_numer_denom()
eq = expand(n)
r = collect(eq, [func.diff(x, 2), func.diff(x), func]).match(deq)
if r and r[a3]!=0:
A = cancel(r[b3]/r[a3])
B = cancel(r[c3]/r[a3])
return [A, B]
else:
return []
def equivalence_hypergeometric(A, B, func):
# This method for finding the equivalence is only for 2F1 type.
# We can extend it for 1F1 and 0F1 type also.
x = func.args[0]
# making given equation in normal form
I1 = factor(cancel(A.diff(x)/2 + A**2/4 - B))
# computing shifted invariant(J1) of the equation
J1 = factor(cancel(x**2*I1 + S(1)/4))
num, dem = J1.as_numer_denom()
num = powdenest(expand(num))
dem = powdenest(expand(dem))
# this function will compute the different powers of variable(x) in J1.
# then it will help in finding value of k. k is power of x such that we can express
# J1 = x**k * J0(x**k) then all the powers in J0 become integers.
def _power_counting(num):
_pow = {0}
for val in num:
if val.has(x):
if isinstance(val, Pow) and val.as_base_exp()[0] == x:
_pow.add(val.as_base_exp()[1])
elif val == x:
_pow.add(val.as_base_exp()[1])
else:
_pow.update(_power_counting(val.args))
return _pow
pow_num = _power_counting((num, ))
pow_dem = _power_counting((dem, ))
pow_dem.update(pow_num)
_pow = pow_dem
k = gcd(_pow)
# computing I0 of the given equation
I0 = powdenest(simplify(factor(((J1/k**2) - S(1)/4)/((x**k)**2))), force=True)
I0 = factor(cancel(powdenest(I0.subs(x, x**(S(1)/k)), force=True)))
# Before this point I0, J1 might be functions of e.g. sqrt(x) but replacing
# x with x**(1/k) should result in I0 being a rational function of x or
# otherwise the hypergeometric solver cannot be used. Note that k can be a
# non-integer rational such as 2/7.
if not I0.is_rational_function(x):
return None
num, dem = I0.as_numer_denom()
max_num_pow = max(_power_counting((num, )))
dem_args = dem.args
sing_point = []
dem_pow = []
# calculating singular point of I0.
for arg in dem_args:
if arg.has(x):
if isinstance(arg, Pow):
# (x-a)**n
dem_pow.append(arg.as_base_exp()[1])
sing_point.append(list(roots(arg.as_base_exp()[0], x).keys())[0])
else:
# (x-a) type
dem_pow.append(arg.as_base_exp()[1])
sing_point.append(list(roots(arg, x).keys())[0])
dem_pow.sort()
# checking if equivalence is exists or not.
if equivalence(max_num_pow, dem_pow) == "2F1":
return {'I0':I0, 'k':k, 'sing_point':sing_point, 'type':"2F1"}
else:
return None
def match_2nd_2F1_hypergeometric(I, k, sing_point, func):
x = func.args[0]
a = Wild("a")
b = Wild("b")
c = Wild("c")
t = Wild("t")
s = Wild("s")
r = Wild("r")
alpha = Wild("alpha")
beta = Wild("beta")
gamma = Wild("gamma")
delta = Wild("delta")
# I0 of the standard 2F1 equation.
I0 = ((a-b+1)*(a-b-1)*x**2 + 2*((1-a-b)*c + 2*a*b)*x + c*(c-2))/(4*x**2*(x-1)**2)
if sing_point != [0, 1]:
# If singular point is [0, 1] then we have standard equation.
eqs = []
sing_eqs = [-beta/alpha, -delta/gamma, (delta-beta)/(alpha-gamma)]
# making equations for the finding the mobius transformation
for i in range(3):
if i<len(sing_point):
eqs.append(Eq(sing_eqs[i], sing_point[i]))
else:
eqs.append(Eq(1/sing_eqs[i], 0))
# solving above equations for the mobius transformation
_beta = -alpha*sing_point[0]
_delta = -gamma*sing_point[1]
_gamma = alpha
if len(sing_point) == 3:
_gamma = (_beta + sing_point[2]*alpha)/(sing_point[2] - sing_point[1])
mob = (alpha*x + beta)/(gamma*x + delta)
mob = mob.subs(beta, _beta)
mob = mob.subs(delta, _delta)
mob = mob.subs(gamma, _gamma)
mob = cancel(mob)
t = (beta - delta*x)/(gamma*x - alpha)
t = cancel(((t.subs(beta, _beta)).subs(delta, _delta)).subs(gamma, _gamma))
else:
mob = x
t = x
# applying mobius transformation in I to make it into I0.
I = I.subs(x, t)
I = I*(t.diff(x))**2
I = factor(I)
dict_I = {x**2:0, x:0, 1:0}
I0_num, I0_dem = I0.as_numer_denom()
# collecting coeff of (x**2, x), of the standard equation.
# substituting (a-b) = s, (a+b) = r
dict_I0 = {x**2:s**2 - 1, x:(2*(1-r)*c + (r+s)*(r-s)), 1:c*(c-2)}
# collecting coeff of (x**2, x) from I0 of the given equation.
dict_I.update(collect(expand(cancel(I*I0_dem)), [x**2, x], evaluate=False))
eqs = []
# We are comparing the coeff of powers of different x, for finding the values of
# parameters of standard equation.
for key in [x**2, x, 1]:
eqs.append(Eq(dict_I[key], dict_I0[key]))
# We can have many possible roots for the equation.
# I am selecting the root on the basis that when we have
# standard equation eq = x*(x-1)*f(x).diff(x, 2) + ((a+b+1)*x-c)*f(x).diff(x) + a*b*f(x)
# then root should be a, b, c.
_c = 1 - factor(sqrt(1+eqs[2].lhs))
if not _c.has(Symbol):
_c = min(list(roots(eqs[2], c)))
_s = factor(sqrt(eqs[0].lhs + 1))
_r = _c - factor(sqrt(_c**2 + _s**2 + eqs[1].lhs - 2*_c))
_a = (_r + _s)/2
_b = (_r - _s)/2
rn = {'a':simplify(_a), 'b':simplify(_b), 'c':simplify(_c), 'k':k, 'mobius':mob, 'type':"2F1"}
return rn
def equivalence(max_num_pow, dem_pow):
# this function is made for checking the equivalence with 2F1 type of equation.
# max_num_pow is the value of maximum power of x in numerator
# and dem_pow is list of powers of different factor of form (a*x b).
# reference from table 1 in paper - "Non-Liouvillian solutions for second order
# linear ODEs" by L. Chan, E.S. Cheb-Terrab.
# We can extend it for 1F1 and 0F1 type also.
if max_num_pow == 2:
if dem_pow in [[2, 2], [2, 2, 2]]:
return "2F1"
elif max_num_pow == 1:
if dem_pow in [[1, 2, 2], [2, 2, 2], [1, 2], [2, 2]]:
return "2F1"
elif max_num_pow == 0:
if dem_pow in [[1, 1, 2], [2, 2], [1, 2, 2], [1, 1], [2], [1, 2], [2, 2]]:
return "2F1"
return None
def get_sol_2F1_hypergeometric(eq, func, match_object):
x = func.args[0]
from sympy.simplify.hyperexpand import hyperexpand
from sympy.polys.polytools import factor
C0, C1 = get_numbered_constants(eq, num=2)
a = match_object['a']
b = match_object['b']
c = match_object['c']
A = match_object['A']
sol = None
if c.is_integer == False:
sol = C0*hyper([a, b], [c], x) + C1*hyper([a-c+1, b-c+1], [2-c], x)*x**(1-c)
elif c == 1:
y2 = Integral(exp(Integral((-(a+b+1)*x + c)/(x**2-x), x))/(hyperexpand(hyper([a, b], [c], x))**2), x)*hyper([a, b], [c], x)
sol = C0*hyper([a, b], [c], x) + C1*y2
elif (c-a-b).is_integer == False:
sol = C0*hyper([a, b], [1+a+b-c], 1-x) + C1*hyper([c-a, c-b], [1+c-a-b], 1-x)*(1-x)**(c-a-b)
if sol:
# applying transformation in the solution
subs = match_object['mobius']
dtdx = simplify(1/(subs.diff(x)))
_B = ((a + b + 1)*x - c).subs(x, subs)*dtdx
_B = factor(_B + ((x**2 -x).subs(x, subs))*(dtdx.diff(x)*dtdx))
_A = factor((x**2 - x).subs(x, subs)*(dtdx**2))
e = exp(logcombine(Integral(cancel(_B/(2*_A)), x), force=True))
sol = sol.subs(x, match_object['mobius'])
sol = sol.subs(x, x**match_object['k'])
e = e.subs(x, x**match_object['k'])
if not A.is_zero:
e1 = Integral(A/2, x)
e1 = exp(logcombine(e1, force=True))
sol = cancel((e/e1)*x**((-match_object['k']+1)/2))*sol
sol = Eq(func, sol)
return sol
sol = cancel((e)*x**((-match_object['k']+1)/2))*sol
sol = Eq(func, sol)
return sol
File diff suppressed because it is too large Load Diff
@@ -0,0 +1,484 @@
r"""
This File contains helper functions for nth_linear_constant_coeff_undetermined_coefficients,
nth_linear_euler_eq_nonhomogeneous_undetermined_coefficients,
nth_linear_constant_coeff_variation_of_parameters,
and nth_linear_euler_eq_nonhomogeneous_variation_of_parameters.
All the functions in this file are used by more than one solvers so, instead of creating
instances in other classes for using them it is better to keep it here as separate helpers.
"""
from collections import Counter
from sympy.core import Add, S
from sympy.core.function import diff, expand, _mexpand, expand_mul
from sympy.core.relational import Eq
from sympy.core.sorting import default_sort_key
from sympy.core.symbol import Dummy, Wild
from sympy.functions import exp, cos, cosh, im, log, re, sin, sinh, \
atan2, conjugate
from sympy.integrals import Integral
from sympy.polys import (Poly, RootOf, rootof, roots)
from sympy.simplify import collect, simplify, separatevars, powsimp, trigsimp # type: ignore
from sympy.utilities import numbered_symbols
from sympy.solvers.solvers import solve
from sympy.matrices import wronskian
from .subscheck import sub_func_doit
from sympy.solvers.ode.ode import get_numbered_constants
def _test_term(coeff, func, order):
r"""
Linear Euler ODEs have the form K*x**order*diff(y(x), x, order) = F(x),
where K is independent of x and y(x), order>= 0.
So we need to check that for each term, coeff == K*x**order from
some K. We have a few cases, since coeff may have several
different types.
"""
x = func.args[0]
f = func.func
if order < 0:
raise ValueError("order should be greater than 0")
if coeff == 0:
return True
if order == 0:
if x in coeff.free_symbols:
return False
return True
if coeff.is_Mul:
if coeff.has(f(x)):
return False
return x**order in coeff.args
elif coeff.is_Pow:
return coeff.as_base_exp() == (x, order)
elif order == 1:
return x == coeff
return False
def _get_euler_characteristic_eq_sols(eq, func, match_obj):
r"""
Returns the solution of homogeneous part of the linear euler ODE and
the list of roots of characteristic equation.
The parameter ``match_obj`` is a dict of order:coeff terms, where order is the order
of the derivative on each term, and coeff is the coefficient of that derivative.
"""
x = func.args[0]
f = func.func
# First, set up characteristic equation.
chareq, symbol = S.Zero, Dummy('x')
for i in match_obj:
if i >= 0:
chareq += (match_obj[i]*diff(x**symbol, x, i)*x**-symbol).expand()
chareq = Poly(chareq, symbol)
chareqroots = [rootof(chareq, k) for k in range(chareq.degree())]
collectterms = []
# A generator of constants
constants = list(get_numbered_constants(eq, num=chareq.degree()*2))
constants.reverse()
# Create a dict root: multiplicity or charroots
charroots = Counter(chareqroots)
gsol = S.Zero
ln = log
for root, multiplicity in charroots.items():
for i in range(multiplicity):
if isinstance(root, RootOf):
gsol += (x**root) * constants.pop()
if multiplicity != 1:
raise ValueError("Value should be 1")
collectterms = [(0, root, 0)] + collectterms
elif root.is_real:
gsol += ln(x)**i*(x**root) * constants.pop()
collectterms = [(i, root, 0)] + collectterms
else:
reroot = re(root)
imroot = im(root)
gsol += ln(x)**i * (x**reroot) * (
constants.pop() * sin(abs(imroot)*ln(x))
+ constants.pop() * cos(imroot*ln(x)))
collectterms = [(i, reroot, imroot)] + collectterms
gsol = Eq(f(x), gsol)
gensols = []
# Keep track of when to use sin or cos for nonzero imroot
for i, reroot, imroot in collectterms:
if imroot == 0:
gensols.append(ln(x)**i*x**reroot)
else:
sin_form = ln(x)**i*x**reroot*sin(abs(imroot)*ln(x))
if sin_form in gensols:
cos_form = ln(x)**i*x**reroot*cos(imroot*ln(x))
gensols.append(cos_form)
else:
gensols.append(sin_form)
return gsol, gensols
def _solve_variation_of_parameters(eq, func, roots, homogen_sol, order, match_obj, simplify_flag=True):
r"""
Helper function for the method of variation of parameters and nonhomogeneous euler eq.
See the
:py:meth:`~sympy.solvers.ode.single.NthLinearConstantCoeffVariationOfParameters`
docstring for more information on this method.
The parameter are ``match_obj`` should be a dictionary that has the following
keys:
``list``
A list of solutions to the homogeneous equation.
``sol``
The general solution.
"""
f = func.func
x = func.args[0]
r = match_obj
psol = 0
wr = wronskian(roots, x)
if simplify_flag:
wr = simplify(wr) # We need much better simplification for
# some ODEs. See issue 4662, for example.
# To reduce commonly occurring sin(x)**2 + cos(x)**2 to 1
wr = trigsimp(wr, deep=True, recursive=True)
if not wr:
# The wronskian will be 0 iff the solutions are not linearly
# independent.
raise NotImplementedError("Cannot find " + str(order) +
" solutions to the homogeneous equation necessary to apply " +
"variation of parameters to " + str(eq) + " (Wronskian == 0)")
if len(roots) != order:
raise NotImplementedError("Cannot find " + str(order) +
" solutions to the homogeneous equation necessary to apply " +
"variation of parameters to " +
str(eq) + " (number of terms != order)")
negoneterm = S.NegativeOne**(order)
for i in roots:
psol += negoneterm*Integral(wronskian([sol for sol in roots if sol != i], x)*r[-1]/wr, x)*i/r[order]
negoneterm *= -1
if simplify_flag:
psol = simplify(psol)
psol = trigsimp(psol, deep=True)
return Eq(f(x), homogen_sol.rhs + psol)
def _get_const_characteristic_eq_sols(r, func, order):
r"""
Returns the roots of characteristic equation of constant coefficient
linear ODE and list of collectterms which is later on used by simplification
to use collect on solution.
The parameter `r` is a dict of order:coeff terms, where order is the order of the
derivative on each term, and coeff is the coefficient of that derivative.
"""
x = func.args[0]
# First, set up characteristic equation.
chareq, symbol = S.Zero, Dummy('x')
for i in r.keys():
if isinstance(i, str) or i < 0:
pass
else:
chareq += r[i]*symbol**i
chareq = Poly(chareq, symbol)
# Can't just call roots because it doesn't return rootof for unsolveable
# polynomials.
chareqroots = roots(chareq, multiple=True)
if len(chareqroots) != order:
chareqroots = [rootof(chareq, k) for k in range(chareq.degree())]
chareq_is_complex = not all(i.is_real for i in chareq.all_coeffs())
# Create a dict root: multiplicity or charroots
charroots = Counter(chareqroots)
# We need to keep track of terms so we can run collect() at the end.
# This is necessary for constantsimp to work properly.
collectterms = []
gensols = []
conjugate_roots = [] # used to prevent double-use of conjugate roots
# Loop over roots in theorder provided by roots/rootof...
for root in chareqroots:
# but don't repoeat multiple roots.
if root not in charroots:
continue
multiplicity = charroots.pop(root)
for i in range(multiplicity):
if chareq_is_complex:
gensols.append(x**i*exp(root*x))
collectterms = [(i, root, 0)] + collectterms
continue
reroot = re(root)
imroot = im(root)
if imroot.has(atan2) and reroot.has(atan2):
# Remove this condition when re and im stop returning
# circular atan2 usages.
gensols.append(x**i*exp(root*x))
collectterms = [(i, root, 0)] + collectterms
else:
if root in conjugate_roots:
collectterms = [(i, reroot, imroot)] + collectterms
continue
if imroot == 0:
gensols.append(x**i*exp(reroot*x))
collectterms = [(i, reroot, 0)] + collectterms
continue
conjugate_roots.append(conjugate(root))
gensols.append(x**i*exp(reroot*x) * sin(abs(imroot) * x))
gensols.append(x**i*exp(reroot*x) * cos( imroot * x))
# This ordering is important
collectterms = [(i, reroot, imroot)] + collectterms
return gensols, collectterms
# Ideally these kind of simplification functions shouldn't be part of solvers.
# odesimp should be improved to handle these kind of specific simplifications.
def _get_simplified_sol(sol, func, collectterms):
r"""
Helper function which collects the solution on
collectterms. Ideally this should be handled by odesimp.It is used
only when the simplify is set to True in dsolve.
The parameter ``collectterms`` is a list of tuple (i, reroot, imroot) where `i` is
the multiplicity of the root, reroot is real part and imroot being the imaginary part.
"""
f = func.func
x = func.args[0]
collectterms.sort(key=default_sort_key)
collectterms.reverse()
assert len(sol) == 1 and sol[0].lhs == f(x)
sol = sol[0].rhs
sol = expand_mul(sol)
for i, reroot, imroot in collectterms:
sol = collect(sol, x**i*exp(reroot*x)*sin(abs(imroot)*x))
sol = collect(sol, x**i*exp(reroot*x)*cos(imroot*x))
for i, reroot, imroot in collectterms:
sol = collect(sol, x**i*exp(reroot*x))
sol = powsimp(sol)
return Eq(f(x), sol)
def _undetermined_coefficients_match(expr, x, func=None, eq_homogeneous=S.Zero):
r"""
Returns a trial function match if undetermined coefficients can be applied
to ``expr``, and ``None`` otherwise.
A trial expression can be found for an expression for use with the method
of undetermined coefficients if the expression is an
additive/multiplicative combination of constants, polynomials in `x` (the
independent variable of expr), `\sin(a x + b)`, `\cos(a x + b)`, and
`e^{a x}` terms (in other words, it has a finite number of linearly
independent derivatives).
Note that you may still need to multiply each term returned here by
sufficient `x` to make it linearly independent with the solutions to the
homogeneous equation.
This is intended for internal use by ``undetermined_coefficients`` hints.
SymPy currently has no way to convert `\sin^n(x) \cos^m(y)` into a sum of
only `\sin(a x)` and `\cos(b x)` terms, so these are not implemented. So,
for example, you will need to manually convert `\sin^2(x)` into `[1 +
\cos(2 x)]/2` to properly apply the method of undetermined coefficients on
it.
Examples
========
>>> from sympy import log, exp
>>> from sympy.solvers.ode.nonhomogeneous import _undetermined_coefficients_match
>>> from sympy.abc import x
>>> _undetermined_coefficients_match(9*x*exp(x) + exp(-x), x)
{'test': True, 'trialset': {x*exp(x), exp(-x), exp(x)}}
>>> _undetermined_coefficients_match(log(x), x)
{'test': False}
"""
a = Wild('a', exclude=[x])
b = Wild('b', exclude=[x])
expr = powsimp(expr, combine='exp') # exp(x)*exp(2*x + 1) => exp(3*x + 1)
retdict = {}
def _test_term(expr, x) -> bool:
r"""
Test if ``expr`` fits the proper form for undetermined coefficients.
"""
if not expr.has(x):
return True
if expr.is_Add:
return all(_test_term(i, x) for i in expr.args)
if expr.is_Mul:
if expr.has(sin, cos):
foundtrig = False
# Make sure that there is only one trig function in the args.
# See the docstring.
for i in expr.args:
if i.has(sin, cos):
if foundtrig:
return False
else:
foundtrig = True
return all(_test_term(i, x) for i in expr.args)
if expr.is_Function:
return expr.func in (sin, cos, exp, sinh, cosh) and \
bool(expr.args[0].match(a*x + b))
if expr.is_Pow and expr.base.is_Symbol and expr.exp.is_Integer and \
expr.exp >= 0:
return True
if expr.is_Pow and expr.base.is_number:
return bool(expr.exp.match(a*x + b))
return expr.is_Symbol or bool(expr.is_number)
def _get_trial_set(expr, x, exprs=set()):
r"""
Returns a set of trial terms for undetermined coefficients.
The idea behind undetermined coefficients is that the terms expression
repeat themselves after a finite number of derivatives, except for the
coefficients (they are linearly dependent). So if we collect these,
we should have the terms of our trial function.
"""
def _remove_coefficient(expr, x):
r"""
Returns the expression without a coefficient.
Similar to expr.as_independent(x)[1], except it only works
multiplicatively.
"""
term = S.One
if expr.is_Mul:
for i in expr.args:
if i.has(x):
term *= i
elif expr.has(x):
term = expr
return term
expr = expand_mul(expr)
if expr.is_Add:
for term in expr.args:
if _remove_coefficient(term, x) in exprs:
pass
else:
exprs.add(_remove_coefficient(term, x))
exprs = exprs.union(_get_trial_set(term, x, exprs))
else:
term = _remove_coefficient(expr, x)
tmpset = exprs.union({term})
oldset = set()
while tmpset != oldset:
# If you get stuck in this loop, then _test_term is probably
# broken
oldset = tmpset.copy()
expr = expr.diff(x)
term = _remove_coefficient(expr, x)
if term.is_Add:
tmpset = tmpset.union(_get_trial_set(term, x, tmpset))
else:
tmpset.add(term)
exprs = tmpset
return exprs
def is_homogeneous_solution(term):
r""" This function checks whether the given trialset contains any root
of homogeneous equation"""
return expand(sub_func_doit(eq_homogeneous, func, term)).is_zero
retdict['test'] = _test_term(expr, x)
if retdict['test']:
# Try to generate a list of trial solutions that will have the
# undetermined coefficients. Note that if any of these are not linearly
# independent with any of the solutions to the homogeneous equation,
# then they will need to be multiplied by sufficient x to make them so.
# This function DOES NOT do that (it doesn't even look at the
# homogeneous equation).
temp_set = set()
for i in Add.make_args(expr):
act = _get_trial_set(i, x)
if eq_homogeneous is not S.Zero:
while any(is_homogeneous_solution(ts) for ts in act):
act = {x*ts for ts in act}
temp_set = temp_set.union(act)
retdict['trialset'] = temp_set
return retdict
def _solve_undetermined_coefficients(eq, func, order, match, trialset):
r"""
Helper function for the method of undetermined coefficients.
See the
:py:meth:`~sympy.solvers.ode.single.NthLinearConstantCoeffUndeterminedCoefficients`
docstring for more information on this method.
The parameter ``trialset`` is the set of trial functions as returned by
``_undetermined_coefficients_match()['trialset']``.
The parameter ``match`` should be a dictionary that has the following
keys:
``list``
A list of solutions to the homogeneous equation.
``sol``
The general solution.
"""
r = match
coeffs = numbered_symbols('a', cls=Dummy)
coefflist = []
gensols = r['list']
gsol = r['sol']
f = func.func
x = func.args[0]
if len(gensols) != order:
raise NotImplementedError("Cannot find " + str(order) +
" solutions to the homogeneous equation necessary to apply" +
" undetermined coefficients to " + str(eq) +
" (number of terms != order)")
trialfunc = 0
for i in trialset:
c = next(coeffs)
coefflist.append(c)
trialfunc += c*i
eqs = sub_func_doit(eq, f(x), trialfunc)
coeffsdict = dict(list(zip(trialset, [0]*(len(trialset) + 1))))
eqs = _mexpand(eqs)
for i in Add.make_args(eqs):
s = separatevars(i, dict=True, symbols=[x])
if coeffsdict.get(s[x]):
coeffsdict[s[x]] += s['coeff']
else:
coeffsdict[s[x]] = s['coeff']
coeffvals = solve(list(coeffsdict.values()), coefflist)
if not coeffvals:
raise NotImplementedError(
"Could not solve `%s` using the "
"method of undetermined coefficients "
"(unable to solve for coefficients)." % eq)
psol = trialfunc.subs(coeffvals)
return Eq(f(x), gsol.rhs + psol)
File diff suppressed because it is too large Load Diff
@@ -0,0 +1,893 @@
r"""
This module contains :py:meth:`~sympy.solvers.ode.riccati.solve_riccati`,
a function which gives all rational particular solutions to first order
Riccati ODEs. A general first order Riccati ODE is given by -
.. math:: y' = b_0(x) + b_1(x)w + b_2(x)w^2
where `b_0, b_1` and `b_2` can be arbitrary rational functions of `x`
with `b_2 \ne 0`. When `b_2 = 0`, the equation is not a Riccati ODE
anymore and becomes a Linear ODE. Similarly, when `b_0 = 0`, the equation
is a Bernoulli ODE. The algorithm presented below can find rational
solution(s) to all ODEs with `b_2 \ne 0` that have a rational solution,
or prove that no rational solution exists for the equation.
Background
==========
A Riccati equation can be transformed to its normal form
.. math:: y' + y^2 = a(x)
using the transformation
.. math:: y = -b_2(x) - \frac{b'_2(x)}{2 b_2(x)} - \frac{b_1(x)}{2}
where `a(x)` is given by
.. math:: a(x) = \frac{1}{4}\left(\frac{b_2'}{b_2} + b_1\right)^2 - \frac{1}{2}\left(\frac{b_2'}{b_2} + b_1\right)' - b_0 b_2
Thus, we can develop an algorithm to solve for the Riccati equation
in its normal form, which would in turn give us the solution for
the original Riccati equation.
Algorithm
=========
The algorithm implemented here is presented in the Ph.D thesis
"Rational and Algebraic Solutions of First-Order Algebraic ODEs"
by N. Thieu Vo. The entire thesis can be found here -
https://www3.risc.jku.at/publications/download/risc_5387/PhDThesisThieu.pdf
We have only implemented the Rational Riccati solver (Algorithm 11,
Pg 78-82 in Thesis). Before we proceed towards the implementation
of the algorithm, a few definitions to understand are -
1. Valuation of a Rational Function at `\infty`:
The valuation of a rational function `p(x)` at `\infty` is equal
to the difference between the degree of the denominator and the
numerator of `p(x)`.
NOTE: A general definition of valuation of a rational function
at any value of `x` can be found in Pg 63 of the thesis, but
is not of any interest for this algorithm.
2. Zeros and Poles of a Rational Function:
Let `a(x) = \frac{S(x)}{T(x)}, T \ne 0` be a rational function
of `x`. Then -
a. The Zeros of `a(x)` are the roots of `S(x)`.
b. The Poles of `a(x)` are the roots of `T(x)`. However, `\infty`
can also be a pole of a(x). We say that `a(x)` has a pole at
`\infty` if `a(\frac{1}{x})` has a pole at 0.
Every pole is associated with an order that is equal to the multiplicity
of its appearance as a root of `T(x)`. A pole is called a simple pole if
it has an order 1. Similarly, a pole is called a multiple pole if it has
an order `\ge` 2.
Necessary Conditions
====================
For a Riccati equation in its normal form,
.. math:: y' + y^2 = a(x)
we can define
a. A pole is called a movable pole if it is a pole of `y(x)` and is not
a pole of `a(x)`.
b. Similarly, a pole is called a non-movable pole if it is a pole of both
`y(x)` and `a(x)`.
Then, the algorithm states that a rational solution exists only if -
a. Every pole of `a(x)` must be either a simple pole or a multiple pole
of even order.
b. The valuation of `a(x)` at `\infty` must be even or be `\ge` 2.
This algorithm finds all possible rational solutions for the Riccati ODE.
If no rational solutions are found, it means that no rational solutions
exist.
The algorithm works for Riccati ODEs where the coefficients are rational
functions in the independent variable `x` with rational number coefficients
i.e. in `Q(x)`. The coefficients in the rational function cannot be floats,
irrational numbers, symbols or any other kind of expression. The reasons
for this are -
1. When using symbols, different symbols could take the same value and this
would affect the multiplicity of poles if symbols are present here.
2. An integer degree bound is required to calculate a polynomial solution
to an auxiliary differential equation, which in turn gives the particular
solution for the original ODE. If symbols/floats/irrational numbers are
present, we cannot determine if the expression for the degree bound is an
integer or not.
Solution
========
With these definitions, we can state a general form for the solution of
the equation. `y(x)` must have the form -
.. math:: y(x) = \sum_{i=1}^{n} \sum_{j=1}^{r_i} \frac{c_{ij}}{(x - x_i)^j} + \sum_{i=1}^{m} \frac{1}{x - \chi_i} + \sum_{i=0}^{N} d_i x^i
where `x_1, x_2, \dots, x_n` are non-movable poles of `a(x)`,
`\chi_1, \chi_2, \dots, \chi_m` are movable poles of `a(x)`, and the values
of `N, n, r_1, r_2, \dots, r_n` can be determined from `a(x)`. The
coefficient vectors `(d_0, d_1, \dots, d_N)` and `(c_{i1}, c_{i2}, \dots, c_{i r_i})`
can be determined from `a(x)`. We will have 2 choices each of these vectors
and part of the procedure is figuring out which of the 2 should be used
to get the solution correctly.
Implementation
==============
In this implementation, we use ``Poly`` to represent a rational function
rather than using ``Expr`` since ``Poly`` is much faster. Since we cannot
represent rational functions directly using ``Poly``, we instead represent
a rational function with 2 ``Poly`` objects - one for its numerator and
the other for its denominator.
The code is written to match the steps given in the thesis (Pg 82)
Step 0 : Match the equation -
Find `b_0, b_1` and `b_2`. If `b_2 = 0` or no such functions exist, raise
an error
Step 1 : Transform the equation to its normal form as explained in the
theory section.
Step 2 : Initialize an empty set of solutions, ``sol``.
Step 3 : If `a(x) = 0`, append `\frac{1}/{(x - C1)}` to ``sol``.
Step 4 : If `a(x)` is a rational non-zero number, append `\pm \sqrt{a}`
to ``sol``.
Step 5 : Find the poles and their multiplicities of `a(x)`. Let
the number of poles be `n`. Also find the valuation of `a(x)` at
`\infty` using ``val_at_inf``.
NOTE: Although the algorithm considers `\infty` as a pole, it is
not mentioned if it a part of the set of finite poles. `\infty`
is NOT a part of the set of finite poles. If a pole exists at
`\infty`, we use its multiplicity to find the laurent series of
`a(x)` about `\infty`.
Step 6 : Find `n` c-vectors (one for each pole) and 1 d-vector using
``construct_c`` and ``construct_d``. Now, determine all the ``2**(n + 1)``
combinations of choosing between 2 choices for each of the `n` c-vectors
and 1 d-vector.
NOTE: The equation for `d_{-1}` in Case 4 (Pg 80) has a printinig
mistake. The term `- d_N` must be replaced with `-N d_N`. The same
has been explained in the code as well.
For each of these above combinations, do
Step 8 : Compute `m` in ``compute_m_ybar``. `m` is the degree bound of
the polynomial solution we must find for the auxiliary equation.
Step 9 : In ``compute_m_ybar``, compute ybar as well where ``ybar`` is
one part of y(x) -
.. math:: \overline{y}(x) = \sum_{i=1}^{n} \sum_{j=1}^{r_i} \frac{c_{ij}}{(x - x_i)^j} + \sum_{i=0}^{N} d_i x^i
Step 10 : If `m` is a non-negative integer -
Step 11: Find a polynomial solution of degree `m` for the auxiliary equation.
There are 2 cases possible -
a. `m` is a non-negative integer: We can solve for the coefficients
in `p(x)` using Undetermined Coefficients.
b. `m` is not a non-negative integer: In this case, we cannot find
a polynomial solution to the auxiliary equation, and hence, we ignore
this value of `m`.
Step 12 : For each `p(x)` that exists, append `ybar + \frac{p'(x)}{p(x)}`
to ``sol``.
Step 13 : For each solution in ``sol``, apply an inverse transformation,
so that the solutions of the original equation are found using the
solutions of the equation in its normal form.
"""
from itertools import product
from sympy.core import S
from sympy.core.add import Add
from sympy.core.numbers import oo, Float
from sympy.core.function import count_ops
from sympy.core.relational import Eq
from sympy.core.symbol import symbols, Symbol, Dummy
from sympy.functions import sqrt, exp
from sympy.functions.elementary.complexes import sign
from sympy.integrals.integrals import Integral
from sympy.polys.domains import ZZ
from sympy.polys.polytools import Poly
from sympy.polys.polyroots import roots
from sympy.solvers.solveset import linsolve
def riccati_normal(w, x, b1, b2):
"""
Given a solution `w(x)` to the equation
.. math:: w'(x) = b_0(x) + b_1(x)*w(x) + b_2(x)*w(x)^2
and rational function coefficients `b_1(x)` and
`b_2(x)`, this function transforms the solution to
give a solution `y(x)` for its corresponding normal
Riccati ODE
.. math:: y'(x) + y(x)^2 = a(x)
using the transformation
.. math:: y(x) = -b_2(x)*w(x) - b'_2(x)/(2*b_2(x)) - b_1(x)/2
"""
return -b2*w - b2.diff(x)/(2*b2) - b1/2
def riccati_inverse_normal(y, x, b1, b2, bp=None):
"""
Inverse transforming the solution to the normal
Riccati ODE to get the solution to the Riccati ODE.
"""
# bp is the expression which is independent of the solution
# and hence, it need not be computed again
if bp is None:
bp = -b2.diff(x)/(2*b2**2) - b1/(2*b2)
# w(x) = -y(x)/b2(x) - b2'(x)/(2*b2(x)^2) - b1(x)/(2*b2(x))
return -y/b2 + bp
def riccati_reduced(eq, f, x):
"""
Convert a Riccati ODE into its corresponding
normal Riccati ODE.
"""
match, funcs = match_riccati(eq, f, x)
# If equation is not a Riccati ODE, exit
if not match:
return False
# Using the rational functions, find the expression for a(x)
b0, b1, b2 = funcs
a = -b0*b2 + b1**2/4 - b1.diff(x)/2 + 3*b2.diff(x)**2/(4*b2**2) + b1*b2.diff(x)/(2*b2) - \
b2.diff(x, 2)/(2*b2)
# Normal form of Riccati ODE is f'(x) + f(x)^2 = a(x)
return f(x).diff(x) + f(x)**2 - a
def linsolve_dict(eq, syms):
"""
Get the output of linsolve as a dict
"""
# Convert tuple type return value of linsolve
# to a dictionary for ease of use
sol = linsolve(eq, syms)
if not sol:
return {}
return dict(zip(syms, list(sol)[0]))
def match_riccati(eq, f, x):
"""
A function that matches and returns the coefficients
if an equation is a Riccati ODE
Parameters
==========
eq: Equation to be matched
f: Dependent variable
x: Independent variable
Returns
=======
match: True if equation is a Riccati ODE, False otherwise
funcs: [b0, b1, b2] if match is True, [] otherwise. Here,
b0, b1 and b2 are rational functions which match the equation.
"""
# Group terms based on f(x)
if isinstance(eq, Eq):
eq = eq.lhs - eq.rhs
eq = eq.expand().collect(f(x))
cf = eq.coeff(f(x).diff(x))
# There must be an f(x).diff(x) term.
# eq must be an Add object since we are using the expanded
# equation and it must have atleast 2 terms (b2 != 0)
if cf != 0 and isinstance(eq, Add):
# Divide all coefficients by the coefficient of f(x).diff(x)
# and add the terms again to get the same equation
eq = Add(*((x/cf).cancel() for x in eq.args)).collect(f(x))
# Match the equation with the pattern
b1 = -eq.coeff(f(x))
b2 = -eq.coeff(f(x)**2)
b0 = (f(x).diff(x) - b1*f(x) - b2*f(x)**2 - eq).expand()
funcs = [b0, b1, b2]
# Check if coefficients are not symbols and floats
if any(len(x.atoms(Symbol)) > 1 or len(x.atoms(Float)) for x in funcs):
return False, []
# If b_0(x) contains f(x), it is not a Riccati ODE
if len(b0.atoms(f)) or not all((b2 != 0, b0.is_rational_function(x),
b1.is_rational_function(x), b2.is_rational_function(x))):
return False, []
return True, funcs
return False, []
def val_at_inf(num, den, x):
# Valuation of a rational function at oo = deg(denom) - deg(numer)
return den.degree(x) - num.degree(x)
def check_necessary_conds(val_inf, muls):
"""
The necessary conditions for a rational solution
to exist are as follows -
i) Every pole of a(x) must be either a simple pole
or a multiple pole of even order.
ii) The valuation of a(x) at infinity must be even
or be greater than or equal to 2.
Here, a simple pole is a pole with multiplicity 1
and a multiple pole is a pole with multiplicity
greater than 1.
"""
return (val_inf >= 2 or (val_inf <= 0 and val_inf%2 == 0)) and \
all(mul == 1 or (mul%2 == 0 and mul >= 2) for mul in muls)
def inverse_transform_poly(num, den, x):
"""
A function to make the substitution
x -> 1/x in a rational function that
is represented using Poly objects for
numerator and denominator.
"""
# Declare for reuse
one = Poly(1, x)
xpoly = Poly(x, x)
# Check if degree of numerator is same as denominator
pwr = val_at_inf(num, den, x)
if pwr >= 0:
# Denominator has greater degree. Substituting x with
# 1/x would make the extra power go to the numerator
if num.expr != 0:
num = num.transform(one, xpoly) * x**pwr
den = den.transform(one, xpoly)
else:
# Numerator has greater degree. Substituting x with
# 1/x would make the extra power go to the denominator
num = num.transform(one, xpoly)
den = den.transform(one, xpoly) * x**(-pwr)
return num.cancel(den, include=True)
def limit_at_inf(num, den, x):
"""
Find the limit of a rational function
at oo
"""
# pwr = degree(num) - degree(den)
pwr = -val_at_inf(num, den, x)
# Numerator has a greater degree than denominator
# Limit at infinity would depend on the sign of the
# leading coefficients of numerator and denominator
if pwr > 0:
return oo*sign(num.LC()/den.LC())
# Degree of numerator is equal to that of denominator
# Limit at infinity is just the ratio of leading coeffs
elif pwr == 0:
return num.LC()/den.LC()
# Degree of numerator is less than that of denominator
# Limit at infinity is just 0
else:
return 0
def construct_c_case_1(num, den, x, pole):
# Find the coefficient of 1/(x - pole)**2 in the
# Laurent series expansion of a(x) about pole.
num1, den1 = (num*Poly((x - pole)**2, x, extension=True)).cancel(den, include=True)
r = (num1.subs(x, pole))/(den1.subs(x, pole))
# If multiplicity is 2, the coefficient to be added
# in the c-vector is c = (1 +- sqrt(1 + 4*r))/2
if r != -S(1)/4:
return [[(1 + sqrt(1 + 4*r))/2], [(1 - sqrt(1 + 4*r))/2]]
return [[S.Half]]
def construct_c_case_2(num, den, x, pole, mul):
# Generate the coefficients using the recurrence
# relation mentioned in (5.14) in the thesis (Pg 80)
# r_i = mul/2
ri = mul//2
# Find the Laurent series coefficients about the pole
ser = rational_laurent_series(num, den, x, pole, mul, 6)
# Start with an empty memo to store the coefficients
# This is for the plus case
cplus = [0 for i in range(ri)]
# Base Case
cplus[ri-1] = sqrt(ser[2*ri])
# Iterate backwards to find all coefficients
s = ri - 1
sm = 0
for s in range(ri-1, 0, -1):
sm = 0
for j in range(s+1, ri):
sm += cplus[j-1]*cplus[ri+s-j-1]
if s!= 1:
cplus[s-1] = (ser[ri+s] - sm)/(2*cplus[ri-1])
# Memo for the minus case
cminus = [-x for x in cplus]
# Find the 0th coefficient in the recurrence
cplus[0] = (ser[ri+s] - sm - ri*cplus[ri-1])/(2*cplus[ri-1])
cminus[0] = (ser[ri+s] - sm - ri*cminus[ri-1])/(2*cminus[ri-1])
# Add both the plus and minus cases' coefficients
if cplus != cminus:
return [cplus, cminus]
return cplus
def construct_c_case_3():
# If multiplicity is 1, the coefficient to be added
# in the c-vector is 1 (no choice)
return [[1]]
def construct_c(num, den, x, poles, muls):
"""
Helper function to calculate the coefficients
in the c-vector for each pole.
"""
c = []
for pole, mul in zip(poles, muls):
c.append([])
# Case 3
if mul == 1:
# Add the coefficients from Case 3
c[-1].extend(construct_c_case_3())
# Case 1
elif mul == 2:
# Add the coefficients from Case 1
c[-1].extend(construct_c_case_1(num, den, x, pole))
# Case 2
else:
# Add the coefficients from Case 2
c[-1].extend(construct_c_case_2(num, den, x, pole, mul))
return c
def construct_d_case_4(ser, N):
# Initialize an empty vector
dplus = [0 for i in range(N+2)]
# d_N = sqrt(a_{2*N})
dplus[N] = sqrt(ser[2*N])
# Use the recurrence relations to find
# the value of d_s
for s in range(N-1, -2, -1):
sm = 0
for j in range(s+1, N):
sm += dplus[j]*dplus[N+s-j]
if s != -1:
dplus[s] = (ser[N+s] - sm)/(2*dplus[N])
# Coefficients for the case of d_N = -sqrt(a_{2*N})
dminus = [-x for x in dplus]
# The third equation in Eq 5.15 of the thesis is WRONG!
# d_N must be replaced with N*d_N in that equation.
dplus[-1] = (ser[N+s] - N*dplus[N] - sm)/(2*dplus[N])
dminus[-1] = (ser[N+s] - N*dminus[N] - sm)/(2*dminus[N])
if dplus != dminus:
return [dplus, dminus]
return dplus
def construct_d_case_5(ser):
# List to store coefficients for plus case
dplus = [0, 0]
# d_0 = sqrt(a_0)
dplus[0] = sqrt(ser[0])
# d_(-1) = a_(-1)/(2*d_0)
dplus[-1] = ser[-1]/(2*dplus[0])
# Coefficients for the minus case are just the negative
# of the coefficients for the positive case.
dminus = [-x for x in dplus]
if dplus != dminus:
return [dplus, dminus]
return dplus
def construct_d_case_6(num, den, x):
# s_oo = lim x->0 1/x**2 * a(1/x) which is equivalent to
# s_oo = lim x->oo x**2 * a(x)
s_inf = limit_at_inf(Poly(x**2, x)*num, den, x)
# d_(-1) = (1 +- sqrt(1 + 4*s_oo))/2
if s_inf != -S(1)/4:
return [[(1 + sqrt(1 + 4*s_inf))/2], [(1 - sqrt(1 + 4*s_inf))/2]]
return [[S.Half]]
def construct_d(num, den, x, val_inf):
"""
Helper function to calculate the coefficients
in the d-vector based on the valuation of the
function at oo.
"""
N = -val_inf//2
# Multiplicity of oo as a pole
mul = -val_inf if val_inf < 0 else 0
ser = rational_laurent_series(num, den, x, oo, mul, 1)
# Case 4
if val_inf < 0:
d = construct_d_case_4(ser, N)
# Case 5
elif val_inf == 0:
d = construct_d_case_5(ser)
# Case 6
else:
d = construct_d_case_6(num, den, x)
return d
def rational_laurent_series(num, den, x, r, m, n):
r"""
The function computes the Laurent series coefficients
of a rational function.
Parameters
==========
num: A Poly object that is the numerator of `f(x)`.
den: A Poly object that is the denominator of `f(x)`.
x: The variable of expansion of the series.
r: The point of expansion of the series.
m: Multiplicity of r if r is a pole of `f(x)`. Should
be zero otherwise.
n: Order of the term upto which the series is expanded.
Returns
=======
series: A dictionary that has power of the term as key
and coefficient of that term as value.
Below is a basic outline of how the Laurent series of a
rational function `f(x)` about `x_0` is being calculated -
1. Substitute `x + x_0` in place of `x`. If `x_0`
is a pole of `f(x)`, multiply the expression by `x^m`
where `m` is the multiplicity of `x_0`. Denote the
the resulting expression as g(x). We do this substitution
so that we can now find the Laurent series of g(x) about
`x = 0`.
2. We can then assume that the Laurent series of `g(x)`
takes the following form -
.. math:: g(x) = \frac{num(x)}{den(x)} = \sum_{m = 0}^{\infty} a_m x^m
where `a_m` denotes the Laurent series coefficients.
3. Multiply the denominator to the RHS of the equation
and form a recurrence relation for the coefficients `a_m`.
"""
one = Poly(1, x, extension=True)
if r == oo:
# Series at x = oo is equal to first transforming
# the function from x -> 1/x and finding the
# series at x = 0
num, den = inverse_transform_poly(num, den, x)
r = S(0)
if r:
# For an expansion about a non-zero point, a
# transformation from x -> x + r must be made
num = num.transform(Poly(x + r, x, extension=True), one)
den = den.transform(Poly(x + r, x, extension=True), one)
# Remove the pole from the denominator if the series
# expansion is about one of the poles
num, den = (num*x**m).cancel(den, include=True)
# Equate coefficients for the first terms (base case)
maxdegree = 1 + max(num.degree(), den.degree())
syms = symbols(f'a:{maxdegree}', cls=Dummy)
diff = num - den * Poly(syms[::-1], x)
coeff_diffs = diff.all_coeffs()[::-1][:maxdegree]
(coeffs, ) = linsolve(coeff_diffs, syms)
# Use the recursion relation for the rest
recursion = den.all_coeffs()[::-1]
div, rec_rhs = recursion[0], recursion[1:]
series = list(coeffs)
while len(series) < n:
next_coeff = Add(*(c*series[-1-n] for n, c in enumerate(rec_rhs))) / div
series.append(-next_coeff)
series = {m - i: val for i, val in enumerate(series)}
return series
def compute_m_ybar(x, poles, choice, N):
"""
Helper function to calculate -
1. m - The degree bound for the polynomial
solution that must be found for the auxiliary
differential equation.
2. ybar - Part of the solution which can be
computed using the poles, c and d vectors.
"""
ybar = 0
m = Poly(choice[-1][-1], x, extension=True)
# Calculate the first (nested) summation for ybar
# as given in Step 9 of the Thesis (Pg 82)
dybar = []
for i, polei in enumerate(poles):
for j, cij in enumerate(choice[i]):
dybar.append(cij/(x - polei)**(j + 1))
m -=Poly(choice[i][0], x, extension=True) # can't accumulate Poly and use with Add
ybar += Add(*dybar)
# Calculate the second summation for ybar
for i in range(N+1):
ybar += choice[-1][i]*x**i
return (m.expr, ybar)
def solve_aux_eq(numa, dena, numy, deny, x, m):
"""
Helper function to find a polynomial solution
of degree m for the auxiliary differential
equation.
"""
# Assume that the solution is of the type
# p(x) = C_0 + C_1*x + ... + C_{m-1}*x**(m-1) + x**m
psyms = symbols(f'C0:{m}', cls=Dummy)
K = ZZ[psyms]
psol = Poly(K.gens, x, domain=K) + Poly(x**m, x, domain=K)
# Eq (5.16) in Thesis - Pg 81
auxeq = (dena*(numy.diff(x)*deny - numy*deny.diff(x) + numy**2) - numa*deny**2)*psol
if m >= 1:
px = psol.diff(x)
auxeq += px*(2*numy*deny*dena)
if m >= 2:
auxeq += px.diff(x)*(deny**2*dena)
if m != 0:
# m is a non-zero integer. Find the constant terms using undetermined coefficients
return psol, linsolve_dict(auxeq.all_coeffs(), psyms), True
else:
# m == 0 . Check if 1 (x**0) is a solution to the auxiliary equation
return S.One, auxeq, auxeq == 0
def remove_redundant_sols(sol1, sol2, x):
"""
Helper function to remove redundant
solutions to the differential equation.
"""
# If y1 and y2 are redundant solutions, there is
# some value of the arbitrary constant for which
# they will be equal
syms1 = sol1.atoms(Symbol, Dummy)
syms2 = sol2.atoms(Symbol, Dummy)
num1, den1 = [Poly(e, x, extension=True) for e in sol1.together().as_numer_denom()]
num2, den2 = [Poly(e, x, extension=True) for e in sol2.together().as_numer_denom()]
# Cross multiply
e = num1*den2 - den1*num2
# Check if there are any constants
syms = list(e.atoms(Symbol, Dummy))
if len(syms):
# Find values of constants for which solutions are equal
redn = linsolve(e.all_coeffs(), syms)
if len(redn):
# Return the general solution over a particular solution
if len(syms1) > len(syms2):
return sol2
# If both have constants, return the lesser complex solution
elif len(syms1) == len(syms2):
return sol1 if count_ops(syms1) >= count_ops(syms2) else sol2
else:
return sol1
def get_gen_sol_from_part_sol(part_sols, a, x):
""""
Helper function which computes the general
solution for a Riccati ODE from its particular
solutions.
There are 3 cases to find the general solution
from the particular solutions for a Riccati ODE
depending on the number of particular solution(s)
we have - 1, 2 or 3.
For more information, see Section 6 of
"Methods of Solution of the Riccati Differential Equation"
by D. R. Haaheim and F. M. Stein
"""
# If no particular solutions are found, a general
# solution cannot be found
if len(part_sols) == 0:
return []
# In case of a single particular solution, the general
# solution can be found by using the substitution
# y = y1 + 1/z and solving a Bernoulli ODE to find z.
elif len(part_sols) == 1:
y1 = part_sols[0]
i = exp(Integral(2*y1, x))
z = i * Integral(a/i, x)
z = z.doit()
if a == 0 or z == 0:
return y1
return y1 + 1/z
# In case of 2 particular solutions, the general solution
# can be found by solving a separable equation. This is
# the most common case, i.e. most Riccati ODEs have 2
# rational particular solutions.
elif len(part_sols) == 2:
y1, y2 = part_sols
# One of them already has a constant
if len(y1.atoms(Dummy)) + len(y2.atoms(Dummy)) > 0:
u = exp(Integral(y2 - y1, x)).doit()
# Introduce a constant
else:
C1 = Dummy('C1')
u = C1*exp(Integral(y2 - y1, x)).doit()
if u == 1:
return y2
return (y2*u - y1)/(u - 1)
# In case of 3 particular solutions, a closed form
# of the general solution can be obtained directly
else:
y1, y2, y3 = part_sols[:3]
C1 = Dummy('C1')
return (C1 + 1)*y2*(y1 - y3)/(C1*y1 + y2 - (C1 + 1)*y3)
def solve_riccati(fx, x, b0, b1, b2, gensol=False):
"""
The main function that gives particular/general
solutions to Riccati ODEs that have atleast 1
rational particular solution.
"""
# Step 1 : Convert to Normal Form
a = -b0*b2 + b1**2/4 - b1.diff(x)/2 + 3*b2.diff(x)**2/(4*b2**2) + b1*b2.diff(x)/(2*b2) - \
b2.diff(x, 2)/(2*b2)
a_t = a.together()
num, den = [Poly(e, x, extension=True) for e in a_t.as_numer_denom()]
num, den = num.cancel(den, include=True)
# Step 2
presol = []
# Step 3 : a(x) is 0
if num == 0:
presol.append(1/(x + Dummy('C1')))
# Step 4 : a(x) is a non-zero constant
elif x not in num.free_symbols.union(den.free_symbols):
presol.extend([sqrt(a), -sqrt(a)])
# Step 5 : Find poles and valuation at infinity
poles = roots(den, x)
poles, muls = list(poles.keys()), list(poles.values())
val_inf = val_at_inf(num, den, x)
if len(poles):
# Check necessary conditions (outlined in the module docstring)
if not check_necessary_conds(val_inf, muls):
raise ValueError("Rational Solution doesn't exist")
# Step 6
# Construct c-vectors for each singular point
c = construct_c(num, den, x, poles, muls)
# Construct d vectors for each singular point
d = construct_d(num, den, x, val_inf)
# Step 7 : Iterate over all possible combinations and return solutions
# For each possible combination, generate an array of 0's and 1's
# where 0 means pick 1st choice and 1 means pick the second choice.
# NOTE: We could exit from the loop if we find 3 particular solutions,
# but it is not implemented here as -
# a. Finding 3 particular solutions is very rare. Most of the time,
# only 2 particular solutions are found.
# b. In case we exit after finding 3 particular solutions, it might
# happen that 1 or 2 of them are redundant solutions. So, instead of
# spending some more time in computing the particular solutions,
# we will end up computing the general solution from a single
# particular solution which is usually slower than computing the
# general solution from 2 or 3 particular solutions.
c.append(d)
choices = product(*c)
for choice in choices:
m, ybar = compute_m_ybar(x, poles, choice, -val_inf//2)
numy, deny = [Poly(e, x, extension=True) for e in ybar.together().as_numer_denom()]
# Step 10 : Check if a valid solution exists. If yes, also check
# if m is a non-negative integer
if m.is_nonnegative == True and m.is_integer == True:
# Step 11 : Find polynomial solutions of degree m for the auxiliary equation
psol, coeffs, exists = solve_aux_eq(num, den, numy, deny, x, m)
# Step 12 : If valid polynomial solution exists, append solution.
if exists:
# m == 0 case
if psol == 1 and coeffs == 0:
# p(x) = 1, so p'(x)/p(x) term need not be added
presol.append(ybar)
# m is a positive integer and there are valid coefficients
elif len(coeffs):
# Substitute the valid coefficients to get p(x)
psol = psol.xreplace(coeffs)
# y(x) = ybar(x) + p'(x)/p(x)
presol.append(ybar + psol.diff(x)/psol)
# Remove redundant solutions from the list of existing solutions
remove = set()
for i in range(len(presol)):
for j in range(i+1, len(presol)):
rem = remove_redundant_sols(presol[i], presol[j], x)
if rem is not None:
remove.add(rem)
sols = [x for x in presol if x not in remove]
# Step 15 : Inverse transform the solutions of the equation in normal form
bp = -b2.diff(x)/(2*b2**2) - b1/(2*b2)
# If general solution is required, compute it from the particular solutions
if gensol:
sols = [get_gen_sol_from_part_sol(sols, a, x)]
# Inverse transform the particular solutions
presol = [Eq(fx, riccati_inverse_normal(y, x, b1, b2, bp).cancel(extension=True)) for y in sols]
return presol
File diff suppressed because it is too large Load Diff
@@ -0,0 +1,392 @@
from sympy.core import S, Pow
from sympy.core.function import (Derivative, AppliedUndef, diff)
from sympy.core.relational import Equality, Eq
from sympy.core.symbol import Dummy
from sympy.core.sympify import sympify
from sympy.logic.boolalg import BooleanAtom
from sympy.functions import exp
from sympy.series import Order
from sympy.simplify.simplify import simplify, posify, besselsimp
from sympy.simplify.trigsimp import trigsimp
from sympy.simplify.sqrtdenest import sqrtdenest
from sympy.solvers import solve
from sympy.solvers.deutils import _preprocess, ode_order
from sympy.utilities.iterables import iterable, is_sequence
def sub_func_doit(eq, func, new):
r"""
When replacing the func with something else, we usually want the
derivative evaluated, so this function helps in making that happen.
Examples
========
>>> from sympy import Derivative, symbols, Function
>>> from sympy.solvers.ode.subscheck import sub_func_doit
>>> x, z = symbols('x, z')
>>> y = Function('y')
>>> sub_func_doit(3*Derivative(y(x), x) - 1, y(x), x)
2
>>> sub_func_doit(x*Derivative(y(x), x) - y(x)**2 + y(x), y(x),
... 1/(x*(z + 1/x)))
x*(-1/(x**2*(z + 1/x)) + 1/(x**3*(z + 1/x)**2)) + 1/(x*(z + 1/x))
...- 1/(x**2*(z + 1/x)**2)
"""
reps= {func: new}
for d in eq.atoms(Derivative):
if d.expr == func:
reps[d] = new.diff(*d.variable_count)
else:
reps[d] = d.xreplace({func: new}).doit(deep=False)
return eq.xreplace(reps)
def checkodesol(ode, sol, func=None, order='auto', solve_for_func=True):
r"""
Substitutes ``sol`` into ``ode`` and checks that the result is ``0``.
This works when ``func`` is one function, like `f(x)` or a list of
functions like `[f(x), g(x)]` when `ode` is a system of ODEs. ``sol`` can
be a single solution or a list of solutions. Each solution may be an
:py:class:`~sympy.core.relational.Equality` that the solution satisfies,
e.g. ``Eq(f(x), C1), Eq(f(x) + C1, 0)``; or simply an
:py:class:`~sympy.core.expr.Expr`, e.g. ``f(x) - C1``. In most cases it
will not be necessary to explicitly identify the function, but if the
function cannot be inferred from the original equation it can be supplied
through the ``func`` argument.
If a sequence of solutions is passed, the same sort of container will be
used to return the result for each solution.
It tries the following methods, in order, until it finds zero equivalence:
1. Substitute the solution for `f` in the original equation. This only
works if ``ode`` is solved for `f`. It will attempt to solve it first
unless ``solve_for_func == False``.
2. Take `n` derivatives of the solution, where `n` is the order of
``ode``, and check to see if that is equal to the solution. This only
works on exact ODEs.
3. Take the 1st, 2nd, ..., `n`\th derivatives of the solution, each time
solving for the derivative of `f` of that order (this will always be
possible because `f` is a linear operator). Then back substitute each
derivative into ``ode`` in reverse order.
This function returns a tuple. The first item in the tuple is ``True`` if
the substitution results in ``0``, and ``False`` otherwise. The second
item in the tuple is what the substitution results in. It should always
be ``0`` if the first item is ``True``. Sometimes this function will
return ``False`` even when an expression is identically equal to ``0``.
This happens when :py:meth:`~sympy.simplify.simplify.simplify` does not
reduce the expression to ``0``. If an expression returned by this
function vanishes identically, then ``sol`` really is a solution to
the ``ode``.
If this function seems to hang, it is probably because of a hard
simplification.
To use this function to test, test the first item of the tuple.
Examples
========
>>> from sympy import (Eq, Function, checkodesol, symbols,
... Derivative, exp)
>>> x, C1, C2 = symbols('x,C1,C2')
>>> f, g = symbols('f g', cls=Function)
>>> checkodesol(f(x).diff(x), Eq(f(x), C1))
(True, 0)
>>> assert checkodesol(f(x).diff(x), C1)[0]
>>> assert not checkodesol(f(x).diff(x), x)[0]
>>> checkodesol(f(x).diff(x, 2), x**2)
(False, 2)
>>> eqs = [Eq(Derivative(f(x), x), f(x)), Eq(Derivative(g(x), x), g(x))]
>>> sol = [Eq(f(x), C1*exp(x)), Eq(g(x), C2*exp(x))]
>>> checkodesol(eqs, sol)
(True, [0, 0])
"""
if iterable(ode):
return checksysodesol(ode, sol, func=func)
if not isinstance(ode, Equality):
ode = Eq(ode, 0)
if func is None:
try:
_, func = _preprocess(ode.lhs)
except ValueError:
funcs = [s.atoms(AppliedUndef) for s in (
sol if is_sequence(sol, set) else [sol])]
funcs = set().union(*funcs)
if len(funcs) != 1:
raise ValueError(
'must pass func arg to checkodesol for this case.')
func = funcs.pop()
if not isinstance(func, AppliedUndef) or len(func.args) != 1:
raise ValueError(
"func must be a function of one variable, not %s" % func)
if is_sequence(sol, set):
return type(sol)([checkodesol(ode, i, order=order, solve_for_func=solve_for_func) for i in sol])
if not isinstance(sol, Equality):
sol = Eq(func, sol)
elif sol.rhs == func:
sol = sol.reversed
if order == 'auto':
order = ode_order(ode, func)
solved = sol.lhs == func and not sol.rhs.has(func)
if solve_for_func and not solved:
rhs = solve(sol, func)
if rhs:
eqs = [Eq(func, t) for t in rhs]
if len(rhs) == 1:
eqs = eqs[0]
return checkodesol(ode, eqs, order=order,
solve_for_func=False)
x = func.args[0]
# Handle series solutions here
if sol.has(Order):
assert sol.lhs == func
Oterm = sol.rhs.getO()
solrhs = sol.rhs.removeO()
Oexpr = Oterm.expr
assert isinstance(Oexpr, Pow)
sorder = Oexpr.exp
assert Oterm == Order(x**sorder)
odesubs = (ode.lhs-ode.rhs).subs(func, solrhs).doit().expand()
neworder = Order(x**(sorder - order))
odesubs = odesubs + neworder
assert odesubs.getO() == neworder
residual = odesubs.removeO()
return (residual == 0, residual)
s = True
testnum = 0
while s:
if testnum == 0:
# First pass, try substituting a solved solution directly into the
# ODE. This has the highest chance of succeeding.
ode_diff = ode.lhs - ode.rhs
if sol.lhs == func:
s = sub_func_doit(ode_diff, func, sol.rhs)
s = besselsimp(s)
else:
testnum += 1
continue
ss = simplify(s.rewrite(exp))
if ss:
# with the new numer_denom in power.py, if we do a simple
# expansion then testnum == 0 verifies all solutions.
s = ss.expand(force=True)
else:
s = 0
testnum += 1
elif testnum == 1:
# Second pass. If we cannot substitute f, try seeing if the nth
# derivative is equal, this will only work for odes that are exact,
# by definition.
s = simplify(
trigsimp(diff(sol.lhs, x, order) - diff(sol.rhs, x, order)) -
trigsimp(ode.lhs) + trigsimp(ode.rhs))
# s2 = simplify(
# diff(sol.lhs, x, order) - diff(sol.rhs, x, order) - \
# ode.lhs + ode.rhs)
testnum += 1
elif testnum == 2:
# Third pass. Try solving for df/dx and substituting that into the
# ODE. Thanks to Chris Smith for suggesting this method. Many of
# the comments below are his, too.
# The method:
# - Take each of 1..n derivatives of the solution.
# - Solve each nth derivative for d^(n)f/dx^(n)
# (the differential of that order)
# - Back substitute into the ODE in decreasing order
# (i.e., n, n-1, ...)
# - Check the result for zero equivalence
if sol.lhs == func and not sol.rhs.has(func):
diffsols = {0: sol.rhs}
elif sol.rhs == func and not sol.lhs.has(func):
diffsols = {0: sol.lhs}
else:
diffsols = {}
sol = sol.lhs - sol.rhs
for i in range(1, order + 1):
# Differentiation is a linear operator, so there should always
# be 1 solution. Nonetheless, we test just to make sure.
# We only need to solve once. After that, we automatically
# have the solution to the differential in the order we want.
if i == 1:
ds = sol.diff(x)
try:
sdf = solve(ds, func.diff(x, i))
if not sdf:
raise NotImplementedError
except NotImplementedError:
testnum += 1
break
else:
diffsols[i] = sdf[0]
else:
# This is what the solution says df/dx should be.
diffsols[i] = diffsols[i - 1].diff(x)
# Make sure the above didn't fail.
if testnum > 2:
continue
else:
# Substitute it into ODE to check for self consistency.
lhs, rhs = ode.lhs, ode.rhs
for i in range(order, -1, -1):
if i == 0 and 0 not in diffsols:
# We can only substitute f(x) if the solution was
# solved for f(x).
break
lhs = sub_func_doit(lhs, func.diff(x, i), diffsols[i])
rhs = sub_func_doit(rhs, func.diff(x, i), diffsols[i])
ode_or_bool = Eq(lhs, rhs)
ode_or_bool = simplify(ode_or_bool)
if isinstance(ode_or_bool, (bool, BooleanAtom)):
if ode_or_bool:
lhs = rhs = S.Zero
else:
lhs = ode_or_bool.lhs
rhs = ode_or_bool.rhs
# No sense in overworking simplify -- just prove that the
# numerator goes to zero
num = trigsimp((lhs - rhs).as_numer_denom()[0])
# since solutions are obtained using force=True we test
# using the same level of assumptions
## replace function with dummy so assumptions will work
_func = Dummy('func')
num = num.subs(func, _func)
## posify the expression
num, reps = posify(num)
s = simplify(num).xreplace(reps).xreplace({_func: func})
testnum += 1
else:
break
if not s:
return (True, s)
elif s is True: # The code above never was able to change s
raise NotImplementedError("Unable to test if " + str(sol) +
" is a solution to " + str(ode) + ".")
else:
return (False, s)
def checksysodesol(eqs, sols, func=None):
r"""
Substitutes corresponding ``sols`` for each functions into each ``eqs`` and
checks that the result of substitutions for each equation is ``0``. The
equations and solutions passed can be any iterable.
This only works when each ``sols`` have one function only, like `x(t)` or `y(t)`.
For each function, ``sols`` can have a single solution or a list of solutions.
In most cases it will not be necessary to explicitly identify the function,
but if the function cannot be inferred from the original equation it
can be supplied through the ``func`` argument.
When a sequence of equations is passed, the same sequence is used to return
the result for each equation with each function substituted with corresponding
solutions.
It tries the following method to find zero equivalence for each equation:
Substitute the solutions for functions, like `x(t)` and `y(t)` into the
original equations containing those functions.
This function returns a tuple. The first item in the tuple is ``True`` if
the substitution results for each equation is ``0``, and ``False`` otherwise.
The second item in the tuple is what the substitution results in. Each element
of the ``list`` should always be ``0`` corresponding to each equation if the
first item is ``True``. Note that sometimes this function may return ``False``,
but with an expression that is identically equal to ``0``, instead of returning
``True``. This is because :py:meth:`~sympy.simplify.simplify.simplify` cannot
reduce the expression to ``0``. If an expression returned by each function
vanishes identically, then ``sols`` really is a solution to ``eqs``.
If this function seems to hang, it is probably because of a difficult simplification.
Examples
========
>>> from sympy import Eq, diff, symbols, sin, cos, exp, sqrt, S, Function
>>> from sympy.solvers.ode.subscheck import checksysodesol
>>> C1, C2 = symbols('C1:3')
>>> t = symbols('t')
>>> x, y = symbols('x, y', cls=Function)
>>> eq = (Eq(diff(x(t),t), x(t) + y(t) + 17), Eq(diff(y(t),t), -2*x(t) + y(t) + 12))
>>> sol = [Eq(x(t), (C1*sin(sqrt(2)*t) + C2*cos(sqrt(2)*t))*exp(t) - S(5)/3),
... Eq(y(t), (sqrt(2)*C1*cos(sqrt(2)*t) - sqrt(2)*C2*sin(sqrt(2)*t))*exp(t) - S(46)/3)]
>>> checksysodesol(eq, sol)
(True, [0, 0])
>>> eq = (Eq(diff(x(t),t),x(t)*y(t)**4), Eq(diff(y(t),t),y(t)**3))
>>> sol = [Eq(x(t), C1*exp(-1/(4*(C2 + t)))), Eq(y(t), -sqrt(2)*sqrt(-1/(C2 + t))/2),
... Eq(x(t), C1*exp(-1/(4*(C2 + t)))), Eq(y(t), sqrt(2)*sqrt(-1/(C2 + t))/2)]
>>> checksysodesol(eq, sol)
(True, [0, 0])
"""
def _sympify(eq):
return list(map(sympify, eq if iterable(eq) else [eq]))
eqs = _sympify(eqs)
for i in range(len(eqs)):
if isinstance(eqs[i], Equality):
eqs[i] = eqs[i].lhs - eqs[i].rhs
if func is None:
funcs = []
for eq in eqs:
derivs = eq.atoms(Derivative)
func = set().union(*[d.atoms(AppliedUndef) for d in derivs])
funcs.extend(func)
funcs = list(set(funcs))
if not all(isinstance(func, AppliedUndef) and len(func.args) == 1 for func in funcs)\
and len({func.args for func in funcs})!=1:
raise ValueError("func must be a function of one variable, not %s" % func)
for sol in sols:
if len(sol.atoms(AppliedUndef)) != 1:
raise ValueError("solutions should have one function only")
if len(funcs) != len({sol.lhs for sol in sols}):
raise ValueError("number of solutions provided does not match the number of equations")
dictsol = {}
for sol in sols:
func = list(sol.atoms(AppliedUndef))[0]
if sol.rhs == func:
sol = sol.reversed
solved = sol.lhs == func and not sol.rhs.has(func)
if not solved:
rhs = solve(sol, func)
if not rhs:
raise NotImplementedError
else:
rhs = sol.rhs
dictsol[func] = rhs
checkeq = []
for eq in eqs:
for func in funcs:
eq = sub_func_doit(eq, func, dictsol[func])
ss = simplify(eq)
if ss != 0:
eq = ss.expand(force=True)
if eq != 0:
eq = sqrtdenest(eq).simplify()
else:
eq = 0
checkeq.append(eq)
if len(set(checkeq)) == 1 and list(set(checkeq))[0] == 0:
return (True, checkeq)
else:
return (False, checkeq)
File diff suppressed because it is too large Load Diff
@@ -0,0 +1,152 @@
from sympy.core.function import Function
from sympy.core.numbers import Rational
from sympy.core.relational import Eq
from sympy.core.symbol import (Symbol, symbols)
from sympy.functions.elementary.exponential import (exp, log)
from sympy.functions.elementary.miscellaneous import sqrt
from sympy.functions.elementary.trigonometric import (atan, sin, tan)
from sympy.solvers.ode import (classify_ode, checkinfsol, dsolve, infinitesimals)
from sympy.solvers.ode.subscheck import checkodesol
from sympy.testing.pytest import XFAIL
C1 = Symbol('C1')
x, y = symbols("x y")
f = Function('f')
xi = Function('xi')
eta = Function('eta')
def test_heuristic1():
a, b, c, a4, a3, a2, a1, a0 = symbols("a b c a4 a3 a2 a1 a0")
df = f(x).diff(x)
eq = Eq(df, x**2*f(x))
eq1 = f(x).diff(x) + a*f(x) - c*exp(b*x)
eq2 = f(x).diff(x) + 2*x*f(x) - x*exp(-x**2)
eq3 = (1 + 2*x)*df + 2 - 4*exp(-f(x))
eq4 = f(x).diff(x) - (a4*x**4 + a3*x**3 + a2*x**2 + a1*x + a0)**Rational(-1, 2)
eq5 = x**2*df - f(x) + x**2*exp(x - (1/x))
eqlist = [eq, eq1, eq2, eq3, eq4, eq5]
i = infinitesimals(eq, hint='abaco1_simple')
assert i == [{eta(x, f(x)): exp(x**3/3), xi(x, f(x)): 0},
{eta(x, f(x)): f(x), xi(x, f(x)): 0},
{eta(x, f(x)): 0, xi(x, f(x)): x**(-2)}]
i1 = infinitesimals(eq1, hint='abaco1_simple')
assert i1 == [{eta(x, f(x)): exp(-a*x), xi(x, f(x)): 0}]
i2 = infinitesimals(eq2, hint='abaco1_simple')
assert i2 == [{eta(x, f(x)): exp(-x**2), xi(x, f(x)): 0}]
i3 = infinitesimals(eq3, hint='abaco1_simple')
assert i3 == [{eta(x, f(x)): 0, xi(x, f(x)): 2*x + 1},
{eta(x, f(x)): 0, xi(x, f(x)): 1/(exp(f(x)) - 2)}]
i4 = infinitesimals(eq4, hint='abaco1_simple')
assert i4 == [{eta(x, f(x)): 1, xi(x, f(x)): 0},
{eta(x, f(x)): 0,
xi(x, f(x)): sqrt(a0 + a1*x + a2*x**2 + a3*x**3 + a4*x**4)}]
i5 = infinitesimals(eq5, hint='abaco1_simple')
assert i5 == [{xi(x, f(x)): 0, eta(x, f(x)): exp(-1/x)}]
ilist = [i, i1, i2, i3, i4, i5]
for eq, i in (zip(eqlist, ilist)):
check = checkinfsol(eq, i)
assert check[0]
# This ODE can be solved by the Lie Group method, when there are
# better assumptions
eq6 = df - (f(x)/x)*(x*log(x**2/f(x)) + 2)
i = infinitesimals(eq6, hint='abaco1_product')
assert i == [{eta(x, f(x)): f(x)*exp(-x), xi(x, f(x)): 0}]
assert checkinfsol(eq6, i)[0]
eq7 = x*(f(x).diff(x)) + 1 - f(x)**2
i = infinitesimals(eq7, hint='chi')
assert checkinfsol(eq7, i)[0]
def test_heuristic3():
a, b = symbols("a b")
df = f(x).diff(x)
eq = x**2*df + x*f(x) + f(x)**2 + x**2
i = infinitesimals(eq, hint='bivariate')
assert i == [{eta(x, f(x)): f(x), xi(x, f(x)): x}]
assert checkinfsol(eq, i)[0]
eq = x**2*(-f(x)**2 + df)- a*x**2*f(x) + 2 - a*x
i = infinitesimals(eq, hint='bivariate')
assert checkinfsol(eq, i)[0]
def test_heuristic_function_sum():
eq = f(x).diff(x) - (3*(1 + x**2/f(x)**2)*atan(f(x)/x) + (1 - 2*f(x))/x +
(1 - 3*f(x))*(x/f(x)**2))
i = infinitesimals(eq, hint='function_sum')
assert i == [{eta(x, f(x)): f(x)**(-2) + x**(-2), xi(x, f(x)): 0}]
assert checkinfsol(eq, i)[0]
def test_heuristic_abaco2_similar():
a, b = symbols("a b")
F = Function('F')
eq = f(x).diff(x) - F(a*x + b*f(x))
i = infinitesimals(eq, hint='abaco2_similar')
assert i == [{eta(x, f(x)): -a/b, xi(x, f(x)): 1}]
assert checkinfsol(eq, i)[0]
eq = f(x).diff(x) - (f(x)**2 / (sin(f(x) - x) - x**2 + 2*x*f(x)))
i = infinitesimals(eq, hint='abaco2_similar')
assert i == [{eta(x, f(x)): f(x)**2, xi(x, f(x)): f(x)**2}]
assert checkinfsol(eq, i)[0]
def test_heuristic_abaco2_unique_unknown():
a, b = symbols("a b")
F = Function('F')
eq = f(x).diff(x) - x**(a - 1)*(f(x)**(1 - b))*F(x**a/a + f(x)**b/b)
i = infinitesimals(eq, hint='abaco2_unique_unknown')
assert i == [{eta(x, f(x)): -f(x)*f(x)**(-b), xi(x, f(x)): x*x**(-a)}]
assert checkinfsol(eq, i)[0]
eq = f(x).diff(x) + tan(F(x**2 + f(x)**2) + atan(x/f(x)))
i = infinitesimals(eq, hint='abaco2_unique_unknown')
assert i == [{eta(x, f(x)): x, xi(x, f(x)): -f(x)}]
assert checkinfsol(eq, i)[0]
eq = (x*f(x).diff(x) + f(x) + 2*x)**2 -4*x*f(x) -4*x**2 -4*a
i = infinitesimals(eq, hint='abaco2_unique_unknown')
assert checkinfsol(eq, i)[0]
def test_heuristic_linear():
a, b, m, n = symbols("a b m n")
eq = x**(n*(m + 1) - m)*(f(x).diff(x)) - a*f(x)**n -b*x**(n*(m + 1))
i = infinitesimals(eq, hint='linear')
assert checkinfsol(eq, i)[0]
@XFAIL
def test_kamke():
a, b, alpha, c = symbols("a b alpha c")
eq = x**2*(a*f(x)**2+(f(x).diff(x))) + b*x**alpha + c
i = infinitesimals(eq, hint='sum_function') # XFAIL
assert checkinfsol(eq, i)[0]
def test_user_infinitesimals():
x = Symbol("x") # assuming x is real generates an error
eq = x*(f(x).diff(x)) + 1 - f(x)**2
sol = Eq(f(x), (C1 + x**2)/(C1 - x**2))
infinitesimals = {'xi':sqrt(f(x) - 1)/sqrt(f(x) + 1), 'eta':0}
assert dsolve(eq, hint='lie_group', **infinitesimals) == sol
assert checkodesol(eq, sol) == (True, 0)
@XFAIL
def test_lie_group_issue15219():
eqn = exp(f(x).diff(x)-f(x))
assert 'lie_group' not in classify_ode(eqn, f(x))
File diff suppressed because it is too large Load Diff
@@ -0,0 +1,877 @@
from sympy.core.random import randint
from sympy.core.function import Function
from sympy.core.mul import Mul
from sympy.core.numbers import (I, Rational, oo)
from sympy.core.relational import Eq
from sympy.core.singleton import S
from sympy.core.symbol import (Dummy, symbols)
from sympy.functions.elementary.exponential import (exp, log)
from sympy.functions.elementary.hyperbolic import tanh
from sympy.functions.elementary.miscellaneous import sqrt
from sympy.functions.elementary.trigonometric import sin
from sympy.polys.polytools import Poly
from sympy.simplify.ratsimp import ratsimp
from sympy.solvers.ode.subscheck import checkodesol
from sympy.testing.pytest import slow
from sympy.solvers.ode.riccati import (riccati_normal, riccati_inverse_normal,
riccati_reduced, match_riccati, inverse_transform_poly, limit_at_inf,
check_necessary_conds, val_at_inf, construct_c_case_1,
construct_c_case_2, construct_c_case_3, construct_d_case_4,
construct_d_case_5, construct_d_case_6, rational_laurent_series,
solve_riccati)
f = Function('f')
x = symbols('x')
# These are the functions used to generate the tests
# SHOULD NOT BE USED DIRECTLY IN TESTS
def rand_rational(maxint):
return Rational(randint(-maxint, maxint), randint(1, maxint))
def rand_poly(x, degree, maxint):
return Poly([rand_rational(maxint) for _ in range(degree+1)], x)
def rand_rational_function(x, degree, maxint):
degnum = randint(1, degree)
degden = randint(1, degree)
num = rand_poly(x, degnum, maxint)
den = rand_poly(x, degden, maxint)
while den == Poly(0, x):
den = rand_poly(x, degden, maxint)
return num / den
def find_riccati_ode(ratfunc, x, yf):
y = ratfunc
yp = y.diff(x)
q1 = rand_rational_function(x, 1, 3)
q2 = rand_rational_function(x, 1, 3)
while q2 == 0:
q2 = rand_rational_function(x, 1, 3)
q0 = ratsimp(yp - q1*y - q2*y**2)
eq = Eq(yf.diff(), q0 + q1*yf + q2*yf**2)
sol = Eq(yf, y)
assert checkodesol(eq, sol) == (True, 0)
return eq, q0, q1, q2
# Testing functions start
def test_riccati_transformation():
"""
This function tests the transformation of the
solution of a Riccati ODE to the solution of
its corresponding normal Riccati ODE.
Each test case 4 values -
1. w - The solution to be transformed
2. b1 - The coefficient of f(x) in the ODE.
3. b2 - The coefficient of f(x)**2 in the ODE.
4. y - The solution to the normal Riccati ODE.
"""
tests = [
(
x/(x - 1),
(x**2 + 7)/3*x,
x,
-x**2/(x - 1) - x*(x**2/3 + S(7)/3)/2 - 1/(2*x)
),
(
(2*x + 3)/(2*x + 2),
(3 - 3*x)/(x + 1),
5*x,
-5*x*(2*x + 3)/(2*x + 2) - (3 - 3*x)/(Mul(2, x + 1, evaluate=False)) - 1/(2*x)
),
(
-1/(2*x**2 - 1),
0,
(2 - x)/(4*x - 2),
(2 - x)/((4*x - 2)*(2*x**2 - 1)) - (4*x - 2)*(Mul(-4, 2 - x, evaluate=False)/(4*x - \
2)**2 - 1/(4*x - 2))/(Mul(2, 2 - x, evaluate=False))
),
(
x,
(8*x - 12)/(12*x + 9),
x**3/(6*x - 9),
-x**4/(6*x - 9) - (8*x - 12)/(Mul(2, 12*x + 9, evaluate=False)) - (6*x - 9)*(-6*x**3/(6*x \
- 9)**2 + 3*x**2/(6*x - 9))/(2*x**3)
)]
for w, b1, b2, y in tests:
assert y == riccati_normal(w, x, b1, b2)
assert w == riccati_inverse_normal(y, x, b1, b2).cancel()
# Test bp parameter in riccati_inverse_normal
tests = [
(
(-2*x - 1)/(2*x**2 + 2*x - 2),
-2/x,
(-x - 1)/(4*x),
8*x**2*(1/(4*x) + (-x - 1)/(4*x**2))/(-x - 1)**2 + 4/(-x - 1),
-2*x*(-1/(4*x) - (-x - 1)/(4*x**2))/(-x - 1) - (-2*x - 1)*(-x - 1)/(4*x*(2*x**2 + 2*x \
- 2)) + 1/x
),
(
3/(2*x**2),
-2/x,
(-x - 1)/(4*x),
8*x**2*(1/(4*x) + (-x - 1)/(4*x**2))/(-x - 1)**2 + 4/(-x - 1),
-2*x*(-1/(4*x) - (-x - 1)/(4*x**2))/(-x - 1) + 1/x - Mul(3, -x - 1, evaluate=False)/(8*x**3)
)]
for w, b1, b2, bp, y in tests:
assert y == riccati_normal(w, x, b1, b2)
assert w == riccati_inverse_normal(y, x, b1, b2, bp).cancel()
def test_riccati_reduced():
"""
This function tests the transformation of a
Riccati ODE to its normal Riccati ODE.
Each test case 2 values -
1. eq - A Riccati ODE.
2. normal_eq - The normal Riccati ODE of eq.
"""
tests = [
(
f(x).diff(x) - x**2 - x*f(x) - x*f(x)**2,
f(x).diff(x) + f(x)**2 + x**3 - x**2/4 - 3/(4*x**2)
),
(
6*x/(2*x + 9) + f(x).diff(x) - (x + 1)*f(x)**2/x,
-3*x**2*(1/x + (-x - 1)/x**2)**2/(4*(-x - 1)**2) + Mul(6, \
-x - 1, evaluate=False)/(2*x + 9) + f(x)**2 + f(x).diff(x) \
- (-1 + (x + 1)/x)/(x*(-x - 1))
),
(
f(x)**2 + f(x).diff(x) - (x - 1)*f(x)/(-x - S(1)/2),
-(2*x - 2)**2/(4*(2*x + 1)**2) + (2*x - 2)/(2*x + 1)**2 + \
f(x)**2 + f(x).diff(x) - 1/(2*x + 1)
),
(
f(x).diff(x) - f(x)**2/x,
f(x)**2 + f(x).diff(x) + 1/(4*x**2)
),
(
-3*(-x**2 - x + 1)/(x**2 + 6*x + 1) + f(x).diff(x) + f(x)**2/x,
f(x)**2 + f(x).diff(x) + (3*x**2/(x**2 + 6*x + 1) + 3*x/(x**2 \
+ 6*x + 1) - 3/(x**2 + 6*x + 1))/x + 1/(4*x**2)
),
(
6*x/(2*x + 9) + f(x).diff(x) - (x + 1)*f(x)/x,
False
),
(
f(x)*f(x).diff(x) - 1/x + f(x)/3 + f(x)**2/(x**2 - 2),
False
)]
for eq, normal_eq in tests:
assert normal_eq == riccati_reduced(eq, f, x)
def test_match_riccati():
"""
This function tests if an ODE is Riccati or not.
Each test case has 5 values -
1. eq - The Riccati ODE.
2. match - Boolean indicating if eq is a Riccati ODE.
3. b0 -
4. b1 - Coefficient of f(x) in eq.
5. b2 - Coefficient of f(x)**2 in eq.
"""
tests = [
# Test Rational Riccati ODEs
(
f(x).diff(x) - (405*x**3 - 882*x**2 - 78*x + 92)/(243*x**4 \
- 945*x**3 + 846*x**2 + 180*x - 72) - 2 - f(x)**2/(3*x + 1) \
- (S(1)/3 - x)*f(x)/(S(1)/3 - 3*x/2),
True,
45*x**3/(27*x**4 - 105*x**3 + 94*x**2 + 20*x - 8) - 98*x**2/ \
(27*x**4 - 105*x**3 + 94*x**2 + 20*x - 8) - 26*x/(81*x**4 - \
315*x**3 + 282*x**2 + 60*x - 24) + 2 + 92/(243*x**4 - 945*x**3 \
+ 846*x**2 + 180*x - 72),
Mul(-1, 2 - 6*x, evaluate=False)/(9*x - 2),
1/(3*x + 1)
),
(
f(x).diff(x) + 4*x/27 - (x/3 - 1)*f(x)**2 - (2*x/3 + \
1)*f(x)/(3*x + 2) - S(10)/27 - (265*x**2 + 423*x + 162) \
/(324*x**3 + 216*x**2),
True,
-4*x/27 + S(10)/27 + 3/(6*x**3 + 4*x**2) + 47/(36*x**2 \
+ 24*x) + 265/(324*x + 216),
Mul(-1, -2*x - 3, evaluate=False)/(9*x + 6),
x/3 - 1
),
(
f(x).diff(x) - (304*x**5 - 745*x**4 + 631*x**3 - 876*x**2 \
+ 198*x - 108)/(36*x**6 - 216*x**5 + 477*x**4 - 567*x**3 + \
360*x**2 - 108*x) - S(17)/9 - (x - S(3)/2)*f(x)/(x/2 - \
S(3)/2) - (x/3 - 3)*f(x)**2/(3*x),
True,
304*x**4/(36*x**5 - 216*x**4 + 477*x**3 - 567*x**2 + 360*x - \
108) - 745*x**3/(36*x**5 - 216*x**4 + 477*x**3 - 567*x**2 + \
360*x - 108) + 631*x**2/(36*x**5 - 216*x**4 + 477*x**3 - 567* \
x**2 + 360*x - 108) - 292*x/(12*x**5 - 72*x**4 + 159*x**3 - \
189*x**2 + 120*x - 36) + S(17)/9 - 12/(4*x**6 - 24*x**5 + \
53*x**4 - 63*x**3 + 40*x**2 - 12*x) + 22/(4*x**5 - 24*x**4 \
+ 53*x**3 - 63*x**2 + 40*x - 12),
Mul(-1, 3 - 2*x, evaluate=False)/(x - 3),
Mul(-1, 9 - x, evaluate=False)/(9*x)
),
# Test Non-Rational Riccati ODEs
(
f(x).diff(x) - x**(S(3)/2)/(x**(S(1)/2) - 2) + x**2*f(x) + \
x*f(x)**2/(x**(S(3)/4)),
False, 0, 0, 0
),
(
f(x).diff(x) - sin(x**2) + exp(x)*f(x) + log(x)*f(x)**2,
False, 0, 0, 0
),
(
f(x).diff(x) - tanh(x + sqrt(x)) + f(x) + x**4*f(x)**2,
False, 0, 0, 0
),
# Test Non-Riccati ODEs
(
(1 - x**2)*f(x).diff(x, 2) - 2*x*f(x).diff(x) + 20*f(x),
False, 0, 0, 0
),
(
f(x).diff(x) - x**2 + x**3*f(x) + (x**2/(x + 1))*f(x)**3,
False, 0, 0, 0
),
(
f(x).diff(x)*f(x)**2 + (x**2 - 1)/(x**3 + 1)*f(x) + 1/(2*x \
+ 3) + f(x)**2,
False, 0, 0, 0
)]
for eq, res, b0, b1, b2 in tests:
match, funcs = match_riccati(eq, f, x)
assert match == res
if res:
assert [b0, b1, b2] == funcs
def test_val_at_inf():
"""
This function tests the valuation of rational
function at oo.
Each test case has 3 values -
1. num - Numerator of rational function.
2. den - Denominator of rational function.
3. val_inf - Valuation of rational function at oo
"""
tests = [
# degree(denom) > degree(numer)
(
Poly(10*x**3 + 8*x**2 - 13*x + 6, x),
Poly(-13*x**10 - x**9 + 5*x**8 + 7*x**7 + 10*x**6 + 6*x**5 - 7*x**4 + 11*x**3 - 8*x**2 + 5*x + 13, x),
7
),
(
Poly(1, x),
Poly(-9*x**4 + 3*x**3 + 15*x**2 - 6*x - 14, x),
4
),
# degree(denom) == degree(numer)
(
Poly(-6*x**3 - 8*x**2 + 8*x - 6, x),
Poly(-5*x**3 + 12*x**2 - 6*x - 9, x),
0
),
# degree(denom) < degree(numer)
(
Poly(12*x**8 - 12*x**7 - 11*x**6 + 8*x**5 + 3*x**4 - x**3 + x**2 - 11*x, x),
Poly(-14*x**2 + x, x),
-6
),
(
Poly(5*x**6 + 9*x**5 - 11*x**4 - 9*x**3 + x**2 - 4*x + 4, x),
Poly(15*x**4 + 3*x**3 - 8*x**2 + 15*x + 12, x),
-2
)]
for num, den, val in tests:
assert val_at_inf(num, den, x) == val
def test_necessary_conds():
"""
This function tests the necessary conditions for
a Riccati ODE to have a rational particular solution.
"""
# Valuation at Infinity is an odd negative integer
assert check_necessary_conds(-3, [1, 2, 4]) == False
# Valuation at Infinity is a positive integer lesser than 2
assert check_necessary_conds(1, [1, 2, 4]) == False
# Multiplicity of a pole is an odd integer greater than 1
assert check_necessary_conds(2, [3, 1, 6]) == False
# All values are correct
assert check_necessary_conds(-10, [1, 2, 8, 12]) == True
def test_inverse_transform_poly():
"""
This function tests the substitution x -> 1/x
in rational functions represented using Poly.
"""
fns = [
(15*x**3 - 8*x**2 - 2*x - 6)/(18*x + 6),
(180*x**5 + 40*x**4 + 80*x**3 + 30*x**2 - 60*x - 80)/(180*x**3 - 150*x**2 + 75*x + 12),
(-15*x**5 - 36*x**4 + 75*x**3 - 60*x**2 - 80*x - 60)/(80*x**4 + 60*x**3 + 60*x**2 + 60*x - 80),
(60*x**7 + 24*x**6 - 15*x**5 - 20*x**4 + 30*x**2 + 100*x - 60)/(240*x**2 - 20*x - 30),
(30*x**6 - 12*x**5 + 15*x**4 - 15*x**2 + 10*x + 60)/(3*x**10 - 45*x**9 + 15*x**5 + 15*x**4 - 5*x**3 \
+ 15*x**2 + 45*x - 15)
]
for f in fns:
num, den = [Poly(e, x) for e in f.as_numer_denom()]
num, den = inverse_transform_poly(num, den, x)
assert f.subs(x, 1/x).cancel() == num/den
def test_limit_at_inf():
"""
This function tests the limit at oo of a
rational function.
Each test case has 3 values -
1. num - Numerator of rational function.
2. den - Denominator of rational function.
3. limit_at_inf - Limit of rational function at oo
"""
tests = [
# deg(denom) > deg(numer)
(
Poly(-12*x**2 + 20*x + 32, x),
Poly(32*x**3 + 72*x**2 + 3*x - 32, x),
0
),
# deg(denom) < deg(numer)
(
Poly(1260*x**4 - 1260*x**3 - 700*x**2 - 1260*x + 1400, x),
Poly(6300*x**3 - 1575*x**2 + 756*x - 540, x),
oo
),
# deg(denom) < deg(numer), one of the leading coefficients is negative
(
Poly(-735*x**8 - 1400*x**7 + 1680*x**6 - 315*x**5 - 600*x**4 + 840*x**3 - 525*x**2 \
+ 630*x + 3780, x),
Poly(1008*x**7 - 2940*x**6 - 84*x**5 + 2940*x**4 - 420*x**3 + 1512*x**2 + 105*x + 168, x),
-oo
),
# deg(denom) == deg(numer)
(
Poly(105*x**7 - 960*x**6 + 60*x**5 + 60*x**4 - 80*x**3 + 45*x**2 + 120*x + 15, x),
Poly(735*x**7 + 525*x**6 + 720*x**5 + 720*x**4 - 8400*x**3 - 2520*x**2 + 2800*x + 280, x),
S(1)/7
),
(
Poly(288*x**4 - 450*x**3 + 280*x**2 - 900*x - 90, x),
Poly(607*x**4 + 840*x**3 - 1050*x**2 + 420*x + 420, x),
S(288)/607
)]
for num, den, lim in tests:
assert limit_at_inf(num, den, x) == lim
def test_construct_c_case_1():
"""
This function tests the Case 1 in the step
to calculate coefficients of c-vectors.
Each test case has 4 values -
1. num - Numerator of the rational function a(x).
2. den - Denominator of the rational function a(x).
3. pole - Pole of a(x) for which c-vector is being
calculated.
4. c - The c-vector for the pole.
"""
tests = [
(
Poly(-3*x**3 + 3*x**2 + 4*x - 5, x, extension=True),
Poly(4*x**8 + 16*x**7 + 9*x**5 + 12*x**4 + 6*x**3 + 12*x**2, x, extension=True),
S(0),
[[S(1)/2 + sqrt(6)*I/6], [S(1)/2 - sqrt(6)*I/6]]
),
(
Poly(1200*x**3 + 1440*x**2 + 816*x + 560, x, extension=True),
Poly(128*x**5 - 656*x**4 + 1264*x**3 - 1125*x**2 + 385*x + 49, x, extension=True),
S(7)/4,
[[S(1)/2 + sqrt(16367978)/634], [S(1)/2 - sqrt(16367978)/634]]
),
(
Poly(4*x + 2, x, extension=True),
Poly(18*x**4 + (2 - 18*sqrt(3))*x**3 + (14 - 11*sqrt(3))*x**2 + (4 - 6*sqrt(3))*x \
+ 8*sqrt(3) + 16, x, domain='QQ<sqrt(3)>'),
(S(1) + sqrt(3))/2,
[[S(1)/2 + sqrt(Mul(4, 2*sqrt(3) + 4, evaluate=False)/(19*sqrt(3) + 44) + 1)/2], \
[S(1)/2 - sqrt(Mul(4, 2*sqrt(3) + 4, evaluate=False)/(19*sqrt(3) + 44) + 1)/2]]
)]
for num, den, pole, c in tests:
assert construct_c_case_1(num, den, x, pole) == c
def test_construct_c_case_2():
"""
This function tests the Case 2 in the step
to calculate coefficients of c-vectors.
Each test case has 5 values -
1. num - Numerator of the rational function a(x).
2. den - Denominator of the rational function a(x).
3. pole - Pole of a(x) for which c-vector is being
calculated.
4. mul - The multiplicity of the pole.
5. c - The c-vector for the pole.
"""
tests = [
# Testing poles with multiplicity 2
(
Poly(1, x, extension=True),
Poly((x - 1)**2*(x - 2), x, extension=True),
1, 2,
[[-I*(-1 - I)/2], [I*(-1 + I)/2]]
),
(
Poly(3*x**5 - 12*x**4 - 7*x**3 + 1, x, extension=True),
Poly((3*x - 1)**2*(x + 2)**2, x, extension=True),
S(1)/3, 2,
[[-S(89)/98], [-S(9)/98]]
),
# Testing poles with multiplicity 4
(
Poly(x**3 - x**2 + 4*x, x, extension=True),
Poly((x - 2)**4*(x + 5)**2, x, extension=True),
2, 4,
[[7*sqrt(3)*(S(60)/343 - 4*sqrt(3)/7)/12, 2*sqrt(3)/7], \
[-7*sqrt(3)*(S(60)/343 + 4*sqrt(3)/7)/12, -2*sqrt(3)/7]]
),
(
Poly(3*x**5 + x**4 + 3, x, extension=True),
Poly((4*x + 1)**4*(x + 2), x, extension=True),
-S(1)/4, 4,
[[128*sqrt(439)*(-sqrt(439)/128 - S(55)/14336)/439, sqrt(439)/256], \
[-128*sqrt(439)*(sqrt(439)/128 - S(55)/14336)/439, -sqrt(439)/256]]
),
# Testing poles with multiplicity 6
(
Poly(x**3 + 2, x, extension=True),
Poly((3*x - 1)**6*(x**2 + 1), x, extension=True),
S(1)/3, 6,
[[27*sqrt(66)*(-sqrt(66)/54 - S(131)/267300)/22, -2*sqrt(66)/1485, sqrt(66)/162], \
[-27*sqrt(66)*(sqrt(66)/54 - S(131)/267300)/22, 2*sqrt(66)/1485, -sqrt(66)/162]]
),
(
Poly(x**2 + 12, x, extension=True),
Poly((x - sqrt(2))**6, x, extension=True),
sqrt(2), 6,
[[sqrt(14)*(S(6)/7 - 3*sqrt(14))/28, sqrt(7)/7, sqrt(14)], \
[-sqrt(14)*(S(6)/7 + 3*sqrt(14))/28, -sqrt(7)/7, -sqrt(14)]]
)]
for num, den, pole, mul, c in tests:
assert construct_c_case_2(num, den, x, pole, mul) == c
def test_construct_c_case_3():
"""
This function tests the Case 3 in the step
to calculate coefficients of c-vectors.
"""
assert construct_c_case_3() == [[1]]
def test_construct_d_case_4():
"""
This function tests the Case 4 in the step
to calculate coefficients of the d-vector.
Each test case has 4 values -
1. num - Numerator of the rational function a(x).
2. den - Denominator of the rational function a(x).
3. mul - Multiplicity of oo as a pole.
4. d - The d-vector.
"""
tests = [
# Tests with multiplicity at oo = 2
(
Poly(-x**5 - 2*x**4 + 4*x**3 + 2*x + 5, x, extension=True),
Poly(9*x**3 - 2*x**2 + 10*x - 2, x, extension=True),
2,
[[10*I/27, I/3, -3*I*(S(158)/243 - I/3)/2], \
[-10*I/27, -I/3, 3*I*(S(158)/243 + I/3)/2]]
),
(
Poly(-x**6 + 9*x**5 + 5*x**4 + 6*x**3 + 5*x**2 + 6*x + 7, x, extension=True),
Poly(x**4 + 3*x**3 + 12*x**2 - x + 7, x, extension=True),
2,
[[-6*I, I, -I*(17 - I)/2], [6*I, -I, I*(17 + I)/2]]
),
# Tests with multiplicity at oo = 4
(
Poly(-2*x**6 - x**5 - x**4 - 2*x**3 - x**2 - 3*x - 3, x, extension=True),
Poly(3*x**2 + 10*x + 7, x, extension=True),
4,
[[269*sqrt(6)*I/288, -17*sqrt(6)*I/36, sqrt(6)*I/3, -sqrt(6)*I*(S(16969)/2592 \
- 2*sqrt(6)*I/3)/4], [-269*sqrt(6)*I/288, 17*sqrt(6)*I/36, -sqrt(6)*I/3, \
sqrt(6)*I*(S(16969)/2592 + 2*sqrt(6)*I/3)/4]]
),
(
Poly(-3*x**5 - 3*x**4 - 3*x**3 - x**2 - 1, x, extension=True),
Poly(12*x - 2, x, extension=True),
4,
[[41*I/192, 7*I/24, I/2, -I*(-S(59)/6912 - I)], \
[-41*I/192, -7*I/24, -I/2, I*(-S(59)/6912 + I)]]
),
# Tests with multiplicity at oo = 4
(
Poly(-x**7 - x**5 - x**4 - x**2 - x, x, extension=True),
Poly(x + 2, x, extension=True),
6,
[[-5*I/2, 2*I, -I, I, -I*(-9 - 3*I)/2], [5*I/2, -2*I, I, -I, I*(-9 + 3*I)/2]]
),
(
Poly(-x**7 - x**6 - 2*x**5 - 2*x**4 - x**3 - x**2 + 2*x - 2, x, extension=True),
Poly(2*x - 2, x, extension=True),
6,
[[3*sqrt(2)*I/4, 3*sqrt(2)*I/4, sqrt(2)*I/2, sqrt(2)*I/2, -sqrt(2)*I*(-S(7)/8 - \
3*sqrt(2)*I/2)/2], [-3*sqrt(2)*I/4, -3*sqrt(2)*I/4, -sqrt(2)*I/2, -sqrt(2)*I/2, \
sqrt(2)*I*(-S(7)/8 + 3*sqrt(2)*I/2)/2]]
)]
for num, den, mul, d in tests:
ser = rational_laurent_series(num, den, x, oo, mul, 1)
assert construct_d_case_4(ser, mul//2) == d
def test_construct_d_case_5():
"""
This function tests the Case 5 in the step
to calculate coefficients of the d-vector.
Each test case has 3 values -
1. num - Numerator of the rational function a(x).
2. den - Denominator of the rational function a(x).
3. d - The d-vector.
"""
tests = [
(
Poly(2*x**3 + x**2 + x - 2, x, extension=True),
Poly(9*x**3 + 5*x**2 + 2*x - 1, x, extension=True),
[[sqrt(2)/3, -sqrt(2)/108], [-sqrt(2)/3, sqrt(2)/108]]
),
(
Poly(3*x**5 + x**4 - x**3 + x**2 - 2*x - 2, x, domain='ZZ'),
Poly(9*x**5 + 7*x**4 + 3*x**3 + 2*x**2 + 5*x + 7, x, domain='ZZ'),
[[sqrt(3)/3, -2*sqrt(3)/27], [-sqrt(3)/3, 2*sqrt(3)/27]]
),
(
Poly(x**2 - x + 1, x, domain='ZZ'),
Poly(3*x**2 + 7*x + 3, x, domain='ZZ'),
[[sqrt(3)/3, -5*sqrt(3)/9], [-sqrt(3)/3, 5*sqrt(3)/9]]
)]
for num, den, d in tests:
# Multiplicity of oo is 0
ser = rational_laurent_series(num, den, x, oo, 0, 1)
assert construct_d_case_5(ser) == d
def test_construct_d_case_6():
"""
This function tests the Case 6 in the step
to calculate coefficients of the d-vector.
Each test case has 3 values -
1. num - Numerator of the rational function a(x).
2. den - Denominator of the rational function a(x).
3. d - The d-vector.
"""
tests = [
(
Poly(-2*x**2 - 5, x, domain='ZZ'),
Poly(4*x**4 + 2*x**2 + 10*x + 2, x, domain='ZZ'),
[[S(1)/2 + I/2], [S(1)/2 - I/2]]
),
(
Poly(-2*x**3 - 4*x**2 - 2*x - 5, x, domain='ZZ'),
Poly(x**6 - x**5 + 2*x**4 - 4*x**3 - 5*x**2 - 5*x + 9, x, domain='ZZ'),
[[1], [0]]
),
(
Poly(-5*x**3 + x**2 + 11*x + 12, x, domain='ZZ'),
Poly(6*x**8 - 26*x**7 - 27*x**6 - 10*x**5 - 44*x**4 - 46*x**3 - 34*x**2 \
- 27*x - 42, x, domain='ZZ'),
[[1], [0]]
)]
for num, den, d in tests:
assert construct_d_case_6(num, den, x) == d
def test_rational_laurent_series():
"""
This function tests the computation of coefficients
of Laurent series of a rational function.
Each test case has 5 values -
1. num - Numerator of the rational function.
2. den - Denominator of the rational function.
3. x0 - Point about which Laurent series is to
be calculated.
4. mul - Multiplicity of x0 if x0 is a pole of
the rational function (0 otherwise).
5. n - Number of terms upto which the series
is to be calculated.
"""
tests = [
# Laurent series about simple pole (Multiplicity = 1)
(
Poly(x**2 - 3*x + 9, x, extension=True),
Poly(x**2 - x, x, extension=True),
S(1), 1, 6,
{1: 7, 0: -8, -1: 9, -2: -9, -3: 9, -4: -9}
),
# Laurent series about multiple pole (Multiplicity > 1)
(
Poly(64*x**3 - 1728*x + 1216, x, extension=True),
Poly(64*x**4 - 80*x**3 - 831*x**2 + 1809*x - 972, x, extension=True),
S(9)/8, 2, 3,
{0: S(32177152)/46521675, 2: S(1019)/984, -1: S(11947565056)/28610830125, \
1: S(209149)/75645}
),
(
Poly(1, x, extension=True),
Poly(x**5 + (-4*sqrt(2) - 1)*x**4 + (4*sqrt(2) + 12)*x**3 + (-12 - 8*sqrt(2))*x**2 \
+ (4 + 8*sqrt(2))*x - 4, x, extension=True),
sqrt(2), 4, 6,
{4: 1 + sqrt(2), 3: -3 - 2*sqrt(2), 2: Mul(-1, -3 - 2*sqrt(2), evaluate=False)/(-1 \
+ sqrt(2)), 1: (-3 - 2*sqrt(2))/(-1 + sqrt(2))**2, 0: Mul(-1, -3 - 2*sqrt(2), evaluate=False \
)/(-1 + sqrt(2))**3, -1: (-3 - 2*sqrt(2))/(-1 + sqrt(2))**4}
),
# Laurent series about oo
(
Poly(x**5 - 4*x**3 + 6*x**2 + 10*x - 13, x, extension=True),
Poly(x**2 - 5, x, extension=True),
oo, 3, 6,
{3: 1, 2: 0, 1: 1, 0: 6, -1: 15, -2: 17}
),
# Laurent series at x0 where x0 is not a pole of the function
# Using multiplicity as 0 (as x0 will not be a pole)
(
Poly(3*x**3 + 6*x**2 - 2*x + 5, x, extension=True),
Poly(9*x**4 - x**3 - 3*x**2 + 4*x + 4, x, extension=True),
S(2)/5, 0, 1,
{0: S(3345)/3304, -1: S(399325)/2729104, -2: S(3926413375)/4508479808, \
-3: S(-5000852751875)/1862002160704, -4: S(-6683640101653125)/6152055138966016}
),
(
Poly(-7*x**2 + 2*x - 4, x, extension=True),
Poly(7*x**5 + 9*x**4 + 8*x**3 + 3*x**2 + 6*x + 9, x, extension=True),
oo, 0, 6,
{0: 0, -2: 0, -5: -S(71)/49, -1: 0, -3: -1, -4: S(11)/7}
)]
for num, den, x0, mul, n, ser in tests:
assert ser == rational_laurent_series(num, den, x, x0, mul, n)
def check_dummy_sol(eq, solse, dummy_sym):
"""
Helper function to check if actual solution
matches expected solution if actual solution
contains dummy symbols.
"""
if isinstance(eq, Eq):
eq = eq.lhs - eq.rhs
_, funcs = match_riccati(eq, f, x)
sols = solve_riccati(f(x), x, *funcs)
C1 = Dummy('C1')
sols = [sol.subs(C1, dummy_sym) for sol in sols]
assert all(x[0] for x in checkodesol(eq, sols))
assert all(s1.dummy_eq(s2, dummy_sym) for s1, s2 in zip(sols, solse))
def test_solve_riccati():
"""
This function tests the computation of rational
particular solutions for a Riccati ODE.
Each test case has 2 values -
1. eq - Riccati ODE to be solved.
2. sol - Expected solution to the equation.
Some examples have been taken from the paper - "Statistical Investigation of
First-Order Algebraic ODEs and their Rational General Solutions" by
Georg Grasegger, N. Thieu Vo, Franz Winkler
https://www3.risc.jku.at/publications/download/risc_5197/RISCReport15-19.pdf
"""
C0 = Dummy('C0')
# Type: 1st Order Rational Riccati, dy/dx = a + b*y + c*y**2,
# a, b, c are rational functions of x
tests = [
# a(x) is a constant
(
Eq(f(x).diff(x) + f(x)**2 - 2, 0),
[Eq(f(x), sqrt(2)), Eq(f(x), -sqrt(2))]
),
# a(x) is a constant
(
f(x)**2 + f(x).diff(x) + 4*f(x)/x + 2/x**2,
[Eq(f(x), (-2*C0 - x)/(C0*x + x**2))]
),
# a(x) is a constant
(
2*x**2*f(x).diff(x) - x*(4*f(x) + f(x).diff(x) - 4) + (f(x) - 1)*f(x),
[Eq(f(x), (C0 + 2*x**2)/(C0 + x))]
),
# Pole with multiplicity 1
(
Eq(f(x).diff(x), -f(x)**2 - 2/(x**3 - x**2)),
[Eq(f(x), 1/(x**2 - x))]
),
# One pole of multiplicity 2
(
x**2 - (2*x + 1/x)*f(x) + f(x)**2 + f(x).diff(x),
[Eq(f(x), (C0*x + x**3 + 2*x)/(C0 + x**2)), Eq(f(x), x)]
),
(
x**4*f(x).diff(x) + x**2 - x*(2*f(x)**2 + f(x).diff(x)) + f(x),
[Eq(f(x), (C0*x**2 + x)/(C0 + x**2)), Eq(f(x), x**2)]
),
# Multiple poles of multiplicity 2
(
-f(x)**2 + f(x).diff(x) + (15*x**2 - 20*x + 7)/((x - 1)**2*(2*x \
- 1)**2),
[Eq(f(x), (9*C0*x - 6*C0 - 15*x**5 + 60*x**4 - 94*x**3 + 72*x**2 \
- 30*x + 6)/(6*C0*x**2 - 9*C0*x + 3*C0 + 6*x**6 - 29*x**5 + \
57*x**4 - 58*x**3 + 30*x**2 - 6*x)), Eq(f(x), (3*x - 2)/(2*x**2 \
- 3*x + 1))]
),
# Regression: Poles with even multiplicity > 2 fixed
(
f(x)**2 + f(x).diff(x) - (4*x**6 - 8*x**5 + 12*x**4 + 4*x**3 + \
7*x**2 - 20*x + 4)/(4*x**4),
[Eq(f(x), (2*x**5 - 2*x**4 - x**3 + 4*x**2 + 3*x - 2)/(2*x**4 \
- 2*x**2))]
),
# Regression: Poles with even multiplicity > 2 fixed
(
Eq(f(x).diff(x), (-x**6 + 15*x**4 - 40*x**3 + 45*x**2 - 24*x + 4)/\
(x**12 - 12*x**11 + 66*x**10 - 220*x**9 + 495*x**8 - 792*x**7 + 924*x**6 - \
792*x**5 + 495*x**4 - 220*x**3 + 66*x**2 - 12*x + 1) + f(x)**2 + f(x)),
[Eq(f(x), 1/(x**6 - 6*x**5 + 15*x**4 - 20*x**3 + 15*x**2 - 6*x + 1))]
),
# More than 2 poles with multiplicity 2
# Regression: Fixed mistake in necessary conditions
(
Eq(f(x).diff(x), x*f(x) + 2*x + (3*x - 2)*f(x)**2/(4*x + 2) + \
(8*x**2 - 7*x + 26)/(16*x**3 - 24*x**2 + 8) - S(3)/2),
[Eq(f(x), (1 - 4*x)/(2*x - 2))]
),
# Regression: Fixed mistake in necessary conditions
(
Eq(f(x).diff(x), (-12*x**2 - 48*x - 15)/(24*x**3 - 40*x**2 + 8*x + 8) \
+ 3*f(x)**2/(6*x + 2)),
[Eq(f(x), (2*x + 1)/(2*x - 2))]
),
# Imaginary poles
(
f(x).diff(x) + (3*x**2 + 1)*f(x)**2/x + (6*x**2 - x + 3)*f(x)/(x*(x \
- 1)) + (3*x**2 - 2*x + 2)/(x*(x - 1)**2),
[Eq(f(x), (-C0 - x**3 + x**2 - 2*x)/(C0*x - C0 + x**4 - x**3 + x**2 \
- x)), Eq(f(x), -1/(x - 1))],
),
# Imaginary coefficients in equation
(
f(x).diff(x) - 2*I*(f(x)**2 + 1)/x,
[Eq(f(x), (-I*C0 + I*x**4)/(C0 + x**4)), Eq(f(x), -I)]
),
# Regression: linsolve returning empty solution
# Large value of m (> 10)
(
Eq(f(x).diff(x), x*f(x)/(S(3)/2 - 2*x) + (x/2 - S(1)/3)*f(x)**2/\
(2*x/3 - S(1)/2) - S(5)/4 + (281*x**2 - 1260*x + 756)/(16*x**3 - 12*x**2)),
[Eq(f(x), (9 - x)/x), Eq(f(x), (40*x**14 + 28*x**13 + 420*x**12 + 2940*x**11 + \
18480*x**10 + 103950*x**9 + 519750*x**8 + 2286900*x**7 + 8731800*x**6 + 28378350*\
x**5 + 76403250*x**4 + 163721250*x**3 + 261954000*x**2 + 278326125*x + 147349125)/\
((24*x**14 + 140*x**13 + 840*x**12 + 4620*x**11 + 23100*x**10 + 103950*x**9 + \
415800*x**8 + 1455300*x**7 + 4365900*x**6 + 10914750*x**5 + 21829500*x**4 + 32744250\
*x**3 + 32744250*x**2 + 16372125*x)))]
),
# Regression: Fixed bug due to a typo in paper
(
Eq(f(x).diff(x), 18*x**3 + 18*x**2 + (-x/2 - S(1)/2)*f(x)**2 + 6),
[Eq(f(x), 6*x)]
),
# Regression: Fixed bug due to a typo in paper
(
Eq(f(x).diff(x), -3*x**3/4 + 15*x/2 + (x/3 - S(4)/3)*f(x)**2 \
+ 9 + (1 - x)*f(x)/x + 3/x),
[Eq(f(x), -3*x/2 - 3)]
)]
for eq, sol in tests:
check_dummy_sol(eq, sol, C0)
@slow
def test_solve_riccati_slow():
"""
This function tests the computation of rational
particular solutions for a Riccati ODE.
Each test case has 2 values -
1. eq - Riccati ODE to be solved.
2. sol - Expected solution to the equation.
"""
C0 = Dummy('C0')
tests = [
# Very large values of m (989 and 991)
(
Eq(f(x).diff(x), (1 - x)*f(x)/(x - 3) + (2 - 12*x)*f(x)**2/(2*x - 9) + \
(54924*x**3 - 405264*x**2 + 1084347*x - 1087533)/(8*x**4 - 132*x**3 + 810*x**2 - \
2187*x + 2187) + 495),
[Eq(f(x), (18*x + 6)/(2*x - 9))]
)]
for eq, sol in tests:
check_dummy_sol(eq, sol, C0)
File diff suppressed because it is too large Load Diff
@@ -0,0 +1,203 @@
from sympy.core.function import (Derivative, Function, diff)
from sympy.core.numbers import (I, Rational, pi)
from sympy.core.relational import Eq
from sympy.core.symbol import (Symbol, symbols)
from sympy.functions.elementary.exponential import (exp, log)
from sympy.functions.elementary.miscellaneous import sqrt
from sympy.functions.elementary.trigonometric import (cos, sin)
from sympy.functions.special.error_functions import (Ei, erf, erfi)
from sympy.integrals.integrals import Integral
from sympy.solvers.ode.subscheck import checkodesol, checksysodesol
from sympy.functions import besselj, bessely
from sympy.testing.pytest import raises, slow
C0, C1, C2, C3, C4 = symbols('C0:5')
u, x, y, z = symbols('u,x:z', real=True)
f = Function('f')
g = Function('g')
h = Function('h')
@slow
def test_checkodesol():
# For the most part, checkodesol is well tested in the tests below.
# These tests only handle cases not checked below.
raises(ValueError, lambda: checkodesol(f(x, y).diff(x), Eq(f(x, y), x)))
raises(ValueError, lambda: checkodesol(f(x).diff(x), Eq(f(x, y),
x), f(x, y)))
assert checkodesol(f(x).diff(x), Eq(f(x, y), x)) == \
(False, -f(x).diff(x) + f(x, y).diff(x) - 1)
assert checkodesol(f(x).diff(x), Eq(f(x), x)) is not True
assert checkodesol(f(x).diff(x), Eq(f(x), x)) == (False, 1)
sol1 = Eq(f(x)**5 + 11*f(x) - 2*f(x) + x, 0)
assert checkodesol(diff(sol1.lhs, x), sol1) == (True, 0)
assert checkodesol(diff(sol1.lhs, x)*exp(f(x)), sol1) == (True, 0)
assert checkodesol(diff(sol1.lhs, x, 2), sol1) == (True, 0)
assert checkodesol(diff(sol1.lhs, x, 2)*exp(f(x)), sol1) == (True, 0)
assert checkodesol(diff(sol1.lhs, x, 3), sol1) == (True, 0)
assert checkodesol(diff(sol1.lhs, x, 3)*exp(f(x)), sol1) == (True, 0)
assert checkodesol(diff(sol1.lhs, x, 3), Eq(f(x), x*log(x))) == \
(False, 60*x**4*((log(x) + 1)**2 + log(x))*(
log(x) + 1)*log(x)**2 - 5*x**4*log(x)**4 - 9)
assert checkodesol(diff(exp(f(x)) + x, x)*x, Eq(exp(f(x)) + x, 0)) == \
(True, 0)
assert checkodesol(diff(exp(f(x)) + x, x)*x, Eq(exp(f(x)) + x, 0),
solve_for_func=False) == (True, 0)
assert checkodesol(f(x).diff(x, 2), [Eq(f(x), C1 + C2*x),
Eq(f(x), C2 + C1*x), Eq(f(x), C1*x + C2*x**2)]) == \
[(True, 0), (True, 0), (False, C2)]
assert checkodesol(f(x).diff(x, 2), {Eq(f(x), C1 + C2*x),
Eq(f(x), C2 + C1*x), Eq(f(x), C1*x + C2*x**2)}) == \
{(True, 0), (True, 0), (False, C2)}
assert checkodesol(f(x).diff(x) - 1/f(x)/2, Eq(f(x)**2, x)) == \
[(True, 0), (True, 0)]
assert checkodesol(f(x).diff(x) - f(x), Eq(C1*exp(x), f(x))) == (True, 0)
# Based on test_1st_homogeneous_coeff_ode2_eq3sol. Make sure that
# checkodesol tries back substituting f(x) when it can.
eq3 = x*exp(f(x)/x) + f(x) - x*f(x).diff(x)
sol3 = Eq(f(x), log(log(C1/x)**(-x)))
assert not checkodesol(eq3, sol3)[1].has(f(x))
# This case was failing intermittently depending on hash-seed:
eqn = Eq(Derivative(x*Derivative(f(x), x), x)/x, exp(x))
sol = Eq(f(x), C1 + C2*log(x) + exp(x) - Ei(x))
assert checkodesol(eqn, sol, order=2, solve_for_func=False)[0]
eq = x**2*(f(x).diff(x, 2)) + x*(f(x).diff(x)) + (2*x**2 +25)*f(x)
sol = Eq(f(x), C1*besselj(5*I, sqrt(2)*x) + C2*bessely(5*I, sqrt(2)*x))
assert checkodesol(eq, sol) == (True, 0)
eqs = [Eq(f(x).diff(x), f(x) + g(x)), Eq(g(x).diff(x), f(x) + g(x))]
sol = [Eq(f(x), -C1 + C2*exp(2*x)), Eq(g(x), C1 + C2*exp(2*x))]
assert checkodesol(eqs, sol) == (True, [0, 0])
def test_checksysodesol():
x, y, z = symbols('x, y, z', cls=Function)
t = Symbol('t')
eq = (Eq(diff(x(t),t), 9*y(t)), Eq(diff(y(t),t), 12*x(t)))
sol = [Eq(x(t), 9*C1*exp(-6*sqrt(3)*t) + 9*C2*exp(6*sqrt(3)*t)), \
Eq(y(t), -6*sqrt(3)*C1*exp(-6*sqrt(3)*t) + 6*sqrt(3)*C2*exp(6*sqrt(3)*t))]
assert checksysodesol(eq, sol) == (True, [0, 0])
eq = (Eq(diff(x(t),t), 2*x(t) + 4*y(t)), Eq(diff(y(t),t), 12*x(t) + 41*y(t)))
sol = [Eq(x(t), 4*C1*exp(t*(-sqrt(1713)/2 + Rational(43, 2))) + 4*C2*exp(t*(sqrt(1713)/2 + \
Rational(43, 2)))), Eq(y(t), C1*(-sqrt(1713)/2 + Rational(39, 2))*exp(t*(-sqrt(1713)/2 + \
Rational(43, 2))) + C2*(Rational(39, 2) + sqrt(1713)/2)*exp(t*(sqrt(1713)/2 + Rational(43, 2))))]
assert checksysodesol(eq, sol) == (True, [0, 0])
eq = (Eq(diff(x(t),t), x(t) + y(t)), Eq(diff(y(t),t), -2*x(t) + 2*y(t)))
sol = [Eq(x(t), (C1*sin(sqrt(7)*t/2) + C2*cos(sqrt(7)*t/2))*exp(t*Rational(3, 2))), \
Eq(y(t), ((C1/2 - sqrt(7)*C2/2)*sin(sqrt(7)*t/2) + (sqrt(7)*C1/2 + \
C2/2)*cos(sqrt(7)*t/2))*exp(t*Rational(3, 2)))]
assert checksysodesol(eq, sol) == (True, [0, 0])
eq = (Eq(diff(x(t),t), x(t) + y(t) + 9), Eq(diff(y(t),t), 2*x(t) + 5*y(t) + 23))
sol = [Eq(x(t), C1*exp(t*(-sqrt(6) + 3)) + C2*exp(t*(sqrt(6) + 3)) - \
Rational(22, 3)), Eq(y(t), C1*(-sqrt(6) + 2)*exp(t*(-sqrt(6) + 3)) + C2*(2 + \
sqrt(6))*exp(t*(sqrt(6) + 3)) - Rational(5, 3))]
assert checksysodesol(eq, sol) == (True, [0, 0])
eq = (Eq(diff(x(t),t), x(t) + y(t) + 81), Eq(diff(y(t),t), -2*x(t) + y(t) + 23))
sol = [Eq(x(t), (C1*sin(sqrt(2)*t) + C2*cos(sqrt(2)*t))*exp(t) - Rational(58, 3)), \
Eq(y(t), (sqrt(2)*C1*cos(sqrt(2)*t) - sqrt(2)*C2*sin(sqrt(2)*t))*exp(t) - Rational(185, 3))]
assert checksysodesol(eq, sol) == (True, [0, 0])
eq = (Eq(diff(x(t),t), 5*t*x(t) + 2*y(t)), Eq(diff(y(t),t), 2*x(t) + 5*t*y(t)))
sol = [Eq(x(t), (C1*exp(Integral(2, t).doit()) + C2*exp(-(Integral(2, t)).doit()))*\
exp((Integral(5*t, t)).doit())), Eq(y(t), (C1*exp((Integral(2, t)).doit()) - \
C2*exp(-(Integral(2, t)).doit()))*exp((Integral(5*t, t)).doit()))]
assert checksysodesol(eq, sol) == (True, [0, 0])
eq = (Eq(diff(x(t),t), 5*t*x(t) + t**2*y(t)), Eq(diff(y(t),t), -t**2*x(t) + 5*t*y(t)))
sol = [Eq(x(t), (C1*cos((Integral(t**2, t)).doit()) + C2*sin((Integral(t**2, t)).doit()))*\
exp((Integral(5*t, t)).doit())), Eq(y(t), (-C1*sin((Integral(t**2, t)).doit()) + \
C2*cos((Integral(t**2, t)).doit()))*exp((Integral(5*t, t)).doit()))]
assert checksysodesol(eq, sol) == (True, [0, 0])
eq = (Eq(diff(x(t),t), 5*t*x(t) + t**2*y(t)), Eq(diff(y(t),t), -t**2*x(t) + (5*t+9*t**2)*y(t)))
sol = [Eq(x(t), (C1*exp((-sqrt(77)/2 + Rational(9, 2))*(Integral(t**2, t)).doit()) + \
C2*exp((sqrt(77)/2 + Rational(9, 2))*(Integral(t**2, t)).doit()))*exp((Integral(5*t, t)).doit())), \
Eq(y(t), (C1*(-sqrt(77)/2 + Rational(9, 2))*exp((-sqrt(77)/2 + Rational(9, 2))*(Integral(t**2, t)).doit()) + \
C2*(sqrt(77)/2 + Rational(9, 2))*exp((sqrt(77)/2 + Rational(9, 2))*(Integral(t**2, t)).doit()))*exp((Integral(5*t, t)).doit()))]
assert checksysodesol(eq, sol) == (True, [0, 0])
eq = (Eq(diff(x(t),t,t), 5*x(t) + 43*y(t)), Eq(diff(y(t),t,t), x(t) + 9*y(t)))
root0 = -sqrt(-sqrt(47) + 7)
root1 = sqrt(-sqrt(47) + 7)
root2 = -sqrt(sqrt(47) + 7)
root3 = sqrt(sqrt(47) + 7)
sol = [Eq(x(t), 43*C1*exp(t*root0) + 43*C2*exp(t*root1) + 43*C3*exp(t*root2) + 43*C4*exp(t*root3)), \
Eq(y(t), C1*(root0**2 - 5)*exp(t*root0) + C2*(root1**2 - 5)*exp(t*root1) + \
C3*(root2**2 - 5)*exp(t*root2) + C4*(root3**2 - 5)*exp(t*root3))]
assert checksysodesol(eq, sol) == (True, [0, 0])
eq = (Eq(diff(x(t),t,t), 8*x(t)+3*y(t)+31), Eq(diff(y(t),t,t), 9*x(t)+7*y(t)+12))
root0 = -sqrt(-sqrt(109)/2 + Rational(15, 2))
root1 = sqrt(-sqrt(109)/2 + Rational(15, 2))
root2 = -sqrt(sqrt(109)/2 + Rational(15, 2))
root3 = sqrt(sqrt(109)/2 + Rational(15, 2))
sol = [Eq(x(t), 3*C1*exp(t*root0) + 3*C2*exp(t*root1) + 3*C3*exp(t*root2) + 3*C4*exp(t*root3) - Rational(181, 29)), \
Eq(y(t), C1*(root0**2 - 8)*exp(t*root0) + C2*(root1**2 - 8)*exp(t*root1) + \
C3*(root2**2 - 8)*exp(t*root2) + C4*(root3**2 - 8)*exp(t*root3) + Rational(183, 29))]
assert checksysodesol(eq, sol) == (True, [0, 0])
eq = (Eq(diff(x(t),t,t) - 9*diff(y(t),t) + 7*x(t),0), Eq(diff(y(t),t,t) + 9*diff(x(t),t) + 7*y(t),0))
sol = [Eq(x(t), C1*cos(t*(Rational(9, 2) + sqrt(109)/2)) + C2*sin(t*(Rational(9, 2) + sqrt(109)/2)) + \
C3*cos(t*(-sqrt(109)/2 + Rational(9, 2))) + C4*sin(t*(-sqrt(109)/2 + Rational(9, 2)))), Eq(y(t), -C1*sin(t*(Rational(9, 2) + sqrt(109)/2)) \
+ C2*cos(t*(Rational(9, 2) + sqrt(109)/2)) - C3*sin(t*(-sqrt(109)/2 + Rational(9, 2))) + C4*cos(t*(-sqrt(109)/2 + Rational(9, 2))))]
assert checksysodesol(eq, sol) == (True, [0, 0])
eq = (Eq(diff(x(t),t,t), 9*t*diff(y(t),t)-9*y(t)), Eq(diff(y(t),t,t),7*t*diff(x(t),t)-7*x(t)))
I1 = sqrt(6)*7**Rational(1, 4)*sqrt(pi)*erfi(sqrt(6)*7**Rational(1, 4)*t/2)/2 - exp(3*sqrt(7)*t**2/2)/t
I2 = -sqrt(6)*7**Rational(1, 4)*sqrt(pi)*erf(sqrt(6)*7**Rational(1, 4)*t/2)/2 - exp(-3*sqrt(7)*t**2/2)/t
sol = [Eq(x(t), C3*t + t*(9*C1*I1 + 9*C2*I2)), Eq(y(t), C4*t + t*(3*sqrt(7)*C1*I1 - 3*sqrt(7)*C2*I2))]
assert checksysodesol(eq, sol) == (True, [0, 0])
eq = (Eq(diff(x(t),t), 21*x(t)), Eq(diff(y(t),t), 17*x(t)+3*y(t)), Eq(diff(z(t),t), 5*x(t)+7*y(t)+9*z(t)))
sol = [Eq(x(t), C1*exp(21*t)), Eq(y(t), 17*C1*exp(21*t)/18 + C2*exp(3*t)), \
Eq(z(t), 209*C1*exp(21*t)/216 - 7*C2*exp(3*t)/6 + C3*exp(9*t))]
assert checksysodesol(eq, sol) == (True, [0, 0, 0])
eq = (Eq(diff(x(t),t),3*y(t)-11*z(t)),Eq(diff(y(t),t),7*z(t)-3*x(t)),Eq(diff(z(t),t),11*x(t)-7*y(t)))
sol = [Eq(x(t), 7*C0 + sqrt(179)*C1*cos(sqrt(179)*t) + (77*C1/3 + 130*C2/3)*sin(sqrt(179)*t)), \
Eq(y(t), 11*C0 + sqrt(179)*C2*cos(sqrt(179)*t) + (-58*C1/3 - 77*C2/3)*sin(sqrt(179)*t)), \
Eq(z(t), 3*C0 + sqrt(179)*(-7*C1/3 - 11*C2/3)*cos(sqrt(179)*t) + (11*C1 - 7*C2)*sin(sqrt(179)*t))]
assert checksysodesol(eq, sol) == (True, [0, 0, 0])
eq = (Eq(3*diff(x(t),t),4*5*(y(t)-z(t))),Eq(4*diff(y(t),t),3*5*(z(t)-x(t))),Eq(5*diff(z(t),t),3*4*(x(t)-y(t))))
sol = [Eq(x(t), C0 + 5*sqrt(2)*C1*cos(5*sqrt(2)*t) + (12*C1/5 + 164*C2/15)*sin(5*sqrt(2)*t)), \
Eq(y(t), C0 + 5*sqrt(2)*C2*cos(5*sqrt(2)*t) + (-51*C1/10 - 12*C2/5)*sin(5*sqrt(2)*t)), \
Eq(z(t), C0 + 5*sqrt(2)*(-9*C1/25 - 16*C2/25)*cos(5*sqrt(2)*t) + (12*C1/5 - 12*C2/5)*sin(5*sqrt(2)*t))]
assert checksysodesol(eq, sol) == (True, [0, 0, 0])
eq = (Eq(diff(x(t),t),4*x(t) - z(t)),Eq(diff(y(t),t),2*x(t)+2*y(t)-z(t)),Eq(diff(z(t),t),3*x(t)+y(t)))
sol = [Eq(x(t), C1*exp(2*t) + C2*t*exp(2*t) + C2*exp(2*t) + C3*t**2*exp(2*t)/2 + C3*t*exp(2*t) + C3*exp(2*t)), \
Eq(y(t), C1*exp(2*t) + C2*t*exp(2*t) + C2*exp(2*t) + C3*t**2*exp(2*t)/2 + C3*t*exp(2*t)), \
Eq(z(t), 2*C1*exp(2*t) + 2*C2*t*exp(2*t) + C2*exp(2*t) + C3*t**2*exp(2*t) + C3*t*exp(2*t) + C3*exp(2*t))]
assert checksysodesol(eq, sol) == (True, [0, 0, 0])
eq = (Eq(diff(x(t),t),4*x(t) - y(t) - 2*z(t)),Eq(diff(y(t),t),2*x(t) + y(t)- 2*z(t)),Eq(diff(z(t),t),5*x(t)-3*z(t)))
sol = [Eq(x(t), C1*exp(2*t) + C2*(-sin(t) + 3*cos(t)) + C3*(3*sin(t) + cos(t))), \
Eq(y(t), C2*(-sin(t) + 3*cos(t)) + C3*(3*sin(t) + cos(t))), Eq(z(t), C1*exp(2*t) + 5*C2*cos(t) + 5*C3*sin(t))]
assert checksysodesol(eq, sol) == (True, [0, 0, 0])
eq = (Eq(diff(x(t),t),x(t)*y(t)**3), Eq(diff(y(t),t),y(t)**5))
sol = [Eq(x(t), C1*exp((-1/(4*C2 + 4*t))**(Rational(-1, 4)))), Eq(y(t), -(-1/(4*C2 + 4*t))**Rational(1, 4)), \
Eq(x(t), C1*exp(-1/(-1/(4*C2 + 4*t))**Rational(1, 4))), Eq(y(t), (-1/(4*C2 + 4*t))**Rational(1, 4)), \
Eq(x(t), C1*exp(-I/(-1/(4*C2 + 4*t))**Rational(1, 4))), Eq(y(t), -I*(-1/(4*C2 + 4*t))**Rational(1, 4)), \
Eq(x(t), C1*exp(I/(-1/(4*C2 + 4*t))**Rational(1, 4))), Eq(y(t), I*(-1/(4*C2 + 4*t))**Rational(1, 4))]
assert checksysodesol(eq, sol) == (True, [0, 0])
eq = (Eq(diff(x(t),t), exp(3*x(t))*y(t)**3),Eq(diff(y(t),t), y(t)**5))
sol = [Eq(x(t), -log(C1 - 3/(-1/(4*C2 + 4*t))**Rational(1, 4))/3), Eq(y(t), -(-1/(4*C2 + 4*t))**Rational(1, 4)), \
Eq(x(t), -log(C1 + 3/(-1/(4*C2 + 4*t))**Rational(1, 4))/3), Eq(y(t), (-1/(4*C2 + 4*t))**Rational(1, 4)), \
Eq(x(t), -log(C1 + 3*I/(-1/(4*C2 + 4*t))**Rational(1, 4))/3), Eq(y(t), -I*(-1/(4*C2 + 4*t))**Rational(1, 4)), \
Eq(x(t), -log(C1 - 3*I/(-1/(4*C2 + 4*t))**Rational(1, 4))/3), Eq(y(t), I*(-1/(4*C2 + 4*t))**Rational(1, 4))]
assert checksysodesol(eq, sol) == (True, [0, 0])
eq = (Eq(x(t),t*diff(x(t),t)+diff(x(t),t)*diff(y(t),t)), Eq(y(t),t*diff(y(t),t)+diff(y(t),t)**2))
sol = {Eq(x(t), C1*C2 + C1*t), Eq(y(t), C2**2 + C2*t)}
assert checksysodesol(eq, sol) == (True, [0, 0])
File diff suppressed because it is too large Load Diff
@@ -0,0 +1,966 @@
"""
This module contains pdsolve() and different helper functions that it
uses. It is heavily inspired by the ode module and hence the basic
infrastructure remains the same.
**Functions in this module**
These are the user functions in this module:
- pdsolve() - Solves PDE's
- classify_pde() - Classifies PDEs into possible hints for dsolve().
- pde_separate() - Separate variables in partial differential equation either by
additive or multiplicative separation approach.
These are the helper functions in this module:
- pde_separate_add() - Helper function for searching additive separable solutions.
- pde_separate_mul() - Helper function for searching multiplicative
separable solutions.
**Currently implemented solver methods**
The following methods are implemented for solving partial differential
equations. See the docstrings of the various pde_hint() functions for
more information on each (run help(pde)):
- 1st order linear homogeneous partial differential equations
with constant coefficients.
- 1st order linear general partial differential equations
with constant coefficients.
- 1st order linear partial differential equations with
variable coefficients.
"""
from functools import reduce
from itertools import combinations_with_replacement
from sympy.simplify import simplify # type: ignore
from sympy.core import Add, S
from sympy.core.function import Function, expand, AppliedUndef, Subs
from sympy.core.relational import Equality, Eq
from sympy.core.symbol import Symbol, Wild, symbols
from sympy.functions import exp
from sympy.integrals.integrals import Integral, integrate
from sympy.utilities.iterables import has_dups, is_sequence
from sympy.utilities.misc import filldedent
from sympy.solvers.deutils import _preprocess, ode_order, _desolve
from sympy.solvers.solvers import solve
from sympy.simplify.radsimp import collect
import operator
allhints = (
"1st_linear_constant_coeff_homogeneous",
"1st_linear_constant_coeff",
"1st_linear_constant_coeff_Integral",
"1st_linear_variable_coeff"
)
def pdsolve(eq, func=None, hint='default', dict=False, solvefun=None, **kwargs):
"""
Solves any (supported) kind of partial differential equation.
**Usage**
pdsolve(eq, f(x,y), hint) -> Solve partial differential equation
eq for function f(x,y), using method hint.
**Details**
``eq`` can be any supported partial differential equation (see
the pde docstring for supported methods). This can either
be an Equality, or an expression, which is assumed to be
equal to 0.
``f(x,y)`` is a function of two variables whose derivatives in that
variable make up the partial differential equation. In many
cases it is not necessary to provide this; it will be autodetected
(and an error raised if it could not be detected).
``hint`` is the solving method that you want pdsolve to use. Use
classify_pde(eq, f(x,y)) to get all of the possible hints for
a PDE. The default hint, 'default', will use whatever hint
is returned first by classify_pde(). See Hints below for
more options that you can use for hint.
``solvefun`` is the convention used for arbitrary functions returned
by the PDE solver. If not set by the user, it is set by default
to be F.
**Hints**
Aside from the various solving methods, there are also some
meta-hints that you can pass to pdsolve():
"default":
This uses whatever hint is returned first by
classify_pde(). This is the default argument to
pdsolve().
"all":
To make pdsolve apply all relevant classification hints,
use pdsolve(PDE, func, hint="all"). This will return a
dictionary of hint:solution terms. If a hint causes
pdsolve to raise the NotImplementedError, value of that
hint's key will be the exception object raised. The
dictionary will also include some special keys:
- order: The order of the PDE. See also ode_order() in
deutils.py
- default: The solution that would be returned by
default. This is the one produced by the hint that
appears first in the tuple returned by classify_pde().
"all_Integral":
This is the same as "all", except if a hint also has a
corresponding "_Integral" hint, it only returns the
"_Integral" hint. This is useful if "all" causes
pdsolve() to hang because of a difficult or impossible
integral. This meta-hint will also be much faster than
"all", because integrate() is an expensive routine.
See also the classify_pde() docstring for more info on hints,
and the pde docstring for a list of all supported hints.
**Tips**
- You can declare the derivative of an unknown function this way:
>>> from sympy import Function, Derivative
>>> from sympy.abc import x, y # x and y are the independent variables
>>> f = Function("f")(x, y) # f is a function of x and y
>>> # fx will be the partial derivative of f with respect to x
>>> fx = Derivative(f, x)
>>> # fy will be the partial derivative of f with respect to y
>>> fy = Derivative(f, y)
- See test_pde.py for many tests, which serves also as a set of
examples for how to use pdsolve().
- pdsolve always returns an Equality class (except for the case
when the hint is "all" or "all_Integral"). Note that it is not possible
to get an explicit solution for f(x, y) as in the case of ODE's
- Do help(pde.pde_hintname) to get help more information on a
specific hint
Examples
========
>>> from sympy.solvers.pde import pdsolve
>>> from sympy import Function, Eq
>>> from sympy.abc import x, y
>>> f = Function('f')
>>> u = f(x, y)
>>> ux = u.diff(x)
>>> uy = u.diff(y)
>>> eq = Eq(1 + (2*(ux/u)) + (3*(uy/u)), 0)
>>> pdsolve(eq)
Eq(f(x, y), F(3*x - 2*y)*exp(-2*x/13 - 3*y/13))
"""
if not solvefun:
solvefun = Function('F')
# See the docstring of _desolve for more details.
hints = _desolve(eq, func=func, hint=hint, simplify=True,
type='pde', **kwargs)
eq = hints.pop('eq', False)
all_ = hints.pop('all', False)
if all_:
# TODO : 'best' hint should be implemented when adequate
# number of hints are added.
pdedict = {}
failed_hints = {}
gethints = classify_pde(eq, dict=True)
pdedict.update({'order': gethints['order'],
'default': gethints['default']})
for hint in hints:
try:
rv = _helper_simplify(eq, hint, hints[hint]['func'],
hints[hint]['order'], hints[hint][hint], solvefun)
except NotImplementedError as detail:
failed_hints[hint] = detail
else:
pdedict[hint] = rv
pdedict.update(failed_hints)
return pdedict
else:
return _helper_simplify(eq, hints['hint'], hints['func'],
hints['order'], hints[hints['hint']], solvefun)
def _helper_simplify(eq, hint, func, order, match, solvefun):
"""Helper function of pdsolve that calls the respective
pde functions to solve for the partial differential
equations. This minimizes the computation in
calling _desolve multiple times.
"""
solvefunc = globals()["pde_" + hint.removesuffix("_Integral")]
return _handle_Integral(solvefunc(eq, func, order,
match, solvefun), func, order, hint)
def _handle_Integral(expr, func, order, hint):
r"""
Converts a solution with integrals in it into an actual solution.
Simplifies the integral mainly using doit()
"""
if hint.endswith("_Integral"):
return expr
elif hint == "1st_linear_constant_coeff":
return simplify(expr.doit())
else:
return expr
def classify_pde(eq, func=None, dict=False, *, prep=True, **kwargs):
"""
Returns a tuple of possible pdsolve() classifications for a PDE.
The tuple is ordered so that first item is the classification that
pdsolve() uses to solve the PDE by default. In general,
classifications near the beginning of the list will produce
better solutions faster than those near the end, though there are
always exceptions. To make pdsolve use a different classification,
use pdsolve(PDE, func, hint=<classification>). See also the pdsolve()
docstring for different meta-hints you can use.
If ``dict`` is true, classify_pde() will return a dictionary of
hint:match expression terms. This is intended for internal use by
pdsolve(). Note that because dictionaries are ordered arbitrarily,
this will most likely not be in the same order as the tuple.
You can get help on different hints by doing help(pde.pde_hintname),
where hintname is the name of the hint without "_Integral".
See sympy.pde.allhints or the sympy.pde docstring for a list of all
supported hints that can be returned from classify_pde.
Examples
========
>>> from sympy.solvers.pde import classify_pde
>>> from sympy import Function, Eq
>>> from sympy.abc import x, y
>>> f = Function('f')
>>> u = f(x, y)
>>> ux = u.diff(x)
>>> uy = u.diff(y)
>>> eq = Eq(1 + (2*(ux/u)) + (3*(uy/u)), 0)
>>> classify_pde(eq)
('1st_linear_constant_coeff_homogeneous',)
"""
if func and len(func.args) != 2:
raise NotImplementedError("Right now only partial "
"differential equations of two variables are supported")
if prep or func is None:
prep, func_ = _preprocess(eq, func)
if func is None:
func = func_
if isinstance(eq, Equality):
if eq.rhs != 0:
return classify_pde(eq.lhs - eq.rhs, func)
eq = eq.lhs
f = func.func
x = func.args[0]
y = func.args[1]
fx = f(x,y).diff(x)
fy = f(x,y).diff(y)
# TODO : For now pde.py uses support offered by the ode_order function
# to find the order with respect to a multi-variable function. An
# improvement could be to classify the order of the PDE on the basis of
# individual variables.
order = ode_order(eq, f(x,y))
# hint:matchdict or hint:(tuple of matchdicts)
# Also will contain "default":<default hint> and "order":order items.
matching_hints = {'order': order}
if not order:
if dict:
matching_hints["default"] = None
return matching_hints
return ()
eq = expand(eq)
a = Wild('a', exclude = [f(x,y)])
b = Wild('b', exclude = [f(x,y), fx, fy, x, y])
c = Wild('c', exclude = [f(x,y), fx, fy, x, y])
d = Wild('d', exclude = [f(x,y), fx, fy, x, y])
e = Wild('e', exclude = [f(x,y), fx, fy])
n = Wild('n', exclude = [x, y])
# Try removing the smallest power of f(x,y)
# from the highest partial derivatives of f(x,y)
reduced_eq = eq
if eq.is_Add:
power = None
for i in set(combinations_with_replacement((x,y), order)):
coeff = eq.coeff(f(x,y).diff(*i))
if coeff == 1:
continue
match = coeff.match(a*f(x,y)**n)
if match and match[a]:
if power is None or match[n] < power:
power = match[n]
if power:
den = f(x,y)**power
reduced_eq = Add(*[arg/den for arg in eq.args])
if order == 1:
reduced_eq = collect(reduced_eq, f(x, y))
r = reduced_eq.match(b*fx + c*fy + d*f(x,y) + e)
if r:
if not r[e]:
## Linear first-order homogeneous partial-differential
## equation with constant coefficients
r.update({'b': b, 'c': c, 'd': d})
matching_hints["1st_linear_constant_coeff_homogeneous"] = r
elif r[b]**2 + r[c]**2 != 0:
## Linear first-order general partial-differential
## equation with constant coefficients
r.update({'b': b, 'c': c, 'd': d, 'e': e})
matching_hints["1st_linear_constant_coeff"] = r
matching_hints["1st_linear_constant_coeff_Integral"] = r
else:
b = Wild('b', exclude=[f(x, y), fx, fy])
c = Wild('c', exclude=[f(x, y), fx, fy])
d = Wild('d', exclude=[f(x, y), fx, fy])
r = reduced_eq.match(b*fx + c*fy + d*f(x,y) + e)
if r:
r.update({'b': b, 'c': c, 'd': d, 'e': e})
matching_hints["1st_linear_variable_coeff"] = r
# Order keys based on allhints.
rettuple = tuple(i for i in allhints if i in matching_hints)
if dict:
# Dictionaries are ordered arbitrarily, so make note of which
# hint would come first for pdsolve(). Use an ordered dict in Py 3.
matching_hints["default"] = None
matching_hints["ordered_hints"] = rettuple
for i in allhints:
if i in matching_hints:
matching_hints["default"] = i
break
return matching_hints
return rettuple
def checkpdesol(pde, sol, func=None, solve_for_func=True):
"""
Checks if the given solution satisfies the partial differential
equation.
pde is the partial differential equation which can be given in the
form of an equation or an expression. sol is the solution for which
the pde is to be checked. This can also be given in an equation or
an expression form. If the function is not provided, the helper
function _preprocess from deutils is used to identify the function.
If a sequence of solutions is passed, the same sort of container will be
used to return the result for each solution.
The following methods are currently being implemented to check if the
solution satisfies the PDE:
1. Directly substitute the solution in the PDE and check. If the
solution has not been solved for f, then it will solve for f
provided solve_for_func has not been set to False.
If the solution satisfies the PDE, then a tuple (True, 0) is returned.
Otherwise a tuple (False, expr) where expr is the value obtained
after substituting the solution in the PDE. However if a known solution
returns False, it may be due to the inability of doit() to simplify it to zero.
Examples
========
>>> from sympy import Function, symbols
>>> from sympy.solvers.pde import checkpdesol, pdsolve
>>> x, y = symbols('x y')
>>> f = Function('f')
>>> eq = 2*f(x,y) + 3*f(x,y).diff(x) + 4*f(x,y).diff(y)
>>> sol = pdsolve(eq)
>>> assert checkpdesol(eq, sol)[0]
>>> eq = x*f(x,y) + f(x,y).diff(x)
>>> checkpdesol(eq, sol)
(False, (x*F(4*x - 3*y) - 6*F(4*x - 3*y)/25 + 4*Subs(Derivative(F(_xi_1), _xi_1), _xi_1, 4*x - 3*y))*exp(-6*x/25 - 8*y/25))
"""
# Converting the pde into an equation
if not isinstance(pde, Equality):
pde = Eq(pde, 0)
# If no function is given, try finding the function present.
if func is None:
try:
_, func = _preprocess(pde.lhs)
except ValueError:
funcs = [s.atoms(AppliedUndef) for s in (
sol if is_sequence(sol, set) else [sol])]
funcs = set().union(funcs)
if len(funcs) != 1:
raise ValueError(
'must pass func arg to checkpdesol for this case.')
func = funcs.pop()
# If the given solution is in the form of a list or a set
# then return a list or set of tuples.
if is_sequence(sol, set):
return type(sol)([checkpdesol(
pde, i, func=func,
solve_for_func=solve_for_func) for i in sol])
# Convert solution into an equation
if not isinstance(sol, Equality):
sol = Eq(func, sol)
elif sol.rhs == func:
sol = sol.reversed
# Try solving for the function
solved = sol.lhs == func and not sol.rhs.has(func)
if solve_for_func and not solved:
solved = solve(sol, func)
if solved:
if len(solved) == 1:
return checkpdesol(pde, Eq(func, solved[0]),
func=func, solve_for_func=False)
else:
return checkpdesol(pde, [Eq(func, t) for t in solved],
func=func, solve_for_func=False)
# try direct substitution of the solution into the PDE and simplify
if sol.lhs == func:
pde = pde.lhs - pde.rhs
s = simplify(pde.subs(func, sol.rhs).doit())
return s is S.Zero, s
raise NotImplementedError(filldedent('''
Unable to test if %s is a solution to %s.''' % (sol, pde)))
def pde_1st_linear_constant_coeff_homogeneous(eq, func, order, match, solvefun):
r"""
Solves a first order linear homogeneous
partial differential equation with constant coefficients.
The general form of this partial differential equation is
.. math:: a \frac{\partial f(x,y)}{\partial x}
+ b \frac{\partial f(x,y)}{\partial y} + c f(x,y) = 0
where `a`, `b` and `c` are constants.
The general solution is of the form:
.. math::
f(x, y) = F(- a y + b x ) e^{- \frac{c (a x + b y)}{a^2 + b^2}}
and can be found in SymPy with ``pdsolve``::
>>> from sympy.solvers import pdsolve
>>> from sympy.abc import x, y, a, b, c
>>> from sympy import Function, pprint
>>> f = Function('f')
>>> u = f(x,y)
>>> ux = u.diff(x)
>>> uy = u.diff(y)
>>> genform = a*ux + b*uy + c*u
>>> pprint(genform)
d d
a*--(f(x, y)) + b*--(f(x, y)) + c*f(x, y)
dx dy
>>> pprint(pdsolve(genform))
-c*(a*x + b*y)
---------------
2 2
a + b
f(x, y) = F(-a*y + b*x)*e
Examples
========
>>> from sympy import pdsolve
>>> from sympy import Function, pprint
>>> from sympy.abc import x,y
>>> f = Function('f')
>>> pdsolve(f(x,y) + f(x,y).diff(x) + f(x,y).diff(y))
Eq(f(x, y), F(x - y)*exp(-x/2 - y/2))
>>> pprint(pdsolve(f(x,y) + f(x,y).diff(x) + f(x,y).diff(y)))
x y
- - - -
2 2
f(x, y) = F(x - y)*e
References
==========
- Viktor Grigoryan, "Partial Differential Equations"
Math 124A - Fall 2010, pp.7
"""
# TODO : For now homogeneous first order linear PDE's having
# two variables are implemented. Once there is support for
# solving systems of ODE's, this can be extended to n variables.
f = func.func
x = func.args[0]
y = func.args[1]
b = match[match['b']]
c = match[match['c']]
d = match[match['d']]
return Eq(f(x,y), exp(-S(d)/(b**2 + c**2)*(b*x + c*y))*solvefun(c*x - b*y))
def pde_1st_linear_constant_coeff(eq, func, order, match, solvefun):
r"""
Solves a first order linear partial differential equation
with constant coefficients.
The general form of this partial differential equation is
.. math:: a \frac{\partial f(x,y)}{\partial x}
+ b \frac{\partial f(x,y)}{\partial y}
+ c f(x,y) = G(x,y)
where `a`, `b` and `c` are constants and `G(x, y)` can be an arbitrary
function in `x` and `y`.
The general solution of the PDE is:
.. math::
f(x, y) = \left. \left[F(\eta) + \frac{1}{a^2 + b^2}
\int\limits^{a x + b y} G\left(\frac{a \xi + b \eta}{a^2 + b^2},
\frac{- a \eta + b \xi}{a^2 + b^2} \right)
e^{\frac{c \xi}{a^2 + b^2}}\, d\xi\right]
e^{- \frac{c \xi}{a^2 + b^2}}
\right|_{\substack{\eta=- a y + b x\\ \xi=a x + b y }}\, ,
where `F(\eta)` is an arbitrary single-valued function. The solution
can be found in SymPy with ``pdsolve``::
>>> from sympy.solvers import pdsolve
>>> from sympy.abc import x, y, a, b, c
>>> from sympy import Function, pprint
>>> f = Function('f')
>>> G = Function('G')
>>> u = f(x, y)
>>> ux = u.diff(x)
>>> uy = u.diff(y)
>>> genform = a*ux + b*uy + c*u - G(x,y)
>>> pprint(genform)
d d
a*--(f(x, y)) + b*--(f(x, y)) + c*f(x, y) - G(x, y)
dx dy
>>> pprint(pdsolve(genform, hint='1st_linear_constant_coeff_Integral'))
// a*x + b*y \ \|
|| / | ||
|| | | ||
|| | c*xi | ||
|| | ------- | ||
|| | 2 2 | ||
|| | /a*xi + b*eta -a*eta + b*xi\ a + b | ||
|| | G|------------, -------------|*e d(xi)| ||
|| | | 2 2 2 2 | | ||
|| | \ a + b a + b / | -c*xi ||
|| | | -------||
|| / | 2 2||
|| | a + b ||
f(x, y) = ||F(eta) + -------------------------------------------------------|*e ||
|| 2 2 | ||
\\ a + b / /|eta=-a*y + b*x, xi=a*x + b*y
Examples
========
>>> from sympy.solvers.pde import pdsolve
>>> from sympy import Function, pprint, exp
>>> from sympy.abc import x,y
>>> f = Function('f')
>>> eq = -2*f(x,y).diff(x) + 4*f(x,y).diff(y) + 5*f(x,y) - exp(x + 3*y)
>>> pdsolve(eq)
Eq(f(x, y), (F(4*x + 2*y)*exp(x/2) + exp(x + 4*y)/15)*exp(-y))
References
==========
- Viktor Grigoryan, "Partial Differential Equations"
Math 124A - Fall 2010, pp.7
"""
# TODO : For now homogeneous first order linear PDE's having
# two variables are implemented. Once there is support for
# solving systems of ODE's, this can be extended to n variables.
xi, eta = symbols("xi eta")
f = func.func
x = func.args[0]
y = func.args[1]
b = match[match['b']]
c = match[match['c']]
d = match[match['d']]
e = -match[match['e']]
expterm = exp(-S(d)/(b**2 + c**2)*xi)
functerm = solvefun(eta)
solvedict = solve((b*x + c*y - xi, c*x - b*y - eta), x, y)
# Integral should remain as it is in terms of xi,
# doit() should be done in _handle_Integral.
genterm = (1/S(b**2 + c**2))*Integral(
(1/expterm*e).subs(solvedict), (xi, b*x + c*y))
return Eq(f(x,y), Subs(expterm*(functerm + genterm),
(eta, xi), (c*x - b*y, b*x + c*y)))
def pde_1st_linear_variable_coeff(eq, func, order, match, solvefun):
r"""
Solves a first order linear partial differential equation
with variable coefficients. The general form of this partial
differential equation is
.. math:: a(x, y) \frac{\partial f(x, y)}{\partial x}
+ b(x, y) \frac{\partial f(x, y)}{\partial y}
+ c(x, y) f(x, y) = G(x, y)
where `a(x, y)`, `b(x, y)`, `c(x, y)` and `G(x, y)` are arbitrary
functions in `x` and `y`. This PDE is converted into an ODE by
making the following transformation:
1. `\xi` as `x`
2. `\eta` as the constant in the solution to the differential
equation `\frac{dy}{dx} = -\frac{b}{a}`
Making the previous substitutions reduces it to the linear ODE
.. math:: a(\xi, \eta)\frac{du}{d\xi} + c(\xi, \eta)u - G(\xi, \eta) = 0
which can be solved using ``dsolve``.
>>> from sympy.abc import x, y
>>> from sympy import Function, pprint
>>> a, b, c, G, f= [Function(i) for i in ['a', 'b', 'c', 'G', 'f']]
>>> u = f(x,y)
>>> ux = u.diff(x)
>>> uy = u.diff(y)
>>> genform = a(x, y)*u + b(x, y)*ux + c(x, y)*uy - G(x,y)
>>> pprint(genform)
d d
-G(x, y) + a(x, y)*f(x, y) + b(x, y)*--(f(x, y)) + c(x, y)*--(f(x, y))
dx dy
Examples
========
>>> from sympy.solvers.pde import pdsolve
>>> from sympy import Function, pprint
>>> from sympy.abc import x,y
>>> f = Function('f')
>>> eq = x*(u.diff(x)) - y*(u.diff(y)) + y**2*u - y**2
>>> pdsolve(eq)
Eq(f(x, y), F(x*y)*exp(y**2/2) + 1)
References
==========
- Viktor Grigoryan, "Partial Differential Equations"
Math 124A - Fall 2010, pp.7
"""
from sympy.solvers.ode import dsolve
eta = symbols("eta")
f = func.func
x = func.args[0]
y = func.args[1]
b = match[match['b']]
c = match[match['c']]
d = match[match['d']]
e = -match[match['e']]
if not d:
# To deal with cases like b*ux = e or c*uy = e
if not (b and c):
if c:
try:
tsol = integrate(e/c, y)
except NotImplementedError:
raise NotImplementedError("Unable to find a solution"
" due to inability of integrate")
else:
return Eq(f(x,y), solvefun(x) + tsol)
if b:
try:
tsol = integrate(e/b, x)
except NotImplementedError:
raise NotImplementedError("Unable to find a solution"
" due to inability of integrate")
else:
return Eq(f(x,y), solvefun(y) + tsol)
if not c:
# To deal with cases when c is 0, a simpler method is used.
# The PDE reduces to b*(u.diff(x)) + d*u = e, which is a linear ODE in x
plode = f(x).diff(x)*b + d*f(x) - e
sol = dsolve(plode, f(x))
syms = sol.free_symbols - plode.free_symbols - {x, y}
rhs = _simplify_variable_coeff(sol.rhs, syms, solvefun, y)
return Eq(f(x, y), rhs)
if not b:
# To deal with cases when b is 0, a simpler method is used.
# The PDE reduces to c*(u.diff(y)) + d*u = e, which is a linear ODE in y
plode = f(y).diff(y)*c + d*f(y) - e
sol = dsolve(plode, f(y))
syms = sol.free_symbols - plode.free_symbols - {x, y}
rhs = _simplify_variable_coeff(sol.rhs, syms, solvefun, x)
return Eq(f(x, y), rhs)
dummy = Function('d')
h = (c/b).subs(y, dummy(x))
sol = dsolve(dummy(x).diff(x) - h, dummy(x))
if isinstance(sol, list):
sol = sol[0]
solsym = sol.free_symbols - h.free_symbols - {x, y}
if len(solsym) == 1:
solsym = solsym.pop()
etat = (solve(sol, solsym)[0]).subs(dummy(x), y)
ysub = solve(eta - etat, y)[0]
deq = (b*(f(x).diff(x)) + d*f(x) - e).subs(y, ysub)
final = (dsolve(deq, f(x), hint='1st_linear')).rhs
if isinstance(final, list):
final = final[0]
finsyms = final.free_symbols - deq.free_symbols - {x, y}
rhs = _simplify_variable_coeff(final, finsyms, solvefun, etat)
return Eq(f(x, y), rhs)
else:
raise NotImplementedError("Cannot solve the partial differential equation due"
" to inability of constantsimp")
def _simplify_variable_coeff(sol, syms, func, funcarg):
r"""
Helper function to replace constants by functions in 1st_linear_variable_coeff
"""
eta = Symbol("eta")
if len(syms) == 1:
sym = syms.pop()
final = sol.subs(sym, func(funcarg))
else:
for sym in syms:
final = sol.subs(sym, func(funcarg))
return simplify(final.subs(eta, funcarg))
def pde_separate(eq, fun, sep, strategy='mul'):
"""Separate variables in partial differential equation either by additive
or multiplicative separation approach. It tries to rewrite an equation so
that one of the specified variables occurs on a different side of the
equation than the others.
:param eq: Partial differential equation
:param fun: Original function F(x, y, z)
:param sep: List of separated functions [X(x), u(y, z)]
:param strategy: Separation strategy. You can choose between additive
separation ('add') and multiplicative separation ('mul') which is
default.
Examples
========
>>> from sympy import E, Eq, Function, pde_separate, Derivative as D
>>> from sympy.abc import x, t
>>> u, X, T = map(Function, 'uXT')
>>> eq = Eq(D(u(x, t), x), E**(u(x, t))*D(u(x, t), t))
>>> pde_separate(eq, u(x, t), [X(x), T(t)], strategy='add')
[exp(-X(x))*Derivative(X(x), x), exp(T(t))*Derivative(T(t), t)]
>>> eq = Eq(D(u(x, t), x, 2), D(u(x, t), t, 2))
>>> pde_separate(eq, u(x, t), [X(x), T(t)], strategy='mul')
[Derivative(X(x), (x, 2))/X(x), Derivative(T(t), (t, 2))/T(t)]
See Also
========
pde_separate_add, pde_separate_mul
"""
do_add = False
if strategy == 'add':
do_add = True
elif strategy == 'mul':
do_add = False
else:
raise ValueError('Unknown strategy: %s' % strategy)
if isinstance(eq, Equality):
if eq.rhs != 0:
return pde_separate(Eq(eq.lhs - eq.rhs, 0), fun, sep, strategy)
else:
return pde_separate(Eq(eq, 0), fun, sep, strategy)
if eq.rhs != 0:
raise ValueError("Value should be 0")
# Handle arguments
orig_args = list(fun.args)
subs_args = [arg for s in sep for arg in s.args]
if do_add:
functions = reduce(operator.add, sep)
else:
functions = reduce(operator.mul, sep)
# Check whether variables match
if len(subs_args) != len(orig_args):
raise ValueError("Variable counts do not match")
# Check for duplicate arguments like [X(x), u(x, y)]
if has_dups(subs_args):
raise ValueError("Duplicate substitution arguments detected")
# Check whether the variables match
if set(orig_args) != set(subs_args):
raise ValueError("Arguments do not match")
# Substitute original function with separated...
result = eq.lhs.subs(fun, functions).doit()
# Divide by terms when doing multiplicative separation
if not do_add:
eq = 0
for i in result.args:
eq += i/functions
result = eq
svar = subs_args[0]
dvar = subs_args[1:]
return _separate(result, svar, dvar)
def pde_separate_add(eq, fun, sep):
"""
Helper function for searching additive separable solutions.
Consider an equation of two independent variables x, y and a dependent
variable w, we look for the product of two functions depending on different
arguments:
`w(x, y, z) = X(x) + y(y, z)`
Examples
========
>>> from sympy import E, Eq, Function, pde_separate_add, Derivative as D
>>> from sympy.abc import x, t
>>> u, X, T = map(Function, 'uXT')
>>> eq = Eq(D(u(x, t), x), E**(u(x, t))*D(u(x, t), t))
>>> pde_separate_add(eq, u(x, t), [X(x), T(t)])
[exp(-X(x))*Derivative(X(x), x), exp(T(t))*Derivative(T(t), t)]
"""
return pde_separate(eq, fun, sep, strategy='add')
def pde_separate_mul(eq, fun, sep):
"""
Helper function for searching multiplicative separable solutions.
Consider an equation of two independent variables x, y and a dependent
variable w, we look for the product of two functions depending on different
arguments:
`w(x, y, z) = X(x)*u(y, z)`
Examples
========
>>> from sympy import Function, Eq, pde_separate_mul, Derivative as D
>>> from sympy.abc import x, y
>>> u, X, Y = map(Function, 'uXY')
>>> eq = Eq(D(u(x, y), x, 2), D(u(x, y), y, 2))
>>> pde_separate_mul(eq, u(x, y), [X(x), Y(y)])
[Derivative(X(x), (x, 2))/X(x), Derivative(Y(y), (y, 2))/Y(y)]
"""
return pde_separate(eq, fun, sep, strategy='mul')
def _separate(eq, dep, others):
"""Separate expression into two parts based on dependencies of variables."""
# FIRST PASS
# Extract derivatives depending our separable variable...
terms = set()
for term in eq.args:
if term.is_Mul:
for i in term.args:
if i.is_Derivative and not i.has(*others):
terms.add(term)
continue
elif term.is_Derivative and not term.has(*others):
terms.add(term)
# Find the factor that we need to divide by
div = set()
for term in terms:
ext, sep = term.expand().as_independent(dep)
# Failed?
if sep.has(*others):
return None
div.add(ext)
# FIXME: Find lcm() of all the divisors and divide with it, instead of
# current hack :(
# https://github.com/sympy/sympy/issues/4597
if len(div) > 0:
# double sum required or some tests will fail
eq = Add(*[simplify(Add(*[term/i for i in div])) for term in eq.args])
# SECOND PASS - separate the derivatives
div = set()
lhs = rhs = 0
for term in eq.args:
# Check, whether we have already term with independent variable...
if not term.has(*others):
lhs += term
continue
# ...otherwise, try to separate
temp, sep = term.expand().as_independent(dep)
# Failed?
if sep.has(*others):
return None
# Extract the divisors
div.add(sep)
rhs -= term.expand()
# Do the division
fulldiv = reduce(operator.add, div)
lhs = simplify(lhs/fulldiv).expand()
rhs = simplify(rhs/fulldiv).expand()
# ...and check whether we were successful :)
if lhs.has(*others) or rhs.has(dep):
return None
return [lhs, rhs]
@@ -0,0 +1,872 @@
"""Solvers of systems of polynomial equations. """
from __future__ import annotations
from typing import Any
from collections.abc import Sequence, Iterable
import itertools
from sympy import Dummy
from sympy.core import S
from sympy.core.expr import Expr
from sympy.core.exprtools import factor_terms
from sympy.core.sorting import default_sort_key
from sympy.logic.boolalg import Boolean
from sympy.polys import Poly, groebner, roots
from sympy.polys.domains import ZZ
from sympy.polys.polyoptions import build_options
from sympy.polys.polytools import parallel_poly_from_expr, sqf_part
from sympy.polys.polyerrors import (
ComputationFailed,
PolificationFailed,
CoercionFailed,
GeneratorsNeeded,
DomainError
)
from sympy.simplify import rcollect
from sympy.utilities import postfixes
from sympy.utilities.iterables import cartes
from sympy.utilities.misc import filldedent
from sympy.logic.boolalg import Or, And
from sympy.core.relational import Eq
class SolveFailed(Exception):
"""Raised when solver's conditions were not met. """
def solve_poly_system(seq, *gens, strict=False, **args):
"""
Return a list of solutions for the system of polynomial equations
or else None.
Parameters
==========
seq: a list/tuple/set
Listing all the equations that are needed to be solved
gens: generators
generators of the equations in seq for which we want the
solutions
strict: a boolean (default is False)
if strict is True, NotImplementedError will be raised if
the solution is known to be incomplete (which can occur if
not all solutions are expressible in radicals)
args: Keyword arguments
Special options for solving the equations.
Returns
=======
List[Tuple]
a list of tuples with elements being solutions for the
symbols in the order they were passed as gens
None
None is returned when the computed basis contains only the ground.
Examples
========
>>> from sympy import solve_poly_system
>>> from sympy.abc import x, y
>>> solve_poly_system([x*y - 2*y, 2*y**2 - x**2], x, y)
[(0, 0), (2, -sqrt(2)), (2, sqrt(2))]
>>> solve_poly_system([x**5 - x + y**3, y**2 - 1], x, y, strict=True)
Traceback (most recent call last):
...
UnsolvableFactorError
"""
try:
polys, opt = parallel_poly_from_expr(seq, *gens, **args)
except PolificationFailed as exc:
raise ComputationFailed('solve_poly_system', len(seq), exc)
if len(polys) == len(opt.gens) == 2:
f, g = polys
if all(i <= 2 for i in f.degree_list() + g.degree_list()):
try:
return solve_biquadratic(f, g, opt)
except SolveFailed:
pass
return solve_generic(polys, opt, strict=strict)
def solve_biquadratic(f, g, opt):
"""Solve a system of two bivariate quadratic polynomial equations.
Parameters
==========
f: a single Expr or Poly
First equation
g: a single Expr or Poly
Second Equation
opt: an Options object
For specifying keyword arguments and generators
Returns
=======
List[Tuple]
a list of tuples with elements being solutions for the
symbols in the order they were passed as gens
None
None is returned when the computed basis contains only the ground.
Examples
========
>>> from sympy import Options, Poly
>>> from sympy.abc import x, y
>>> from sympy.solvers.polysys import solve_biquadratic
>>> NewOption = Options((x, y), {'domain': 'ZZ'})
>>> a = Poly(y**2 - 4 + x, y, x, domain='ZZ')
>>> b = Poly(y*2 + 3*x - 7, y, x, domain='ZZ')
>>> solve_biquadratic(a, b, NewOption)
[(1/3, 3), (41/27, 11/9)]
>>> a = Poly(y + x**2 - 3, y, x, domain='ZZ')
>>> b = Poly(-y + x - 4, y, x, domain='ZZ')
>>> solve_biquadratic(a, b, NewOption)
[(7/2 - sqrt(29)/2, -sqrt(29)/2 - 1/2), (sqrt(29)/2 + 7/2, -1/2 + \
sqrt(29)/2)]
"""
G = groebner([f, g])
if len(G) == 1 and G[0].is_ground:
return None
if len(G) != 2:
raise SolveFailed
x, y = opt.gens
p, q = G
if not p.gcd(q).is_ground:
# not 0-dimensional
raise SolveFailed
p = Poly(p, x, expand=False)
p_roots = [rcollect(expr, y) for expr in roots(p).keys()]
q = q.ltrim(-1)
q_roots = list(roots(q).keys())
solutions = [(p_root.subs(y, q_root), q_root) for q_root, p_root in
itertools.product(q_roots, p_roots)]
return sorted(solutions, key=default_sort_key)
def solve_generic(polys, opt, strict=False):
"""
Solve a generic system of polynomial equations.
Returns all possible solutions over C[x_1, x_2, ..., x_m] of a
set F = { f_1, f_2, ..., f_n } of polynomial equations, using
Groebner basis approach. For now only zero-dimensional systems
are supported, which means F can have at most a finite number
of solutions. If the basis contains only the ground, None is
returned.
The algorithm works by the fact that, supposing G is the basis
of F with respect to an elimination order (here lexicographic
order is used), G and F generate the same ideal, they have the
same set of solutions. By the elimination property, if G is a
reduced, zero-dimensional Groebner basis, then there exists an
univariate polynomial in G (in its last variable). This can be
solved by computing its roots. Substituting all computed roots
for the last (eliminated) variable in other elements of G, new
polynomial system is generated. Applying the above procedure
recursively, a finite number of solutions can be found.
The ability of finding all solutions by this procedure depends
on the root finding algorithms. If no solutions were found, it
means only that roots() failed, but the system is solvable. To
overcome this difficulty use numerical algorithms instead.
Parameters
==========
polys: a list/tuple/set
Listing all the polynomial equations that are needed to be solved
opt: an Options object
For specifying keyword arguments and generators
strict: a boolean
If strict is True, NotImplementedError will be raised if the solution
is known to be incomplete
Returns
=======
List[Tuple]
a list of tuples with elements being solutions for the
symbols in the order they were passed as gens
None
None is returned when the computed basis contains only the ground.
References
==========
.. [Buchberger01] B. Buchberger, Groebner Bases: A Short
Introduction for Systems Theorists, In: R. Moreno-Diaz,
B. Buchberger, J.L. Freire, Proceedings of EUROCAST'01,
February, 2001
.. [Cox97] D. Cox, J. Little, D. O'Shea, Ideals, Varieties
and Algorithms, Springer, Second Edition, 1997, pp. 112
Raises
========
NotImplementedError
If the system is not zero-dimensional (does not have a finite
number of solutions)
UnsolvableFactorError
If ``strict`` is True and not all solution components are
expressible in radicals
Examples
========
>>> from sympy import Poly, Options
>>> from sympy.solvers.polysys import solve_generic
>>> from sympy.abc import x, y
>>> NewOption = Options((x, y), {'domain': 'ZZ'})
>>> a = Poly(x - y + 5, x, y, domain='ZZ')
>>> b = Poly(x + y - 3, x, y, domain='ZZ')
>>> solve_generic([a, b], NewOption)
[(-1, 4)]
>>> a = Poly(x - 2*y + 5, x, y, domain='ZZ')
>>> b = Poly(2*x - y - 3, x, y, domain='ZZ')
>>> solve_generic([a, b], NewOption)
[(11/3, 13/3)]
>>> a = Poly(x**2 + y, x, y, domain='ZZ')
>>> b = Poly(x + y*4, x, y, domain='ZZ')
>>> solve_generic([a, b], NewOption)
[(0, 0), (1/4, -1/16)]
>>> a = Poly(x**5 - x + y**3, x, y, domain='ZZ')
>>> b = Poly(y**2 - 1, x, y, domain='ZZ')
>>> solve_generic([a, b], NewOption, strict=True)
Traceback (most recent call last):
...
UnsolvableFactorError
"""
def _is_univariate(f):
"""Returns True if 'f' is univariate in its last variable. """
for monom in f.monoms():
if any(monom[:-1]):
return False
return True
def _subs_root(f, gen, zero):
"""Replace generator with a root so that the result is nice. """
p = f.as_expr({gen: zero})
if f.degree(gen) >= 2:
p = p.expand(deep=False)
return p
def _solve_reduced_system(system, gens, entry=False):
"""Recursively solves reduced polynomial systems. """
if len(system) == len(gens) == 1:
# the below line will produce UnsolvableFactorError if
# strict=True and the solution from `roots` is incomplete
zeros = list(roots(system[0], gens[-1], strict=strict).keys())
return [(zero,) for zero in zeros]
basis = groebner(system, gens, polys=True)
if len(basis) == 1 and basis[0].is_ground:
if not entry:
return []
else:
return None
univariate = list(filter(_is_univariate, basis))
if len(basis) < len(gens):
raise NotImplementedError(filldedent('''
only zero-dimensional systems supported
(finite number of solutions)
'''))
if len(univariate) == 1:
f = univariate.pop()
else:
raise NotImplementedError(filldedent('''
only zero-dimensional systems supported
(finite number of solutions)
'''))
gens = f.gens
gen = gens[-1]
# the below line will produce UnsolvableFactorError if
# strict=True and the solution from `roots` is incomplete
zeros = list(roots(f.ltrim(gen), strict=strict).keys())
if not zeros:
return []
if len(basis) == 1:
return [(zero,) for zero in zeros]
solutions = []
for zero in zeros:
new_system = []
new_gens = gens[:-1]
for b in basis[:-1]:
eq = _subs_root(b, gen, zero)
if eq is not S.Zero:
new_system.append(eq)
for solution in _solve_reduced_system(new_system, new_gens):
solutions.append(solution + (zero,))
if solutions and len(solutions[0]) != len(gens):
raise NotImplementedError(filldedent('''
only zero-dimensional systems supported
(finite number of solutions)
'''))
return solutions
try:
result = _solve_reduced_system(polys, opt.gens, entry=True)
except CoercionFailed:
raise NotImplementedError
if result is not None:
return sorted(result, key=default_sort_key)
def solve_triangulated(polys, *gens, **args):
"""
Solve a polynomial system using Gianni-Kalkbrenner algorithm.
The algorithm proceeds by computing one Groebner basis in the ground
domain and then by iteratively computing polynomial factorizations in
appropriately constructed algebraic extensions of the ground domain.
Parameters
==========
polys: a list/tuple/set
Listing all the equations that are needed to be solved
gens: generators
generators of the equations in polys for which we want the
solutions
args: Keyword arguments
Special options for solving the equations
Returns
=======
List[Tuple]
A List of tuples. Solutions for symbols that satisfy the
equations listed in polys
Examples
========
>>> from sympy import solve_triangulated
>>> from sympy.abc import x, y, z
>>> F = [x**2 + y + z - 1, x + y**2 + z - 1, x + y + z**2 - 1]
>>> solve_triangulated(F, x, y, z)
[(0, 0, 1), (0, 1, 0), (1, 0, 0)]
Using extension for algebraic solutions.
>>> solve_triangulated(F, x, y, z, extension=True) #doctest: +NORMALIZE_WHITESPACE
[(0, 0, 1), (0, 1, 0), (1, 0, 0),
(CRootOf(x**2 + 2*x - 1, 0), CRootOf(x**2 + 2*x - 1, 0), CRootOf(x**2 + 2*x - 1, 0)),
(CRootOf(x**2 + 2*x - 1, 1), CRootOf(x**2 + 2*x - 1, 1), CRootOf(x**2 + 2*x - 1, 1))]
References
==========
1. Patrizia Gianni, Teo Mora, Algebraic Solution of System of
Polynomial Equations using Groebner Bases, AAECC-5 on Applied Algebra,
Algebraic Algorithms and Error-Correcting Codes, LNCS 356 247--257, 1989
"""
opt = build_options(gens, args)
G = groebner(polys, gens, polys=True)
G = list(reversed(G))
extension = opt.get('extension', False)
if extension:
def _solve_univariate(f):
return [r for r, _ in f.all_roots(multiple=False, radicals=False)]
else:
domain = opt.get('domain')
if domain is not None:
for i, g in enumerate(G):
G[i] = g.set_domain(domain)
def _solve_univariate(f):
return list(f.ground_roots().keys())
f, G = G[0].ltrim(-1), G[1:]
dom = f.get_domain()
zeros = _solve_univariate(f)
if extension:
solutions = {((zero,), dom.algebraic_field(zero)) for zero in zeros}
else:
solutions = {((zero,), dom) for zero in zeros}
var_seq = reversed(gens[:-1])
vars_seq = postfixes(gens[1:])
for var, vars in zip(var_seq, vars_seq):
_solutions = set()
for values, dom in solutions:
H, mapping = [], list(zip(vars, values))
for g in G:
_vars = (var,) + vars
if g.has_only_gens(*_vars) and g.degree(var) != 0:
if extension:
g = g.set_domain(g.domain.unify(dom))
h = g.ltrim(var).eval(dict(mapping))
if g.degree(var) == h.degree():
H.append(h)
p = min(H, key=lambda h: h.degree())
zeros = _solve_univariate(p)
for zero in zeros:
if not (zero in dom):
dom_zero = dom.algebraic_field(zero)
else:
dom_zero = dom
_solutions.add(((zero,) + values, dom_zero))
solutions = _solutions
return sorted((s for s, _ in solutions), key=default_sort_key)
def factor_system(eqs: Sequence[Expr | complex], gens: Sequence[Expr] = (), **kwargs: Any) -> list[list[Expr]]:
"""
Factorizes a system of polynomial equations into
irreducible subsystems.
Parameters
==========
eqs : list
List of expressions to be factored.
Each expression is assumed to be equal to zero.
gens : list, optional
Generator(s) of the polynomial ring.
If not provided, all free symbols will be used.
**kwargs : dict, optional
Same optional arguments taken by ``factor``
Returns
=======
list[list[Expr]]
A list of lists of expressions, where each sublist represents
an irreducible subsystem. When solved, each subsystem gives
one component of the solution. Only generic solutions are
returned (cases not requiring parameters to be zero).
Examples
========
>>> from sympy.solvers.polysys import factor_system, factor_system_cond
>>> from sympy.abc import x, y, a, b, c
A simple system with multiple solutions:
>>> factor_system([x**2 - 1, y - 1])
[[x + 1, y - 1], [x - 1, y - 1]]
A system with no solution:
>>> factor_system([x, 1])
[]
A system where any value of the symbol(s) is a solution:
>>> factor_system([x - x, (x + 1)**2 - (x**2 + 2*x + 1)])
[[]]
A system with no generic solution:
>>> factor_system([a*x*(x-1), b*y, c], [x, y])
[]
If c is added to the unknowns then the system has a generic solution:
>>> factor_system([a*x*(x-1), b*y, c], [x, y, c])
[[x - 1, y, c], [x, y, c]]
Alternatively :func:`factor_system_cond` can be used to get degenerate
cases as well:
>>> factor_system_cond([a*x*(x-1), b*y, c], [x, y])
[[x - 1, y, c], [x, y, c], [x - 1, b, c], [x, b, c], [y, a, c], [a, b, c]]
Each of the above cases is only satisfiable in the degenerate case `c = 0`.
The solution set of the original system represented
by eqs is the union of the solution sets of the
factorized systems.
An empty list [] means no generic solution exists.
A list containing an empty list [[]] means any value of
the symbol(s) is a solution.
See Also
========
factor_system_cond : Returns both generic and degenerate solutions
factor_system_bool : Returns a Boolean combination representing all solutions
sympy.polys.polytools.factor : Factors a polynomial into irreducible factors
over the rational numbers
"""
systems = _factor_system_poly_from_expr(eqs, gens, **kwargs)
systems_generic = [sys for sys in systems if not _is_degenerate(sys)]
systems_expr = [[p.as_expr() for p in system] for system in systems_generic]
return systems_expr
def _is_degenerate(system: list[Poly]) -> bool:
"""Helper function to check if a system is degenerate"""
return any(p.is_ground for p in system)
def factor_system_bool(eqs: Sequence[Expr | complex], gens: Sequence[Expr] = (), **kwargs: Any) -> Boolean:
"""
Factorizes a system of polynomial equations into irreducible DNF.
The system of expressions(eqs) is taken and a Boolean combination
of equations is returned that represents the same solution set.
The result is in disjunctive normal form (OR of ANDs).
Parameters
==========
eqs : list
List of expressions to be factored.
Each expression is assumed to be equal to zero.
gens : list, optional
Generator(s) of the polynomial ring.
If not provided, all free symbols will be used.
**kwargs : dict, optional
Optional keyword arguments
Returns
=======
Boolean:
A Boolean combination of equations. The result is typically in
the form of a conjunction (AND) of a disjunctive normal form
with additional conditions.
Examples
========
>>> from sympy.solvers.polysys import factor_system_bool
>>> from sympy.abc import x, y, a, b, c
>>> factor_system_bool([x**2 - 1])
Eq(x - 1, 0) | Eq(x + 1, 0)
>>> factor_system_bool([x**2 - 1, y - 1])
(Eq(x - 1, 0) & Eq(y - 1, 0)) | (Eq(x + 1, 0) & Eq(y - 1, 0))
>>> eqs = [a * (x - 1), b]
>>> factor_system_bool([a*(x - 1), b])
(Eq(a, 0) & Eq(b, 0)) | (Eq(b, 0) & Eq(x - 1, 0))
>>> factor_system_bool([a*x**2 - a, b*(x + 1), c], [x])
(Eq(c, 0) & Eq(x + 1, 0)) | (Eq(a, 0) & Eq(b, 0) & Eq(c, 0)) | (Eq(b, 0) & Eq(c, 0) & Eq(x - 1, 0))
>>> factor_system_bool([x**2 + 2*x + 1 - (x + 1)**2])
True
The result is logically equivalent to the system of equations
i.e. eqs. The function returns ``True`` when all values of
the symbol(s) is a solution and ``False`` when the system
cannot be solved.
See Also
========
factor_system : Returns factors and solvability condition separately
factor_system_cond : Returns both factors and conditions
"""
systems = factor_system_cond(eqs, gens, **kwargs)
return Or(*[And(*[Eq(eq, 0) for eq in sys]) for sys in systems])
def factor_system_cond(eqs: Sequence[Expr | complex], gens: Sequence[Expr] = (), **kwargs: Any) -> list[list[Expr]]:
"""
Factorizes a polynomial system into irreducible components and returns
both generic and degenerate solutions.
Parameters
==========
eqs : list
List of expressions to be factored.
Each expression is assumed to be equal to zero.
gens : list, optional
Generator(s) of the polynomial ring.
If not provided, all free symbols will be used.
**kwargs : dict, optional
Optional keyword arguments.
Returns
=======
list[list[Expr]]
A list of lists of expressions, where each sublist represents
an irreducible subsystem. Includes both generic solutions and
degenerate cases requiring equality conditions on parameters.
Examples
========
>>> from sympy.solvers.polysys import factor_system_cond
>>> from sympy.abc import x, y, a, b, c
>>> factor_system_cond([x**2 - 4, a*y, b], [x, y])
[[x + 2, y, b], [x - 2, y, b], [x + 2, a, b], [x - 2, a, b]]
>>> factor_system_cond([a*x*(x-1), b*y, c], [x, y])
[[x - 1, y, c], [x, y, c], [x - 1, b, c], [x, b, c], [y, a, c], [a, b, c]]
An empty list [] means no solution exists.
A list containing an empty list [[]] means any value of
the symbol(s) is a solution.
See Also
========
factor_system : Returns only generic solutions
factor_system_bool : Returns a Boolean combination representing all solutions
sympy.polys.polytools.factor : Factors a polynomial into irreducible factors
over the rational numbers
"""
systems_poly = _factor_system_poly_from_expr(eqs, gens, **kwargs)
systems = [[p.as_expr() for p in system] for system in systems_poly]
return systems
def _factor_system_poly_from_expr(
eqs: Sequence[Expr | complex], gens: Sequence[Expr], **kwargs: Any
) -> list[list[Poly]]:
"""
Convert expressions to polynomials and factor the system.
Takes a sequence of expressions, converts them to
polynomials, and factors the resulting system. Handles both regular
polynomial systems and purely numerical cases.
"""
try:
polys, opts = parallel_poly_from_expr(eqs, *gens, **kwargs)
only_numbers = False
except (GeneratorsNeeded, PolificationFailed):
_u = Dummy('u')
polys, opts = parallel_poly_from_expr(eqs, [_u], **kwargs)
assert opts['domain'].is_Numerical
only_numbers = True
if only_numbers:
return [[]] if all(p == 0 for p in polys) else []
return factor_system_poly(polys)
def factor_system_poly(polys: list[Poly]) -> list[list[Poly]]:
"""
Factors a system of polynomial equations into irreducible subsystems
Core implementation that works directly with Poly instances.
Parameters
==========
polys : list[Poly]
A list of Poly instances to be factored.
Returns
=======
list[list[Poly]]
A list of lists of polynomials, where each sublist represents
an irreducible component of the solution. Includes both
generic and degenerate cases.
Examples
========
>>> from sympy import symbols, Poly, ZZ
>>> from sympy.solvers.polysys import factor_system_poly
>>> a, b, c, x = symbols('a b c x')
>>> p1 = Poly((a - 1)*(x - 2), x, domain=ZZ[a,b,c])
>>> p2 = Poly((b - 3)*(x - 2), x, domain=ZZ[a,b,c])
>>> p3 = Poly(c, x, domain=ZZ[a,b,c])
The equation to be solved for x is ``x - 2 = 0`` provided either
of the two conditions on the parameters ``a`` and ``b`` is nonzero
and the constant parameter ``c`` should be zero.
>>> sys1, sys2 = factor_system_poly([p1, p2, p3])
>>> sys1
[Poly(x - 2, x, domain='ZZ[a,b,c]'),
Poly(c, x, domain='ZZ[a,b,c]')]
>>> sys2
[Poly(a - 1, x, domain='ZZ[a,b,c]'),
Poly(b - 3, x, domain='ZZ[a,b,c]'),
Poly(c, x, domain='ZZ[a,b,c]')]
An empty list [] when returned means no solution exists.
Whereas a list containing an empty list [[]] means any value is a solution.
See Also
========
factor_system : Returns only generic solutions
factor_system_bool : Returns a Boolean combination representing the solutions
factor_system_cond : Returns both generic and degenerate solutions
sympy.polys.polytools.factor : Factors a polynomial into irreducible factors
over the rational numbers
"""
if not all(isinstance(poly, Poly) for poly in polys):
raise TypeError("polys should be a list of Poly instances")
if not polys:
return [[]]
base_domain = polys[0].domain
base_gens = polys[0].gens
if not all(poly.domain == base_domain and poly.gens == base_gens for poly in polys[1:]):
raise DomainError("All polynomials must have the same domain and generators")
factor_sets = []
for poly in polys:
constant, factors_mult = poly.factor_list()
if constant.is_zero is True:
continue
elif constant.is_zero is False:
if not factors_mult:
return []
factor_sets.append([f for f, _ in factors_mult])
else:
constant = sqf_part(factor_terms(constant).as_coeff_Mul()[1])
constp = Poly(constant, base_gens, domain=base_domain)
factors = [f for f, _ in factors_mult]
factors.append(constp)
factor_sets.append(factors)
if not factor_sets:
return [[]]
result = _factor_sets(factor_sets)
return _sort_systems(result)
def _factor_sets_slow(eqs: list[list]) -> set[frozenset]:
"""
Helper to find the minimal set of factorised subsystems that is
equivalent to the original system.
The result is in DNF.
"""
if not eqs:
return {frozenset()}
systems_set = {frozenset(sys) for sys in cartes(*eqs)}
return {s1 for s1 in systems_set if not any(s1 > s2 for s2 in systems_set)}
def _factor_sets(eqs: list[list]) -> set[frozenset]:
"""
Helper that builds factor combinations.
"""
if not eqs:
return {frozenset()}
current_set = min(eqs, key=len)
other_sets = [s for s in eqs if s is not current_set]
stack = [(factor, [s for s in other_sets if factor not in s], {factor})
for factor in current_set]
result = set()
while stack:
factor, remaining_sets, current_solution = stack.pop()
if not remaining_sets:
result.add(frozenset(current_solution))
continue
next_set = min(remaining_sets, key=len)
next_remaining = [s for s in remaining_sets if s is not next_set]
for next_factor in next_set:
valid_remaining = [s for s in next_remaining if next_factor not in s]
new_solution = current_solution | {next_factor}
stack.append((next_factor, valid_remaining, new_solution))
return {s1 for s1 in result if not any(s1 > s2 for s2 in result)}
def _sort_systems(systems: Iterable[Iterable[Poly]]) -> list[list[Poly]]:
"""Sorts a list of lists of polynomials"""
systems_list = [sorted(s, key=_poly_sort_key, reverse=True) for s in systems]
return sorted(systems_list, key=_sys_sort_key, reverse=True)
def _poly_sort_key(poly):
"""Sort key for polynomials"""
if poly.domain.is_FF:
poly = poly.set_domain(ZZ)
return poly.degree_list(), poly.rep.to_list()
def _sys_sort_key(sys):
"""Sort key for lists of polynomials"""
return list(zip(*map(_poly_sort_key, sys)))
@@ -0,0 +1,843 @@
r"""
This module is intended for solving recurrences or, in other words,
difference equations. Currently supported are linear, inhomogeneous
equations with polynomial or rational coefficients.
The solutions are obtained among polynomials, rational functions,
hypergeometric terms, or combinations of hypergeometric term which
are pairwise dissimilar.
``rsolve_X`` functions were meant as a low level interface
for ``rsolve`` which would use Mathematica's syntax.
Given a recurrence relation:
.. math:: a_{k}(n) y(n+k) + a_{k-1}(n) y(n+k-1) +
... + a_{0}(n) y(n) = f(n)
where `k > 0` and `a_{i}(n)` are polynomials in `n`. To use
``rsolve_X`` we need to put all coefficients in to a list ``L`` of
`k+1` elements the following way:
``L = [a_{0}(n), ..., a_{k-1}(n), a_{k}(n)]``
where ``L[i]``, for `i=0, \ldots, k`, maps to
`a_{i}(n) y(n+i)` (`y(n+i)` is implicit).
For example if we would like to compute `m`-th Bernoulli polynomial
up to a constant (example was taken from rsolve_poly docstring),
then we would use `b(n+1) - b(n) = m n^{m-1}` recurrence, which
has solution `b(n) = B_m + C`.
Then ``L = [-1, 1]`` and `f(n) = m n^(m-1)` and finally for `m=4`:
>>> from sympy import Symbol, bernoulli, rsolve_poly
>>> n = Symbol('n', integer=True)
>>> rsolve_poly([-1, 1], 4*n**3, n)
C0 + n**4 - 2*n**3 + n**2
>>> bernoulli(4, n)
n**4 - 2*n**3 + n**2 - 1/30
For the sake of completeness, `f(n)` can be:
[1] a polynomial -> rsolve_poly
[2] a rational function -> rsolve_ratio
[3] a hypergeometric function -> rsolve_hyper
"""
from collections import defaultdict
from sympy.concrete import product
from sympy.core.singleton import S
from sympy.core.numbers import Rational, I
from sympy.core.symbol import Symbol, Wild, Dummy
from sympy.core.relational import Equality
from sympy.core.add import Add
from sympy.core.mul import Mul
from sympy.core.sorting import default_sort_key
from sympy.core.sympify import sympify
from sympy.simplify import simplify, hypersimp, hypersimilar # type: ignore
from sympy.solvers import solve, solve_undetermined_coeffs
from sympy.polys import Poly, quo, gcd, lcm, roots, resultant
from sympy.functions import binomial, factorial, FallingFactorial, RisingFactorial
from sympy.matrices import Matrix, casoratian
from sympy.utilities.iterables import numbered_symbols
def rsolve_poly(coeffs, f, n, shift=0, **hints):
r"""
Given linear recurrence operator `\operatorname{L}` of order
`k` with polynomial coefficients and inhomogeneous equation
`\operatorname{L} y = f`, where `f` is a polynomial, we seek for
all polynomial solutions over field `K` of characteristic zero.
The algorithm performs two basic steps:
(1) Compute degree `N` of the general polynomial solution.
(2) Find all polynomials of degree `N` or less
of `\operatorname{L} y = f`.
There are two methods for computing the polynomial solutions.
If the degree bound is relatively small, i.e. it's smaller than
or equal to the order of the recurrence, then naive method of
undetermined coefficients is being used. This gives a system
of algebraic equations with `N+1` unknowns.
In the other case, the algorithm performs transformation of the
initial equation to an equivalent one for which the system of
algebraic equations has only `r` indeterminates. This method is
quite sophisticated (in comparison with the naive one) and was
invented together by Abramov, Bronstein and Petkovsek.
It is possible to generalize the algorithm implemented here to
the case of linear q-difference and differential equations.
Lets say that we would like to compute `m`-th Bernoulli polynomial
up to a constant. For this we can use `b(n+1) - b(n) = m n^{m-1}`
recurrence, which has solution `b(n) = B_m + C`. For example:
>>> from sympy import Symbol, rsolve_poly
>>> n = Symbol('n', integer=True)
>>> rsolve_poly([-1, 1], 4*n**3, n)
C0 + n**4 - 2*n**3 + n**2
References
==========
.. [1] S. A. Abramov, M. Bronstein and M. Petkovsek, On polynomial
solutions of linear operator equations, in: T. Levelt, ed.,
Proc. ISSAC '95, ACM Press, New York, 1995, 290-296.
.. [2] M. Petkovsek, Hypergeometric solutions of linear recurrences
with polynomial coefficients, J. Symbolic Computation,
14 (1992), 243-264.
.. [3] M. Petkovsek, H. S. Wilf, D. Zeilberger, A = B, 1996.
"""
f = sympify(f)
if not f.is_polynomial(n):
return None
homogeneous = f.is_zero
r = len(coeffs) - 1
coeffs = [Poly(coeff, n) for coeff in coeffs]
polys = [Poly(0, n)]*(r + 1)
terms = [(S.Zero, S.NegativeInfinity)]*(r + 1)
for i in range(r + 1):
for j in range(i, r + 1):
polys[i] += coeffs[j]*(binomial(j, i).as_poly(n))
if not polys[i].is_zero:
(exp,), coeff = polys[i].LT()
terms[i] = (coeff, exp)
d = b = terms[0][1]
for i in range(1, r + 1):
if terms[i][1] > d:
d = terms[i][1]
if terms[i][1] - i > b:
b = terms[i][1] - i
d, b = int(d), int(b)
x = Dummy('x')
degree_poly = S.Zero
for i in range(r + 1):
if terms[i][1] - i == b:
degree_poly += terms[i][0]*FallingFactorial(x, i)
nni_roots = list(roots(degree_poly, x, filter='Z',
predicate=lambda r: r >= 0).keys())
if nni_roots:
N = [max(nni_roots)]
else:
N = []
if homogeneous:
N += [-b - 1]
else:
N += [f.as_poly(n).degree() - b, -b - 1]
N = int(max(N))
if N < 0:
if homogeneous:
if hints.get('symbols', False):
return (S.Zero, [])
else:
return S.Zero
else:
return None
if N <= r:
C = []
y = E = S.Zero
for i in range(N + 1):
C.append(Symbol('C' + str(i + shift)))
y += C[i] * n**i
for i in range(r + 1):
E += coeffs[i].as_expr()*y.subs(n, n + i)
solutions = solve_undetermined_coeffs(E - f, C, n)
if solutions is not None:
_C = C
C = [c for c in C if (c not in solutions)]
result = y.subs(solutions)
else:
return None # TBD
else:
A = r
U = N + A + b + 1
nni_roots = list(roots(polys[r], filter='Z',
predicate=lambda r: r >= 0).keys())
if nni_roots != []:
a = max(nni_roots) + 1
else:
a = S.Zero
def _zero_vector(k):
return [S.Zero] * k
def _one_vector(k):
return [S.One] * k
def _delta(p, k):
B = S.One
D = p.subs(n, a + k)
for i in range(1, k + 1):
B *= Rational(i - k - 1, i)
D += B * p.subs(n, a + k - i)
return D
alpha = {}
for i in range(-A, d + 1):
I = _one_vector(d + 1)
for k in range(1, d + 1):
I[k] = I[k - 1] * (x + i - k + 1)/k
alpha[i] = S.Zero
for j in range(A + 1):
for k in range(d + 1):
B = binomial(k, i + j)
D = _delta(polys[j].as_expr(), k)
alpha[i] += I[k]*B*D
V = Matrix(U, A, lambda i, j: int(i == j))
if homogeneous:
for i in range(A, U):
v = _zero_vector(A)
for k in range(1, A + b + 1):
if i - k < 0:
break
B = alpha[k - A].subs(x, i - k)
for j in range(A):
v[j] += B * V[i - k, j]
denom = alpha[-A].subs(x, i)
for j in range(A):
V[i, j] = -v[j] / denom
else:
G = _zero_vector(U)
for i in range(A, U):
v = _zero_vector(A)
g = S.Zero
for k in range(1, A + b + 1):
if i - k < 0:
break
B = alpha[k - A].subs(x, i - k)
for j in range(A):
v[j] += B * V[i - k, j]
g += B * G[i - k]
denom = alpha[-A].subs(x, i)
for j in range(A):
V[i, j] = -v[j] / denom
G[i] = (_delta(f, i - A) - g) / denom
P, Q = _one_vector(U), _zero_vector(A)
for i in range(1, U):
P[i] = (P[i - 1] * (n - a - i + 1)/i).expand()
for i in range(A):
Q[i] = Add(*[(v*p).expand() for v, p in zip(V[:, i], P)])
if not homogeneous:
h = Add(*[(g*p).expand() for g, p in zip(G, P)])
C = [Symbol('C' + str(i + shift)) for i in range(A)]
g = lambda i: Add(*[c*_delta(q, i) for c, q in zip(C, Q)])
if homogeneous:
E = [g(i) for i in range(N + 1, U)]
else:
E = [g(i) + _delta(h, i) for i in range(N + 1, U)]
if E != []:
solutions = solve(E, *C)
if not solutions:
if homogeneous:
if hints.get('symbols', False):
return (S.Zero, [])
else:
return S.Zero
else:
return None
else:
solutions = {}
if homogeneous:
result = S.Zero
else:
result = h
_C = C[:]
for c, q in list(zip(C, Q)):
if c in solutions:
s = solutions[c]*q
C.remove(c)
else:
s = c*q
result += s.expand()
if C != _C:
# renumber so they are contiguous
result = result.xreplace(dict(zip(C, _C)))
C = _C[:len(C)]
if hints.get('symbols', False):
return (result, C)
else:
return result
def rsolve_ratio(coeffs, f, n, **hints):
r"""
Given linear recurrence operator `\operatorname{L}` of order `k`
with polynomial coefficients and inhomogeneous equation
`\operatorname{L} y = f`, where `f` is a polynomial, we seek
for all rational solutions over field `K` of characteristic zero.
This procedure accepts only polynomials, however if you are
interested in solving recurrence with rational coefficients
then use ``rsolve`` which will pre-process the given equation
and run this procedure with polynomial arguments.
The algorithm performs two basic steps:
(1) Compute polynomial `v(n)` which can be used as universal
denominator of any rational solution of equation
`\operatorname{L} y = f`.
(2) Construct new linear difference equation by substitution
`y(n) = u(n)/v(n)` and solve it for `u(n)` finding all its
polynomial solutions. Return ``None`` if none were found.
The algorithm implemented here is a revised version of the original
Abramov's algorithm, developed in 1989. The new approach is much
simpler to implement and has better overall efficiency. This
method can be easily adapted to the q-difference equations case.
Besides finding rational solutions alone, this functions is
an important part of Hyper algorithm where it is used to find
a particular solution for the inhomogeneous part of a recurrence.
Examples
========
>>> from sympy.abc import x
>>> from sympy.solvers.recurr import rsolve_ratio
>>> rsolve_ratio([-2*x**3 + x**2 + 2*x - 1, 2*x**3 + x**2 - 6*x,
... - 2*x**3 - 11*x**2 - 18*x - 9, 2*x**3 + 13*x**2 + 22*x + 8], 0, x)
C0*(2*x - 3)/(2*(x**2 - 1))
References
==========
.. [1] S. A. Abramov, Rational solutions of linear difference
and q-difference equations with polynomial coefficients,
in: T. Levelt, ed., Proc. ISSAC '95, ACM Press, New York,
1995, 285-289
See Also
========
rsolve_hyper
"""
f = sympify(f)
if not f.is_polynomial(n):
return None
coeffs = list(map(sympify, coeffs))
r = len(coeffs) - 1
A, B = coeffs[r], coeffs[0]
A = A.subs(n, n - r).expand()
h = Dummy('h')
res = resultant(A, B.subs(n, n + h), n)
if not res.is_polynomial(h):
p, q = res.as_numer_denom()
res = quo(p, q, h)
nni_roots = list(roots(res, h, filter='Z',
predicate=lambda r: r >= 0).keys())
if not nni_roots:
return rsolve_poly(coeffs, f, n, **hints)
else:
C, numers = S.One, [S.Zero]*(r + 1)
for i in range(int(max(nni_roots)), -1, -1):
d = gcd(A, B.subs(n, n + i), n)
A = quo(A, d, n)
B = quo(B, d.subs(n, n - i), n)
C *= Mul(*[d.subs(n, n - j) for j in range(i + 1)])
denoms = [C.subs(n, n + i) for i in range(r + 1)]
for i in range(r + 1):
g = gcd(coeffs[i], denoms[i], n)
numers[i] = quo(coeffs[i], g, n)
denoms[i] = quo(denoms[i], g, n)
for i in range(r + 1):
numers[i] *= Mul(*(denoms[:i] + denoms[i + 1:]))
result = rsolve_poly(numers, f * Mul(*denoms), n, **hints)
if result is not None:
if hints.get('symbols', False):
return (simplify(result[0] / C), result[1])
else:
return simplify(result / C)
else:
return None
def rsolve_hyper(coeffs, f, n, **hints):
r"""
Given linear recurrence operator `\operatorname{L}` of order `k`
with polynomial coefficients and inhomogeneous equation
`\operatorname{L} y = f` we seek for all hypergeometric solutions
over field `K` of characteristic zero.
The inhomogeneous part can be either hypergeometric or a sum
of a fixed number of pairwise dissimilar hypergeometric terms.
The algorithm performs three basic steps:
(1) Group together similar hypergeometric terms in the
inhomogeneous part of `\operatorname{L} y = f`, and find
particular solution using Abramov's algorithm.
(2) Compute generating set of `\operatorname{L}` and find basis
in it, so that all solutions are linearly independent.
(3) Form final solution with the number of arbitrary
constants equal to dimension of basis of `\operatorname{L}`.
Term `a(n)` is hypergeometric if it is annihilated by first order
linear difference equations with polynomial coefficients or, in
simpler words, if consecutive term ratio is a rational function.
The output of this procedure is a linear combination of fixed
number of hypergeometric terms. However the underlying method
can generate larger class of solutions - D'Alembertian terms.
Note also that this method not only computes the kernel of the
inhomogeneous equation, but also reduces in to a basis so that
solutions generated by this procedure are linearly independent
Examples
========
>>> from sympy.solvers import rsolve_hyper
>>> from sympy.abc import x
>>> rsolve_hyper([-1, -1, 1], 0, x)
C0*(1/2 - sqrt(5)/2)**x + C1*(1/2 + sqrt(5)/2)**x
>>> rsolve_hyper([-1, 1], 1 + x, x)
C0 + x*(x + 1)/2
References
==========
.. [1] M. Petkovsek, Hypergeometric solutions of linear recurrences
with polynomial coefficients, J. Symbolic Computation,
14 (1992), 243-264.
.. [2] M. Petkovsek, H. S. Wilf, D. Zeilberger, A = B, 1996.
"""
coeffs = list(map(sympify, coeffs))
f = sympify(f)
r, kernel, symbols = len(coeffs) - 1, [], set()
if not f.is_zero:
if f.is_Add:
similar = {}
for g in f.expand().args:
if not g.is_hypergeometric(n):
return None
for h in similar.keys():
if hypersimilar(g, h, n):
similar[h] += g
break
else:
similar[g] = S.Zero
inhomogeneous = [g + h for g, h in similar.items()]
elif f.is_hypergeometric(n):
inhomogeneous = [f]
else:
return None
for i, g in enumerate(inhomogeneous):
coeff, polys = S.One, coeffs[:]
denoms = [S.One]*(r + 1)
s = hypersimp(g, n)
for j in range(1, r + 1):
coeff *= s.subs(n, n + j - 1)
p, q = coeff.as_numer_denom()
polys[j] *= p
denoms[j] = q
for j in range(r + 1):
polys[j] *= Mul(*(denoms[:j] + denoms[j + 1:]))
# FIXME: The call to rsolve_ratio below should suffice (rsolve_poly
# call can be removed) but the XFAIL test_rsolve_ratio_missed must
# be fixed first.
R = rsolve_ratio(polys, Mul(*denoms), n, symbols=True)
if R is not None:
R, syms = R
if syms:
R = R.subs(zip(syms, [0]*len(syms)))
else:
R = rsolve_poly(polys, Mul(*denoms), n)
if R:
inhomogeneous[i] *= R
else:
return None
result = Add(*inhomogeneous)
result = simplify(result)
else:
result = S.Zero
Z = Dummy('Z')
p, q = coeffs[0], coeffs[r].subs(n, n - r + 1)
p_factors = list(roots(p, n).keys())
q_factors = list(roots(q, n).keys())
factors = [(S.One, S.One)]
for p in p_factors:
for q in q_factors:
if p.is_integer and q.is_integer and p <= q:
continue
else:
factors += [(n - p, n - q)]
p = [(n - p, S.One) for p in p_factors]
q = [(S.One, n - q) for q in q_factors]
factors = p + factors + q
for A, B in factors:
polys, degrees = [], []
D = A*B.subs(n, n + r - 1)
for i in range(r + 1):
a = Mul(*[A.subs(n, n + j) for j in range(i)])
b = Mul(*[B.subs(n, n + j) for j in range(i, r)])
poly = quo(coeffs[i]*a*b, D, n)
polys.append(poly.as_poly(n))
if not poly.is_zero:
degrees.append(polys[i].degree())
if degrees:
d, poly = max(degrees), S.Zero
else:
return None
for i in range(r + 1):
coeff = polys[i].nth(d)
if coeff is not S.Zero:
poly += coeff * Z**i
for z in roots(poly, Z).keys():
if z.is_zero:
continue
recurr_coeffs = [polys[i].as_expr()*z**i for i in range(r + 1)]
if d == 0 and 0 != Add(*[recurr_coeffs[j]*j for j in range(1, r + 1)]):
# faster inline check (than calling rsolve_poly) for a
# constant solution to a constant coefficient recurrence.
sol = [Symbol("C" + str(len(symbols)))]
else:
sol, syms = rsolve_poly(recurr_coeffs, 0, n, len(symbols), symbols=True)
sol = sol.collect(syms)
sol = [sol.coeff(s) for s in syms]
for C in sol:
ratio = z * A * C.subs(n, n + 1) / B / C
ratio = simplify(ratio)
# If there is a nonnegative root in the denominator of the ratio,
# this indicates that the term y(n_root) is zero, and one should
# start the product with the term y(n_root + 1).
n0 = 0
for n_root in roots(ratio.as_numer_denom()[1], n).keys():
if n_root.has(I):
return None
elif (n0 < (n_root + 1)) == True:
n0 = n_root + 1
K = product(ratio, (n, n0, n - 1))
if K.has(factorial, FallingFactorial, RisingFactorial):
K = simplify(K)
if casoratian(kernel + [K], n, zero=False) != 0:
kernel.append(K)
kernel.sort(key=default_sort_key)
sk = list(zip(numbered_symbols('C'), kernel))
for C, ker in sk:
result += C * ker
if hints.get('symbols', False):
# XXX: This returns the symbols in a non-deterministic order
symbols |= {s for s, k in sk}
return (result, list(symbols))
else:
return result
def rsolve(f, y, init=None):
r"""
Solve univariate recurrence with rational coefficients.
Given `k`-th order linear recurrence `\operatorname{L} y = f`,
or equivalently:
.. math:: a_{k}(n) y(n+k) + a_{k-1}(n) y(n+k-1) +
\cdots + a_{0}(n) y(n) = f(n)
where `a_{i}(n)`, for `i=0, \ldots, k`, are polynomials or rational
functions in `n`, and `f` is a hypergeometric function or a sum
of a fixed number of pairwise dissimilar hypergeometric terms in
`n`, finds all solutions or returns ``None``, if none were found.
Initial conditions can be given as a dictionary in two forms:
(1) ``{ n_0 : v_0, n_1 : v_1, ..., n_m : v_m}``
(2) ``{y(n_0) : v_0, y(n_1) : v_1, ..., y(n_m) : v_m}``
or as a list ``L`` of values:
``L = [v_0, v_1, ..., v_m]``
where ``L[i] = v_i``, for `i=0, \ldots, m`, maps to `y(n_i)`.
Examples
========
Lets consider the following recurrence:
.. math:: (n - 1) y(n + 2) - (n^2 + 3 n - 2) y(n + 1) +
2 n (n + 1) y(n) = 0
>>> from sympy import Function, rsolve
>>> from sympy.abc import n
>>> y = Function('y')
>>> f = (n - 1)*y(n + 2) - (n**2 + 3*n - 2)*y(n + 1) + 2*n*(n + 1)*y(n)
>>> rsolve(f, y(n))
2**n*C0 + C1*factorial(n)
>>> rsolve(f, y(n), {y(0):0, y(1):3})
3*2**n - 3*factorial(n)
See Also
========
rsolve_poly, rsolve_ratio, rsolve_hyper
"""
if isinstance(f, Equality):
f = f.lhs - f.rhs
n = y.args[0]
k = Wild('k', exclude=(n,))
# Preprocess user input to allow things like
# y(n) + a*(y(n + 1) + y(n - 1))/2
f = f.expand().collect(y.func(Wild('m', integer=True)))
h_part = defaultdict(list)
i_part = []
for g in Add.make_args(f):
coeff, dep = g.as_coeff_mul(y.func)
if not dep:
i_part.append(coeff)
continue
for h in dep:
if h.is_Function and h.func == y.func:
result = h.args[0].match(n + k)
if result is not None:
h_part[int(result[k])].append(coeff)
continue
raise ValueError(
"'%s(%s + k)' expected, got '%s'" % (y.func, n, h))
for k in h_part:
h_part[k] = Add(*h_part[k])
h_part.default_factory = lambda: 0
i_part = Add(*i_part)
for k, coeff in h_part.items():
h_part[k] = simplify(coeff)
common = S.One
if not i_part.is_zero and not i_part.is_hypergeometric(n) and \
not (i_part.is_Add and all((x.is_hypergeometric(n) for x in i_part.expand().args))):
raise ValueError("The independent term should be a sum of hypergeometric functions, got '%s'" % i_part)
for coeff in h_part.values():
if coeff.is_rational_function(n):
if not coeff.is_polynomial(n):
common = lcm(common, coeff.as_numer_denom()[1], n)
else:
raise ValueError(
"Polynomial or rational function expected, got '%s'" % coeff)
i_numer, i_denom = i_part.as_numer_denom()
if i_denom.is_polynomial(n):
common = lcm(common, i_denom, n)
if common is not S.One:
for k, coeff in h_part.items():
numer, denom = coeff.as_numer_denom()
h_part[k] = numer*quo(common, denom, n)
i_part = i_numer*quo(common, i_denom, n)
K_min = min(h_part.keys())
if K_min < 0:
K = abs(K_min)
H_part = defaultdict(lambda: S.Zero)
i_part = i_part.subs(n, n + K).expand()
common = common.subs(n, n + K).expand()
for k, coeff in h_part.items():
H_part[k + K] = coeff.subs(n, n + K).expand()
else:
H_part = h_part
K_max = max(H_part.keys())
coeffs = [H_part[i] for i in range(K_max + 1)]
result = rsolve_hyper(coeffs, -i_part, n, symbols=True)
if result is None:
return None
solution, symbols = result
if init in ({}, []):
init = None
if symbols and init is not None:
if isinstance(init, list):
init = {i: init[i] for i in range(len(init))}
equations = []
for k, v in init.items():
try:
i = int(k)
except TypeError:
if k.is_Function and k.func == y.func:
i = int(k.args[0])
else:
raise ValueError("Integer or term expected, got '%s'" % k)
eq = solution.subs(n, i) - v
if eq.has(S.NaN):
eq = solution.limit(n, i) - v
equations.append(eq)
result = solve(equations, *symbols)
if not result:
return None
else:
solution = solution.subs(result)
return solution
File diff suppressed because it is too large Load Diff
File diff suppressed because it is too large Load Diff
File diff suppressed because it is too large Load Diff
@@ -0,0 +1,179 @@
"""
If the arbitrary constant class from issue 4435 is ever implemented, this
should serve as a set of test cases.
"""
from sympy.core.function import Function
from sympy.core.numbers import I
from sympy.core.power import Pow
from sympy.core.relational import Eq
from sympy.core.singleton import S
from sympy.core.symbol import Symbol
from sympy.functions.elementary.exponential import (exp, log)
from sympy.functions.elementary.hyperbolic import (cosh, sinh)
from sympy.functions.elementary.miscellaneous import sqrt
from sympy.functions.elementary.trigonometric import (acos, cos, sin)
from sympy.integrals.integrals import Integral
from sympy.solvers.ode.ode import constantsimp, constant_renumber
from sympy.testing.pytest import XFAIL
x = Symbol('x')
y = Symbol('y')
z = Symbol('z')
u2 = Symbol('u2')
_a = Symbol('_a')
C1 = Symbol('C1')
C2 = Symbol('C2')
C3 = Symbol('C3')
f = Function('f')
def test_constant_mul():
# We want C1 (Constant) below to absorb the y's, but not the x's
assert constant_renumber(constantsimp(y*C1, [C1])) == C1*y
assert constant_renumber(constantsimp(C1*y, [C1])) == C1*y
assert constant_renumber(constantsimp(x*C1, [C1])) == x*C1
assert constant_renumber(constantsimp(C1*x, [C1])) == x*C1
assert constant_renumber(constantsimp(2*C1, [C1])) == C1
assert constant_renumber(constantsimp(C1*2, [C1])) == C1
assert constant_renumber(constantsimp(y*C1*x, [C1, y])) == C1*x
assert constant_renumber(constantsimp(x*y*C1, [C1, y])) == x*C1
assert constant_renumber(constantsimp(y*x*C1, [C1, y])) == x*C1
assert constant_renumber(constantsimp(C1*x*y, [C1, y])) == C1*x
assert constant_renumber(constantsimp(x*C1*y, [C1, y])) == x*C1
assert constant_renumber(constantsimp(C1*y*(y + 1), [C1])) == C1*y*(y+1)
assert constant_renumber(constantsimp(y*C1*(y + 1), [C1])) == C1*y*(y+1)
assert constant_renumber(constantsimp(x*(y*C1), [C1])) == x*y*C1
assert constant_renumber(constantsimp(x*(C1*y), [C1])) == x*y*C1
assert constant_renumber(constantsimp(C1*(x*y), [C1, y])) == C1*x
assert constant_renumber(constantsimp((x*y)*C1, [C1, y])) == x*C1
assert constant_renumber(constantsimp((y*x)*C1, [C1, y])) == x*C1
assert constant_renumber(constantsimp(y*(y + 1)*C1, [C1, y])) == C1
assert constant_renumber(constantsimp((C1*x)*y, [C1, y])) == C1*x
assert constant_renumber(constantsimp(y*(x*C1), [C1, y])) == x*C1
assert constant_renumber(constantsimp((x*C1)*y, [C1, y])) == x*C1
assert constant_renumber(constantsimp(C1*x*y*x*y*2, [C1, y])) == C1*x**2
assert constant_renumber(constantsimp(C1*x*y*z, [C1, y, z])) == C1*x
assert constant_renumber(constantsimp(C1*x*y**2*sin(z), [C1, y, z])) == C1*x
assert constant_renumber(constantsimp(C1*C1, [C1])) == C1
assert constant_renumber(constantsimp(C1*C2, [C1, C2])) == C1
assert constant_renumber(constantsimp(C2*C2, [C1, C2])) == C1
assert constant_renumber(constantsimp(C1*C1*C2, [C1, C2])) == C1
assert constant_renumber(constantsimp(C1*x*2**x, [C1])) == C1*x*2**x
def test_constant_add():
assert constant_renumber(constantsimp(C1 + C1, [C1])) == C1
assert constant_renumber(constantsimp(C1 + 2, [C1])) == C1
assert constant_renumber(constantsimp(2 + C1, [C1])) == C1
assert constant_renumber(constantsimp(C1 + y, [C1, y])) == C1
assert constant_renumber(constantsimp(C1 + x, [C1])) == C1 + x
assert constant_renumber(constantsimp(C1 + C1, [C1])) == C1
assert constant_renumber(constantsimp(C1 + C2, [C1, C2])) == C1
assert constant_renumber(constantsimp(C2 + C1, [C1, C2])) == C1
assert constant_renumber(constantsimp(C1 + C2 + C1, [C1, C2])) == C1
def test_constant_power_as_base():
assert constant_renumber(constantsimp(C1**C1, [C1])) == C1
assert constant_renumber(constantsimp(Pow(C1, C1), [C1])) == C1
assert constant_renumber(constantsimp(C1**C1, [C1])) == C1
assert constant_renumber(constantsimp(C1**C2, [C1, C2])) == C1
assert constant_renumber(constantsimp(C2**C1, [C1, C2])) == C1
assert constant_renumber(constantsimp(C2**C2, [C1, C2])) == C1
assert constant_renumber(constantsimp(C1**y, [C1, y])) == C1
assert constant_renumber(constantsimp(C1**x, [C1])) == C1**x
assert constant_renumber(constantsimp(C1**2, [C1])) == C1
assert constant_renumber(
constantsimp(C1**(x*y), [C1])) == C1**(x*y)
def test_constant_power_as_exp():
assert constant_renumber(constantsimp(x**C1, [C1])) == x**C1
assert constant_renumber(constantsimp(y**C1, [C1, y])) == C1
assert constant_renumber(constantsimp(x**y**C1, [C1, y])) == x**C1
assert constant_renumber(
constantsimp((x**y)**C1, [C1])) == (x**y)**C1
assert constant_renumber(
constantsimp(x**(y**C1), [C1, y])) == x**C1
assert constant_renumber(constantsimp(x**C1**y, [C1, y])) == x**C1
assert constant_renumber(
constantsimp(x**(C1**y), [C1, y])) == x**C1
assert constant_renumber(
constantsimp((x**C1)**y, [C1])) == (x**C1)**y
assert constant_renumber(constantsimp(2**C1, [C1])) == C1
assert constant_renumber(constantsimp(S(2)**C1, [C1])) == C1
assert constant_renumber(constantsimp(exp(C1), [C1])) == C1
assert constant_renumber(
constantsimp(exp(C1 + x), [C1])) == C1*exp(x)
assert constant_renumber(constantsimp(Pow(2, C1), [C1])) == C1
def test_constant_function():
assert constant_renumber(constantsimp(sin(C1), [C1])) == C1
assert constant_renumber(constantsimp(f(C1), [C1])) == C1
assert constant_renumber(constantsimp(f(C1, C1), [C1])) == C1
assert constant_renumber(constantsimp(f(C1, C2), [C1, C2])) == C1
assert constant_renumber(constantsimp(f(C2, C1), [C1, C2])) == C1
assert constant_renumber(constantsimp(f(C2, C2), [C1, C2])) == C1
assert constant_renumber(
constantsimp(f(C1, x), [C1])) == f(C1, x)
assert constant_renumber(constantsimp(f(C1, y), [C1, y])) == C1
assert constant_renumber(constantsimp(f(y, C1), [C1, y])) == C1
assert constant_renumber(constantsimp(f(C1, y, C2), [C1, C2, y])) == C1
def test_constant_function_multiple():
# The rules to not renumber in this case would be too complicated, and
# dsolve is not likely to ever encounter anything remotely like this.
assert constant_renumber(
constantsimp(f(C1, C1, x), [C1])) == f(C1, C1, x)
def test_constant_multiple():
assert constant_renumber(constantsimp(C1*2 + 2, [C1])) == C1
assert constant_renumber(constantsimp(x*2/C1, [C1])) == C1*x
assert constant_renumber(constantsimp(C1**2*2 + 2, [C1])) == C1
assert constant_renumber(
constantsimp(sin(2*C1) + x + sqrt(2), [C1])) == C1 + x
assert constant_renumber(constantsimp(2*C1 + C2, [C1, C2])) == C1
def test_constant_repeated():
assert C1 + C1*x == constant_renumber( C1 + C1*x)
def test_ode_solutions():
# only a few examples here, the rest will be tested in the actual dsolve tests
assert constant_renumber(constantsimp(C1*exp(2*x) + exp(x)*(C2 + C3), [C1, C2, C3])) == \
constant_renumber(C1*exp(x) + C2*exp(2*x))
assert constant_renumber(
constantsimp(Eq(f(x), I*C1*sinh(x/3) + C2*cosh(x/3)), [C1, C2])
) == constant_renumber(Eq(f(x), C1*sinh(x/3) + C2*cosh(x/3)))
assert constant_renumber(constantsimp(Eq(f(x), acos((-C1)/cos(x))), [C1])) == \
Eq(f(x), acos(C1/cos(x)))
assert constant_renumber(
constantsimp(Eq(log(f(x)/C1) + 2*exp(x/f(x)), 0), [C1])
) == Eq(log(C1*f(x)) + 2*exp(x/f(x)), 0)
assert constant_renumber(constantsimp(Eq(log(x*sqrt(2)*sqrt(1/x)*sqrt(f(x))
/C1) + x**2/(2*f(x)**2), 0), [C1])) == \
Eq(log(C1*sqrt(x)*sqrt(f(x))) + x**2/(2*f(x)**2), 0)
assert constant_renumber(constantsimp(Eq(-exp(-f(x)/x)*sin(f(x)/x)/2 + log(x/C1) -
cos(f(x)/x)*exp(-f(x)/x)/2, 0), [C1])) == \
Eq(-exp(-f(x)/x)*sin(f(x)/x)/2 + log(C1*x) - cos(f(x)/x)*
exp(-f(x)/x)/2, 0)
assert constant_renumber(constantsimp(Eq(-Integral(-1/(sqrt(1 - u2**2)*u2),
(u2, _a, x/f(x))) + log(f(x)/C1), 0), [C1])) == \
Eq(-Integral(-1/(u2*sqrt(1 - u2**2)), (u2, _a, x/f(x))) +
log(C1*f(x)), 0)
assert [constantsimp(i, [C1]) for i in [Eq(f(x), sqrt(-C1*x + x**2)), Eq(f(x), -sqrt(-C1*x + x**2))]] == \
[Eq(f(x), sqrt(x*(C1 + x))), Eq(f(x), -sqrt(x*(C1 + x)))]
@XFAIL
def test_nonlocal_simplification():
assert constantsimp(C1 + C2+x*C2, [C1, C2]) == C1 + C2*x
def test_constant_Eq():
# C1 on the rhs is well-tested, but the lhs is only tested here
assert constantsimp(Eq(C1, 3 + f(x)*x), [C1]) == Eq(x*f(x), C1)
assert constantsimp(Eq(C1, 3 * f(x)*x), [C1]) == Eq(f(x)*x, C1)
@@ -0,0 +1,59 @@
from sympy.solvers.decompogen import decompogen, compogen
from sympy.core.symbol import symbols
from sympy.functions.elementary.complexes import Abs
from sympy.functions.elementary.exponential import exp
from sympy.functions.elementary.miscellaneous import sqrt, Max
from sympy.functions.elementary.trigonometric import (cos, sin)
from sympy.testing.pytest import XFAIL, raises
x, y = symbols('x y')
def test_decompogen():
assert decompogen(sin(cos(x)), x) == [sin(x), cos(x)]
assert decompogen(sin(x)**2 + sin(x) + 1, x) == [x**2 + x + 1, sin(x)]
assert decompogen(sqrt(6*x**2 - 5), x) == [sqrt(x), 6*x**2 - 5]
assert decompogen(sin(sqrt(cos(x**2 + 1))), x) == [sin(x), sqrt(x), cos(x), x**2 + 1]
assert decompogen(Abs(cos(x)**2 + 3*cos(x) - 4), x) == [Abs(x), x**2 + 3*x - 4, cos(x)]
assert decompogen(sin(x)**2 + sin(x) - sqrt(3)/2, x) == [x**2 + x - sqrt(3)/2, sin(x)]
assert decompogen(Abs(cos(y)**2 + 3*cos(x) - 4), x) == [Abs(x), 3*x + cos(y)**2 - 4, cos(x)]
assert decompogen(x, y) == [x]
assert decompogen(1, x) == [1]
assert decompogen(Max(3, x), x) == [Max(3, x)]
raises(TypeError, lambda: decompogen(x < 5, x))
u = 2*x + 3
assert decompogen(Max(sqrt(u),(u)**2), x) == [Max(sqrt(x), x**2), u]
assert decompogen(Max(u, u**2, y), x) == [Max(x, x**2, y), u]
assert decompogen(Max(sin(x), u), x) == [Max(2*x + 3, sin(x))]
def test_decompogen_poly():
assert decompogen(x**4 + 2*x**2 + 1, x) == [x**2 + 2*x + 1, x**2]
assert decompogen(x**4 + 2*x**3 - x - 1, x) == [x**2 - x - 1, x**2 + x]
@XFAIL
def test_decompogen_fails():
A = lambda x: x**2 + 2*x + 3
B = lambda x: 4*x**2 + 5*x + 6
assert decompogen(A(x*exp(x)), x) == [x**2 + 2*x + 3, x*exp(x)]
assert decompogen(A(B(x)), x) == [x**2 + 2*x + 3, 4*x**2 + 5*x + 6]
assert decompogen(A(1/x + 1/x**2), x) == [x**2 + 2*x + 3, 1/x + 1/x**2]
assert decompogen(A(1/x + 2/(x + 1)), x) == [x**2 + 2*x + 3, 1/x + 2/(x + 1)]
def test_compogen():
assert compogen([sin(x), cos(x)], x) == sin(cos(x))
assert compogen([x**2 + x + 1, sin(x)], x) == sin(x)**2 + sin(x) + 1
assert compogen([sqrt(x), 6*x**2 - 5], x) == sqrt(6*x**2 - 5)
assert compogen([sin(x), sqrt(x), cos(x), x**2 + 1], x) == sin(sqrt(
cos(x**2 + 1)))
assert compogen([Abs(x), x**2 + 3*x - 4, cos(x)], x) == Abs(cos(x)**2 +
3*cos(x) - 4)
assert compogen([x**2 + x - sqrt(3)/2, sin(x)], x) == (sin(x)**2 + sin(x) -
sqrt(3)/2)
assert compogen([Abs(x), 3*x + cos(y)**2 - 4, cos(x)], x) == \
Abs(3*cos(x) + cos(y)**2 - 4)
assert compogen([x**2 + 2*x + 1, x**2], x) == x**4 + 2*x**2 + 1
# the result is in unsimplified form
assert compogen([x**2 - x - 1, x**2 + x], x) == -x**2 - x + (x**2 + x)**2 - 1
@@ -0,0 +1,500 @@
"""Tests for tools for solving inequalities and systems of inequalities. """
from sympy.concrete.summations import Sum
from sympy.core.function import Function
from sympy.core.numbers import I, Rational, oo, pi
from sympy.core.relational import Eq, Ge, Gt, Le, Lt, Ne
from sympy.core.singleton import S
from sympy.core.symbol import (Dummy, Symbol)
from sympy.functions.elementary.complexes import Abs
from sympy.functions.elementary.exponential import exp, log
from sympy.functions.elementary.miscellaneous import root, sqrt
from sympy.functions.elementary.piecewise import Piecewise
from sympy.functions.elementary.trigonometric import cos, sin, tan
from sympy.integrals.integrals import Integral
from sympy.logic.boolalg import And, Or
from sympy.polys.polytools import Poly, PurePoly
from sympy.sets.sets import FiniteSet, Interval, Union
from sympy.solvers.inequalities import (reduce_inequalities,
solve_poly_inequality as psolve,
reduce_rational_inequalities,
solve_univariate_inequality as isolve,
reduce_abs_inequality,
_solve_inequality)
from sympy.polys.rootoftools import rootof
from sympy.solvers.solvers import solve
from sympy.solvers.solveset import solveset
from sympy.core.mod import Mod
from sympy.abc import x, y
from sympy.testing.pytest import raises, XFAIL
inf = oo.evalf()
def test_solve_poly_inequality():
assert psolve(Poly(0, x), '==') == [S.Reals]
assert psolve(Poly(1, x), '==') == [S.EmptySet]
assert psolve(PurePoly(x + 1, x), ">") == [Interval(-1, oo, True, False)]
def test_reduce_poly_inequalities_real_interval():
assert reduce_rational_inequalities(
[[Eq(x**2, 0)]], x, relational=False) == FiniteSet(0)
assert reduce_rational_inequalities(
[[Le(x**2, 0)]], x, relational=False) == FiniteSet(0)
assert reduce_rational_inequalities(
[[Lt(x**2, 0)]], x, relational=False) == S.EmptySet
assert reduce_rational_inequalities(
[[Ge(x**2, 0)]], x, relational=False) == \
S.Reals if x.is_real else Interval(-oo, oo)
assert reduce_rational_inequalities(
[[Gt(x**2, 0)]], x, relational=False) == \
FiniteSet(0).complement(S.Reals)
assert reduce_rational_inequalities(
[[Ne(x**2, 0)]], x, relational=False) == \
FiniteSet(0).complement(S.Reals)
assert reduce_rational_inequalities(
[[Eq(x**2, 1)]], x, relational=False) == FiniteSet(-1, 1)
assert reduce_rational_inequalities(
[[Le(x**2, 1)]], x, relational=False) == Interval(-1, 1)
assert reduce_rational_inequalities(
[[Lt(x**2, 1)]], x, relational=False) == Interval(-1, 1, True, True)
assert reduce_rational_inequalities(
[[Ge(x**2, 1)]], x, relational=False) == \
Union(Interval(-oo, -1), Interval(1, oo))
assert reduce_rational_inequalities(
[[Gt(x**2, 1)]], x, relational=False) == \
Interval(-1, 1).complement(S.Reals)
assert reduce_rational_inequalities(
[[Ne(x**2, 1)]], x, relational=False) == \
FiniteSet(-1, 1).complement(S.Reals)
assert reduce_rational_inequalities([[Eq(
x**2, 1.0)]], x, relational=False) == FiniteSet(-1.0, 1.0).evalf()
assert reduce_rational_inequalities(
[[Le(x**2, 1.0)]], x, relational=False) == Interval(-1.0, 1.0)
assert reduce_rational_inequalities([[Lt(
x**2, 1.0)]], x, relational=False) == Interval(-1.0, 1.0, True, True)
assert reduce_rational_inequalities(
[[Ge(x**2, 1.0)]], x, relational=False) == \
Union(Interval(-inf, -1.0), Interval(1.0, inf))
assert reduce_rational_inequalities(
[[Gt(x**2, 1.0)]], x, relational=False) == \
Union(Interval(-inf, -1.0, right_open=True),
Interval(1.0, inf, left_open=True))
assert reduce_rational_inequalities([[Ne(
x**2, 1.0)]], x, relational=False) == \
FiniteSet(-1.0, 1.0).complement(S.Reals)
s = sqrt(2)
assert reduce_rational_inequalities([[Lt(
x**2 - 1, 0), Gt(x**2 - 1, 0)]], x, relational=False) == S.EmptySet
assert reduce_rational_inequalities([[Le(x**2 - 1, 0), Ge(
x**2 - 1, 0)]], x, relational=False) == FiniteSet(-1, 1)
assert reduce_rational_inequalities(
[[Le(x**2 - 2, 0), Ge(x**2 - 1, 0)]], x, relational=False
) == Union(Interval(-s, -1, False, False), Interval(1, s, False, False))
assert reduce_rational_inequalities(
[[Le(x**2 - 2, 0), Gt(x**2 - 1, 0)]], x, relational=False
) == Union(Interval(-s, -1, False, True), Interval(1, s, True, False))
assert reduce_rational_inequalities(
[[Lt(x**2 - 2, 0), Ge(x**2 - 1, 0)]], x, relational=False
) == Union(Interval(-s, -1, True, False), Interval(1, s, False, True))
assert reduce_rational_inequalities(
[[Lt(x**2 - 2, 0), Gt(x**2 - 1, 0)]], x, relational=False
) == Union(Interval(-s, -1, True, True), Interval(1, s, True, True))
assert reduce_rational_inequalities(
[[Lt(x**2 - 2, 0), Ne(x**2 - 1, 0)]], x, relational=False
) == Union(Interval(-s, -1, True, True), Interval(-1, 1, True, True),
Interval(1, s, True, True))
assert reduce_rational_inequalities([[Lt(x**2, -1.)]], x) is S.false
def test_reduce_poly_inequalities_complex_relational():
assert reduce_rational_inequalities(
[[Eq(x**2, 0)]], x, relational=True) == Eq(x, 0)
assert reduce_rational_inequalities(
[[Le(x**2, 0)]], x, relational=True) == Eq(x, 0)
assert reduce_rational_inequalities(
[[Lt(x**2, 0)]], x, relational=True) == False
assert reduce_rational_inequalities(
[[Ge(x**2, 0)]], x, relational=True) == And(Lt(-oo, x), Lt(x, oo))
assert reduce_rational_inequalities(
[[Gt(x**2, 0)]], x, relational=True) == \
And(Gt(x, -oo), Lt(x, oo), Ne(x, 0))
assert reduce_rational_inequalities(
[[Ne(x**2, 0)]], x, relational=True) == \
And(Gt(x, -oo), Lt(x, oo), Ne(x, 0))
for one in (S.One, S(1.0)):
inf = one*oo
assert reduce_rational_inequalities(
[[Eq(x**2, one)]], x, relational=True) == \
Or(Eq(x, -one), Eq(x, one))
assert reduce_rational_inequalities(
[[Le(x**2, one)]], x, relational=True) == \
And(And(Le(-one, x), Le(x, one)))
assert reduce_rational_inequalities(
[[Lt(x**2, one)]], x, relational=True) == \
And(And(Lt(-one, x), Lt(x, one)))
assert reduce_rational_inequalities(
[[Ge(x**2, one)]], x, relational=True) == \
And(Or(And(Le(one, x), Lt(x, inf)), And(Le(x, -one), Lt(-inf, x))))
assert reduce_rational_inequalities(
[[Gt(x**2, one)]], x, relational=True) == \
And(Or(And(Lt(-inf, x), Lt(x, -one)), And(Lt(one, x), Lt(x, inf))))
assert reduce_rational_inequalities(
[[Ne(x**2, one)]], x, relational=True) == \
Or(And(Lt(-inf, x), Lt(x, -one)),
And(Lt(-one, x), Lt(x, one)),
And(Lt(one, x), Lt(x, inf)))
def test_reduce_rational_inequalities_real_relational():
assert reduce_rational_inequalities([], x) == False
assert reduce_rational_inequalities(
[[(x**2 + 3*x + 2)/(x**2 - 16) >= 0]], x, relational=False) == \
Union(Interval.open(-oo, -4), Interval(-2, -1), Interval.open(4, oo))
assert reduce_rational_inequalities(
[[((-2*x - 10)*(3 - x))/((x**2 + 5)*(x - 2)**2) < 0]], x,
relational=False) == \
Union(Interval.open(-5, 2), Interval.open(2, 3))
assert reduce_rational_inequalities([[(x + 1)/(x - 5) <= 0]], x,
relational=False) == \
Interval.Ropen(-1, 5)
assert reduce_rational_inequalities([[(x**2 + 4*x + 3)/(x - 1) > 0]], x,
relational=False) == \
Union(Interval.open(-3, -1), Interval.open(1, oo))
assert reduce_rational_inequalities([[(x**2 - 16)/(x - 1)**2 < 0]], x,
relational=False) == \
Union(Interval.open(-4, 1), Interval.open(1, 4))
assert reduce_rational_inequalities([[(3*x + 1)/(x + 4) >= 1]], x,
relational=False) == \
Union(Interval.open(-oo, -4), Interval.Ropen(Rational(3, 2), oo))
assert reduce_rational_inequalities([[(x - 8)/x <= 3 - x]], x,
relational=False) == \
Union(Interval.Lopen(-oo, -2), Interval.Lopen(0, 4))
# issue sympy/sympy#10237
assert reduce_rational_inequalities(
[[x < oo, x >= 0, -oo < x]], x, relational=False) == Interval(0, oo)
def test_reduce_abs_inequalities():
e = abs(x - 5) < 3
ans = And(Lt(2, x), Lt(x, 8))
assert reduce_inequalities(e) == ans
assert reduce_inequalities(e, x) == ans
assert reduce_inequalities(abs(x - 5)) == Eq(x, 5)
assert reduce_inequalities(
abs(2*x + 3) >= 8) == Or(And(Le(Rational(5, 2), x), Lt(x, oo)),
And(Le(x, Rational(-11, 2)), Lt(-oo, x)))
assert reduce_inequalities(abs(x - 4) + abs(
3*x - 5) < 7) == And(Lt(S.Half, x), Lt(x, 4))
assert reduce_inequalities(abs(x - 4) + abs(3*abs(x) - 5) < 7) == \
Or(And(S(-2) < x, x < -1), And(S.Half < x, x < 4))
nr = Symbol('nr', extended_real=False)
raises(TypeError, lambda: reduce_inequalities(abs(nr - 5) < 3))
assert reduce_inequalities(x < 3, symbols=[x, nr]) == And(-oo < x, x < 3)
def test_reduce_inequalities_general():
assert reduce_inequalities(Ge(sqrt(2)*x, 1)) == And(sqrt(2)/2 <= x, x < oo)
assert reduce_inequalities(x + 1 > 0) == And(S.NegativeOne < x, x < oo)
def test_reduce_inequalities_boolean():
assert reduce_inequalities(
[Eq(x**2, 0), True]) == Eq(x, 0)
assert reduce_inequalities([Eq(x**2, 0), False]) == False
assert reduce_inequalities(x**2 >= 0) is S.true # issue 10196
def test_reduce_inequalities_multivariate():
assert reduce_inequalities([Ge(x**2, 1), Ge(y**2, 1)]) == And(
Or(And(Le(S.One, x), Lt(x, oo)), And(Le(x, -1), Lt(-oo, x))),
Or(And(Le(S.One, y), Lt(y, oo)), And(Le(y, -1), Lt(-oo, y))))
def test_reduce_inequalities_errors():
raises(NotImplementedError, lambda: reduce_inequalities(Ge(sin(x) + x, 1)))
raises(NotImplementedError, lambda: reduce_inequalities(Ge(x**2*y + y, 1)))
def test__solve_inequalities():
assert reduce_inequalities(x + y < 1, symbols=[x]) == (x < 1 - y)
assert reduce_inequalities(x + y >= 1, symbols=[x]) == (x < oo) & (x >= -y + 1)
assert reduce_inequalities(Eq(0, x - y), symbols=[x]) == Eq(x, y)
assert reduce_inequalities(Ne(0, x - y), symbols=[x]) == Ne(x, y)
def test_issue_6343():
eq = -3*x**2/2 - x*Rational(45, 4) + Rational(33, 2) > 0
assert reduce_inequalities(eq) == \
And(x < Rational(-15, 4) + sqrt(401)/4, -sqrt(401)/4 - Rational(15, 4) < x)
def test_issue_8235():
assert reduce_inequalities(x**2 - 1 < 0) == \
And(S.NegativeOne < x, x < 1)
assert reduce_inequalities(x**2 - 1 <= 0) == \
And(S.NegativeOne <= x, x <= 1)
assert reduce_inequalities(x**2 - 1 > 0) == \
Or(And(-oo < x, x < -1), And(x < oo, S.One < x))
assert reduce_inequalities(x**2 - 1 >= 0) == \
Or(And(-oo < x, x <= -1), And(S.One <= x, x < oo))
eq = x**8 + x - 9 # we want CRootOf solns here
sol = solve(eq >= 0)
tru = Or(And(rootof(eq, 1) <= x, x < oo), And(-oo < x, x <= rootof(eq, 0)))
assert sol == tru
# recast vanilla as real
assert solve(sqrt((-x + 1)**2) < 1) == And(S.Zero < x, x < 2)
def test_issue_5526():
assert reduce_inequalities(0 <=
x + Integral(y**2, (y, 1, 3)) - 1, [x]) == \
(x >= -Integral(y**2, (y, 1, 3)) + 1)
f = Function('f')
e = Sum(f(x), (x, 1, 3))
assert reduce_inequalities(0 <= x + e + y**2, [x]) == \
(x >= -y**2 - Sum(f(x), (x, 1, 3)))
def test_solve_univariate_inequality():
assert isolve(x**2 >= 4, x, relational=False) == Union(Interval(-oo, -2),
Interval(2, oo))
assert isolve(x**2 >= 4, x) == Or(And(Le(2, x), Lt(x, oo)), And(Le(x, -2),
Lt(-oo, x)))
assert isolve((x - 1)*(x - 2)*(x - 3) >= 0, x, relational=False) == \
Union(Interval(1, 2), Interval(3, oo))
assert isolve((x - 1)*(x - 2)*(x - 3) >= 0, x) == \
Or(And(Le(1, x), Le(x, 2)), And(Le(3, x), Lt(x, oo)))
assert isolve((x - 1)*(x - 2)*(x - 4) < 0, x, domain = FiniteSet(0, 3)) == \
Or(Eq(x, 0), Eq(x, 3))
# issue 2785:
assert isolve(x**3 - 2*x - 1 > 0, x, relational=False) == \
Union(Interval(-1, -sqrt(5)/2 + S.Half, True, True),
Interval(S.Half + sqrt(5)/2, oo, True, True))
# issue 2794:
assert isolve(x**3 - x**2 + x - 1 > 0, x, relational=False) == \
Interval(1, oo, True)
#issue 13105
assert isolve((x + I)*(x + 2*I) < 0, x) == Eq(x, 0)
assert isolve(((x - 1)*(x - 2) + I)*((x - 1)*(x - 2) + 2*I) < 0, x) == Or(Eq(x, 1), Eq(x, 2))
assert isolve((((x - 1)*(x - 2) + I)*((x - 1)*(x - 2) + 2*I))/(x - 2) > 0, x) == Eq(x, 1)
raises (ValueError, lambda: isolve((x**2 - 3*x*I + 2)/x < 0, x))
# numerical testing in valid() is needed
assert isolve(x**7 - x - 2 > 0, x) == \
And(rootof(x**7 - x - 2, 0) < x, x < oo)
# handle numerator and denominator; although these would be handled as
# rational inequalities, these test confirm that the right thing is done
# when the domain is EX (e.g. when 2 is replaced with sqrt(2))
assert isolve(1/(x - 2) > 0, x) == And(S(2) < x, x < oo)
den = ((x - 1)*(x - 2)).expand()
assert isolve((x - 1)/den <= 0, x) == \
(x > -oo) & (x < 2) & Ne(x, 1)
n = Dummy('n')
raises(NotImplementedError, lambda: isolve(Abs(x) <= n, x, relational=False))
c1 = Dummy("c1", positive=True)
raises(NotImplementedError, lambda: isolve(n/c1 < 0, c1))
n = Dummy('n', negative=True)
assert isolve(n/c1 > -2, c1) == (-n/2 < c1)
assert isolve(n/c1 < 0, c1) == True
assert isolve(n/c1 > 0, c1) == False
zero = cos(1)**2 + sin(1)**2 - 1
raises(NotImplementedError, lambda: isolve(x**2 < zero, x))
raises(NotImplementedError, lambda: isolve(
x**2 < zero*I, x))
raises(NotImplementedError, lambda: isolve(1/(x - y) < 2, x))
raises(NotImplementedError, lambda: isolve(1/(x - y) < 0, x))
raises(TypeError, lambda: isolve(x - I < 0, x))
zero = x**2 + x - x*(x + 1)
assert isolve(zero < 0, x, relational=False) is S.EmptySet
assert isolve(zero <= 0, x, relational=False) is S.Reals
# make sure iter_solutions gets a default value
raises(NotImplementedError, lambda: isolve(
Eq(cos(x)**2 + sin(x)**2, 1), x))
def test_trig_inequalities():
# all the inequalities are solved in a periodic interval.
assert isolve(sin(x) < S.Half, x, relational=False) == \
Union(Interval(0, pi/6, False, True), Interval.open(pi*Rational(5, 6), 2*pi))
assert isolve(sin(x) > S.Half, x, relational=False) == \
Interval(pi/6, pi*Rational(5, 6), True, True)
assert isolve(cos(x) < S.Zero, x, relational=False) == \
Interval(pi/2, pi*Rational(3, 2), True, True)
assert isolve(cos(x) >= S.Zero, x, relational=False) == \
Union(Interval(0, pi/2), Interval.Ropen(pi*Rational(3, 2), 2*pi))
assert isolve(tan(x) < S.One, x, relational=False) == \
Union(Interval.Ropen(0, pi/4), Interval.open(pi/2, pi))
assert isolve(sin(x) <= S.Zero, x, relational=False) == \
Union(FiniteSet(S.Zero), Interval.Ropen(pi, 2*pi))
assert isolve(sin(x) <= S.One, x, relational=False) == S.Reals
assert isolve(cos(x) < S(-2), x, relational=False) == S.EmptySet
assert isolve(sin(x) >= S.NegativeOne, x, relational=False) == S.Reals
assert isolve(cos(x) > S.One, x, relational=False) == S.EmptySet
def test_issue_9954():
assert isolve(x**2 >= 0, x, relational=False) == S.Reals
assert isolve(x**2 >= 0, x, relational=True) == S.Reals.as_relational(x)
assert isolve(x**2 < 0, x, relational=False) == S.EmptySet
assert isolve(x**2 < 0, x, relational=True) == S.EmptySet.as_relational(x)
@XFAIL
def test_slow_general_univariate():
r = rootof(x**5 - x**2 + 1, 0)
assert solve(sqrt(x) + 1/root(x, 3) > 1) == \
Or(And(0 < x, x < r**6), And(r**6 < x, x < oo))
def test_issue_8545():
eq = 1 - x - abs(1 - x)
ans = And(Lt(1, x), Lt(x, oo))
assert reduce_abs_inequality(eq, '<', x) == ans
eq = 1 - x - sqrt((1 - x)**2)
assert reduce_inequalities(eq < 0) == ans
def test_issue_8974():
assert isolve(-oo < x, x) == And(-oo < x, x < oo)
assert isolve(oo > x, x) == And(-oo < x, x < oo)
def test_issue_10198():
assert reduce_inequalities(
-1 + 1/abs(1/x - 1) < 0) == (x > -oo) & (x < S(1)/2) & Ne(x, 0)
assert reduce_inequalities(abs(1/sqrt(x)) - 1, x) == Eq(x, 1)
assert reduce_abs_inequality(-3 + 1/abs(1 - 1/x), '<', x) == \
Or(And(-oo < x, x < 0),
And(S.Zero < x, x < Rational(3, 4)), And(Rational(3, 2) < x, x < oo))
raises(ValueError,lambda: reduce_abs_inequality(-3 + 1/abs(
1 - 1/sqrt(x)), '<', x))
def test_issue_10047():
# issue 10047: this must remain an inequality, not True, since if x
# is not real the inequality is invalid
# assert solve(sin(x) < 2) == (x <= oo)
# with PR 16956, (x <= oo) autoevaluates when x is extended_real
# which is assumed in the current implementation of inequality solvers
assert solve(sin(x) < 2) == True
assert solveset(sin(x) < 2, domain=S.Reals) == S.Reals
def test_issue_10268():
assert solve(log(x) < 1000) == And(S.Zero < x, x < exp(1000))
@XFAIL
def test_isolve_Sets():
n = Dummy('n')
assert isolve(Abs(x) <= n, x, relational=False) == \
Piecewise((S.EmptySet, n < 0), (Interval(-n, n), True))
def test_integer_domain_relational_isolve():
dom = FiniteSet(0, 3)
x = Symbol('x',zero=False)
assert isolve((x - 1)*(x - 2)*(x - 4) < 0, x, domain=dom) == Eq(x, 3)
x = Symbol('x')
assert isolve(x + 2 < 0, x, domain=S.Integers) == \
(x <= -3) & (x > -oo) & Eq(Mod(x, 1), 0)
assert isolve(2 * x + 3 > 0, x, domain=S.Integers) == \
(x >= -1) & (x < oo) & Eq(Mod(x, 1), 0)
assert isolve((x ** 2 + 3 * x - 2) < 0, x, domain=S.Integers) == \
(x >= -3) & (x <= 0) & Eq(Mod(x, 1), 0)
assert isolve((x ** 2 + 3 * x - 2) > 0, x, domain=S.Integers) == \
((x >= 1) & (x < oo) & Eq(Mod(x, 1), 0)) | (
(x <= -4) & (x > -oo) & Eq(Mod(x, 1), 0))
def test_issue_10671_12466():
assert solveset(sin(y), y, Interval(0, pi)) == FiniteSet(0, pi)
i = Interval(1, 10)
assert solveset((1/x).diff(x) < 0, x, i) == i
assert solveset((log(x - 6)/x) <= 0, x, S.Reals) == \
Interval.Lopen(6, 7)
def test__solve_inequality():
for op in (Gt, Lt, Le, Ge, Eq, Ne):
assert _solve_inequality(op(x, 1), x).lhs == x
assert _solve_inequality(op(S.One, x), x).lhs == x
# don't get tricked by symbol on right: solve it
assert _solve_inequality(Eq(2*x - 1, x), x) == Eq(x, 1)
ie = Eq(S.One, y)
assert _solve_inequality(ie, x) == ie
for fx in (x**2, exp(x), sin(x) + cos(x), x*(1 + x)):
for c in (0, 1):
e = 2*fx - c > 0
assert _solve_inequality(e, x, linear=True) == (
fx > c/S(2))
assert _solve_inequality(2*x**2 + 2*x - 1 < 0, x, linear=True) == (
x*(x + 1) < S.Half)
assert _solve_inequality(Eq(x*y, 1), x) == Eq(x*y, 1)
nz = Symbol('nz', nonzero=True)
assert _solve_inequality(Eq(x*nz, 1), x) == Eq(x, 1/nz)
assert _solve_inequality(x*nz < 1, x) == (x*nz < 1)
a = Symbol('a', positive=True)
assert _solve_inequality(a/x > 1, x) == (S.Zero < x) & (x < a)
assert _solve_inequality(a/x > 1, x, linear=True) == (1/x > 1/a)
# make sure to include conditions under which solution is valid
e = Eq(1 - x, x*(1/x - 1))
assert _solve_inequality(e, x) == Ne(x, 0)
assert _solve_inequality(x < x*(1/x - 1), x) == (x < S.Half) & Ne(x, 0)
def test__pt():
from sympy.solvers.inequalities import _pt
assert _pt(-oo, oo) == 0
assert _pt(S.One, S(3)) == 2
assert _pt(S.One, oo) == _pt(oo, S.One) == 2
assert _pt(S.One, -oo) == _pt(-oo, S.One) == S.Half
assert _pt(S.NegativeOne, oo) == _pt(oo, S.NegativeOne) == Rational(-1, 2)
assert _pt(S.NegativeOne, -oo) == _pt(-oo, S.NegativeOne) == -2
assert _pt(x, oo) == _pt(oo, x) == x + 1
assert _pt(x, -oo) == _pt(-oo, x) == x - 1
raises(ValueError, lambda: _pt(Dummy('i', infinite=True), S.One))
def test_issue_25697():
assert _solve_inequality(log(x, 3) <= 2, x) == (x <= 9) & (S.Zero < x)
def test_issue_25738():
assert reduce_inequalities(3 < abs(x)
) == reduce_inequalities(pi < abs(x)).subs(pi, 3)
def test_issue_25983():
assert(reduce_inequalities(pi/Abs(x) <= 1) == ((pi <= x) & (x < oo)) | ((-oo < x) & (x <= -pi)))
@@ -0,0 +1,139 @@
from sympy.core.function import nfloat
from sympy.core.numbers import (Float, I, Rational, pi)
from sympy.core.relational import Eq
from sympy.core.symbol import (Symbol, symbols)
from sympy.functions.elementary.miscellaneous import sqrt
from sympy.functions.elementary.piecewise import Piecewise
from sympy.functions.elementary.trigonometric import sin
from sympy.integrals.integrals import Integral
from sympy.matrices.dense import Matrix
from mpmath import mnorm, mpf
from sympy.solvers import nsolve
from sympy.utilities.lambdify import lambdify
from sympy.testing.pytest import raises, XFAIL
from sympy.utilities.decorator import conserve_mpmath_dps
@XFAIL
def test_nsolve_fail():
x = symbols('x')
# Sometimes it is better to use the numerator (issue 4829)
# but sometimes it is not (issue 11768) so leave this to
# the discretion of the user
ans = nsolve(x**2/(1 - x)/(1 - 2*x)**2 - 100, x, 0)
assert ans > 0.46 and ans < 0.47
def test_nsolve_denominator():
x = symbols('x')
# Test that nsolve uses the full expression (numerator and denominator).
ans = nsolve((x**2 + 3*x + 2)/(x + 2), -2.1)
# The root -2 was divided out, so make sure we don't find it.
assert ans == -1.0
def test_nsolve():
# onedimensional
x = Symbol('x')
assert nsolve(sin(x), 2) - pi.evalf() < 1e-15
assert nsolve(Eq(2*x, 2), x, -10) == nsolve(2*x - 2, -10)
# Testing checks on number of inputs
raises(TypeError, lambda: nsolve(Eq(2*x, 2)))
raises(TypeError, lambda: nsolve(Eq(2*x, 2), x, 1, 2))
# multidimensional
x1 = Symbol('x1')
x2 = Symbol('x2')
f1 = 3 * x1**2 - 2 * x2**2 - 1
f2 = x1**2 - 2 * x1 + x2**2 + 2 * x2 - 8
f = Matrix((f1, f2)).T
F = lambdify((x1, x2), f.T, modules='mpmath')
for x0 in [(-1, 1), (1, -2), (4, 4), (-4, -4)]:
x = nsolve(f, (x1, x2), x0, tol=1.e-8)
assert mnorm(F(*x), 1) <= 1.e-10
# The Chinese mathematician Zhu Shijie was the very first to solve this
# nonlinear system 700 years ago (z was added to make it 3-dimensional)
x = Symbol('x')
y = Symbol('y')
z = Symbol('z')
f1 = -x + 2*y
f2 = (x**2 + x*(y**2 - 2) - 4*y) / (x + 4)
f3 = sqrt(x**2 + y**2)*z
f = Matrix((f1, f2, f3)).T
F = lambdify((x, y, z), f.T, modules='mpmath')
def getroot(x0):
root = nsolve(f, (x, y, z), x0)
assert mnorm(F(*root), 1) <= 1.e-8
return root
assert list(map(round, getroot((1, 1, 1)))) == [2, 1, 0]
assert nsolve([Eq(
f1, 0), Eq(f2, 0), Eq(f3, 0)], [x, y, z], (1, 1, 1)) # just see that it works
a = Symbol('a')
assert abs(nsolve(1/(0.001 + a)**3 - 6/(0.9 - a)**3, a, 0.3) -
mpf('0.31883011387318591')) < 1e-15
def test_issue_6408():
x = Symbol('x')
assert nsolve(Piecewise((x, x < 1), (x**2, True)), x, 2) == 0
def test_issue_6408_integral():
x, y = symbols('x y')
assert nsolve(Integral(x*y, (x, 0, 5)), y, 2) == 0
@conserve_mpmath_dps
def test_increased_dps():
# Issue 8564
import mpmath
mpmath.mp.dps = 128
x = Symbol('x')
e1 = x**2 - pi
q = nsolve(e1, x, 3.0)
assert abs(sqrt(pi).evalf(128) - q) < 1e-128
def test_nsolve_precision():
x, y = symbols('x y')
sol = nsolve(x**2 - pi, x, 3, prec=128)
assert abs(sqrt(pi).evalf(128) - sol) < 1e-128
assert isinstance(sol, Float)
sols = nsolve((y**2 - x, x**2 - pi), (x, y), (3, 3), prec=128)
assert isinstance(sols, Matrix)
assert sols.shape == (2, 1)
assert abs(sqrt(pi).evalf(128) - sols[0]) < 1e-128
assert abs(sqrt(sqrt(pi)).evalf(128) - sols[1]) < 1e-128
assert all(isinstance(i, Float) for i in sols)
def test_nsolve_complex():
x, y = symbols('x y')
assert nsolve(x**2 + 2, 1j) == sqrt(2.)*I
assert nsolve(x**2 + 2, I) == sqrt(2.)*I
assert nsolve([x**2 + 2, y**2 + 2], [x, y], [I, I]) == Matrix([sqrt(2.)*I, sqrt(2.)*I])
assert nsolve([x**2 + 2, y**2 + 2], [x, y], [I, I]) == Matrix([sqrt(2.)*I, sqrt(2.)*I])
def test_nsolve_dict_kwarg():
x, y = symbols('x y')
# one variable
assert nsolve(x**2 - 2, 1, dict = True) == \
[{x: sqrt(2.)}]
# one variable with complex solution
assert nsolve(x**2 + 2, I, dict = True) == \
[{x: sqrt(2.)*I}]
# two variables
assert nsolve([x**2 + y**2 - 5, x**2 - y**2 + 1], [x, y], [1, 1], dict = True) == \
[{x: sqrt(2.), y: sqrt(3.)}]
def test_nsolve_rational():
x = symbols('x')
assert nsolve(x - Rational(1, 3), 0, prec=100) == Rational(1, 3).evalf(100)
def test_issue_14950():
x = Matrix(symbols('t s'))
x0 = Matrix([17, 23])
eqn = x + x0
assert nsolve(eqn, x, x0) == nfloat(-x0)
assert nsolve(eqn.T, x.T, x0.T) == nfloat(-x0)
@@ -0,0 +1,239 @@
from sympy.core.function import (Derivative as D, Function)
from sympy.core.relational import Eq
from sympy.core.symbol import (Symbol, symbols)
from sympy.functions.elementary.exponential import (exp, log)
from sympy.functions.elementary.trigonometric import (cos, sin)
from sympy.core import S
from sympy.solvers.pde import (pde_separate, pde_separate_add, pde_separate_mul,
pdsolve, classify_pde, checkpdesol)
from sympy.testing.pytest import raises
a, b, c, x, y = symbols('a b c x y')
def test_pde_separate_add():
x, y, z, t = symbols("x,y,z,t")
F, T, X, Y, Z, u = map(Function, 'FTXYZu')
eq = Eq(D(u(x, t), x), D(u(x, t), t)*exp(u(x, t)))
res = pde_separate_add(eq, u(x, t), [X(x), T(t)])
assert res == [D(X(x), x)*exp(-X(x)), D(T(t), t)*exp(T(t))]
def test_pde_separate():
x, y, z, t = symbols("x,y,z,t")
F, T, X, Y, Z, u = map(Function, 'FTXYZu')
eq = Eq(D(u(x, t), x), D(u(x, t), t)*exp(u(x, t)))
raises(ValueError, lambda: pde_separate(eq, u(x, t), [X(x), T(t)], 'div'))
def test_pde_separate_mul():
x, y, z, t = symbols("x,y,z,t")
c = Symbol("C", real=True)
Phi = Function('Phi')
F, R, T, X, Y, Z, u = map(Function, 'FRTXYZu')
r, theta, z = symbols('r,theta,z')
# Something simple :)
eq = Eq(D(F(x, y, z), x) + D(F(x, y, z), y) + D(F(x, y, z), z), 0)
# Duplicate arguments in functions
raises(
ValueError, lambda: pde_separate_mul(eq, F(x, y, z), [X(x), u(z, z)]))
# Wrong number of arguments
raises(ValueError, lambda: pde_separate_mul(eq, F(x, y, z), [X(x), Y(y)]))
# Wrong variables: [x, y] -> [x, z]
raises(
ValueError, lambda: pde_separate_mul(eq, F(x, y, z), [X(t), Y(x, y)]))
assert pde_separate_mul(eq, F(x, y, z), [Y(y), u(x, z)]) == \
[D(Y(y), y)/Y(y), -D(u(x, z), x)/u(x, z) - D(u(x, z), z)/u(x, z)]
assert pde_separate_mul(eq, F(x, y, z), [X(x), Y(y), Z(z)]) == \
[D(X(x), x)/X(x), -D(Z(z), z)/Z(z) - D(Y(y), y)/Y(y)]
# wave equation
wave = Eq(D(u(x, t), t, t), c**2*D(u(x, t), x, x))
res = pde_separate_mul(wave, u(x, t), [X(x), T(t)])
assert res == [D(X(x), x, x)/X(x), D(T(t), t, t)/(c**2*T(t))]
# Laplace equation in cylindrical coords
eq = Eq(1/r * D(Phi(r, theta, z), r) + D(Phi(r, theta, z), r, 2) +
1/r**2 * D(Phi(r, theta, z), theta, 2) + D(Phi(r, theta, z), z, 2), 0)
# Separate z
res = pde_separate_mul(eq, Phi(r, theta, z), [Z(z), u(theta, r)])
assert res == [D(Z(z), z, z)/Z(z),
-D(u(theta, r), r, r)/u(theta, r) -
D(u(theta, r), r)/(r*u(theta, r)) -
D(u(theta, r), theta, theta)/(r**2*u(theta, r))]
# Lets use the result to create a new equation...
eq = Eq(res[1], c)
# ...and separate theta...
res = pde_separate_mul(eq, u(theta, r), [T(theta), R(r)])
assert res == [D(T(theta), theta, theta)/T(theta),
-r*D(R(r), r)/R(r) - r**2*D(R(r), r, r)/R(r) - c*r**2]
# ...or r...
res = pde_separate_mul(eq, u(theta, r), [R(r), T(theta)])
assert res == [r*D(R(r), r)/R(r) + r**2*D(R(r), r, r)/R(r) + c*r**2,
-D(T(theta), theta, theta)/T(theta)]
def test_issue_11726():
x, t = symbols("x t")
f = symbols("f", cls=Function)
X, T = symbols("X T", cls=Function)
u = f(x, t)
eq = u.diff(x, 2) - u.diff(t, 2)
res = pde_separate(eq, u, [T(x), X(t)])
assert res == [D(T(x), x, x)/T(x),D(X(t), t, t)/X(t)]
def test_pde_classify():
# When more number of hints are added, add tests for classifying here.
f = Function('f')
eq1 = a*f(x,y) + b*f(x,y).diff(x) + c*f(x,y).diff(y)
eq2 = 3*f(x,y) + 2*f(x,y).diff(x) + f(x,y).diff(y)
eq3 = a*f(x,y) + b*f(x,y).diff(x) + 2*f(x,y).diff(y)
eq4 = x*f(x,y) + f(x,y).diff(x) + 3*f(x,y).diff(y)
eq5 = x**2*f(x,y) + x*f(x,y).diff(x) + x*y*f(x,y).diff(y)
eq6 = y*x**2*f(x,y) + y*f(x,y).diff(x) + f(x,y).diff(y)
for eq in [eq1, eq2, eq3]:
assert classify_pde(eq) == ('1st_linear_constant_coeff_homogeneous',)
for eq in [eq4, eq5, eq6]:
assert classify_pde(eq) == ('1st_linear_variable_coeff',)
def test_checkpdesol():
f, F = map(Function, ['f', 'F'])
eq1 = a*f(x,y) + b*f(x,y).diff(x) + c*f(x,y).diff(y)
eq2 = 3*f(x,y) + 2*f(x,y).diff(x) + f(x,y).diff(y)
eq3 = a*f(x,y) + b*f(x,y).diff(x) + 2*f(x,y).diff(y)
for eq in [eq1, eq2, eq3]:
assert checkpdesol(eq, pdsolve(eq))[0]
eq4 = x*f(x,y) + f(x,y).diff(x) + 3*f(x,y).diff(y)
eq5 = 2*f(x,y) + 1*f(x,y).diff(x) + 3*f(x,y).diff(y)
eq6 = f(x,y) + 1*f(x,y).diff(x) + 3*f(x,y).diff(y)
assert checkpdesol(eq4, [pdsolve(eq5), pdsolve(eq6)]) == [
(False, (x - 2)*F(3*x - y)*exp(-x/S(5) - 3*y/S(5))),
(False, (x - 1)*F(3*x - y)*exp(-x/S(10) - 3*y/S(10)))]
for eq in [eq4, eq5, eq6]:
assert checkpdesol(eq, pdsolve(eq))[0]
sol = pdsolve(eq4)
sol4 = Eq(sol.lhs - sol.rhs, 0)
raises(NotImplementedError, lambda:
checkpdesol(eq4, sol4, solve_for_func=False))
def test_solvefun():
f, F, G, H = map(Function, ['f', 'F', 'G', 'H'])
eq1 = f(x,y) + f(x,y).diff(x) + f(x,y).diff(y)
assert pdsolve(eq1) == Eq(f(x, y), F(x - y)*exp(-x/2 - y/2))
assert pdsolve(eq1, solvefun=G) == Eq(f(x, y), G(x - y)*exp(-x/2 - y/2))
assert pdsolve(eq1, solvefun=H) == Eq(f(x, y), H(x - y)*exp(-x/2 - y/2))
def test_pde_1st_linear_constant_coeff_homogeneous():
f, F = map(Function, ['f', 'F'])
u = f(x, y)
eq = 2*u + u.diff(x) + u.diff(y)
assert classify_pde(eq) == ('1st_linear_constant_coeff_homogeneous',)
sol = pdsolve(eq)
assert sol == Eq(u, F(x - y)*exp(-x - y))
assert checkpdesol(eq, sol)[0]
eq = 4 + (3*u.diff(x)/u) + (2*u.diff(y)/u)
assert classify_pde(eq) == ('1st_linear_constant_coeff_homogeneous',)
sol = pdsolve(eq)
assert sol == Eq(u, F(2*x - 3*y)*exp(-S(12)*x/13 - S(8)*y/13))
assert checkpdesol(eq, sol)[0]
eq = u + (6*u.diff(x)) + (7*u.diff(y))
assert classify_pde(eq) == ('1st_linear_constant_coeff_homogeneous',)
sol = pdsolve(eq)
assert sol == Eq(u, F(7*x - 6*y)*exp(-6*x/S(85) - 7*y/S(85)))
assert checkpdesol(eq, sol)[0]
eq = a*u + b*u.diff(x) + c*u.diff(y)
sol = pdsolve(eq)
assert checkpdesol(eq, sol)[0]
def test_pde_1st_linear_constant_coeff():
f, F = map(Function, ['f', 'F'])
u = f(x,y)
eq = -2*u.diff(x) + 4*u.diff(y) + 5*u - exp(x + 3*y)
sol = pdsolve(eq)
assert sol == Eq(f(x,y),
(F(4*x + 2*y)*exp(x/2) + exp(x + 4*y)/15)*exp(-y))
assert classify_pde(eq) == ('1st_linear_constant_coeff',
'1st_linear_constant_coeff_Integral')
assert checkpdesol(eq, sol)[0]
eq = (u.diff(x)/u) + (u.diff(y)/u) + 1 - (exp(x + y)/u)
sol = pdsolve(eq)
assert sol == Eq(f(x, y), F(x - y)*exp(-x/2 - y/2) + exp(x + y)/3)
assert classify_pde(eq) == ('1st_linear_constant_coeff',
'1st_linear_constant_coeff_Integral')
assert checkpdesol(eq, sol)[0]
eq = 2*u + -u.diff(x) + 3*u.diff(y) + sin(x)
sol = pdsolve(eq)
assert sol == Eq(f(x, y),
F(3*x + y)*exp(x/5 - 3*y/5) - 2*sin(x)/5 - cos(x)/5)
assert classify_pde(eq) == ('1st_linear_constant_coeff',
'1st_linear_constant_coeff_Integral')
assert checkpdesol(eq, sol)[0]
eq = u + u.diff(x) + u.diff(y) + x*y
sol = pdsolve(eq)
assert sol.expand() == Eq(f(x, y),
x + y + (x - y)**2/4 - (x + y)**2/4 + F(x - y)*exp(-x/2 - y/2) - 2).expand()
assert classify_pde(eq) == ('1st_linear_constant_coeff',
'1st_linear_constant_coeff_Integral')
assert checkpdesol(eq, sol)[0]
eq = u + u.diff(x) + u.diff(y) + log(x)
assert classify_pde(eq) == ('1st_linear_constant_coeff',
'1st_linear_constant_coeff_Integral')
def test_pdsolve_all():
f, F = map(Function, ['f', 'F'])
u = f(x,y)
eq = u + u.diff(x) + u.diff(y) + x**2*y
sol = pdsolve(eq, hint = 'all')
keys = ['1st_linear_constant_coeff',
'1st_linear_constant_coeff_Integral', 'default', 'order']
assert sorted(sol.keys()) == keys
assert sol['order'] == 1
assert sol['default'] == '1st_linear_constant_coeff'
assert sol['1st_linear_constant_coeff'].expand() == Eq(f(x, y),
-x**2*y + x**2 + 2*x*y - 4*x - 2*y + F(x - y)*exp(-x/2 - y/2) + 6).expand()
def test_pdsolve_variable_coeff():
f, F = map(Function, ['f', 'F'])
u = f(x, y)
eq = x*(u.diff(x)) - y*(u.diff(y)) + y**2*u - y**2
sol = pdsolve(eq, hint="1st_linear_variable_coeff")
assert sol == Eq(u, F(x*y)*exp(y**2/2) + 1)
assert checkpdesol(eq, sol)[0]
eq = x**2*u + x*u.diff(x) + x*y*u.diff(y)
sol = pdsolve(eq, hint='1st_linear_variable_coeff')
assert sol == Eq(u, F(y*exp(-x))*exp(-x**2/2))
assert checkpdesol(eq, sol)[0]
eq = y*x**2*u + y*u.diff(x) + u.diff(y)
sol = pdsolve(eq, hint='1st_linear_variable_coeff')
assert sol == Eq(u, F(-2*x + y**2)*exp(-x**3/3))
assert checkpdesol(eq, sol)[0]
eq = exp(x)**2*(u.diff(x)) + y
sol = pdsolve(eq, hint='1st_linear_variable_coeff')
assert sol == Eq(u, y*exp(-2*x)/2 + F(y))
assert checkpdesol(eq, sol)[0]
eq = exp(2*x)*(u.diff(y)) + y*u - u
sol = pdsolve(eq, hint='1st_linear_variable_coeff')
assert sol == Eq(u, F(x)*exp(-y*(y - 2)*exp(-2*x)/2))
@@ -0,0 +1,462 @@
"""Tests for solvers of systems of polynomial equations. """
from sympy.polys.domains import ZZ, QQ_I
from sympy.core.numbers import (I, Integer, Rational)
from sympy.core.singleton import S
from sympy.core.symbol import symbols
from sympy.functions.elementary.miscellaneous import sqrt
from sympy.polys.domains.rationalfield import QQ
from sympy.polys.polyerrors import UnsolvableFactorError
from sympy.polys.polyoptions import Options
from sympy.polys.polytools import Poly
from sympy.polys.rootoftools import CRootOf
from sympy.solvers.solvers import solve
from sympy.utilities.iterables import flatten
from sympy.abc import a, b, c, x, y, z
from sympy.polys import PolynomialError
from sympy.solvers.polysys import (solve_poly_system,
solve_triangulated,
solve_biquadratic, SolveFailed,
solve_generic, factor_system_bool,
factor_system_cond, factor_system_poly,
factor_system, _factor_sets, _factor_sets_slow)
from sympy.polys.polytools import parallel_poly_from_expr
from sympy.testing.pytest import raises
from sympy.core.relational import Eq
from sympy.functions.elementary.trigonometric import sin, cos
from sympy.functions.elementary.exponential import exp
def test_solve_poly_system():
assert solve_poly_system([x - 1], x) == [(S.One,)]
assert solve_poly_system([y - x, y - x - 1], x, y) is None
assert solve_poly_system([y - x**2, y + x**2], x, y) == [(S.Zero, S.Zero)]
assert solve_poly_system([2*x - 3, y*Rational(3, 2) - 2*x, z - 5*y], x, y, z) == \
[(Rational(3, 2), Integer(2), Integer(10))]
assert solve_poly_system([x*y - 2*y, 2*y**2 - x**2], x, y) == \
[(0, 0), (2, -sqrt(2)), (2, sqrt(2))]
assert solve_poly_system([y - x**2, y + x**2 + 1], x, y) == \
[(-I*sqrt(S.Half), Rational(-1, 2)), (I*sqrt(S.Half), Rational(-1, 2))]
f_1 = x**2 + y + z - 1
f_2 = x + y**2 + z - 1
f_3 = x + y + z**2 - 1
a, b = sqrt(2) - 1, -sqrt(2) - 1
assert solve_poly_system([f_1, f_2, f_3], x, y, z) == \
[(0, 0, 1), (0, 1, 0), (1, 0, 0), (a, a, a), (b, b, b)]
solution = [(1, -1), (1, 1)]
assert solve_poly_system([Poly(x**2 - y**2), Poly(x - 1)]) == solution
assert solve_poly_system([x**2 - y**2, x - 1], x, y) == solution
assert solve_poly_system([x**2 - y**2, x - 1]) == solution
assert solve_poly_system(
[x + x*y - 3, y + x*y - 4], x, y) == [(-3, -2), (1, 2)]
raises(NotImplementedError, lambda: solve_poly_system([x**3 - y**3], x, y))
raises(NotImplementedError, lambda: solve_poly_system(
[z, -2*x*y**2 + x + y**2*z, y**2*(-z - 4) + 2]))
raises(PolynomialError, lambda: solve_poly_system([1/x], x))
raises(NotImplementedError, lambda: solve_poly_system(
[x-1,], (x, y)))
raises(NotImplementedError, lambda: solve_poly_system(
[y-1,], (x, y)))
# solve_poly_system should ideally construct solutions using
# CRootOf for the following four tests
assert solve_poly_system([x**5 - x + 1], [x], strict=False) == []
raises(UnsolvableFactorError, lambda: solve_poly_system(
[x**5 - x + 1], [x], strict=True))
assert solve_poly_system([(x - 1)*(x**5 - x + 1), y**2 - 1], [x, y],
strict=False) == [(1, -1), (1, 1)]
raises(UnsolvableFactorError,
lambda: solve_poly_system([(x - 1)*(x**5 - x + 1), y**2-1],
[x, y], strict=True))
def test_solve_generic():
NewOption = Options((x, y), {'domain': 'ZZ'})
assert solve_generic([x**2 - 2*y**2, y**2 - y + 1], NewOption) == \
[(-sqrt(-1 - sqrt(3)*I), Rational(1, 2) - sqrt(3)*I/2),
(sqrt(-1 - sqrt(3)*I), Rational(1, 2) - sqrt(3)*I/2),
(-sqrt(-1 + sqrt(3)*I), Rational(1, 2) + sqrt(3)*I/2),
(sqrt(-1 + sqrt(3)*I), Rational(1, 2) + sqrt(3)*I/2)]
# solve_generic should ideally construct solutions using
# CRootOf for the following two tests
assert solve_generic(
[2*x - y, (y - 1)*(y**5 - y + 1)], NewOption, strict=False) == \
[(Rational(1, 2), 1)]
raises(UnsolvableFactorError, lambda: solve_generic(
[2*x - y, (y - 1)*(y**5 - y + 1)], NewOption, strict=True))
def test_solve_biquadratic():
x0, y0, x1, y1, r = symbols('x0 y0 x1 y1 r')
f_1 = (x - 1)**2 + (y - 1)**2 - r**2
f_2 = (x - 2)**2 + (y - 2)**2 - r**2
s = sqrt(2*r**2 - 1)
a = (3 - s)/2
b = (3 + s)/2
assert solve_poly_system([f_1, f_2], x, y) == [(a, b), (b, a)]
f_1 = (x - 1)**2 + (y - 2)**2 - r**2
f_2 = (x - 1)**2 + (y - 1)**2 - r**2
assert solve_poly_system([f_1, f_2], x, y) == \
[(1 - sqrt((2*r - 1)*(2*r + 1))/2, Rational(3, 2)),
(1 + sqrt((2*r - 1)*(2*r + 1))/2, Rational(3, 2))]
query = lambda expr: expr.is_Pow and expr.exp is S.Half
f_1 = (x - 1 )**2 + (y - 2)**2 - r**2
f_2 = (x - x1)**2 + (y - 1)**2 - r**2
result = solve_poly_system([f_1, f_2], x, y)
assert len(result) == 2 and all(len(r) == 2 for r in result)
assert all(r.count(query) == 1 for r in flatten(result))
f_1 = (x - x0)**2 + (y - y0)**2 - r**2
f_2 = (x - x1)**2 + (y - y1)**2 - r**2
result = solve_poly_system([f_1, f_2], x, y)
assert len(result) == 2 and all(len(r) == 2 for r in result)
assert all(len(r.find(query)) == 1 for r in flatten(result))
s1 = (x*y - y, x**2 - x)
assert solve(s1) == [{x: 1}, {x: 0, y: 0}]
s2 = (x*y - x, y**2 - y)
assert solve(s2) == [{y: 1}, {x: 0, y: 0}]
gens = (x, y)
for seq in (s1, s2):
(f, g), opt = parallel_poly_from_expr(seq, *gens)
raises(SolveFailed, lambda: solve_biquadratic(f, g, opt))
seq = (x**2 + y**2 - 2, y**2 - 1)
(f, g), opt = parallel_poly_from_expr(seq, *gens)
assert solve_biquadratic(f, g, opt) == [
(-1, -1), (-1, 1), (1, -1), (1, 1)]
ans = [(0, -1), (0, 1)]
seq = (x**2 + y**2 - 1, y**2 - 1)
(f, g), opt = parallel_poly_from_expr(seq, *gens)
assert solve_biquadratic(f, g, opt) == ans
seq = (x**2 + y**2 - 1, x**2 - x + y**2 - 1)
(f, g), opt = parallel_poly_from_expr(seq, *gens)
assert solve_biquadratic(f, g, opt) == ans
def test_solve_triangulated():
f_1 = x**2 + y + z - 1
f_2 = x + y**2 + z - 1
f_3 = x + y + z**2 - 1
a, b = sqrt(2) - 1, -sqrt(2) - 1
assert solve_triangulated([f_1, f_2, f_3], x, y, z) == \
[(0, 0, 1), (0, 1, 0), (1, 0, 0)]
dom = QQ.algebraic_field(sqrt(2))
assert solve_triangulated([f_1, f_2, f_3], x, y, z, domain=dom) == \
[(0, 0, 1), (0, 1, 0), (1, 0, 0), (a, a, a), (b, b, b)]
a, b = CRootOf(z**2 + 2*z - 1, 0), CRootOf(z**2 + 2*z - 1, 1)
assert solve_triangulated([f_1, f_2, f_3], x, y, z, extension=True) == \
[(0, 0, 1), (0, 1, 0), (1, 0, 0), (a, a, a), (b, b, b)]
def test_solve_issue_3686():
roots = solve_poly_system([((x - 5)**2/250000 + (y - Rational(5, 10))**2/250000) - 1, x], x, y)
assert roots == [(0, S.Half - 15*sqrt(1111)), (0, S.Half + 15*sqrt(1111))]
roots = solve_poly_system([((x - 5)**2/250000 + (y - 5.0/10)**2/250000) - 1, x], x, y)
# TODO: does this really have to be so complicated?!
assert len(roots) == 2
assert roots[0][0] == 0
assert roots[0][1].epsilon_eq(-499.474999374969, 1e12)
assert roots[1][0] == 0
assert roots[1][1].epsilon_eq(500.474999374969, 1e12)
def test_factor_system():
assert factor_system([x**2 + 2*x + 1]) == [[x + 1]]
assert factor_system([x**2 + 2*x + 1, y**2 + 2*y + 1]) == [[x + 1, y + 1]]
assert factor_system([x**2 + 1]) == [[x**2 + 1]]
assert factor_system([]) == [[]]
assert factor_system([x**2 + y**2 + 2*x*y, x**2 - 2], extension=sqrt(2)) == [
[x + y, x + sqrt(2)],
[x + y, x - sqrt(2)],
]
assert factor_system([x**2 + 1, y**2 + 1], gaussian=True) == [
[x + I, y + I],
[x + I, y - I],
[x - I, y + I],
[x - I, y - I],
]
assert factor_system([x**2 + 1, y**2 + 1], domain=QQ_I) == [
[x + I, y + I],
[x + I, y - I],
[x - I, y + I],
[x - I, y - I],
]
assert factor_system([0]) == [[]]
assert factor_system([1]) == []
assert factor_system([0 , x]) == [[x]]
assert factor_system([1, 0, x]) == []
assert factor_system([x**4 - 1, y**6 - 1]) == [
[x**2 + 1, y**2 + y + 1],
[x**2 + 1, y**2 - y + 1],
[x**2 + 1, y + 1],
[x**2 + 1, y - 1],
[x + 1, y**2 + y + 1],
[x + 1, y**2 - y + 1],
[x - 1, y**2 + y + 1],
[x - 1, y**2 - y + 1],
[x + 1, y + 1],
[x + 1, y - 1],
[x - 1, y + 1],
[x - 1, y - 1],
]
assert factor_system([(x - 1)*(y - 2), (y - 2)*(z - 3)]) == [
[x - 1, z - 3],
[y - 2]
]
assert factor_system([sin(x)**2 + cos(x)**2 - 1, x]) == [
[x, sin(x)**2 + cos(x)**2 - 1],
]
assert factor_system([sin(x)**2 + cos(x)**2 - 1]) == [
[sin(x)**2 + cos(x)**2 - 1]
]
assert factor_system([sin(x)**2 + cos(x)**2]) == [
[sin(x)**2 + cos(x)**2]
]
assert factor_system([a*x, y, a]) == [[y, a]]
assert factor_system([a*x, y, a], [x, y]) == []
assert factor_system([a ** 2 * x, y], [x, y]) == [[x, y]]
assert factor_system([a*x*(x - 1), b*y, c], [x, y]) == []
assert factor_system([a*x*(x - 1), b*y, c], [x, y, c]) == [
[x - 1, y, c],
[x, y, c],
]
assert factor_system([a*x*(x - 1), b*y, c]) == [
[x - 1, y, c],
[x, y, c],
[x - 1, b, c],
[x, b, c],
[y, a, c],
[a, b, c],
]
assert factor_system([x**2 - 2], [y]) == []
assert factor_system([x**2 - 2], [x]) == [[x**2 - 2]]
assert factor_system([cos(x)**2 - sin(x)**2, cos(x)**2 + sin(x)**2 - 1]) == [
[sin(x)**2 + cos(x)**2 - 1, sin(x) + cos(x)],
[sin(x)**2 + cos(x)**2 - 1, -sin(x) + cos(x)],
]
assert factor_system([(cos(x) + sin(x))**2 - 1, cos(x)**2 - sin(x)**2 - cos(2*x)]) == [
[sin(x)**2 - cos(x)**2 + cos(2*x), sin(x) + cos(x) + 1],
[sin(x)**2 - cos(x)**2 + cos(2*x), sin(x) + cos(x) - 1],
]
assert factor_system([(cos(x) + sin(x))*exp(y) - 1, (cos(x) - sin(x))*exp(y) - 1]) == [
[exp(y)*sin(x) + exp(y)*cos(x) - 1, -exp(y)*sin(x) + exp(y)*cos(x) - 1]
]
def test_factor_system_poly():
px = lambda e: Poly(e, x)
pxab = lambda e: Poly(e, x, domain=ZZ[a, b])
pxI = lambda e: Poly(e, x, domain=QQ_I)
pxyz = lambda e: Poly(e, (x, y, z))
assert factor_system_poly([px(x**2 - 1), px(x**2 - 4)]) == [
[px(x + 2), px(x + 1)],
[px(x + 2), px(x - 1)],
[px(x + 1), px(x - 2)],
[px(x - 1), px(x - 2)],
]
assert factor_system_poly([px(x**2 - 1)]) == [[px(x + 1)], [px(x - 1)]]
assert factor_system_poly([pxyz(x**2*y - y), pxyz(x**2*z - z)]) == [
[pxyz(x + 1)],
[pxyz(x - 1)],
[pxyz(y), pxyz(z)],
]
assert factor_system_poly([px(x**2*(x - 1)**2), px(x*(x - 1))]) == [
[px(x)],
[px(x - 1)],
]
assert factor_system_poly([pxyz(x**2 + y*x), pxyz(x**2 + z*x)]) == [
[pxyz(x + y), pxyz(x + z)],
[pxyz(x)],
]
assert factor_system_poly([pxab((a - 1)*(x - 2)), pxab((b - 3)*(x - 2))]) == [
[pxab(x - 2)],
[pxab(a - 1), pxab(b - 3)],
]
assert factor_system_poly([pxI(x**2 + 1)]) == [[pxI(x + I)], [pxI(x - I)]]
assert factor_system_poly([]) == [[]]
assert factor_system_poly([px(1)]) == []
assert factor_system_poly([px(0), px(x)]) == [[px(x)]]
def test_factor_system_cond():
assert factor_system_cond([x ** 2 - 1, x ** 2 - 4]) == [
[x + 2, x + 1],
[x + 2, x - 1],
[x + 1, x - 2],
[x - 1, x - 2],
]
assert factor_system_cond([1]) == []
assert factor_system_cond([0]) == [[]]
assert factor_system_cond([1, x]) == []
assert factor_system_cond([0, x]) == [[x]]
assert factor_system_cond([]) == [[]]
assert factor_system_cond([x**2 + y*x]) == [[x + y], [x]]
assert factor_system_cond([(a - 1)*(x - 2), (b - 3)*(x - 2)], [x]) == [
[x - 2],
[a - 1, b - 3],
]
assert factor_system_cond([a * (x - 1), b], [x]) == [[x - 1, b], [a, b]]
assert factor_system_cond([a*x*(x-1), b*y, c], [x, y]) == [
[x - 1, y, c],
[x, y, c],
[x - 1, b, c],
[x, b, c],
[y, a, c],
[a, b, c],
]
assert factor_system_cond([x*(x-1), y], [x, y]) == [[x - 1, y], [x, y]]
assert factor_system_cond([a*x, y, a], [x, y]) == [[y, a]]
assert factor_system_cond([a*x, b*x], [x, y]) == [[x], [a, b]]
assert factor_system_cond([a*b*x, y], [x, y]) == [[x, y], [y, a*b]]
assert factor_system_cond([a*b*x, y]) == [[x, y], [y, a], [y, b]]
assert factor_system_cond([a**2*x, y], [x, y]) == [[x, y], [y, a]]
def test_factor_system_bool():
eqs = [a*(x - 1)*(y - 1), b*(x - 2)*(y - 1)*(y - 2)]
assert factor_system_bool(eqs, [x, y]) == (
Eq(y - 1, 0)
| (Eq(a, 0) & Eq(b, 0))
| (Eq(a, 0) & Eq(x - 2, 0))
| (Eq(a, 0) & Eq(y - 2, 0))
| (Eq(b, 0) & Eq(x - 1, 0))
| (Eq(x - 2, 0) & Eq(x - 1, 0))
| (Eq(x - 1, 0) & Eq(y - 2, 0))
)
assert factor_system_bool([x - 1], [x]) == Eq(x - 1, 0)
assert factor_system_bool([(x - 1)*(x - 2)], [x]) == Eq(x - 2, 0) | Eq(x - 1, 0)
assert factor_system_bool([], [x]) == True
assert factor_system_bool([0], [x]) == True
assert factor_system_bool([1], [x]) == False
assert factor_system_bool([a], [x]) == Eq(a, 0)
assert factor_system_bool([a * x, y, a], [x, y]) == Eq(a, 0) & Eq(y, 0)
assert (factor_system_bool([a*x, b*y*x, a], [x, y]) == (
Eq(a, 0) & Eq(b, 0))
| (Eq(a, 0) & Eq(x, 0))
| (Eq(a, 0) & Eq(y, 0)))
assert (factor_system_bool([a*x, b*x], [x, y]) == Eq(x, 0) |
(Eq(a, 0) & Eq(b, 0)))
assert (factor_system_bool([a*b*x, y], [x, y]) == (
Eq(x, 0) & Eq(y, 0)) |
(Eq(y, 0) & Eq(a*b, 0)))
assert (factor_system_bool([a**2*x, y], [x, y]) == (
Eq(a, 0) & Eq(y, 0)) |
(Eq(x, 0) & Eq(y, 0)))
assert factor_system_bool([a*x*y, b*y*z], [x, y, z]) == (
Eq(y, 0)
| (Eq(a, 0) & Eq(b, 0))
| (Eq(a, 0) & Eq(z, 0))
| (Eq(b, 0) & Eq(x, 0))
| (Eq(x, 0) & Eq(z, 0))
)
assert factor_system_bool([a*(x - 1), b], [x]) == (
(Eq(a, 0) & Eq(b, 0))
| (Eq(x - 1, 0) & Eq(b, 0))
)
def test_factor_sets():
#
from random import randint
def generate_random_system(n_eqs=3, n_factors=2, max_val=10):
return [
[randint(0, max_val) for _ in range(randint(1, n_factors))]
for _ in range(n_eqs)
]
test_cases = [
[[1, 2], [1, 3]],
[[1, 2], [3, 4]],
[[1], [1, 2], [2]],
]
for case in test_cases:
assert _factor_sets(case) == _factor_sets_slow(case)
for _ in range(100):
system = generate_random_system()
assert _factor_sets(system) == _factor_sets_slow(system)
@@ -0,0 +1,295 @@
from sympy.core.function import (Function, Lambda, expand)
from sympy.core.numbers import (I, Rational)
from sympy.core.relational import Eq
from sympy.core.singleton import S
from sympy.core.symbol import (Symbol, symbols)
from sympy.functions.combinatorial.factorials import (rf, binomial, factorial)
from sympy.functions.elementary.complexes import Abs
from sympy.functions.elementary.miscellaneous import sqrt
from sympy.functions.elementary.trigonometric import (cos, sin)
from sympy.polys.polytools import factor
from sympy.solvers.recurr import rsolve, rsolve_hyper, rsolve_poly, rsolve_ratio
from sympy.testing.pytest import raises, slow, XFAIL
from sympy.abc import a, b
y = Function('y')
n, k = symbols('n,k', integer=True)
C0, C1, C2 = symbols('C0,C1,C2')
def test_rsolve_poly():
assert rsolve_poly([-1, -1, 1], 0, n) == 0
assert rsolve_poly([-1, -1, 1], 1, n) == -1
assert rsolve_poly([-1, n + 1], n, n) == 1
assert rsolve_poly([-1, 1], n, n) == C0 + (n**2 - n)/2
assert rsolve_poly([-n - 1, n], 1, n) == C0*n - 1
assert rsolve_poly([-4*n - 2, 1], 4*n + 1, n) == -1
assert rsolve_poly([-1, 1], n**5 + n**3, n) == \
C0 - n**3 / 2 - n**5 / 2 + n**2 / 6 + n**6 / 6 + 2*n**4 / 3
def test_rsolve_ratio():
solution = rsolve_ratio([-2*n**3 + n**2 + 2*n - 1, 2*n**3 + n**2 - 6*n,
-2*n**3 - 11*n**2 - 18*n - 9, 2*n**3 + 13*n**2 + 22*n + 8], 0, n)
assert solution == C0*(2*n - 3)/(n**2 - 1)/2
def test_rsolve_hyper():
assert rsolve_hyper([-1, -1, 1], 0, n) in [
C0*(S.Half - S.Half*sqrt(5))**n + C1*(S.Half + S.Half*sqrt(5))**n,
C1*(S.Half - S.Half*sqrt(5))**n + C0*(S.Half + S.Half*sqrt(5))**n,
]
assert rsolve_hyper([n**2 - 2, -2*n - 1, 1], 0, n) in [
C0*rf(sqrt(2), n) + C1*rf(-sqrt(2), n),
C1*rf(sqrt(2), n) + C0*rf(-sqrt(2), n),
]
assert rsolve_hyper([n**2 - k, -2*n - 1, 1], 0, n) in [
C0*rf(sqrt(k), n) + C1*rf(-sqrt(k), n),
C1*rf(sqrt(k), n) + C0*rf(-sqrt(k), n),
]
assert rsolve_hyper(
[2*n*(n + 1), -n**2 - 3*n + 2, n - 1], 0, n) == C1*factorial(n) + C0*2**n
assert rsolve_hyper(
[n + 2, -(2*n + 3)*(17*n**2 + 51*n + 39), n + 1], 0, n) == 0
assert rsolve_hyper([-n - 1, -1, 1], 0, n) == 0
assert rsolve_hyper([-1, 1], n, n).expand() == C0 + n**2/2 - n/2
assert rsolve_hyper([-1, 1], 1 + n, n).expand() == C0 + n**2/2 + n/2
assert rsolve_hyper([-1, 1], 3*(n + n**2), n).expand() == C0 + n**3 - n
assert rsolve_hyper([-a, 1],0,n).expand() == C0*a**n
assert rsolve_hyper([-a, 0, 1], 0, n).expand() == (-1)**n*C1*a**(n/2) + C0*a**(n/2)
assert rsolve_hyper([1, 1, 1], 0, n).expand() == \
C0*(Rational(-1, 2) - sqrt(3)*I/2)**n + C1*(Rational(-1, 2) + sqrt(3)*I/2)**n
assert rsolve_hyper([1, -2*n/a - 2/a, 1], 0, n) == 0
@XFAIL
def test_rsolve_ratio_missed():
# this arises during computation
# assert rsolve_hyper([-1, 1], 3*(n + n**2), n).expand() == C0 + n**3 - n
assert rsolve_ratio([-n, n + 2], n, n) is not None
def recurrence_term(c, f):
"""Compute RHS of recurrence in f(n) with coefficients in c."""
return sum(c[i]*f.subs(n, n + i) for i in range(len(c)))
def test_rsolve_bulk():
"""Some bulk-generated tests."""
funcs = [ n, n + 1, n**2, n**3, n**4, n + n**2, 27*n + 52*n**2 - 3*
n**3 + 12*n**4 - 52*n**5 ]
coeffs = [ [-2, 1], [-2, -1, 1], [-1, 1, 1, -1, 1], [-n, 1], [n**2 -
n + 12, 1] ]
for p in funcs:
# compute difference
for c in coeffs:
q = recurrence_term(c, p)
if p.is_polynomial(n):
assert rsolve_poly(c, q, n) == p
# See issue 3956:
if p.is_hypergeometric(n) and len(c) <= 3:
assert rsolve_hyper(c, q, n).subs(zip(symbols('C:3'), [0, 0, 0])).expand() == p
def test_rsolve_0_sol_homogeneous():
# fixed by cherry-pick from
# https://github.com/diofant/diofant/commit/e1d2e52125199eb3df59f12e8944f8a5f24b00a5
assert rsolve_hyper([n**2 - n + 12, 1], n*(n**2 - n + 12) + n + 1, n) == n
def test_rsolve():
f = y(n + 2) - y(n + 1) - y(n)
h = sqrt(5)*(S.Half + S.Half*sqrt(5))**n \
- sqrt(5)*(S.Half - S.Half*sqrt(5))**n
assert rsolve(f, y(n)) in [
C0*(S.Half - S.Half*sqrt(5))**n + C1*(S.Half + S.Half*sqrt(5))**n,
C1*(S.Half - S.Half*sqrt(5))**n + C0*(S.Half + S.Half*sqrt(5))**n,
]
assert rsolve(f, y(n), [0, 5]) == h
assert rsolve(f, y(n), {0: 0, 1: 5}) == h
assert rsolve(f, y(n), {y(0): 0, y(1): 5}) == h
assert rsolve(y(n) - y(n - 1) - y(n - 2), y(n), [0, 5]) == h
assert rsolve(Eq(y(n), y(n - 1) + y(n - 2)), y(n), [0, 5]) == h
assert f.subs(y, Lambda(k, rsolve(f, y(n)).subs(n, k))).simplify() == 0
f = (n - 1)*y(n + 2) - (n**2 + 3*n - 2)*y(n + 1) + 2*n*(n + 1)*y(n)
g = C1*factorial(n) + C0*2**n
h = -3*factorial(n) + 3*2**n
assert rsolve(f, y(n)) == g
assert rsolve(f, y(n), []) == g
assert rsolve(f, y(n), {}) == g
assert rsolve(f, y(n), [0, 3]) == h
assert rsolve(f, y(n), {0: 0, 1: 3}) == h
assert rsolve(f, y(n), {y(0): 0, y(1): 3}) == h
assert f.subs(y, Lambda(k, rsolve(f, y(n)).subs(n, k))).simplify() == 0
f = y(n) - y(n - 1) - 2
assert rsolve(f, y(n), {y(0): 0}) == 2*n
assert rsolve(f, y(n), {y(0): 1}) == 2*n + 1
assert rsolve(f, y(n), {y(0): 0, y(1): 1}) is None
assert f.subs(y, Lambda(k, rsolve(f, y(n)).subs(n, k))).simplify() == 0
f = 3*y(n - 1) - y(n) - 1
assert rsolve(f, y(n), {y(0): 0}) == -3**n/2 + S.Half
assert rsolve(f, y(n), {y(0): 1}) == 3**n/2 + S.Half
assert rsolve(f, y(n), {y(0): 2}) == 3*3**n/2 + S.Half
assert f.subs(y, Lambda(k, rsolve(f, y(n)).subs(n, k))).simplify() == 0
f = y(n) - 1/n*y(n - 1)
assert rsolve(f, y(n)) == C0/factorial(n)
assert f.subs(y, Lambda(k, rsolve(f, y(n)).subs(n, k))).simplify() == 0
f = y(n) - 1/n*y(n - 1) - 1
assert rsolve(f, y(n)) is None
f = 2*y(n - 1) + (1 - n)*y(n)/n
assert rsolve(f, y(n), {y(1): 1}) == 2**(n - 1)*n
assert rsolve(f, y(n), {y(1): 2}) == 2**(n - 1)*n*2
assert rsolve(f, y(n), {y(1): 3}) == 2**(n - 1)*n*3
assert f.subs(y, Lambda(k, rsolve(f, y(n)).subs(n, k))).simplify() == 0
f = (n - 1)*(n - 2)*y(n + 2) - (n + 1)*(n + 2)*y(n)
assert rsolve(f, y(n), {y(3): 6, y(4): 24}) == n*(n - 1)*(n - 2)
assert rsolve(
f, y(n), {y(3): 6, y(4): -24}) == -n*(n - 1)*(n - 2)*(-1)**(n)
assert f.subs(y, Lambda(k, rsolve(f, y(n)).subs(n, k))).simplify() == 0
assert rsolve(Eq(y(n + 1), a*y(n)), y(n), {y(1): a}).simplify() == a**n
assert rsolve(y(n) - a*y(n-2),y(n), \
{y(1): sqrt(a)*(a + b), y(2): a*(a - b)}).simplify() == \
a**(n/2 + 1) - b*(-sqrt(a))**n
f = (-16*n**2 + 32*n - 12)*y(n - 1) + (4*n**2 - 12*n + 9)*y(n)
yn = rsolve(f, y(n), {y(1): binomial(2*n + 1, 3)})
sol = 2**(2*n)*n*(2*n - 1)**2*(2*n + 1)/12
assert factor(expand(yn, func=True)) == sol
sol = rsolve(y(n) + a*(y(n + 1) + y(n - 1))/2, y(n))
assert str(sol) == 'C0*((-sqrt(1 - a**2) - 1)/a)**n + C1*((sqrt(1 - a**2) - 1)/a)**n'
assert rsolve((k + 1)*y(k), y(k)) is None
assert (rsolve((k + 1)*y(k) + (k + 3)*y(k + 1) + (k + 5)*y(k + 2), y(k))
is None)
assert rsolve(y(n) + y(n + 1) + 2**n + 3**n, y(n)) == (-1)**n*C0 - 2**n/3 - 3**n/4
def test_rsolve_raises():
x = Function('x')
raises(ValueError, lambda: rsolve(y(n) - y(k + 1), y(n)))
raises(ValueError, lambda: rsolve(y(n) - y(n + 1), x(n)))
raises(ValueError, lambda: rsolve(y(n) - x(n + 1), y(n)))
raises(ValueError, lambda: rsolve(y(n) - sqrt(n)*y(n + 1), y(n)))
raises(ValueError, lambda: rsolve(y(n) - y(n + 1), y(n), {x(0): 0}))
raises(ValueError, lambda: rsolve(y(n) + y(n + 1) + 2**n + cos(n), y(n)))
def test_issue_6844():
f = y(n + 2) - y(n + 1) + y(n)/4
assert rsolve(f, y(n)) == 2**(-n + 1)*C1*n + 2**(-n)*C0
assert rsolve(f, y(n), {y(0): 0, y(1): 1}) == 2**(1 - n)*n
def test_issue_18751():
r = Symbol('r', positive=True)
theta = Symbol('theta', real=True)
f = y(n) - 2 * r * cos(theta) * y(n - 1) + r**2 * y(n - 2)
assert rsolve(f, y(n)) == \
C0*(r*(cos(theta) - I*Abs(sin(theta))))**n + C1*(r*(cos(theta) + I*Abs(sin(theta))))**n
def test_constant_naming():
#issue 8697
assert rsolve(y(n+3) - y(n+2) - y(n+1) + y(n), y(n)) == (-1)**n*C1 + C0 + C2*n
assert rsolve(y(n+3)+3*y(n+2)+3*y(n+1)+y(n), y(n)).expand() == (-1)**n*C0 - (-1)**n*C1*n - (-1)**n*C2*n**2
assert rsolve(y(n) - 2*y(n - 3) + 5*y(n - 2) - 4*y(n - 1),y(n),[1,3,8]) == 3*2**n - n - 2
#issue 19630
assert rsolve(y(n+3) - 3*y(n+1) + 2*y(n), y(n), {y(1):0, y(2):8, y(3):-2}) == (-2)**n + 2*n
@slow
def test_issue_15751():
f = y(n) + 21*y(n + 1) - 273*y(n + 2) - 1092*y(n + 3) + 1820*y(n + 4) + 1092*y(n + 5) - 273*y(n + 6) - 21*y(n + 7) + y(n + 8)
assert rsolve(f, y(n)) is not None
def test_issue_17990():
f = -10*y(n) + 4*y(n + 1) + 6*y(n + 2) + 46*y(n + 3)
sol = rsolve(f, y(n))
expected = C0*((86*18**(S(1)/3)/69 + (-12 + (-1 + sqrt(3)*I)*(290412 +
3036*sqrt(9165))**(S(1)/3))*(1 - sqrt(3)*I)*(24201 + 253*sqrt(9165))**
(S(1)/3)/276)/((1 - sqrt(3)*I)*(24201 + 253*sqrt(9165))**(S(1)/3))
)**n + C1*((86*18**(S(1)/3)/69 + (-12 + (-1 - sqrt(3)*I)*(290412 + 3036
*sqrt(9165))**(S(1)/3))*(1 + sqrt(3)*I)*(24201 + 253*sqrt(9165))**
(S(1)/3)/276)/((1 + sqrt(3)*I)*(24201 + 253*sqrt(9165))**(S(1)/3))
)**n + C2*(-43*18**(S(1)/3)/(69*(24201 + 253*sqrt(9165))**(S(1)/3)) -
S(1)/23 + (290412 + 3036*sqrt(9165))**(S(1)/3)/138)**n
assert sol == expected
e = sol.subs({C0: 1, C1: 1, C2: 1, n: 1}).evalf()
assert abs(e + 0.130434782608696) < 1e-13
def test_issue_8697():
a = Function('a')
eq = a(n + 3) - a(n + 2) - a(n + 1) + a(n)
assert rsolve(eq, a(n)) == (-1)**n*C1 + C0 + C2*n
eq2 = a(n + 3) + 3*a(n + 2) + 3*a(n + 1) + a(n)
assert (rsolve(eq2, a(n)) ==
(-1)**n*C0 + (-1)**(n + 1)*C1*n + (-1)**(n + 1)*C2*n**2)
assert rsolve(a(n) - 2*a(n - 3) + 5*a(n - 2) - 4*a(n - 1),
a(n), {a(0): 1, a(1): 3, a(2): 8}) == 3*2**n - n - 2
# From issue thread (but fixed by https://github.com/diofant/diofant/commit/da9789c6cd7d0c2ceeea19fbf59645987125b289):
assert rsolve(a(n) - 2*a(n - 1) - n, a(n), {a(0): 1}) == 3*2**n - n - 2
def test_diofantissue_294():
f = y(n) - y(n - 1) - 2*y(n - 2) - 2*n
assert rsolve(f, y(n)) == (-1)**n*C0 + 2**n*C1 - n - Rational(5, 2)
# issue sympy/sympy#11261
assert rsolve(f, y(n), {y(0): -1, y(1): 1}) == (-(-1)**n/2 + 2*2**n -
n - Rational(5, 2))
# issue sympy/sympy#7055
assert rsolve(-2*y(n) + y(n + 1) + n - 1, y(n)) == 2**n*C0 + n
def test_issue_15553():
f = Function("f")
assert rsolve(Eq(f(n), 2*f(n - 1) + n), f(n)) == 2**n*C0 - n - 2
assert rsolve(Eq(f(n + 1), 2*f(n) + n**2 + 1), f(n)) == 2**n*C0 - n**2 - 2*n - 4
assert rsolve(Eq(f(n + 1), 2*f(n) + n**2 + 1), f(n), {f(1): 0}) == 7*2**n/2 - n**2 - 2*n - 4
assert rsolve(Eq(f(n), 2*f(n - 1) + 3*n**2), f(n)) == 2**n*C0 - 3*n**2 - 12*n - 18
assert rsolve(Eq(f(n), 2*f(n - 1) + n**2), f(n)) == 2**n*C0 - n**2 - 4*n - 6
assert rsolve(Eq(f(n), 2*f(n - 1) + n), f(n), {f(0): 1}) == 3*2**n - n - 2
@@ -0,0 +1,254 @@
from sympy.core.numbers import Rational
from sympy.core.relational import Eq, Ne
from sympy.core.symbol import symbols
from sympy.core.sympify import sympify
from sympy.core.singleton import S
from sympy.core.random import random, choice
from sympy.functions.elementary.miscellaneous import sqrt
from sympy.ntheory.generate import randprime
from sympy.matrices.dense import Matrix
from sympy.solvers.solveset import linear_eq_to_matrix
from sympy.solvers.simplex import (_lp as lp, _primal_dual,
UnboundedLPError, InfeasibleLPError, lpmin, lpmax,
_m, _abcd, _simplex, linprog)
from sympy.external.importtools import import_module
from sympy.testing.pytest import raises
from sympy.abc import x, y, z
np = import_module("numpy")
scipy = import_module("scipy")
def test_lp():
r1 = y + 2*z <= 3
r2 = -x - 3*z <= -2
r3 = 2*x + y + 7*z <= 5
constraints = [r1, r2, r3, x >= 0, y >= 0, z >= 0]
objective = -x - y - 5 * z
ans = optimum, argmax = lp(max, objective, constraints)
assert ans == lpmax(objective, constraints)
assert objective.subs(argmax) == optimum
for constr in constraints:
assert constr.subs(argmax) == True
r1 = x - y + 2*z <= 3
r2 = -x + 2*y - 3*z <= -2
r3 = 2*x + y - 7*z <= -5
constraints = [r1, r2, r3, x >= 0, y >= 0, z >= 0]
objective = -x - y - 5*z
ans = optimum, argmax = lp(max, objective, constraints)
assert ans == lpmax(objective, constraints)
assert objective.subs(argmax) == optimum
for constr in constraints:
assert constr.subs(argmax) == True
r1 = x - y + 2*z <= -4
r2 = -x + 2*y - 3*z <= 8
r3 = 2*x + y - 7*z <= 10
constraints = [r1, r2, r3, x >= 0, y >= 0, z >= 0]
const = 2
objective = -x-y-5*z+const # has constant term
ans = optimum, argmax = lp(max, objective, constraints)
assert ans == lpmax(objective, constraints)
assert objective.subs(argmax) == optimum
for constr in constraints:
assert constr.subs(argmax) == True
# Section 4 Problem 1 from
# http://web.tecnico.ulisboa.pt/mcasquilho/acad/or/ftp/FergusonUCLA_LP.pdf
# answer on page 55
v = x1, x2, x3, x4 = symbols('x1 x2 x3 x4')
r1 = x1 - x2 - 2*x3 - x4 <= 4
r2 = 2*x1 + x3 -4*x4 <= 2
r3 = -2*x1 + x2 + x4 <= 1
objective, constraints = x1 - 2*x2 - 3*x3 - x4, [r1, r2, r3] + [
i >= 0 for i in v]
ans = optimum, argmax = lp(max, objective, constraints)
assert ans == lpmax(objective, constraints)
assert ans == (4, {x1: 7, x2: 0, x3: 0, x4: 3})
# input contains Floats
r1 = x - y + 2.0*z <= -4
r2 = -x + 2*y - 3.0*z <= 8
r3 = 2*x + y - 7*z <= 10
constraints = [r1, r2, r3] + [i >= 0 for i in (x, y, z)]
objective = -x-y-5*z
optimum, argmax = lp(max, objective, constraints)
assert objective.subs(argmax) == optimum
for constr in constraints:
assert constr.subs(argmax) == True
# input contains non-float or non-Rational
r1 = x - y + sqrt(2) * z <= -4
r2 = -x + 2*y - 3*z <= 8
r3 = 2*x + y - 7*z <= 10
raises(TypeError, lambda: lp(max, -x-y-5*z, [r1, r2, r3]))
r1 = x >= 0
raises(UnboundedLPError, lambda: lp(max, x, [r1]))
r2 = x <= -1
raises(InfeasibleLPError, lambda: lp(max, x, [r1, r2]))
# strict inequalities are not allowed
r1 = x > 0
raises(TypeError, lambda: lp(max, x, [r1]))
# not equals not allowed
r1 = Ne(x, 0)
raises(TypeError, lambda: lp(max, x, [r1]))
def make_random_problem(nvar=2, num_constraints=2, sparsity=.1):
def rand():
if random() < sparsity:
return sympify(0)
int1, int2 = [randprime(0, 200) for _ in range(2)]
return Rational(int1, int2)*choice([-1, 1])
variables = symbols('x1:%s' % (nvar + 1))
constraints = [(sum(rand()*x for x in variables) <= rand())
for _ in range(num_constraints)]
objective = sum(rand() * x for x in variables)
return objective, constraints, variables
# equality
r1 = Eq(x, y)
r2 = Eq(y, z)
r3 = z <= 3
constraints = [r1, r2, r3]
objective = x
ans = optimum, argmax = lp(max, objective, constraints)
assert ans == lpmax(objective, constraints)
assert objective.subs(argmax) == optimum
for constr in constraints:
assert constr.subs(argmax) == True
def test_simplex():
L = [
[[1, 1], [-1, 1], [0, 1], [-1, 0]],
[5, 1, 2, -1],
[[1, 1]],
[-1]]
A, B, C, D = _abcd(_m(*L), list=False)
assert _simplex(A, B, -C, -D) == (-6, [3, 2], [1, 0, 0, 0])
assert _simplex(A, B, -C, -D, dual=True) == (-6,
[1, 0, 0, 0], [5, 0])
assert _simplex([[]],[],[[1]],[0]) == (0, [0], [])
# handling of Eq (or Eq-like x<=y, x>=y conditions)
assert lpmax(x - y, [x <= y + 2, x >= y + 2, x >= 0, y >= 0]
) == (2, {x: 2, y: 0})
assert lpmax(x - y, [x <= y + 2, Eq(x, y + 2), x >= 0, y >= 0]
) == (2, {x: 2, y: 0})
assert lpmax(x - y, [x <= y + 2, Eq(x, 2)]) == (2, {x: 2, y: 0})
assert lpmax(y, [Eq(y, 2)]) == (2, {y: 2})
# the conditions are equivalent to Eq(x, y + 2)
assert lpmin(y, [x <= y + 2, x >= y + 2, y >= 0]
) == (0, {x: 2, y: 0})
# equivalent to Eq(y, -2)
assert lpmax(y, [0 <= y + 2, 0 >= y + 2]) == (-2, {y: -2})
assert lpmax(y, [0 <= y + 2, 0 >= y + 2, y <= 0]
) == (-2, {y: -2})
# extra symbols symbols
assert lpmin(x, [y >= 1, x >= y]) == (1, {x: 1, y: 1})
assert lpmin(x, [y >= 1, x >= y + z, x >= 0, z >= 0]
) == (1, {x: 1, y: 1, z: 0})
# detect oscillation
# o1
v = x1, x2, x3, x4 = symbols('x1 x2 x3 x4')
raises(InfeasibleLPError, lambda: lpmin(
9*x2 - 8*x3 + 3*x4 + 6,
[5*x2 - 2*x3 <= 0,
-x1 - 8*x2 + 9*x3 <= -3,
10*x1 - x2+ 9*x4 <= -4] + [i >= 0 for i in v]))
# o2 - equations fed to lpmin are changed into a matrix
# system that doesn't oscillate and has the same solution
# as below
M = linear_eq_to_matrix
f = 5*x2 + x3 + 4*x4 - x1
L = 5*x2 + 2*x3 + 5*x4 - (x1 + 5)
cond = [L <= 0] + [Eq(3*x2 + x4, 2), Eq(-x1 + x3 + 2*x4, 1)]
c, d = M(f, v)
a, b = M(L, v)
aeq, beq = M(cond[1:], v)
ans = (S(9)/2, [0, S(1)/2, 0, S(1)/2])
assert linprog(c, a, b, aeq, beq, bounds=(0, 1)) == ans
lpans = lpmin(f, cond + [x1 >= 0, x1 <= 1,
x2 >= 0, x2 <= 1, x3 >= 0, x3 <= 1, x4 >= 0, x4 <= 1])
assert (lpans[0], list(lpans[1].values())) == ans
def test_lpmin_lpmax():
v = x1, x2, y1, y2 = symbols('x1 x2 y1 y2')
L = [[1, -1]], [1], [[1, 1]], [2]
a, b, c, d = [Matrix(i) for i in L]
m = Matrix([[a, b], [c, d]])
f, constr = _primal_dual(m)[0]
ans = lpmin(f, constr + [i >= 0 for i in v[:2]])
assert ans == (-1, {x1: 1, x2: 0}),ans
L = [[1, -1], [1, 1]], [1, 1], [[1, 1]], [2]
a, b, c, d = [Matrix(i) for i in L]
m = Matrix([[a, b], [c, d]])
f, constr = _primal_dual(m)[1]
ans = lpmax(f, constr + [i >= 0 for i in v[-2:]])
assert ans == (-1, {y1: 1, y2: 0})
def test_linprog():
for do in range(2):
if not do:
M = lambda a, b: linear_eq_to_matrix(a, b)
else:
# check matrices as list
M = lambda a, b: tuple([
i.tolist() for i in linear_eq_to_matrix(a, b)])
v = x, y, z = symbols('x1:4')
f = x + y - 2*z
c = M(f, v)[0]
ineq = [7*x + 4*y - 7*z <= 3,
3*x - y + 10*z <= 6,
x >= 0, y >= 0, z >= 0]
ab = M([i.lts - i.gts for i in ineq], v)
ans = (-S(6)/5, [0, 0, S(3)/5])
assert lpmin(f, ineq) == (ans[0], dict(zip(v, ans[1])))
assert linprog(c, *ab) == ans
f += 1
c = M(f, v)[0]
eq = [Eq(y - 9*x, 1)]
abeq = M([i.lhs - i.rhs for i in eq], v)
ans = (1 - S(2)/5, [0, 1, S(7)/10])
assert lpmin(f, ineq + eq) == (ans[0], dict(zip(v, ans[1])))
assert linprog(c, *ab, *abeq) == (ans[0] - 1, ans[1])
eq = [z - y <= S.Half]
abeq = M([i.lhs - i.rhs for i in eq], v)
ans = (1 - S(10)/9, [0, S(1)/9, S(11)/18])
assert lpmin(f, ineq + eq) == (ans[0], dict(zip(v, ans[1])))
assert linprog(c, *ab, *abeq) == (ans[0] - 1, ans[1])
bounds = [(0, None), (0, None), (None, S.Half)]
ans = (0, [0, 0, S.Half])
assert lpmin(f, ineq + [z <= S.Half]) == (
ans[0], dict(zip(v, ans[1])))
assert linprog(c, *ab, bounds=bounds) == (ans[0] - 1, ans[1])
assert linprog(c, *ab, bounds={v.index(z): bounds[-1]}
) == (ans[0] - 1, ans[1])
eq = [z - y <= S.Half]
assert linprog([[1]], [], [], bounds=(2, 3)) == (2, [2])
assert linprog([1], [], [], bounds=(2, 3)) == (2, [2])
assert linprog([1], bounds=(2, 3)) == (2, [2])
assert linprog([1, -1], [[1, 1]], [2], bounds={1:(None, None)}
) == (-2, [0, 2])
assert linprog([1, -1], [[1, 1]], [5], bounds={1:(3, None)}
) == (-5, [0, 5])
File diff suppressed because it is too large Load Diff
File diff suppressed because it is too large Load Diff