Source code for yt.visualization.plot_container

"""
A base class for "image" plots with colorbars.



"""

#-----------------------------------------------------------------------------
# Copyright (c) 2013, yt Development Team.
#
# Distributed under the terms of the Modified BSD License.
#
# The full license is in the file COPYING.txt, distributed with this software.
#-----------------------------------------------------------------------------
from yt.extern.six.moves import builtins
from yt.extern.six import iteritems

import base64
import errno
import numpy as np
import matplotlib
import os

from functools import wraps
from matplotlib.font_manager import FontProperties

from ._mpl_imports import FigureCanvasAgg
from .tick_locators import LogLocator, LinearLocator

from yt.funcs import \
    defaultdict, get_image_suffix, \
    get_ipython_api_version, iterable, \
    ensure_list
from yt.utilities.exceptions import \
    YTNotInsideNotebook

def invalidate_data(f):
    @wraps(f)
    def newfunc(*args, **kwargs):
        rv = f(*args, **kwargs)
        args[0]._data_valid = False
        args[0]._plot_valid = False
        return rv
    return newfunc


def invalidate_figure(f):
    @wraps(f)
    def newfunc(*args, **kwargs):
        rv = f(*args, **kwargs)
        for field in args[0].plots.keys():
            args[0].plots[field].figure = None
            args[0].plots[field].axes = None
            args[0].plots[field].cax = None
        return rv
    return newfunc


def invalidate_plot(f):
    @wraps(f)
    def newfunc(*args, **kwargs):
        rv = f(*args, **kwargs)
        args[0]._plot_valid = False
        return rv
    return newfunc

def validate_plot(f):
    @wraps(f)
    def newfunc(*args, **kwargs):
        if hasattr(args[0], '_data_valid'):
            if not args[0]._data_valid:
                args[0]._recreate_frb()
        if not args[0]._plot_valid:
            args[0]._setup_plots()
            if hasattr(args[0], 'run_callbacks'):
                args[0].run_callbacks()
        rv = f(*args, **kwargs)
        return rv
    return newfunc

def apply_callback(f):
    @wraps(f)
    def newfunc(*args, **kwargs):
        #rv = f(*args[1:], **kwargs)
        args[0]._callbacks.append((f.__name__, (args, kwargs)))
        return args[0]
    return newfunc

def get_log_minorticks(vmin, vmax):
    """calculate positions of linear minorticks on a log colorbar

    Parameters
    ----------
    vmin : float
        the minimum value in the colorbar
    vmax : float
        the maximum value in the colorbar

    """
    expA = np.floor(np.log10(vmin))
    expB = np.floor(np.log10(vmax))
    cofA = np.ceil(vmin/10**expA)
    cofB = np.floor(vmax/10**expB)
    lmticks = []
    while cofA*10**expA <= cofB*10**expB:
        if expA < expB:
            lmticks = np.hstack( (lmticks, np.linspace(cofA, 9, 10-cofA)*10**expA) )
            cofA = 1
            expA += 1
        else:
            lmticks = np.hstack( (lmticks, np.linspace(cofA, cofB, cofB-cofA+1)*10**expA) )
            expA += 1
    return np.array(lmticks)

def get_symlog_minorticks(linthresh, vmin, vmax):
    """calculate positions of linear minorticks on a symmetric log colorbar

    Parameters
    ----------
    linthresh: float
        the threshold for the linear region
    vmin : float
        the minimum value in the colorbar
    vmax : float
        the maximum value in the colorbar

    """
    if vmin >= 0 or vmax <= 0:
        return get_log_minorticks(vmin, vmax)
    else:
        return np.hstack( (-get_log_minorticks(linthresh,-vmin)[::-1], 0,
                            get_log_minorticks(linthresh, vmax)) )

field_transforms = {}


class FieldTransform(object):
    def __init__(self, name, func, locator):
        self.name = name
        self.func = func
        self.locator = locator
        field_transforms[name] = self

    def __call__(self, *args, **kwargs):
        return self.func(*args, **kwargs)

    def ticks(self, mi, ma):
        try:
            ticks = self.locator(mi, ma)
        except:
            ticks = []
        return ticks

log_transform = FieldTransform('log10', np.log10, LogLocator())
linear_transform = FieldTransform('linear', lambda x: x, LinearLocator())
symlog_transform = FieldTransform('symlog', None, LogLocator())

class PlotDictionary(defaultdict):
    def __getitem__(self, item):
        return defaultdict.__getitem__(
            self, self.data_source._determine_fields(item)[0])

    def __setitem__(self, item, value):
        return defaultdict.__setitem__(
            self, self.data_source._determine_fields(item)[0], value)

    def __contains__(self, item):
        return defaultdict.__contains__(
            self, self.data_source._determine_fields(item)[0])

    def __init__(self, data_source, default_factory=None):
        self.data_source = data_source
        return defaultdict.__init__(self, default_factory)

class ImagePlotContainer(object):
    """A container for plots with colorbars.

    """
    _plot_type = None
    _plot_valid = False
    _colorbar_valid = False

    def __init__(self, data_source, figure_size, fontsize):
        self.data_source = data_source
        if iterable(figure_size):
            self.figure_size = float(figure_size[0]), float(figure_size[1])
        else:
            self.figure_size = float(figure_size)
        self.plots = PlotDictionary(data_source)
        self._callbacks = []
        self._field_transform = {}
        self._colormaps = defaultdict(lambda: 'algae')
        font_path = matplotlib.get_data_path() + '/fonts/ttf/STIXGeneral.ttf'
        self._font_properties = FontProperties(size=fontsize, fname=font_path)
        self._font_color = None
        self._xlabel = None
        self._ylabel = None
        self._minorticks = {}
        self._cbar_minorticks = {}
        self._colorbar_label = PlotDictionary(
            self.data_source, lambda: None)

    @invalidate_plot
    def set_log(self, field, log, linthresh=None):
        """set a field to log or linear.

        Parameters
        ----------
        field : string
            the field to set a transform
        log : boolean
            Log on/off.
        linthresh : float (must be positive)
            linthresh will be enabled for symlog scale only when log is true

        """
        if field == 'all':
            fields = list(self.plots.keys())
        else:
            fields = [field]
        for field in self.data_source._determine_fields(fields):
            if log:
                if linthresh is not None:
                    if not linthresh > 0.: 
                        raise ValueError('\"linthresh\" must be positive')
                    self._field_transform[field] = symlog_transform
                    self._field_transform[field].func = linthresh 
                else:
                    self._field_transform[field] = log_transform
            else:
                self._field_transform[field] = linear_transform
        return self

    def get_log(self, field):
        """get the transform type of a field.

        Parameters
        ----------
        field : string
            the field to get a transform

        """
        log = {}
        if field == 'all':
            fields = list(self.plots.keys())
        else:
            fields = [field]
        for field in self.data_source._determine_fields(fields):
            if self._field_transform[field] == log_transform:
                log[field] = True
            else:
                log[field] = False
        return log

    @invalidate_plot
    def set_transform(self, field, name):
        field = self.data_source._determine_fields(field)[0]
        if name not in field_transforms:
            raise KeyError(name)
        self._field_transform[field] = field_transforms[name]
        return self

    @invalidate_plot
    def set_cmap(self, field, cmap):
        """set the colormap for one of the fields

        Parameters
        ----------
        field : string
            the field to set the colormap
            if field == 'all', applies to all plots.
        cmap : string or tuple
            If a string, will be interpreted as name of the colormap.
            If a tuple, it is assumed to be of the form (name, type, number)
            to be used for palettable functionality. (name, type, number, bool)
            can be used to specify if a reverse colormap is to be used.

        """

        if field == 'all':
            fields = list(self.plots.keys())
        else:
            fields = [field]
        for field in self.data_source._determine_fields(fields):
            self._colorbar_valid = False
            self._colormaps[field] = cmap
        return self

    @invalidate_plot
    def set_zlim(self, field, zmin, zmax, dynamic_range=None):
        """set the scale of the colormap

        Parameters
        ----------
        field : string
            the field to set a colormap scale
            if field == 'all', applies to all plots.
        zmin : float
            the new minimum of the colormap scale. If 'min', will
            set to the minimum value in the current view.
        zmax : float
            the new maximum of the colormap scale. If 'max', will
            set to the maximum value in the current view.

        Other Parameters
        ----------------
        dynamic_range : float (default: None)
            The dynamic range of the image.
            If zmin == None, will set zmin = zmax / dynamic_range
            If zmax == None, will set zmax = zmin * dynamic_range
            When dynamic_range is specified, defaults to setting
            zmin = zmax / dynamic_range.

        """
        if field is 'all':
            fields = list(self.plots.keys())
        else:
            fields = ensure_list(field)
        for field in self.data_source._determine_fields(fields):
            myzmin = zmin
            myzmax = zmax
            if zmin == 'min':
                myzmin = self.plots[field].image._A.min()
            if zmax == 'max':
                myzmax = self.plots[field].image._A.max()
            if dynamic_range is not None:
                if zmax is None:
                    myzmax = myzmin * dynamic_range
                else:
                    myzmin = myzmax / dynamic_range

            self.plots[field].zmin = myzmin
            self.plots[field].zmax = myzmax
        return self

    @invalidate_plot
    def set_minorticks(self, field, state):
        """turn minor ticks on or off in the current plot

        Displaying minor ticks reduces performance; turn them off
        using set_minorticks('all', 'off') if drawing speed is a problem.

        Parameters
        ----------
        field : string
            the field to remove minorticks
        state : string
            the state indicating 'on' or 'off'

        """
        if field == 'all':
            fields = list(self.plots.keys())
        else:
            fields = [field]
        for field in self.data_source._determine_fields(fields):
            if state == 'on':
                self._minorticks[field] = True
            else:
                self._minorticks[field] = False
        return self

    @invalidate_plot
    def set_cbar_minorticks(self, field, state):
        """turn colorbar minor ticks on or off in the current plot

        Displaying minor ticks reduces performance; turn them off 
        using set_cbar_minorticks('all', 'off') if drawing speed is a problem.

        Parameters
        ----------
        field : string
            the field to remove colorbar minorticks
        state : string
            the state indicating 'on' or 'off'

        """
        if field == 'all':
            fields = list(self.plots.keys())
        else:
            fields = [field]
        for field in self.data_source._determine_fields(fields):
            if state == 'on':
                self._cbar_minorticks[field] = True
            else:
                self._cbar_minorticks[field] = False
        return self

    def _setup_plots(self):
        # Left blank to be overriden in subclasses
        pass

    def _switch_ds(self, new_ds, data_source=None):
        old_object = self.data_source
        name = old_object._type_name
        kwargs = dict((n, getattr(old_object, n))
                      for n in old_object._con_args)
        kwargs['center'] = getattr(old_object, 'center', None)
        if data_source is not None:
            if name != "proj":
                raise RuntimeError("The data_source keyword argument "
                                   "is only defined for projections.")
            kwargs['data_source'] = data_source
        new_object = getattr(new_ds, name)(**kwargs)
        self.ds = new_ds
        self.data_source = new_object
        self._data_valid = self._plot_valid = False
        for d in 'xyz':
            lim_name = d+'lim'
            if hasattr(self, lim_name):
                lim = getattr(self, lim_name)
                lim = tuple(new_ds.quan(l.value, str(l.units)) for l in lim)
                setattr(self, lim_name, lim)
        self._setup_plots()

    @validate_plot
    def __getitem__(self, item):
        return self.plots[item]

    def _set_font_properties(self):
        for f in self.plots:
            ax = self.plots[f].axes
            cbax = self.plots[f].cb.ax
            labels = ax.xaxis.get_ticklabels() + ax.yaxis.get_ticklabels()
            labels += cbax.yaxis.get_ticklabels()
            labels += [ax.title, ax.xaxis.label, ax.yaxis.label,
                       cbax.yaxis.label, cbax.yaxis.get_offset_text(),
                       ax.xaxis.get_offset_text(), ax.yaxis.get_offset_text()]
            for label in labels:
                label.set_fontproperties(self._font_properties)
                if self._font_color is not None:
                    label.set_color(self._font_color)

    @invalidate_plot
    @invalidate_figure
    def set_font(self, font_dict=None):
        """set the font and font properties

        Parameters
        ----------
        font_dict : dict
        A dict of keyword parameters to be passed to
        :py:class:`matplotlib.font_manager.FontProperties`.

        Possible keys include
        * family - The font family. Can be serif, sans-serif, cursive, 'fantasy' or
          'monospace'.
        * style - The font style. Either normal, italic or oblique.
        * color - A valid color string like 'r', 'g', 'red', 'cobalt', and
          'orange'.
        * variant: Either normal or small-caps.
        * size: Either an relative value of xx-small, x-small, small, medium,
          large, x-large, xx-large or an absolute font size, e.g. 12
        * stretch: A numeric value in the range 0-1000 or one of
          ultra-condensed, extra-condensed, condensed, semi-condensed, normal,
          semi-expanded, expanded, extra-expanded or ultra-expanded
        * weight: A numeric value in the range 0-1000 or one of ultralight,
          light, normal, regular, book, medium, roman, semibold, demibold, demi,
          bold, heavy, extra bold, or black

        See the matplotlib font manager API documentation for more details.
        http://matplotlib.org/api/font_manager_api.html

        Notes
        -----
        Mathtext axis labels will only obey the `size` and `color` keyword.

        Examples
        --------
        This sets the font to be 24-pt, blue, sans-serif, italic, and
        bold-face.

        >>> slc = SlicePlot(ds, 'x', 'Density')
        >>> slc.set_font({'family':'sans-serif', 'style':'italic',
                          'weight':'bold', 'size':24, 'color':'blue'})

        """
        if font_dict is None:
            font_dict = {}
        if 'color' in font_dict:
            self._font_color = font_dict.pop('color')
        # Set default values if the user does not explicitly set them.
        # this prevents reverting to the matplotlib defaults.
        font_dict.setdefault('family', 'stixgeneral')
        font_dict.setdefault('size', 18)
        self._font_properties = \
            FontProperties(**font_dict)
        return self

    def set_font_size(self, size):
        """Set the size of the font used in the plot

        This sets the font size by calling the set_font function.  See set_font
        for more font customization options.

        Parameters
        ----------
        size : float
        The absolute size of the font in points (1 pt = 1/72 inch).

        """
        return self.set_font({'size': size})

    @invalidate_plot
    @invalidate_figure
    def set_figure_size(self, size):
        """Sets a new figure size for the plot

        parameters
        ----------
        size : float
            The size of the figure on the longest axis (in units of inches),
            including the margins but not the colorbar.
        """
        self.figure_size = float(size)
        return self

    @validate_plot
    def save(self, name=None, suffix=None, mpl_kwargs=None):
        """saves the plot to disk.

        Parameters
        ----------
        name : string
           The base of the filename.  If name is a directory or if name is not
           set, the filename of the dataset is used.
        suffix : string
           Specify the image type by its suffix. If not specified, the output
           type will be inferred from the filename. Defaults to PNG.
        mpl_kwargs : dict
           A dict of keyword arguments to be passed to matplotlib.

        >>> slc.save(mpl_kwargs={'bbox_inches':'tight'})

        """
        names = []
        if mpl_kwargs is None: mpl_kwargs = {}
        if name is None:
            name = str(self.ds)
        name = os.path.expanduser(name)
        if name[-1] == os.sep and not os.path.isdir(name):
            try:
                os.mkdir(name)
            except OSError as e:
                if e.errno == errno.EEXIST:
                    pass
                else:
                    raise
        if os.path.isdir(name) and name != str(self.ds):
            name = name + (os.sep if name[-1] != os.sep else '') + str(self.ds)
        if suffix is None:
            suffix = get_image_suffix(name)
            if suffix != '':
                for k, v in iteritems(self.plots):
                    names.append(v.save(name, mpl_kwargs))
                return names
        axis = self.ds.coordinates.axis_name.get(
            self.data_source.axis, '')
        weight = None
        type = self._plot_type
        if type in ['Projection', 'OffAxisProjection']:
            weight = self.data_source.weight_field
            if weight is not None:
                weight = weight[1].replace(' ', '_')
        if 'Cutting' in self.data_source.__class__.__name__:
            type = 'OffAxisSlice'
        for k, v in iteritems(self.plots):
            if isinstance(k, tuple):
                k = k[1]
            if axis:
                n = "%s_%s_%s_%s" % (name, type, axis, k.replace(' ', '_'))
            else:
                # for cutting planes
                n = "%s_%s_%s" % (name, type, k.replace(' ', '_'))
            if weight:
                n += "_%s" % (weight)
            if suffix != '':
                n = ".".join([n,suffix])
            names.append(v.save(n, mpl_kwargs))
        return names

    @invalidate_data
    def refresh(self):
        # invalidate_data will take care of everything
        return self

    @validate_plot
    def _send_zmq(self):
        try:
            # pre-IPython v1.0
            from IPython.zmq.pylab.backend_inline import send_figure as display
        except ImportError:
            # IPython v1.0+
            from IPython.core.display import display
        for k, v in sorted(iteritems(self.plots)):
            # Due to a quirk in the matplotlib API, we need to create
            # a dummy canvas variable here that is never used.
            canvas = FigureCanvasAgg(v.figure)  # NOQA
            display(v.figure)

    @validate_plot
    def show(self):
        r"""This will send any existing plots to the IPython notebook.

        If yt is being run from within an IPython session, and it is able to
        determine this, this function will send any existing plots to the
        notebook for display.

        If yt can't determine if it's inside an IPython session, it will raise
        YTNotInsideNotebook.

        Examples
        --------

        >>> from yt.mods import SlicePlot
        >>> slc = SlicePlot(ds, "x", ["Density", "VelocityMagnitude"])
        >>> slc.show()

        """
        if "__IPYTHON__" in dir(builtins):
            api_version = get_ipython_api_version()
            if api_version in ('0.10', '0.11'):
                self._send_zmq()
            else:
                from IPython.display import display
                display(self)
        else:
            raise YTNotInsideNotebook

    @validate_plot
    def display(self, name=None, mpl_kwargs=None):
        """Will attempt to show the plot in in an IPython notebook.
        Failing that, the plot will be saved to disk."""
        try:
            return self.show()
        except YTNotInsideNotebook:
            return self.save(name=name, mpl_kwargs=mpl_kwargs)

    @validate_plot
    def _repr_html_(self):
        """Return an html representation of the plot object. Will display as a
        png for each WindowPlotMPL instance in self.plots"""
        ret = ''
        for field in self.plots:
            img = base64.b64encode(self.plots[field]._repr_png_()).decode()
            ret += r'<img style="max-width:100%%;max-height:100%%;" ' \
                   r'src="data:image/png;base64,{0}"><br>'.format(img)
        return ret

    @invalidate_plot
    def set_xlabel(self, label):
        r"""
        Allow the user to modify the X-axis title
        Defaults to the global value. Fontsize defaults
        to 18.

        Parameters
        ----------
        x_title: str
              The new string for the x-axis.

        >>>  plot.set_xlabel("H2I Number Density (cm$^{-3}$)")

        """
        self._xlabel = label
        return self

    @invalidate_plot
    def set_ylabel(self, label):
        r"""
        Allow the user to modify the Y-axis title
        Defaults to the global value.

        Parameters
        ----------
        label: str
          The new string for the y-axis.

        >>>  plot.set_ylabel("Temperature (K)")

        """
        self._ylabel = label
        return self

    @invalidate_plot
    def set_colorbar_label(self, field, label):
        r"""
        Sets the colorbar label.

        Parameters
        ----------
        field: str or tuple
          The name of the field to modify the label for.
        label: str
          The new label

        >>>  plot.set_colorbar_label("density", "Dark Matter Density (g cm$^{-3}$)")

        """
        self._colorbar_label[field] = label
        return self

    def _get_axes_labels(self, field):
        return(self._xlabel, self._ylabel, self._colorbar_label[field])