switching to high quality piper tts and added label translations
This commit is contained in:
@@ -0,0 +1,426 @@
|
||||
"""
|
||||
Joint Random Variables Module
|
||||
|
||||
See Also
|
||||
========
|
||||
sympy.stats.rv
|
||||
sympy.stats.frv
|
||||
sympy.stats.crv
|
||||
sympy.stats.drv
|
||||
"""
|
||||
from math import prod
|
||||
|
||||
from sympy.core.basic import Basic
|
||||
from sympy.core.function import Lambda
|
||||
from sympy.core.singleton import S
|
||||
from sympy.core.symbol import (Dummy, Symbol)
|
||||
from sympy.core.sympify import sympify
|
||||
from sympy.sets.sets import ProductSet
|
||||
from sympy.tensor.indexed import Indexed
|
||||
from sympy.concrete.products import Product
|
||||
from sympy.concrete.summations import Sum, summation
|
||||
from sympy.core.containers import Tuple
|
||||
from sympy.integrals.integrals import Integral, integrate
|
||||
from sympy.matrices import ImmutableMatrix, matrix2numpy, list2numpy
|
||||
from sympy.stats.crv import SingleContinuousDistribution, SingleContinuousPSpace
|
||||
from sympy.stats.drv import SingleDiscreteDistribution, SingleDiscretePSpace
|
||||
from sympy.stats.rv import (ProductPSpace, NamedArgsMixin, Distribution,
|
||||
ProductDomain, RandomSymbol, random_symbols,
|
||||
SingleDomain, _symbol_converter)
|
||||
from sympy.utilities.iterables import iterable
|
||||
from sympy.utilities.misc import filldedent
|
||||
from sympy.external import import_module
|
||||
|
||||
# __all__ = ['marginal_distribution']
|
||||
|
||||
class JointPSpace(ProductPSpace):
|
||||
"""
|
||||
Represents a joint probability space. Represented using symbols for
|
||||
each component and a distribution.
|
||||
"""
|
||||
def __new__(cls, sym, dist):
|
||||
if isinstance(dist, SingleContinuousDistribution):
|
||||
return SingleContinuousPSpace(sym, dist)
|
||||
if isinstance(dist, SingleDiscreteDistribution):
|
||||
return SingleDiscretePSpace(sym, dist)
|
||||
sym = _symbol_converter(sym)
|
||||
return Basic.__new__(cls, sym, dist)
|
||||
|
||||
@property
|
||||
def set(self):
|
||||
return self.domain.set
|
||||
|
||||
@property
|
||||
def symbol(self):
|
||||
return self.args[0]
|
||||
|
||||
@property
|
||||
def distribution(self):
|
||||
return self.args[1]
|
||||
|
||||
@property
|
||||
def value(self):
|
||||
return JointRandomSymbol(self.symbol, self)
|
||||
|
||||
@property
|
||||
def component_count(self):
|
||||
_set = self.distribution.set
|
||||
if isinstance(_set, ProductSet):
|
||||
return S(len(_set.args))
|
||||
elif isinstance(_set, Product):
|
||||
return _set.limits[0][-1]
|
||||
return S.One
|
||||
|
||||
@property
|
||||
def pdf(self):
|
||||
sym = [Indexed(self.symbol, i) for i in range(self.component_count)]
|
||||
return self.distribution(*sym)
|
||||
|
||||
@property
|
||||
def domain(self):
|
||||
rvs = random_symbols(self.distribution)
|
||||
if not rvs:
|
||||
return SingleDomain(self.symbol, self.distribution.set)
|
||||
return ProductDomain(*[rv.pspace.domain for rv in rvs])
|
||||
|
||||
def component_domain(self, index):
|
||||
return self.set.args[index]
|
||||
|
||||
def marginal_distribution(self, *indices):
|
||||
count = self.component_count
|
||||
if count.atoms(Symbol):
|
||||
raise ValueError("Marginal distributions cannot be computed "
|
||||
"for symbolic dimensions. It is a work under progress.")
|
||||
orig = [Indexed(self.symbol, i) for i in range(count)]
|
||||
all_syms = [Symbol(str(i)) for i in orig]
|
||||
replace_dict = dict(zip(all_syms, orig))
|
||||
sym = tuple(Symbol(str(Indexed(self.symbol, i))) for i in indices)
|
||||
limits = [[i,] for i in all_syms if i not in sym]
|
||||
index = 0
|
||||
for i in range(count):
|
||||
if i not in indices:
|
||||
limits[index].append(self.distribution.set.args[i])
|
||||
limits[index] = tuple(limits[index])
|
||||
index += 1
|
||||
if self.distribution.is_Continuous:
|
||||
f = Lambda(sym, integrate(self.distribution(*all_syms), *limits))
|
||||
elif self.distribution.is_Discrete:
|
||||
f = Lambda(sym, summation(self.distribution(*all_syms), *limits))
|
||||
return f.xreplace(replace_dict)
|
||||
|
||||
def compute_expectation(self, expr, rvs=None, evaluate=False, **kwargs):
|
||||
syms = tuple(self.value[i] for i in range(self.component_count))
|
||||
rvs = rvs or syms
|
||||
if not any(i in rvs for i in syms):
|
||||
return expr
|
||||
expr = expr*self.pdf
|
||||
for rv in rvs:
|
||||
if isinstance(rv, Indexed):
|
||||
expr = expr.xreplace({rv: Indexed(str(rv.base), rv.args[1])})
|
||||
elif isinstance(rv, RandomSymbol):
|
||||
expr = expr.xreplace({rv: rv.symbol})
|
||||
if self.value in random_symbols(expr):
|
||||
raise NotImplementedError(filldedent('''
|
||||
Expectations of expression with unindexed joint random symbols
|
||||
cannot be calculated yet.'''))
|
||||
limits = tuple((Indexed(str(rv.base),rv.args[1]),
|
||||
self.distribution.set.args[rv.args[1]]) for rv in syms)
|
||||
return Integral(expr, *limits)
|
||||
|
||||
def where(self, condition):
|
||||
raise NotImplementedError()
|
||||
|
||||
def compute_density(self, expr):
|
||||
raise NotImplementedError()
|
||||
|
||||
def sample(self, size=(), library='scipy', seed=None):
|
||||
"""
|
||||
Internal sample method
|
||||
|
||||
Returns dictionary mapping RandomSymbol to realization value.
|
||||
"""
|
||||
return {RandomSymbol(self.symbol, self): self.distribution.sample(size,
|
||||
library=library, seed=seed)}
|
||||
|
||||
def probability(self, condition):
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class SampleJointScipy:
|
||||
"""Returns the sample from scipy of the given distribution"""
|
||||
def __new__(cls, dist, size, seed=None):
|
||||
return cls._sample_scipy(dist, size, seed)
|
||||
|
||||
@classmethod
|
||||
def _sample_scipy(cls, dist, size, seed):
|
||||
"""Sample from SciPy."""
|
||||
|
||||
import numpy
|
||||
if seed is None or isinstance(seed, int):
|
||||
rand_state = numpy.random.default_rng(seed=seed)
|
||||
else:
|
||||
rand_state = seed
|
||||
from scipy import stats as scipy_stats
|
||||
scipy_rv_map = {
|
||||
'MultivariateNormalDistribution': lambda dist, size: scipy_stats.multivariate_normal.rvs(
|
||||
mean=matrix2numpy(dist.mu).flatten(),
|
||||
cov=matrix2numpy(dist.sigma), size=size, random_state=rand_state),
|
||||
'MultivariateBetaDistribution': lambda dist, size: scipy_stats.dirichlet.rvs(
|
||||
alpha=list2numpy(dist.alpha, float).flatten(), size=size, random_state=rand_state),
|
||||
'MultinomialDistribution': lambda dist, size: scipy_stats.multinomial.rvs(
|
||||
n=int(dist.n), p=list2numpy(dist.p, float).flatten(), size=size, random_state=rand_state)
|
||||
}
|
||||
|
||||
sample_shape = {
|
||||
'MultivariateNormalDistribution': lambda dist: matrix2numpy(dist.mu).flatten().shape,
|
||||
'MultivariateBetaDistribution': lambda dist: list2numpy(dist.alpha).flatten().shape,
|
||||
'MultinomialDistribution': lambda dist: list2numpy(dist.p).flatten().shape
|
||||
}
|
||||
|
||||
dist_list = scipy_rv_map.keys()
|
||||
|
||||
if dist.__class__.__name__ not in dist_list:
|
||||
return None
|
||||
|
||||
samples = scipy_rv_map[dist.__class__.__name__](dist, size)
|
||||
return samples.reshape(size + sample_shape[dist.__class__.__name__](dist))
|
||||
|
||||
class SampleJointNumpy:
|
||||
"""Returns the sample from numpy of the given distribution"""
|
||||
|
||||
def __new__(cls, dist, size, seed=None):
|
||||
return cls._sample_numpy(dist, size, seed)
|
||||
|
||||
@classmethod
|
||||
def _sample_numpy(cls, dist, size, seed):
|
||||
"""Sample from NumPy."""
|
||||
|
||||
import numpy
|
||||
if seed is None or isinstance(seed, int):
|
||||
rand_state = numpy.random.default_rng(seed=seed)
|
||||
else:
|
||||
rand_state = seed
|
||||
numpy_rv_map = {
|
||||
'MultivariateNormalDistribution': lambda dist, size: rand_state.multivariate_normal(
|
||||
mean=matrix2numpy(dist.mu, float).flatten(),
|
||||
cov=matrix2numpy(dist.sigma, float), size=size),
|
||||
'MultivariateBetaDistribution': lambda dist, size: rand_state.dirichlet(
|
||||
alpha=list2numpy(dist.alpha, float).flatten(), size=size),
|
||||
'MultinomialDistribution': lambda dist, size: rand_state.multinomial(
|
||||
n=int(dist.n), pvals=list2numpy(dist.p, float).flatten(), size=size)
|
||||
}
|
||||
|
||||
sample_shape = {
|
||||
'MultivariateNormalDistribution': lambda dist: matrix2numpy(dist.mu).flatten().shape,
|
||||
'MultivariateBetaDistribution': lambda dist: list2numpy(dist.alpha).flatten().shape,
|
||||
'MultinomialDistribution': lambda dist: list2numpy(dist.p).flatten().shape
|
||||
}
|
||||
|
||||
dist_list = numpy_rv_map.keys()
|
||||
|
||||
if dist.__class__.__name__ not in dist_list:
|
||||
return None
|
||||
|
||||
samples = numpy_rv_map[dist.__class__.__name__](dist, prod(size))
|
||||
return samples.reshape(size + sample_shape[dist.__class__.__name__](dist))
|
||||
|
||||
class SampleJointPymc:
|
||||
"""Returns the sample from pymc of the given distribution"""
|
||||
|
||||
def __new__(cls, dist, size, seed=None):
|
||||
return cls._sample_pymc(dist, size, seed)
|
||||
|
||||
@classmethod
|
||||
def _sample_pymc(cls, dist, size, seed):
|
||||
"""Sample from PyMC."""
|
||||
|
||||
try:
|
||||
import pymc
|
||||
except ImportError:
|
||||
import pymc3 as pymc
|
||||
pymc_rv_map = {
|
||||
'MultivariateNormalDistribution': lambda dist:
|
||||
pymc.MvNormal('X', mu=matrix2numpy(dist.mu, float).flatten(),
|
||||
cov=matrix2numpy(dist.sigma, float), shape=(1, dist.mu.shape[0])),
|
||||
'MultivariateBetaDistribution': lambda dist:
|
||||
pymc.Dirichlet('X', a=list2numpy(dist.alpha, float).flatten()),
|
||||
'MultinomialDistribution': lambda dist:
|
||||
pymc.Multinomial('X', n=int(dist.n),
|
||||
p=list2numpy(dist.p, float).flatten(), shape=(1, len(dist.p)))
|
||||
}
|
||||
|
||||
sample_shape = {
|
||||
'MultivariateNormalDistribution': lambda dist: matrix2numpy(dist.mu).flatten().shape,
|
||||
'MultivariateBetaDistribution': lambda dist: list2numpy(dist.alpha).flatten().shape,
|
||||
'MultinomialDistribution': lambda dist: list2numpy(dist.p).flatten().shape
|
||||
}
|
||||
|
||||
dist_list = pymc_rv_map.keys()
|
||||
|
||||
if dist.__class__.__name__ not in dist_list:
|
||||
return None
|
||||
|
||||
import logging
|
||||
logging.getLogger("pymc3").setLevel(logging.ERROR)
|
||||
with pymc.Model():
|
||||
pymc_rv_map[dist.__class__.__name__](dist)
|
||||
samples = pymc.sample(draws=prod(size), chains=1, progressbar=False, random_seed=seed, return_inferencedata=False, compute_convergence_checks=False)[:]['X']
|
||||
return samples.reshape(size + sample_shape[dist.__class__.__name__](dist))
|
||||
|
||||
|
||||
_get_sample_class_jrv = {
|
||||
'scipy': SampleJointScipy,
|
||||
'pymc3': SampleJointPymc,
|
||||
'pymc': SampleJointPymc,
|
||||
'numpy': SampleJointNumpy
|
||||
}
|
||||
|
||||
class JointDistribution(Distribution, NamedArgsMixin):
|
||||
"""
|
||||
Represented by the random variables part of the joint distribution.
|
||||
Contains methods for PDF, CDF, sampling, marginal densities, etc.
|
||||
"""
|
||||
|
||||
_argnames = ('pdf', )
|
||||
|
||||
def __new__(cls, *args):
|
||||
args = list(map(sympify, args))
|
||||
for i in range(len(args)):
|
||||
if isinstance(args[i], list):
|
||||
args[i] = ImmutableMatrix(args[i])
|
||||
return Basic.__new__(cls, *args)
|
||||
|
||||
@property
|
||||
def domain(self):
|
||||
return ProductDomain(self.symbols)
|
||||
|
||||
@property
|
||||
def pdf(self):
|
||||
return self.density.args[1]
|
||||
|
||||
def cdf(self, other):
|
||||
if not isinstance(other, dict):
|
||||
raise ValueError("%s should be of type dict, got %s"%(other, type(other)))
|
||||
rvs = other.keys()
|
||||
_set = self.domain.set.sets
|
||||
expr = self.pdf(tuple(i.args[0] for i in self.symbols))
|
||||
for i in range(len(other)):
|
||||
if rvs[i].is_Continuous:
|
||||
density = Integral(expr, (rvs[i], _set[i].inf,
|
||||
other[rvs[i]]))
|
||||
elif rvs[i].is_Discrete:
|
||||
density = Sum(expr, (rvs[i], _set[i].inf,
|
||||
other[rvs[i]]))
|
||||
return density
|
||||
|
||||
def sample(self, size=(), library='scipy', seed=None):
|
||||
""" A random realization from the distribution """
|
||||
|
||||
libraries = ('scipy', 'numpy', 'pymc3', 'pymc')
|
||||
if library not in libraries:
|
||||
raise NotImplementedError("Sampling from %s is not supported yet."
|
||||
% str(library))
|
||||
if not import_module(library):
|
||||
raise ValueError("Failed to import %s" % library)
|
||||
|
||||
samps = _get_sample_class_jrv[library](self, size, seed=seed)
|
||||
|
||||
if samps is not None:
|
||||
return samps
|
||||
raise NotImplementedError(
|
||||
"Sampling for %s is not currently implemented from %s"
|
||||
% (self.__class__.__name__, library)
|
||||
)
|
||||
|
||||
def __call__(self, *args):
|
||||
return self.pdf(*args)
|
||||
|
||||
class JointRandomSymbol(RandomSymbol):
|
||||
"""
|
||||
Representation of random symbols with joint probability distributions
|
||||
to allow indexing."
|
||||
"""
|
||||
def __getitem__(self, key):
|
||||
if isinstance(self.pspace, JointPSpace):
|
||||
if (self.pspace.component_count <= key) == True:
|
||||
raise ValueError("Index keys for %s can only up to %s." %
|
||||
(self.name, self.pspace.component_count - 1))
|
||||
return Indexed(self, key)
|
||||
|
||||
|
||||
|
||||
class MarginalDistribution(Distribution):
|
||||
"""
|
||||
Represents the marginal distribution of a joint probability space.
|
||||
|
||||
Initialised using a probability distribution and random variables(or
|
||||
their indexed components) which should be a part of the resultant
|
||||
distribution.
|
||||
"""
|
||||
|
||||
def __new__(cls, dist, *rvs):
|
||||
if len(rvs) == 1 and iterable(rvs[0]):
|
||||
rvs = tuple(rvs[0])
|
||||
if not all(isinstance(rv, (Indexed, RandomSymbol)) for rv in rvs):
|
||||
raise ValueError(filldedent('''Marginal distribution can be
|
||||
intitialised only in terms of random variables or indexed random
|
||||
variables'''))
|
||||
rvs = Tuple.fromiter(rv for rv in rvs)
|
||||
if not isinstance(dist, JointDistribution) and len(random_symbols(dist)) == 0:
|
||||
return dist
|
||||
return Basic.__new__(cls, dist, rvs)
|
||||
|
||||
def check(self):
|
||||
pass
|
||||
|
||||
@property
|
||||
def set(self):
|
||||
rvs = [i for i in self.args[1] if isinstance(i, RandomSymbol)]
|
||||
return ProductSet(*[rv.pspace.set for rv in rvs])
|
||||
|
||||
@property
|
||||
def symbols(self):
|
||||
rvs = self.args[1]
|
||||
return {rv.pspace.symbol for rv in rvs}
|
||||
|
||||
def pdf(self, *x):
|
||||
expr, rvs = self.args[0], self.args[1]
|
||||
marginalise_out = [i for i in random_symbols(expr) if i not in rvs]
|
||||
if isinstance(expr, JointDistribution):
|
||||
count = len(expr.domain.args)
|
||||
x = Dummy('x', real=True)
|
||||
syms = tuple(Indexed(x, i) for i in count)
|
||||
expr = expr.pdf(syms)
|
||||
else:
|
||||
syms = tuple(rv.pspace.symbol if isinstance(rv, RandomSymbol) else rv.args[0] for rv in rvs)
|
||||
return Lambda(syms, self.compute_pdf(expr, marginalise_out))(*x)
|
||||
|
||||
def compute_pdf(self, expr, rvs):
|
||||
for rv in rvs:
|
||||
lpdf = 1
|
||||
if isinstance(rv, RandomSymbol):
|
||||
lpdf = rv.pspace.pdf
|
||||
expr = self.marginalise_out(expr*lpdf, rv)
|
||||
return expr
|
||||
|
||||
def marginalise_out(self, expr, rv):
|
||||
from sympy.concrete.summations import Sum
|
||||
if isinstance(rv, RandomSymbol):
|
||||
dom = rv.pspace.set
|
||||
elif isinstance(rv, Indexed):
|
||||
dom = rv.base.component_domain(
|
||||
rv.pspace.component_domain(rv.args[1]))
|
||||
expr = expr.xreplace({rv: rv.pspace.symbol})
|
||||
if rv.pspace.is_Continuous:
|
||||
#TODO: Modify to support integration
|
||||
#for all kinds of sets.
|
||||
expr = Integral(expr, (rv.pspace.symbol, dom))
|
||||
elif rv.pspace.is_Discrete:
|
||||
#incorporate this into `Sum`/`summation`
|
||||
if dom in (S.Integers, S.Naturals, S.Naturals0):
|
||||
dom = (dom.inf, dom.sup)
|
||||
expr = Sum(expr, (rv.pspace.symbol, dom))
|
||||
return expr
|
||||
|
||||
def __call__(self, *args):
|
||||
return self.pdf(*args)
|
||||
Reference in New Issue
Block a user