#!/usr/bin/env python
# coding: utf8
'''
simple symbolic math expressions
'''
__author__ = "Philippe Guglielmetti, J.F. Sebastian, Geoff Reedy"
__copyright__ = "Copyright 2013, Philippe Guglielmetti"
__credits__ = [
'http://stackoverflow.com/questions/2371436/evaluating-a-mathematical-expression-in-a-string',
'http://stackoverflow.com/questions/3867028/converting-a-python-numeric-expression-to-latex',
]
__license__ = "LGPL"
import six, logging, copy, collections, inspect, re
from Goulib import plot #sets matplotlib backend
from Goulib import itertools2, math2
# http://stackoverflow.com/questions/2371436/evaluating-a-mathematical-expression-in-a-string
# https://github.com/erwanp/pytexit
import ast
#indexes in _operators, _ functions and _constants to use for corresponding symbols
_dialect_str = 2
_dialect_python = 3
_dialect_latex = 4
constants = { # constants in this dict are recognized in output
bool : {},
float : {},
complex : {},
}
from sortedcollections import SortedDict
functions=SortedDict() #only functions listed in this dict can be used in Expr
operators=dict() #only operators listed in this dict are allowed
[docs]def eval(node,**kwargs):
'''safe eval of ast node : only functions and _operators listed above can be used
:param node: ast.AST to evaluate
:param ctx: dict of varname : value to substitute in node
:return: number or expression string
'''
_ctx=kwargs.get('ctx',{})
_operators=kwargs.get('operators',operators)
if isinstance(node, ast.Num): # <number>
return node.n
elif isinstance(node, ast.Name):
return _ctx.get(node.id,node.id) #return value or var
elif isinstance(node, ast.Attribute):
return getattr(_ctx[node.value.id],node.attr)
elif isinstance(node,ast.Tuple):
return tuple(eval(e,**kwargs) for e in node.elts)
elif isinstance(node, ast.Call):
_functions=kwargs.get('functions',functions)
params=[eval(arg,**kwargs) for arg in node.args]
if not node.func.id in _functions:
raise NameError('%s function not allowed'%node.func.id)
f=_functions[node.func.id][0]
res=f(*params)
return math2.int_or_float(res, 0,1e-12) # try to correct small error
elif isinstance(node, ast.BinOp): # <left> <operator> <right>
op=_operators[type(node.op)]
left=eval(node.left,**kwargs)
right=eval(node.right,**kwargs)
if math2.is_number(left) and math2.is_number(right):
res=op[0](left, right)
# no correction here !
return res
else:
return "%s%s%s"%(left,op[_dialect_python],right)
elif isinstance(node, ast.UnaryOp): # <operator> <operand> e.g., -1
right=eval(node.operand,**kwargs)
return _operators[type(node.op)][0](right)
elif isinstance(node, ast.Compare):
left=eval(node.left,**kwargs)
for op,right in zip(node.ops,node.comparators):
#TODO: find what to do when multiple items in list
return _operators[type(op)][0](left, eval(right,**kwargs))
elif six.PY3 and isinstance(node, ast.NameConstant):
return node.value
else:
logging.warning(ast.dump(node,False,False))
return eval(node.body,**kwargs) #last chance
[docs]def get_function_source(f):
'''returns cleaned code of a function or lambda
currently only supports:
- lambda x:formula_of_(x)
- def anything(x): return formula_of_(x)
'''
f=inspect.getsource(f).rstrip('\n') #TODO: merge lines more subtly
g=re.search(r'lambda(.*):(.*)(\)|#)',f)
if g:
res=g.group(2).strip() #remove leading+trailing spaces
bra,ket=res.count('('),res.count(')')
if bra==ket:
return res
else: #closing parenthesis ?
return res[:-(ket-bra)]
else:
g=re.search(r'def \w*\((.*)\):\s*return (.*)',f)
if g is None:
raise ValueError('not a valid function code %s'%f)
res=g.group(2)
return res
[docs]def plouffe(f,epsilon=1e-6):
if f<0 :
r=plouffe(-f)
if isinstance(r,six.string_types):
return '-'+r
return f
if f!=0 and math2.is_integer(1/f,epsilon):
f='1/%d'%math2.rint(1/f)
elif math2.is_integer(f*f,epsilon):
f='sqrt(%d)'%math2.rint(f*f)
return f
[docs]class Expr(plot.Plot):
'''
Math expressions that can be evaluated like standard functions
combined using standard operators
and plotted in IPython/Jupyter notebooks
'''
[docs] def __init__(self,f,**kwargs):
'''
:param f: function or operator, Expr to copy construct, or formula string
'''
if isinstance(f,Expr):
pass # skip for now, handled below
elif inspect.isfunction(f):
try:
f=get_function_source(f)
except ValueError:
f='%s(x)'%f.__name__
elif isinstance(f, collections.Callable): # builtin function
f='%s(x)'%f.__name__
elif f in ('True','False'):
f=bool(f=='True')
if type(f) is bool:
self.body=ast.Num(f)
return
if type(f) is float: #try to beautify it
if math2.is_integer(f):
f=math2.rint(f)
else:
f=plouffe(f)
if math2.is_number(f): # store it with full precision
# (otherwise Py2 doesn't find pi in _constants ...)
self.body=ast.Num(f)
return
if isinstance(f,Expr): #copy constructor
self.body=f.body
return
elif isinstance(f,ast.AST):
self.body=f
return
f=str(f).replace('^','**') #accept ^ as power operator rather than xor ...
self._operators=kwargs.get('operators',operators)
self._functions=kwargs.get('functions',functions)
self.body=compile(f,'Expr','eval',ast.PyCF_ONLY_AST).body
@property
def isNum(self):
return isinstance(self.body,ast.Num)
@property
def isconstant(self):
''':return: True if Expr evaluates to a constant number or bool'''
res=self()
if math2.is_number(res): return True
if isinstance(res,bool): return True
return False
[docs] def __call__(self ,x=None, **kwargs):
'''evaluate the Expr at x OR compose self(x())'''
if isinstance(x,Expr): #composition
return self.applx(x)
if itertools2.isiterable(x):
return [self(x) for x in x] # return a displayable list
if x is not None:
kwargs['x']=x
kwargs['self']=self #allows to call methods such as in Stats
try:
e=eval(self.body,ctx=kwargs)
except TypeError: # some params remain symbolic
return self
except Exception as error:# ZeroDivisionError, OverflowError
return None
if math2.is_number(e):
return e
return Expr(e)
[docs] def __float__(self):
return self()
[docs] def __repr__(self):
return TextVisitor(_dialect_python).visit(self.body)
[docs] def __str__(self):
return TextVisitor(_dialect_str).visit(self.body)
def _repr_html_(self):
'''default rich format is LaTeX'''
return self._repr_latex_()
[docs] def latex(self):
''':return: string LaTex formula'''
return TextVisitor(_dialect_latex).visit(self.body)
def _repr_latex_(self):
return r'$%s$'%self.latex()
[docs] def points(self, xmin=-1, xmax=1, step=0.1):
''':return: x,y lists of float : points for a line plot'''
if self.isconstant:
return [xmin,xmax],[self(xmin),self(xmax)]
x=list(itertools2.arange(xmin,xmax,step))
y=self(x)
return x,y
def _plot(self, ax, x=None, y=None, **kwargs):
if x is None:
x,y=self.points()
if y is None:
y=self(x)
offset=kwargs.pop('offset',0) #slightly shift the points to make superimposed curves more visible
points=list(zip(x,y)) # might contain (x,None) for undefined points
for xy in itertools2.isplit(points,lambda _:not math2.is_real(_[1])): # curves between defined points
x,y=[],[] # matplotlib doesn't support generators...
for v in xy:
x.append(v[0]+offset)
y.append(v[1]+offset)
ax.plot(x,y, **kwargs)
return ax
[docs] def apply(self,f,right=None):
'''function composition self o f = f(self(x))'''
if right is None:
if isinstance(f, ast.unaryop):
node=ast.UnaryOp(f,self.body)
else:
#if not isinstance(f,Expr): f=Expr(f) #not useful as applx does the reverse
return f.applx(self)
else:
if not isinstance(right,Expr):
right=Expr(right)
node = ast.BinOp(self.body,f,right.body)
return Expr(node)
[docs] def applx(self,f,var='x'):
'''function composition f o self = self(f(x))'''
if isinstance(f,Expr):
f=f.body
class Subst(ast.NodeTransformer):
def visit_Name(self, node):
if node.id==var:
return f
else:
return node
node=copy.deepcopy(self.body)
return Expr(Subst().visit(node))
[docs] def __eq__(self,other):
if math2.is_number(other):
try:
return self()==other
except:
return False
if not isinstance(other,Expr):
other=Expr(other)
return str(self())==str(other())
[docs] def __ne__(self, other):
return not self==other
[docs] def __lt__(self,other):
if math2.is_number(other):
try:
return self()<other
except:
return False
if not isinstance(other,Expr):
other=Expr(other)
return float(self())<float(other())
[docs] def __le__(self, other):
return self<other or self==other
[docs] def __ge__(self, other):
return not self<other
[docs] def __gt__(self,other):
return self>=other and not self==other
[docs] def __add__(self,right):
return self.apply(ast.Add(),right)
[docs] def __sub__(self,right):
return self.apply(ast.Sub(),right)
[docs] def __neg__(self):
return self.apply(ast.USub())
[docs] def __mul__(self,right):
return self.apply(ast.Mult(),right)
[docs] def __rmul__(self,right):
return Expr(right)*self
[docs] def __truediv__(self,right):
return self.apply(ast.Div(),right)
[docs] def __pow__(self,right):
return self.apply(ast.Pow(),right)
__div__=__truediv__
[docs] def __invert__(self):
return self.apply(ast.Invert())
[docs] def __and__(self,right):
return self.apply(ast.And(),right)
[docs] def __or__(self,right):
return self.apply(ast.Or(),right)
[docs] def __xor__(self,right):
return self.apply(ast.BitXor(),right)
[docs] def __lshift__(self,dx):
return self.applx(ast.BinOp(ast.Name('x',None),ast.Add(),ast.Num(dx)))
[docs] def __rshift__(self,dx):
return self.applx(ast.BinOp(ast.Name('x',None),ast.Sub(),ast.Num(dx)))
[docs] def complexity(self):
''' measures the complexity of Expr
:return: int, sum of the precedence of used ops
'''
def _node_complexity(node):
try:
res=self._operators[type(node.op)][1]
except AttributeError:
res=self._operators[type(node)][1]
try:
res+=_node_complexity(node.operand)
except AttributeError:
pass
try:
res+=_node_complexity(node.left)
except AttributeError:
pass
try:
res+=_node_complexity(node.right)
except AttributeError:
pass
return res
return _node_complexity(self.body)
# supported _operators with precedence and text + LaTeX repr
# precedence as in https://docs.python.org/reference/expressions.html#operator-precedence
#
import operator as op
# table of allowed operators
# note we very slightly prefer + over - and * over / for simpler expression generation
operators = {
ast.Or: (op.or_,300,' or ',' or ',' \\vee '),
ast.And: (op.and_,400,' and ',' and ',' \\wedge '),
ast.Not: (op.not_,500,'not ','not ','\\neg'),
ast.Eq: (op.eq,600,'=',' == ',' = '),
ast.Gt: (op.gt,600,' > ',' > ',' \\gtr '),
ast.GtE:(op.ge,600,' >= ',' >= ',' \\gec '),
ast.Lt: (op.lt,600,' < ',' < ',' \\ltr '),
ast.LtE: (op.le,600,' <= ',' <= ',' \\leq '),
ast.BitXor: (op.xor,800,' xor ',' xor ',' xor '),
ast.LShift: (op.lshift, 1000,' << ',' << ',' \\ll '),
ast.RShift: (op.rshift, 1000,' >> ',' >> ',' \\gg '),
ast.Add: (op.add, 1100,'+','+','+'),
ast.Sub: (op.sub, 1101,'-','-','-'),
ast.Mult: (op.mul, 1200,'*','*',' \\cdot '),
ast.Div: (op.truediv, 1201,'/','/','\\frac{%s}{%s}'),
ast.FloorDiv: (op.floordiv, 1201,'//','//','\\left\\lfloor\\frac{%s}{%s}\\right\\rfloor'),
ast.Mod: (op.mod, 1200,' mod ','%',' \\bmod '),
ast.Invert: (op.not_,1300,'~','~','\\sim '),
ast.UAdd: (op.pos,1150,'+','+','+'),
ast.USub: (op.neg,1150,'-','-','-'),
ast.Pow: (math2.pow,1400,'^','**','^'), # ipow returns an integer when result is integer ...
# precedence of other types below
ast.Call:(None,9000),
ast.Name:(None,9000),
ast.Num:(None,9000),
}
[docs]def add_function(f,s=None,r=None,l=None):
''' add a function to those allowed in Expr.
:param f: function
:param s: string representation, should be formula-like
:param r: repr representation, should be cut&pastable in a calculator, or in python ...
:param l: LaTeX representation
'''
functions[f.__name__]=(f,9999,s,r or s,l)
return functions[f.__name__]
[docs]def add_constant(c, name, s=None,r=None,l=None):
''' add a constant to those recognized in Expr.
:param c: constant
:param s: string representation, should be formula-like
:param r: repr representation, should be cut&pastable in a calculator, or in python ...
:param l: LaTeX representation
'''
constants[type(c)][c]=(None,None,s or name,r or name ,l or '\\'+name)
[docs]def add_module(module):
for fname,f in six.iteritems(module.__dict__):
if fname[0]=='_': continue
if isinstance(f, collections.Callable):
add_function(f)
elif math2.is_number(f):
add_constant(f,fname)
add_constant(True,'True')
add_constant(False,'False')
import math
add_module(math)
add_function(abs,l='\\lvert{%s}\\rvert')
add_function(math.fabs,l='\\lvert{%s}\\rvert')
add_function(math.factorial,'%s!','fact','%s!')
add_function(math2.factorial2,'%s!','fact','%s!!')
add_function(math2.sqrt,l='\\sqrt{%s}')
add_function(math.trunc,l='\\left\\lfloor{%s}\\right\\rfloor')
add_function(math.floor,l='\\left\\lfloor{%s}\\right\\rfloor')
add_function(math.ceil,l='\\left\\lceil{%s}\\right\\rceil')
add_function(math.asin,l='\\arcsin')
add_function(math.acos,l='\\arccos')
add_function(math.atan,l='\\arctan')
add_function(math.asinh,l='\\sinh^{-1}')
add_function(math.acosh,l='\\cosh^{-1}')
add_function(math.atanh,l='\\tanh^{-1}')
add_function(math.log,l='\\ln')
add_function(math.log1p,l='\\ln\\left(1-{%s}\\rvert)')
add_function(math.log10,l='\\log_{10}')
add_function(math2.log2,l='\\log_2')
add_function(math.gamma,l='\\Gamma')
add_function(math.exp,l='e^{%s}')
add_function(math.expm1,l='e^{%s}-1')
add_function(math.lgamma,'log(abs(gamma(%s)))',l='\\ln\\lvert\\Gamma\\left({%s}\\rvert)\\right)')
add_function(math.degrees,l='%s\\cdot\\frac{360}{2\\pi}')
add_function(math.radians,l='%s\\cdot\\frac{2\\pi}{360}')
add_constant(complex(0,1),'i')
#http://stackoverflow.com/questions/3867028/converting-a-python-numeric-expression-to-latex
[docs]class TextVisitor(ast.NodeVisitor):
[docs] def __init__(self,dialect,operators=operators, functions=functions):
''':param dialect: int index in _operators of symbols to use
'''
self._dialect=dialect
self._operators=operators
self._functions=functions
[docs] def prec(self, op):
''' calculate the precedence of op '''
if isinstance(op,(ast.BinOp, ast.UnaryOp)):
op=op.op
if isinstance(op,ast.Num) and math2.is_real(op.n) and op.n<0:
return self._operators[ast.USub][1]
try:
return self._operators[type(op)][1]
except KeyError:
return self._operators[type(op)][1]
def _par(self,content):
if self._dialect == _dialect_latex:
return '\\left(%s\\right)'%content
else:
return '(%s)'%content
[docs] def visit_Call(self, n):
args = r', '.join(map(self.visit, n.args))
func = self.visit(n.func)
fname = self._functions[func][self._dialect]
if fname is None:
if self._dialect == _dialect_latex:
fname = '\\'+func
else:
fname=func
if '%s' in fname:
if len(n.args)>1: #TODO: or ... what ?
args=self._par(args)
return fname%args
return fname+self._par(args)
[docs] def visit_Name(self, n):
return n.id
[docs] def visit_NameConstant(self, node):
return str(node.value)
[docs] def visit_UnaryOp(self, n):
op=self.visit(n.operand)
if self.prec(n.op) > self.prec(n.operand):
op=self._par(op)
symbol=self._operators[type(n.op)][self._dialect]
if '%s' in symbol:
return symbol%op
return symbol+op
def _Bin(self, left,op,right):
# commute x*3 in 3*x
if isinstance(op, ast.Mult):
if isinstance(right, ast.Num):
if not Expr(left).isconstant:
return self._Bin(right,op,left)
l,r = self.visit(left),self.visit(right)
symbol=self._operators[type(op)][self._dialect]
if '%s' in symbol: # no parenthesis required in this case
return symbol%(l,r)
#handle precedence (parenthesis) if needed
if self.prec(op) > self.prec(left):
l = self._par(l)
if self.prec(op) > self.prec(right):
if self._dialect == _dialect_latex and isinstance(op, ast.Pow):
r='{'+r+'}'
else:
r = self._par(r)
# remove * if possible
if self._dialect != _dialect_python and isinstance(op, ast.Mult):
if not l[-1].isdigit() or not r[0].isdigit():
symbol=''
res=l+symbol+r
# TODO: find a better way to do this ...
plusminus=self._operators[ast.Add][self._dialect]+self._operators[ast.USub][self._dialect]
minusminus=self._operators[ast.Sub][self._dialect]+self._operators[ast.USub][self._dialect]
res=res.replace(plusminus,self._operators[ast.Sub][self._dialect])
res=res.replace(minusminus,self._operators[ast.Add][self._dialect])
return res
[docs] def visit_BinOp(self, n):
return self._Bin(n.left,n.op,n.right)
[docs] def visit_Compare(self,n):
#TODO: what to do with multiple ops/comparators ?
return self._Bin(n.left,n.ops[0],n.comparators[0])
[docs] def visit_Num(self, n):
try:
d=constants[type(n.n)]
return d[n.n][self._dialect]
except KeyError:
pass
return str(math2.int_or_float(n.n))
[docs] def generic_visit(self, n):
try:
l=list(map(self.visit, n))
return ''.join(l)
except TypeError:
pass
if isinstance(n, ast.AST):
l=map(self.visit, [getattr(n, f) for f in n._fields])
return ''.join(l)
else:
return str(n)