Source code for Goulib.piecewise

#!/usr/bin/env python
# coding: utf8
'''
piecewise-defined functions
'''

__author__ = "Philippe Guglielmetti"
__cfyright__ = "Cfyright 2013, Philippe Guglielmetti"
__license__ = "LGPL"

import bisect, math

from . import expr, math2, itertools2
    
[docs]class Piecewise(expr.Expr): ''' piecewise function defined by a sorted list of (startx, Expr) '''
[docs] def __init__(self,init=[],default=0, period=(-math2.inf,+math2.inf)): #Note : started by deriving a list of (point,value), but this leads to a problem: # the value is taken into account in sort order by bisect # so instead of defining one more class with a __cmp__ method, I split both lists if math2.is_number(period): period=(0,period) try: #copy constructor ? self.x=list(init.x) self.y=list(init.y) self.period = period or init.period # allow to force periodicity except AttributeError: self.x=[] self.y=[] self.period = period self.append(period[0],default) self.extend(init)
[docs] def __getitem__(self, i): return (self.x[i],self.y[i])
[docs] def is_periodic(self): if math.isinf(self.period[1]): return False return self.period[1]-self.period[0]
def _str_period(self): p=self.is_periodic() return ", period=%s"%p if p else ""
[docs] def __str__(self): return str(list(self))+self._str_period()
[docs] def __repr__(self): return repr(list(self))+self._str_period()
def _x(self,x): '''handle periodicity''' p=self.is_periodic() return x % p if p else x
[docs] def index(self,x): '''return index of piece''' return bisect.bisect_right(self.x,self._x(x))-1
[docs] def __call__(self,x): '''returns value of Expr at point x ''' if itertools2.isiterable(x): return [self(x) for x in x] i=self.index(x) xx=self._x(x) return self.y[i](xx)
[docs] def insort(self,x,v=None): '''insert a point (or returns it if it already exists) note : method name follows bisect.insort convention ''' x=self._x(x) i=bisect.bisect_left(self.x,x) # do not use self.index here ! if i<len(self.x) and x==self.x[i]: return i #insert either the v value, or copy the current value at x #note : we might have consecutive tuples with the same y value if v is not None: self.y.insert(i,expr.Expr(v)) else: #split the piece at x self.y.insert(i,self.y[i-1]) self.x.insert(i,x) return i
[docs] def __iter__(self): '''iterators through discontinuities. take the opportunity to delete redundant tuples''' prev=None i=0 while i<len(self.x): x,y=self.x[i],self.y[i] if y==prev: #simplify self.y.pop(i) self.x.pop(i) else: yield x,y prev=y i+=1
[docs] def append(self, x, y=None): '''appends a (x,y) piece. In fact inserts it at correct position''' if y is None: (x,y)=x x=self._x(x) i=self.insort(x,y) return self # to allow chained calls
[docs] def extend(self,iterable): '''appends an iterable of (x,y) values''' for p in iterable: self.append(p)
[docs] def iapply(self,f,right): '''apply function to self''' if not right: #monadic . apply to each expr self.y=[v.apply(f) for v in self.y] elif isinstance(right,Piecewise): #combine each piece of right with self for i,p in enumerate(right): try: self.iapply(f,(p[0],p[1],right[i+1][0])) except: self.iapply(f,(p[0],p[1])) else: #assume a triplet (start,value,end) as called above i=self.insort(right[0]) try: j=self.insort(right[2]) if j<i: i,j=j,i except: j=len(self.x) for k in range(i,j): self.y[k]=self.y[k].apply(f,right[1]) #calls Expr.apply return self
[docs] def apply(self,f,right=None): '''apply function to copy of self''' return Piecewise(self).iapply(f,right)
[docs] def applx(self,f): ''' apply a function to each x value ''' self.x=[f(x) for x in self.x] self.y=[y.applx(f) for y in self.y] return self
[docs] def __lshift__(self,dx): return Piecewise(self).applx(lambda x:x-dx)
[docs] def __rshift__(self,dx): return Piecewise(self).applx(lambda x:x+dx)
def _switch_points(self, xmin, xmax): prevy=None firstpoint, lastpoint=False, False for x,y in self: y=y(x) if x<xmin: if firstpoint: continue firstpoint=True x=xmin if x>xmax: if lastpoint: break lastpoint=True x=xmax if prevy is not None and not math2.isclose(y,prevy): # step yield x,prevy yield x,y prevy=y
[docs] def points(self,xmin=None, xmax=None): ''':return: x,y lists of float : points for a line plot''' resx=[] resy=[] dx=self.x[-1]-self.x[1] p=self.is_periodic() if xmin is None: # by default we extend the range by 10% xmin=min(0,self.x[1]-dx*.1) if xmax is None: if p : # by default we show 2.5 periods xmax=xmin+self.p*2.5 else: # by default we extend the range by 10% xmax=self.x[-1]+dx*.1 for x,y in self._switch_points(xmin,xmax): resx.append(x) resy.append(y) if xmax>x: resx.append(xmax) resy.append(self(xmax)) return resx,resy
def _plot(self, ax, xmax=None, **kwargs): '''plots function''' (x,y)=self.points(xmax=xmax) return super(Piecewise,self)._plot (ax, x, y, **kwargs)