#!/usr/bin/env python
# coding: utf8
"""
plotable rich object display on IPython/Jupyter notebooks
"""
__author__ = "Philippe Guglielmetti"
__copyright__ = "Copyright 2015, Philippe Guglielmetti"
__credits__ = []
__license__ = "LGPL"
#import matplotlib and set backend once for all
import matplotlib, os, sys, logging, six, base64
if os.getenv('TRAVIS'): # are we running https://travis-ci.org/ automated tests ?
matplotlib.use('Agg') # Force matplotlib not to use any Xwindows backend
elif sys.gettrace(): #http://stackoverflow.com/questions/333995/how-to-detect-that-python-code-is-being-executed-through-the-debugger
matplotlib.use('Agg') #because 'QtAgg' crashes python while debugging
else:
pass
# matplotlib.use('pdf') #for high quality pdf, but doesn't work for png, svg ...
logging.info('matplotlib backend is %s'%matplotlib.get_backend())
from . import itertools2
[docs]class Plot(object):
"""base class for plotable rich object display on IPython notebooks
inspired from http://nbviewer.ipython.org/github/ipython/ipython/blob/3607712653c66d63e0d7f13f073bde8c0f209ba8/docs/examples/notebooks/display_protocol.ipynb
"""
def _plot(self, ax, **kwargs):
"""abstract method, must be overriden
:param ax: `matplotlib.axis`
:return ax: `matplotlib.axis` after plot
"""
raise NotImplementedError('objects derived from plot.PLot must define a _plot method')
return ax
[docs] def render(self, fmt='svg', **kwargs):
try:
return render([self],fmt, **kwargs) # call global function
except:
raise AttributeError('%s has no %s representation'%(self,fmt))
[docs] def save(self,filename,**kwargs):
return save([self],filename, **kwargs) # call global function
# for IPython notebooks
def _repr_html_(self):
"""default rich format is svg plot"""
try:
return self._repr_svg_()
except NotImplementedError:
pass
#this returns the same as _repr_png_, but is Table compatible
buffer=self.render('png')
s=base64.b64encode(buffer).decode('utf-8')
return '<img src="data:image/png;base64,%s">'%s
[docs] def html(self,**kwargs):
from IPython.display import HTML
return HTML(self._repr_html_(**kwargs))
[docs] def svg(self,**kwargs):
from IPython.display import SVG
return SVG(self._repr_svg_(**kwargs))
def _repr_svg_(self,**kwargs):
return self.render(fmt='svg',**kwargs).decode('utf-8')
[docs] def png(self,**kwargs):
from IPython.display import Image
return Image(self._repr_png_(**kwargs), embed=True)
def _repr_png_(self,**kwargs):
return self.render(fmt='png',**kwargs)
[docs] def plot(self,**kwargs):
""" renders on IPython Notebook
(alias to make usage more straightforward)
"""
return self.svg(**kwargs)
[docs]def render(plotables, fmt='svg', **kwargs):
"""renders several Plot objects"""
import matplotlib.pyplot as plt
#extract optional arguments used for rasterization
printargs,kwargs=itertools2.dictsplit(
kwargs,
['dpi','transparent','facecolor','background','figsize']
)
ylim=kwargs.pop('ylim',None)
xlim=kwargs.pop('xlim',None)
title=kwargs.pop('title',None)
fig, ax = plt.subplots()
labels=kwargs.pop('labels',[None]*len(plotables))
offset=kwargs.pop('offset',0) #slightly shift the points to make superimposed curves more visible
for i,obj in enumerate(plotables):
if labels[i] is None:
labels[i]=str(obj)
if not title:
try:
title=obj._repr_latex_()
except:
title=labels[i]
ax = obj._plot(ax, label=labels[i], offset=i*offset, **kwargs)
if ylim: plt.ylim(ylim)
if xlim: plt.xlim(xlim)
if title: ax.set_title(title)
if len(labels)>1:
ax.legend()
output = six.BytesIO()
fig.savefig(output, format=fmt, **printargs)
data=output.getvalue()
plt.close(fig)
return data
[docs]def png(plotables, **kwargs):
from IPython.display import Image
return Image(render(plotables,'png',**kwargs), embed=True)
[docs]def svg(plotables, **kwargs):
from IPython.display import SVG
return SVG(render(plotables,'svg',**kwargs))
plot=svg
[docs]def save(plotables,filename,**kwargs):
ext=filename.split('.')[-1].lower()
kwargs.setdefault('dpi',600) #force good quality
return open(filename,'wb').write(render(plotables, ext,**kwargs))