Source code for Goulib.expr

#!/usr/bin/env python
# coding: utf8
'''
simple symbolic math expressions
'''

__author__ = "Philippe Guglielmetti, J.F. Sebastian, Geoff Reedy"
__copyright__ = "Copyright 2013, Philippe Guglielmetti"
__credits__ = [
    'http://stackoverflow.com/questions/2371436/evaluating-a-mathematical-expression-in-a-string',
    'http://stackoverflow.com/questions/3867028/converting-a-python-numeric-expression-to-latex',
    ]
__license__ = "LGPL"

import six, logging, copy, collections, inspect, re

from Goulib import plot #sets matplotlib backend

from Goulib import itertools2, math2

# http://stackoverflow.com/questions/2371436/evaluating-a-mathematical-expression-in-a-string
# https://github.com/erwanp/pytexit

import ast

#indexes in _operators, _ functions and _constants to use for corresponding symbols
_dialect_str = 2
_dialect_python = 3
_dialect_latex = 4

constants = { # constants in this dict are recognized in output
    bool : {},
    float : {},
    complex : {},
}

from sortedcollections import SortedDict
functions=SortedDict() #only functions listed in this dict can be used in Expr
operators=dict() #only operators listed in this dict are allowed


[docs]def eval(node,**kwargs): '''safe eval of ast node : only functions and _operators listed above can be used :param node: ast.AST to evaluate :param ctx: dict of varname : value to substitute in node :return: number or expression string ''' _ctx=kwargs.get('ctx',{}) _operators=kwargs.get('operators',operators) if isinstance(node, ast.Num): # <number> return node.n elif isinstance(node, ast.Name): return _ctx.get(node.id,node.id) #return value or var elif isinstance(node, ast.Attribute): return getattr(_ctx[node.value.id],node.attr) elif isinstance(node,ast.Tuple): return tuple(eval(e,**kwargs) for e in node.elts) elif isinstance(node, ast.Call): _functions=kwargs.get('functions',functions) params=[eval(arg,**kwargs) for arg in node.args] if not node.func.id in _functions: raise NameError('%s function not allowed'%node.func.id) f=_functions[node.func.id][0] res=f(*params) return math2.int_or_float(res, 0,1e-12) # try to correct small error elif isinstance(node, ast.BinOp): # <left> <operator> <right> op=_operators[type(node.op)] left=eval(node.left,**kwargs) right=eval(node.right,**kwargs) if math2.is_number(left) and math2.is_number(right): res=op[0](left, right) # no correction here ! return res else: return "%s%s%s"%(left,op[_dialect_python],right) elif isinstance(node, ast.UnaryOp): # <operator> <operand> e.g., -1 right=eval(node.operand,**kwargs) return _operators[type(node.op)][0](right) elif isinstance(node, ast.Compare): left=eval(node.left,**kwargs) for op,right in zip(node.ops,node.comparators): #TODO: find what to do when multiple items in list return _operators[type(op)][0](left, eval(right,**kwargs)) elif six.PY3 and isinstance(node, ast.NameConstant): return node.value else: logging.warning(ast.dump(node,False,False)) return eval(node.body,**kwargs) #last chance
[docs]def get_function_source(f): '''returns cleaned code of a function or lambda currently only supports: - lambda x:formula_of_(x) - def anything(x): return formula_of_(x) ''' f=inspect.getsource(f).rstrip('\n') #TODO: merge lines more subtly g=re.search(r'lambda(.*):(.*)(\)|#)',f) if g: res=g.group(2).strip() #remove leading+trailing spaces bra,ket=res.count('('),res.count(')') if bra==ket: return res else: #closing parenthesis ? return res[:-(ket-bra)] else: g=re.search(r'def \w*\((.*)\):\s*return (.*)',f) if g is None: raise ValueError('not a valid function code %s'%f) res=g.group(2) return res
[docs]def plouffe(f,epsilon=1e-6): if f<0 : r=plouffe(-f) if isinstance(r,six.string_types): return '-'+r return f if f!=0 and math2.is_integer(1/f,epsilon): f='1/%d'%math2.rint(1/f) elif math2.is_integer(f*f,epsilon): f='sqrt(%d)'%math2.rint(f*f) return f
[docs]class Expr(plot.Plot): ''' Math expressions that can be evaluated like standard functions combined using standard operators and plotted in IPython/Jupyter notebooks '''
[docs] def __init__(self,f,**kwargs): ''' :param f: function or operator, Expr to copy construct, or formula string ''' if isinstance(f,Expr): pass # skip for now, handled below elif inspect.isfunction(f): try: f=get_function_source(f) except ValueError: f='%s(x)'%f.__name__ elif isinstance(f, collections.Callable): # builtin function f='%s(x)'%f.__name__ elif f in ('True','False'): f=bool(f=='True') if type(f) is bool: self.body=ast.Num(f) return if type(f) is float: #try to beautify it if math2.is_integer(f): f=math2.rint(f) else: f=plouffe(f) if math2.is_number(f): # store it with full precision # (otherwise Py2 doesn't find pi in _constants ...) self.body=ast.Num(f) return if isinstance(f,Expr): #copy constructor self.body=f.body return elif isinstance(f,ast.AST): self.body=f return f=str(f).replace('^','**') #accept ^ as power operator rather than xor ... self._operators=kwargs.get('operators',operators) self._functions=kwargs.get('functions',functions) self.body=compile(f,'Expr','eval',ast.PyCF_ONLY_AST).body
@property def isNum(self): return isinstance(self.body,ast.Num) @property def isconstant(self): ''':return: True if Expr evaluates to a constant number or bool''' res=self() if math2.is_number(res): return True if isinstance(res,bool): return True return False
[docs] def __call__(self ,x=None, **kwargs): '''evaluate the Expr at x OR compose self(x())''' if isinstance(x,Expr): #composition return self.applx(x) if itertools2.isiterable(x): return [self(x) for x in x] # return a displayable list if x is not None: kwargs['x']=x kwargs['self']=self #allows to call methods such as in Stats try: e=eval(self.body,ctx=kwargs) except TypeError: # some params remain symbolic return self except Exception as error:# ZeroDivisionError, OverflowError return None if math2.is_number(e): return e return Expr(e)
[docs] def __float__(self): return self()
[docs] def __repr__(self): return TextVisitor(_dialect_python).visit(self.body)
[docs] def __str__(self): return TextVisitor(_dialect_str).visit(self.body)
def _repr_html_(self): '''default rich format is LaTeX''' return self._repr_latex_()
[docs] def latex(self): ''':return: string LaTex formula''' return TextVisitor(_dialect_latex).visit(self.body)
def _repr_latex_(self): return r'$%s$'%self.latex()
[docs] def points(self, xmin=-1, xmax=1, step=0.1): ''':return: x,y lists of float : points for a line plot''' if self.isconstant: return [xmin,xmax],[self(xmin),self(xmax)] x=list(itertools2.arange(xmin,xmax,step)) y=self(x) return x,y
def _plot(self, ax, x=None, y=None, **kwargs): if x is None: x,y=self.points() if y is None: y=self(x) offset=kwargs.pop('offset',0) #slightly shift the points to make superimposed curves more visible points=list(zip(x,y)) # might contain (x,None) for undefined points for xy in itertools2.isplit(points,lambda _:not math2.is_real(_[1])): # curves between defined points x,y=[],[] # matplotlib doesn't support generators... for v in xy: x.append(v[0]+offset) y.append(v[1]+offset) ax.plot(x,y, **kwargs) return ax
[docs] def apply(self,f,right=None): '''function composition self o f = f(self(x))''' if right is None: if isinstance(f, ast.unaryop): node=ast.UnaryOp(f,self.body) else: #if not isinstance(f,Expr): f=Expr(f) #not useful as applx does the reverse return f.applx(self) else: if not isinstance(right,Expr): right=Expr(right) node = ast.BinOp(self.body,f,right.body) return Expr(node)
[docs] def applx(self,f,var='x'): '''function composition f o self = self(f(x))''' if isinstance(f,Expr): f=f.body class Subst(ast.NodeTransformer): def visit_Name(self, node): if node.id==var: return f else: return node node=copy.deepcopy(self.body) return Expr(Subst().visit(node))
[docs] def __eq__(self,other): if math2.is_number(other): try: return self()==other except: return False if not isinstance(other,Expr): other=Expr(other) return str(self())==str(other())
[docs] def __ne__(self, other): return not self==other
[docs] def __lt__(self,other): if math2.is_number(other): try: return self()<other except: return False if not isinstance(other,Expr): other=Expr(other) return float(self())<float(other())
[docs] def __le__(self, other): return self<other or self==other
[docs] def __ge__(self, other): return not self<other
[docs] def __gt__(self,other): return self>=other and not self==other
[docs] def __add__(self,right): return self.apply(ast.Add(),right)
[docs] def __sub__(self,right): return self.apply(ast.Sub(),right)
[docs] def __neg__(self): return self.apply(ast.USub())
[docs] def __mul__(self,right): return self.apply(ast.Mult(),right)
[docs] def __rmul__(self,right): return Expr(right)*self
[docs] def __truediv__(self,right): return self.apply(ast.Div(),right)
[docs] def __pow__(self,right): return self.apply(ast.Pow(),right)
__div__=__truediv__
[docs] def __invert__(self): return self.apply(ast.Invert())
[docs] def __and__(self,right): return self.apply(ast.And(),right)
[docs] def __or__(self,right): return self.apply(ast.Or(),right)
[docs] def __xor__(self,right): return self.apply(ast.BitXor(),right)
[docs] def __lshift__(self,dx): return self.applx(ast.BinOp(ast.Name('x',None),ast.Add(),ast.Num(dx)))
[docs] def __rshift__(self,dx): return self.applx(ast.BinOp(ast.Name('x',None),ast.Sub(),ast.Num(dx)))
[docs] def complexity(self): ''' measures the complexity of Expr :return: int, sum of the precedence of used ops ''' def _node_complexity(node): try: res=self._operators[type(node.op)][1] except AttributeError: res=self._operators[type(node)][1] try: res+=_node_complexity(node.operand) except AttributeError: pass try: res+=_node_complexity(node.left) except AttributeError: pass try: res+=_node_complexity(node.right) except AttributeError: pass return res return _node_complexity(self.body)
# supported _operators with precedence and text + LaTeX repr # precedence as in https://docs.python.org/reference/expressions.html#operator-precedence # import operator as op # table of allowed operators # note we very slightly prefer + over - and * over / for simpler expression generation operators = { ast.Or: (op.or_,300,' or ',' or ',' \\vee '), ast.And: (op.and_,400,' and ',' and ',' \\wedge '), ast.Not: (op.not_,500,'not ','not ','\\neg'), ast.Eq: (op.eq,600,'=',' == ',' = '), ast.Gt: (op.gt,600,' > ',' > ',' \\gtr '), ast.GtE:(op.ge,600,' >= ',' >= ',' \\gec '), ast.Lt: (op.lt,600,' < ',' < ',' \\ltr '), ast.LtE: (op.le,600,' <= ',' <= ',' \\leq '), ast.BitXor: (op.xor,800,' xor ',' xor ',' xor '), ast.LShift: (op.lshift, 1000,' << ',' << ',' \\ll '), ast.RShift: (op.rshift, 1000,' >> ',' >> ',' \\gg '), ast.Add: (op.add, 1100,'+','+','+'), ast.Sub: (op.sub, 1101,'-','-','-'), ast.Mult: (op.mul, 1200,'*','*',' \\cdot '), ast.Div: (op.truediv, 1201,'/','/','\\frac{%s}{%s}'), ast.FloorDiv: (op.floordiv, 1201,'//','//','\\left\\lfloor\\frac{%s}{%s}\\right\\rfloor'), ast.Mod: (op.mod, 1200,' mod ','%',' \\bmod '), ast.Invert: (op.not_,1300,'~','~','\\sim '), ast.UAdd: (op.pos,1150,'+','+','+'), ast.USub: (op.neg,1150,'-','-','-'), ast.Pow: (math2.pow,1400,'^','**','^'), # ipow returns an integer when result is integer ... # precedence of other types below ast.Call:(None,9000), ast.Name:(None,9000), ast.Num:(None,9000), }
[docs]def add_function(f,s=None,r=None,l=None): ''' add a function to those allowed in Expr. :param f: function :param s: string representation, should be formula-like :param r: repr representation, should be cut&pastable in a calculator, or in python ... :param l: LaTeX representation ''' functions[f.__name__]=(f,9999,s,r or s,l) return functions[f.__name__]
[docs]def add_constant(c, name, s=None,r=None,l=None): ''' add a constant to those recognized in Expr. :param c: constant :param s: string representation, should be formula-like :param r: repr representation, should be cut&pastable in a calculator, or in python ... :param l: LaTeX representation ''' constants[type(c)][c]=(None,None,s or name,r or name ,l or '\\'+name)
[docs]def add_module(module): for fname,f in six.iteritems(module.__dict__): if fname[0]=='_': continue if isinstance(f, collections.Callable): add_function(f) elif math2.is_number(f): add_constant(f,fname)
add_constant(True,'True') add_constant(False,'False') import math add_module(math) add_function(abs,l='\\lvert{%s}\\rvert') add_function(math.fabs,l='\\lvert{%s}\\rvert') add_function(math.factorial,'%s!','fact','%s!') add_function(math2.factorial2,'%s!','fact','%s!!') add_function(math2.sqrt,l='\\sqrt{%s}') add_function(math.trunc,l='\\left\\lfloor{%s}\\right\\rfloor') add_function(math.floor,l='\\left\\lfloor{%s}\\right\\rfloor') add_function(math.ceil,l='\\left\\lceil{%s}\\right\\rceil') add_function(math.asin,l='\\arcsin') add_function(math.acos,l='\\arccos') add_function(math.atan,l='\\arctan') add_function(math.asinh,l='\\sinh^{-1}') add_function(math.acosh,l='\\cosh^{-1}') add_function(math.atanh,l='\\tanh^{-1}') add_function(math.log,l='\\ln') add_function(math.log1p,l='\\ln\\left(1-{%s}\\rvert)') add_function(math.log10,l='\\log_{10}') add_function(math2.log2,l='\\log_2') add_function(math.gamma,l='\\Gamma') add_function(math.exp,l='e^{%s}') add_function(math.expm1,l='e^{%s}-1') add_function(math.lgamma,'log(abs(gamma(%s)))',l='\\ln\\lvert\\Gamma\\left({%s}\\rvert)\\right)') add_function(math.degrees,l='%s\\cdot\\frac{360}{2\\pi}') add_function(math.radians,l='%s\\cdot\\frac{2\\pi}{360}') add_constant(complex(0,1),'i') #http://stackoverflow.com/questions/3867028/converting-a-python-numeric-expression-to-latex
[docs]class TextVisitor(ast.NodeVisitor):
[docs] def __init__(self,dialect,operators=operators, functions=functions): ''':param dialect: int index in _operators of symbols to use ''' self._dialect=dialect self._operators=operators self._functions=functions
[docs] def prec(self, op): ''' calculate the precedence of op ''' if isinstance(op,(ast.BinOp, ast.UnaryOp)): op=op.op if isinstance(op,ast.Num) and math2.is_real(op.n) and op.n<0: return self._operators[ast.USub][1] try: return self._operators[type(op)][1] except KeyError: return self._operators[type(op)][1]
def _par(self,content): if self._dialect == _dialect_latex: return '\\left(%s\\right)'%content else: return '(%s)'%content
[docs] def visit_Call(self, n): args = r', '.join(map(self.visit, n.args)) func = self.visit(n.func) fname = self._functions[func][self._dialect] if fname is None: if self._dialect == _dialect_latex: fname = '\\'+func else: fname=func if '%s' in fname: if len(n.args)>1: #TODO: or ... what ? args=self._par(args) return fname%args return fname+self._par(args)
[docs] def visit_Name(self, n): return n.id
[docs] def visit_NameConstant(self, node): return str(node.value)
[docs] def visit_UnaryOp(self, n): op=self.visit(n.operand) if self.prec(n.op) > self.prec(n.operand): op=self._par(op) symbol=self._operators[type(n.op)][self._dialect] if '%s' in symbol: return symbol%op return symbol+op
def _Bin(self, left,op,right): # commute x*3 in 3*x if isinstance(op, ast.Mult): if isinstance(right, ast.Num): if not Expr(left).isconstant: return self._Bin(right,op,left) l,r = self.visit(left),self.visit(right) symbol=self._operators[type(op)][self._dialect] if '%s' in symbol: # no parenthesis required in this case return symbol%(l,r) #handle precedence (parenthesis) if needed if self.prec(op) > self.prec(left): l = self._par(l) if self.prec(op) > self.prec(right): if self._dialect == _dialect_latex and isinstance(op, ast.Pow): r='{'+r+'}' else: r = self._par(r) # remove * if possible if self._dialect != _dialect_python and isinstance(op, ast.Mult): if not l[-1].isdigit() or not r[0].isdigit(): symbol='' res=l+symbol+r # TODO: find a better way to do this ... plusminus=self._operators[ast.Add][self._dialect]+self._operators[ast.USub][self._dialect] minusminus=self._operators[ast.Sub][self._dialect]+self._operators[ast.USub][self._dialect] res=res.replace(plusminus,self._operators[ast.Sub][self._dialect]) res=res.replace(minusminus,self._operators[ast.Add][self._dialect]) return res
[docs] def visit_BinOp(self, n): return self._Bin(n.left,n.op,n.right)
[docs] def visit_Compare(self,n): #TODO: what to do with multiple ops/comparators ? return self._Bin(n.left,n.ops[0],n.comparators[0])
[docs] def visit_Num(self, n): try: d=constants[type(n.n)] return d[n.n][self._dialect] except KeyError: pass return str(math2.int_or_float(n.n))
[docs] def generic_visit(self, n): try: l=list(map(self.visit, n)) return ''.join(l) except TypeError: pass if isinstance(n, ast.AST): l=map(self.visit, [getattr(n, f) for f in n._fields]) return ''.join(l) else: return str(n)