'''
piecewise-defined functions
'''
__author__ = "Philippe Guglielmetti"
__copyright__ = "Copyright 2013, Philippe Guglielmetti"
__license__ = "LGPL"
import bisect
import math
from Goulib import expr, math2, itertools2
[docs]class Piecewise(expr.Expr):
'''
piecewise function defined by a sorted list of (startx, Expr)
'''
[docs] def __init__(self, init=[], default=0, period=(-math2.inf, +math2.inf)):
# Note : started by deriving a list of (point,value), but this leads to a problem:
# the value is taken into account in sort order by bisect
# so instead of defining one more class with a __cmp__ method, I split both lists
if math2.is_number(period):
period = (0, period)
try: # copy constructor ?
self.x = list(init.x)
self.y = list(init.y)
self.period = period or init.period # allow to force periodicity
except AttributeError:
self.x = []
self.y = []
self.period = period
self.append(period[0], default)
self.extend(init)
# to initialize context and such stuff
super(Piecewise, self).__init__(0)
self.body = '?' # should not happen
[docs] def __len__(self):
return len(self.x)
[docs] def __getitem__(self, i):
return (self.x[i], self.y[i])
[docs] def is_periodic(self):
if math.isinf(self.period[1]):
return False
return self.period[1] - self.period[0]
def _str_period(self):
p = self.is_periodic()
return ", period=%s" % p if p else ""
[docs] def __str__(self):
return str(list(self)) + self._str_period()
[docs] def __repr__(self):
return repr(list(self)) + self._str_period()
[docs] def latex(self):
''':return: string LaTex formula'''
def condition(i):
min = self[i][0]
try:
max = self[i + 1][0]
except IndexError:
max = math2.inf
if i == 0:
return r'{x}<{' + str(max) + '}'
elif i == len(self) - 1:
return r'{x}\geq{' + str(min) + '}'
else:
return r'{' + str(min) + r'}\leq{x}<{' + str(max) + '}'
l = [f[1].latex() + '&' + condition(i) for i, f in enumerate(self)]
return r'\begin{cases}' + r'\\'.join(l) + r'\end{cases}'
def _x(self, x):
'''handle periodicity'''
p = self.is_periodic()
return x % p if p else x
[docs] def index(self, x):
'''return index of piece'''
return bisect.bisect_right(self.x, self._x(x)) - 1
[docs] def __call__(self, x):
'''returns value of Expr at point x '''
if itertools2.isiterable(x):
return [self(x) for x in x]
i = self.index(x)
xx = self._x(x)
return self.y[i](xx)
[docs] def insort(self, x, v=None):
'''insert a point (or returns it if it already exists)
note : method name follows bisect.insort convention
'''
x = self._x(x)
i = bisect.bisect_left(self.x, x) # do not use self.index here !
if i < len(self) and x == self.x[i]:
return i
# insert either the v value, or copy the current value at x
# note : we might have consecutive tuples with the same y value
if v is not None:
self.y.insert(i, expr.Expr(v))
else: # split the piece at x
self.y.insert(i, self.y[i - 1])
self.x.insert(i, x)
return i
[docs] def __iter__(self):
'''iterators through discontinuities. take the opportunity to delete redundant tuples'''
prev = None
i = 0
while i < len(self):
x, y = self.x[i], self.y[i]
if y == prev: # simplify
self.y.pop(i)
self.x.pop(i)
else:
yield x, y
prev = y
i += 1
[docs] def append(self, x, y=None):
'''appends a (x,y) piece. In fact inserts it at correct position'''
if y is None:
(x, y) = x
x = self._x(x)
i = self.insort(x, y)
return self # to allow chained calls
[docs] def extend(self, iterable):
'''appends an iterable of (x,y) values'''
for p in iterable:
self.append(p)
[docs] def iapply(self, f, right):
'''apply function to self'''
if not right: # monadic . apply to each expr
self.y = [v.apply(f) for v in self.y]
elif isinstance(right, Piecewise): # combine each piece of right with self
for i, p in enumerate(right):
try:
self.iapply(f, (p[0], p[1], right[i + 1][0]))
except:
self.iapply(f, (p[0], p[1]))
else: # assume a triplet (start,value,end) as called above
i = self.insort(right[0])
try:
j = self.insort(right[2])
if j < i:
i, j = j, i
except:
j = len(self)
for k in range(i, j):
self.y[k] = self.y[k].apply(f, right[1]) # calls Expr.apply
return self
[docs] def apply(self, f, right=None):
'''apply function to copy of self'''
return Piecewise(self).iapply(f, right)
[docs] def applx(self, f):
''' apply a function to each x value '''
self.x = [f(x) for x in self.x]
self.y = [y.applx(f) for y in self.y]
return self
[docs] def __lshift__(self, dx):
return Piecewise(self).applx(lambda x: x - dx)
[docs] def __rshift__(self, dx):
return Piecewise(self).applx(lambda x: x + dx)
def _switch_points(self, xmin, xmax):
prevy = None
firstpoint, lastpoint = False, False
for x, y in self:
y = y(x)
if x < xmin:
if firstpoint:
continue
firstpoint = True
x = xmin
if x > xmax:
if lastpoint:
break
lastpoint = True
x = xmax
if prevy is not None and not math2.isclose(y, prevy): # step
yield x, prevy
yield x, y
prevy = y
[docs] def points(self, xmin=None, xmax=None):
''':return: x,y lists of float : points for a line plot'''
resx = []
resy = []
dx = self.x[-1] - self.x[1]
p = self.is_periodic()
if xmin is None:
# by default we extend the range by 10%
xmin = min(0, self.x[1] - dx * .1)
if xmax is None:
if p:
# by default we show 2.5 periods
xmax = xmin + p * 2.5
else:
# by default we extend the range by 10%
xmax = self.x[-1] + dx * .1
for x, y in self._switch_points(xmin, xmax):
resx.append(x)
resy.append(y)
if xmax > x:
resx.append(xmax)
resy.append(self(xmax))
return resx, resy
def _plot(self, ax, xmax=None, **kwargs):
'''plots function'''
(x, y) = self.points(xmax=xmax)
return super(Piecewise, self)._plot(ax, x, y, **kwargs)