Source code for yt.visualization.plot_modifications

"""

Callbacks to add additional functionality on to plots.



"""
from __future__ import absolute_import
from yt.extern.six import string_types

#-----------------------------------------------------------------------------
# Copyright (c) 2013, yt Development Team.
#
# Distributed under the terms of the Modified BSD License.
#
# The full license is in the file COPYING.txt, distributed with this software.
#-----------------------------------------------------------------------------

import numpy as np
import h5py

from distutils.version import LooseVersion

from matplotlib.patches import Circle
from matplotlib.colors import colorConverter
from mpl_toolkits.axes_grid1.anchored_artists import AnchoredSizeBar

from yt.funcs import *
from yt.extern.six import add_metaclass
from ._mpl_imports import *
from yt.utilities.physical_constants import \
    sec_per_Gyr, sec_per_Myr, \
    sec_per_kyr, sec_per_year, \
    sec_per_day, sec_per_hr
from yt.units.yt_array import YTQuantity, YTArray
from yt.visualization.image_writer import apply_colormap
from yt.utilities.lib.geometry_utils import triangle_plane_intersect
from yt.analysis_modules.cosmological_observation.light_ray.light_ray \
     import periodic_ray
import warnings

from . import _MPL

callback_registry = {}

class RegisteredCallback(type):
    def __init__(cls, name, b, d):
        type.__init__(cls, name, b, d)
        callback_registry[name] = cls

@add_metaclass(RegisteredCallback)
class PlotCallback(object):
    def __init__(self, *args, **kwargs):
        pass

    def project_coords(self, plot, coord):
        """
        Convert coordinates from simulation data coordinates to projected
        data coordinates.  Simulation data coordinates are three dimensional,
        and can either be specified as a YTArray or as a list or array in
        code_length units.  Projected data units are 2D versions of the
        simulation data units relative to the axes of the final plot.
        """
        if len(coord) == 3:
            if not isinstance(coord, YTArray):
                coord = plot.data.ds.arr(coord, 'code_length')
            ax = plot.data.axis
            # if this is an on-axis projection or slice, then
            # just grab the appropriate 2 coords for the on-axis view
            if ax >= 0 and ax <= 2:
                (xi, yi) = (plot.data.ds.coordinates.x_axis[ax],
                            plot.data.ds.coordinates.y_axis[ax])
                coord = (coord[xi], coord[yi])

            # if this is an off-axis project or slice (ie cutting plane)
            # we have to calculate where the data coords fall in the projected
            # plane
            elif ax == 4:
                coord_vectors = coord - plot.data.center
                x = np.dot(coord_vectors, plot.data.orienter.unit_vectors[1])
                y = np.dot(coord_vectors, plot.data.orienter.unit_vectors[0])
                # Transpose into image coords. Due to VR being not a
                # right-handed coord system
                coord = (y, x)
            else:
                raise SyntaxError("Object being plot must have a `data.axis` "
                                  "defined")

        # if the position is already two-coords, it is expected to be
        # in the proper projected orientation
        else:
            raise SyntaxError("'data' coordinates must be 3 dimensions")
        return coord

    def convert_to_plot(self, plot, coord, offset=True):
        """
        Convert coordinates from projected data coordinates to PlotWindow 
        plot coordinates.  Projected data coordinates are two dimensional
        and refer to the location relative to the specific axes being plotted,
        although still in simulation units.  PlotWindow plot coordinates
        are locations as found in the final plot, usually with the origin
        in the center of the image and the extent of the image defined by
        the final plot axis markers.
        """
        # coord should be a 2 x ncoord array-like datatype.
        try:
            ncoord = np.array(coord).shape[1]
        except IndexError:
            ncoord = 1

        # Convert the data and plot limits to tiled numpy arrays so that
        # convert_to_plot is automatically vectorized.

        x0 = np.array(np.tile(plot.xlim[0],ncoord))
        x1 = np.array(np.tile(plot.xlim[1],ncoord))
        x2 = np.array([0, 1])
        xx0 = np.tile(plot._axes.get_xlim()[0],ncoord)
        xx1 = np.tile(plot._axes.get_xlim()[1],ncoord)

        y0 = np.array(np.tile(plot.ylim[0],ncoord))
        y1 = np.array(np.tile(plot.ylim[1],ncoord))
        y2 = np.array([0, 1])
        yy0 = np.tile(plot._axes.get_ylim()[0],ncoord)
        yy1 = np.tile(plot._axes.get_ylim()[1],ncoord)

        ccoord = np.array(coord)

        # We need a special case for when we are only given one coordinate.
        if ccoord.shape == (2,):
            return ((ccoord[0]-x0)/(x1-x0)*(xx1-xx0) + xx0,
                    (ccoord[1]-y0)/(y1-y0)*(yy1-yy0) + yy0)
        else:
            return ((ccoord[0][:]-x0)/(x1-x0)*(xx1-xx0) + xx0,
                    (ccoord[1][:]-y0)/(y1-y0)*(yy1-yy0) + yy0)

    def sanitize_coord_system(self, plot, coord, coord_system):
        """
        Given a set of x,y (and z) coordinates and a coordinate system, 
        convert the coordinates (and transformation) ready for final plotting.

        Coordinate systems
        ------------------

        data : 3D data coordinates relative to original dataset

        plot : 2D coordinates as defined by the final axis locations

        axis : 2D coordinates within the axis object from (0,0) in lower left 
               to (1,1) in upper right.  Same as matplotlib axis coords.

        figure : 2D coordinates within figure object from (0,0) in lower left 
                 to (1,1) in upper right.  Same as matplotlib figure coords.
        """
        # if in data coords, project them to plot coords
        if coord_system == "data":
            if len(coord) < 3:
                raise SyntaxError("Coordinates in 'data' coordinate system " 
                                  "need to be in 3D")
            coord = self.project_coords(plot, coord)
            coord = self.convert_to_plot(plot, coord)
        # if in plot coords, define the transform correctly
        if coord_system == "data" or coord_system == "plot":
            self.transform = plot._axes.transData
            return coord
        # if in axis coords, define the transform correctly
        if coord_system == "axis":
            self.transform = plot._axes.transAxes
            if len(coord) > 2:
                raise SyntaxError("Coordinates in 'axis' coordinate system " 
                                  "need to be in 2D")
            return coord
        # if in figure coords, define the transform correctly
        elif coord_system == "figure":
            self.transform = plot._figure.transFigure
            return coord
        else:
            raise SyntaxError("Argument coord_system must have a value of "
                              "'data', 'plot', 'axis', or 'figure'.")

    def pixel_scale(self, plot):
        x0, x1 = np.array(plot.xlim)
        xx0, xx1 = plot._axes.get_xlim()
        dx = (xx1 - xx0)/(x1 - x0)

        y0, y1 = np.array(plot.ylim)
        yy0, yy1 = plot._axes.get_ylim()
        dy = (yy1 - yy0)/(y1 - y0)

        return (dx,dy)

    def _set_font_properties(self, plot, labels, **kwargs):
        """
        This sets all of the text instances created by a callback to have
        the same font size and properties as all of the other fonts in the
        figure.  If kwargs are set, they override the defaults.
        """
        # This is a little messy because there is no trivial way to update
        # a MPL.font_manager.FontProperties object with new attributes
        # aside from setting them individually.  So we pick out the relevant
        # MPL.Text() kwargs from the local kwargs and let them override the
        # defaults.
        local_font_properties = plot.font_properties.copy()

        # Turn off the default TT font file, otherwise none of this works.
        local_font_properties.set_file(None)
        local_font_properties.set_family('stixgeneral')

        if 'family' in kwargs: 
            local_font_properties.set_family(kwargs['family'])
        if 'file' in kwargs: 
            local_font_properties.set_file(kwargs['file'])
        if 'fontconfig_pattern' in kwargs: 
            local_font_properties.set_fontconfig_pattern(kwargs['fontconfig_pattern'])
        if 'name' in kwargs: 
            local_font_properties.set_name(kwargs['name'])
        if 'size' in kwargs: 
            local_font_properties.set_size(kwargs['size'])
        if 'slant' in kwargs: 
            local_font_properties.set_slant(kwargs['slant'])
        if 'stretch' in kwargs: 
            local_font_properties.set_stretch(kwargs['stretch'])
        if 'style' in kwargs: 
            local_font_properties.set_style(kwargs['style'])
        if 'variant' in kwargs: 
            local_font_properties.set_variant(kwargs['variant'])
        if 'weight' in kwargs: 
            local_font_properties.set_weight(kwargs['weight'])

        # For each label, set the font properties and color to the figure
        # defaults if not already set in the callback itself
        for label in labels:
            if plot.font_color is not None and not 'color' in kwargs:
                label.set_color(plot.font_color)
            label.set_fontproperties(local_font_properties)

[docs]class VelocityCallback(PlotCallback): """ annotate_velocity(factor=16, scale=None, scale_units=None, normalize=False): Adds a 'quiver' plot of velocity to the plot, skipping all but every *factor* datapoint. *scale* is the data units per arrow length unit using *scale_units* (see matplotlib.axes.Axes.quiver for more info). if *normalize* is True, the velocity fields will be scaled by their local (in-plane) length, allowing morphological features to be more clearly seen for fields with substantial variation in field strength (normalize is not implemented and thus ignored for Cutting Planes). """ _type_name = "velocity"
[docs] def __init__(self, factor=16, scale=None, scale_units=None, normalize=False): PlotCallback.__init__(self) self.factor = factor self.scale = scale self.scale_units = scale_units self.normalize = normalize
def __call__(self, plot): # Instantiation of these is cheap if plot._type_name == "CuttingPlane": qcb = CuttingQuiverCallback("cutting_plane_velocity_x", "cutting_plane_velocity_y", self.factor) else: ax = plot.data.axis (xi, yi) = (plot.data.ds.coordinates.x_axis[ax], plot.data.ds.coordinates.y_axis[ax]) axis_names = plot.data.ds.coordinates.axis_name xv = "velocity_%s" % (axis_names[xi]) yv = "velocity_%s" % (axis_names[yi]) bv = plot.data.get_field_parameter("bulk_velocity") if bv is not None: bv_x = bv[xi] bv_y = bv[yi] else: bv_x = bv_y = YTQuantity(0, 'cm/s') qcb = QuiverCallback(xv, yv, self.factor, scale=self.scale, scale_units=self.scale_units, normalize=self.normalize, bv_x=bv_x, bv_y=bv_y) return qcb(plot)
[docs]class MagFieldCallback(PlotCallback): """ annotate_magnetic_field(factor=16, scale=None, scale_units=None, normalize=False): Adds a 'quiver' plot of magnetic field to the plot, skipping all but every *factor* datapoint. *scale* is the data units per arrow length unit using *scale_units* (see matplotlib.axes.Axes.quiver for more info). if *normalize* is True, the magnetic fields will be scaled by their local (in-plane) length, allowing morphological features to be more clearly seen for fields with substantial variation in field strength. """ _type_name = "magnetic_field"
[docs] def __init__(self, factor=16, scale=None, scale_units=None, normalize=False): PlotCallback.__init__(self) self.factor = factor self.scale = scale self.scale_units = scale_units self.normalize = normalize
def __call__(self, plot): # Instantiation of these is cheap if plot._type_name == "CuttingPlane": qcb = CuttingQuiverCallback("cutting_plane_bx", "cutting_plane_by", self.factor) else: xax = plot.data.ds.coordinates.x_axis[plot.data.axis] yax = plot.data.ds.coordinates.y_axis[plot.data.axis] axis_names = plot.data.ds.coordinates.axis_name xv = "magnetic_field_%s" % (axis_names[xax]) yv = "magnetic_field_%s" % (axis_names[yax]) qcb = QuiverCallback(xv, yv, self.factor, scale=self.scale, scale_units=self.scale_units, normalize=self.normalize) return qcb(plot)
[docs]class QuiverCallback(PlotCallback): """ annotate_quiver(field_x, field_y, factor=16, scale=None, scale_units=None, normalize=False, bv_x=0, bv_y=0): Adds a 'quiver' plot to any plot, using the *field_x* and *field_y* from the associated data, skipping every *factor* datapoints *scale* is the data units per arrow length unit using *scale_units* (see matplotlib.axes.Axes.quiver for more info) """ _type_name = "quiver"
[docs] def __init__(self, field_x, field_y, factor=16, scale=None, scale_units=None, normalize=False, bv_x=0, bv_y=0): PlotCallback.__init__(self) self.field_x = field_x self.field_y = field_y self.bv_x = bv_x self.bv_y = bv_y self.factor = factor self.scale = scale self.scale_units = scale_units self.normalize = normalize
def __call__(self, plot): x0, x1 = plot.xlim y0, y1 = plot.ylim xx0, xx1 = plot._axes.get_xlim() yy0, yy1 = plot._axes.get_ylim() plot._axes.hold(True) nx = plot.image._A.shape[0] / self.factor ny = plot.image._A.shape[1] / self.factor # periodicity ax = plot.data.axis ds = plot.data.ds (xi, yi) = (ds.coordinates.x_axis[ax], ds.coordinates.y_axis[ax]) period_x = ds.domain_width[xi] period_y = ds.domain_width[yi] periodic = int(any(ds.periodicity)) fv_x = plot.data[self.field_x] if self.bv_x != 0.0: # Workaround for 0.0 without units fv_x -= self.bv_x fv_y = plot.data[self.field_y] if self.bv_y != 0.0: # Workaround for 0.0 without units fv_y -= self.bv_y pixX = _MPL.Pixelize(plot.data['px'], plot.data['py'], plot.data['pdx'], plot.data['pdy'], fv_x, int(nx), int(ny), (x0, x1, y0, y1), 0, # bounds, antialias (period_x, period_y), periodic, ).transpose() pixY = _MPL.Pixelize(plot.data['px'], plot.data['py'], plot.data['pdx'], plot.data['pdy'], fv_y, int(nx), int(ny), (x0, x1, y0, y1), 0, # bounds, antialias (period_x, period_y), periodic, ).transpose() X,Y = np.meshgrid(np.linspace(xx0,xx1,nx,endpoint=True), np.linspace(yy0,yy1,ny,endpoint=True)) if self.normalize: nn = np.sqrt(pixX**2 + pixY**2) pixX /= nn pixY /= nn plot._axes.quiver(X,Y, pixX, pixY, scale=self.scale, scale_units=self.scale_units) plot._axes.set_xlim(xx0,xx1) plot._axes.set_ylim(yy0,yy1) plot._axes.hold(False)
[docs]class ContourCallback(PlotCallback): """ annotate_contour(field, ncont=5, factor=4, take_log=None, clim=None, plot_args=None, label=False, text_args=None, data_source=None): Add contours in *field* to the plot. *ncont* governs the number of contours generated, *factor* governs the number of points used in the interpolation, *take_log* governs how it is contoured and *clim* gives the (upper, lower) limits for contouring. An alternate data source can be specified with *data_source*, but by default the plot's data source will be queried. """ _type_name = "contour"
[docs] def __init__(self, field, ncont=5, factor=4, clim=None, plot_args=None, label=False, take_log=None, label_args=None, text_args=None, data_source=None): PlotCallback.__init__(self) def_plot_args = {'color':'k'} def_text_args = {'color':'w'} self.ncont = ncont self.field = field self.factor = factor self.clim = clim self.take_log = take_log if plot_args is None: plot_args = def_plot_args self.plot_args = plot_args self.label = label if label_args is not None: text_args = label_args warnings.warn("The label_args keyword is deprecated. Please use " "the text_args keyword instead.") if text_args is None: text_args = def_text_args self.text_args = text_args self.data_source = data_source
def __call__(self, plot): # These need to be in code_length x0, x1 = (v.in_units("code_length") for v in plot.xlim) y0, y1 = (v.in_units("code_length") for v in plot.ylim) # These are in plot coordinates, which may not be code coordinates. xx0, xx1 = plot._axes.get_xlim() yy0, yy1 = plot._axes.get_ylim() plot._axes.hold(True) numPoints_x = plot.image._A.shape[0] numPoints_y = plot.image._A.shape[1] # Multiply by dx and dy to go from data->plot dx = (xx1 - xx0) / (x1-x0) dy = (yy1 - yy0) / (y1-y0) # We want xi, yi in plot coordinates xi, yi = np.mgrid[xx0:xx1:numPoints_x/(self.factor*1j), yy0:yy1:numPoints_y/(self.factor*1j)] data = self.data_source or plot.data if plot._type_name in ['CuttingPlane','Projection','Slice']: if plot._type_name == 'CuttingPlane': x = data["px"]*dx y = data["py"]*dy z = data[self.field] elif plot._type_name in ['Projection','Slice']: #Makes a copy of the position fields "px" and "py" and adds the #appropriate shift to the copied field. AllX = np.zeros(data["px"].size, dtype='bool') AllY = np.zeros(data["py"].size, dtype='bool') XShifted = data["px"].copy() YShifted = data["py"].copy() dom_x, dom_y = plot._period for shift in np.mgrid[-1:1:3j]: xlim = ((data["px"] + shift*dom_x >= x0) & (data["px"] + shift*dom_x <= x1)) ylim = ((data["py"] + shift*dom_y >= y0) & (data["py"] + shift*dom_y <= y1)) XShifted[xlim] += shift * dom_x YShifted[ylim] += shift * dom_y AllX |= xlim AllY |= ylim # At this point XShifted and YShifted are the shifted arrays of # position data in data coordinates wI = (AllX & AllY) # This converts XShifted and YShifted into plot coordinates x = ((XShifted[wI]-x0)*dx).ndarray_view() + xx0 y = ((YShifted[wI]-y0)*dy).ndarray_view() + yy0 z = data[self.field][wI] # Both the input and output from the triangulator are in plot # coordinates if LooseVersion(matplotlib.__version__) < LooseVersion("1.4.0"): from matplotlib.delaunay.triangulate import Triangulation as \ triang zi = triang(x,y).nn_interpolator(z)(xi,yi) else: from matplotlib.tri import Triangulation, LinearTriInterpolator triangulation = Triangulation(x, y) zi = LinearTriInterpolator(triangulation, z)(xi,yi) elif plot._type_name == 'OffAxisProjection': zi = plot.frb[self.field][::self.factor,::self.factor].transpose() if self.take_log is None: field = data._determine_fields([self.field])[0] self.take_log = plot.ds._get_field_info(*field).take_log if self.take_log: zi=np.log10(zi) if self.take_log and self.clim is not None: self.clim = (np.log10(self.clim[0]), np.log10(self.clim[1])) if self.clim is not None: self.ncont = np.linspace(self.clim[0], self.clim[1], self.ncont) cset = plot._axes.contour(xi,yi,zi,self.ncont, **self.plot_args) plot._axes.set_xlim(xx0,xx1) plot._axes.set_ylim(yy0,yy1) plot._axes.hold(False) if self.label: plot._axes.clabel(cset, **self.text_args)
[docs]class GridBoundaryCallback(PlotCallback): """ annotate_grids(alpha=0.7, min_pix=1, min_pix_ids=20, draw_ids=False, periodic=True, min_level=None, max_level=None, cmap='B-W LINEAR_r', edgecolors=None, linewidth=1.0): Draws grids on an existing PlotWindow object. Adds grid boundaries to a plot, optionally with alpha-blending. By default, colors different levels of grids with different colors going from white to black, but you can change to any arbitrary colormap with cmap keyword, to all black grid edges for all levels with cmap=None and edgecolors=None, or to an arbitrary single color for grid edges with edgecolors='YourChosenColor' defined in any of the standard ways (e.g., edgecolors='white', edgecolors='r', edgecolors='#00FFFF', or edgecolor='0.3', where the last is a float in 0-1 scale indicating gray). Note that setting edgecolors overrides cmap if you have both set to non-None values. Cutoff for display is at min_pix wide. draw_ids puts the grid id in the corner of the grid. (Not so great in projections...). One can set min and maximum level of grids to display, and can change the linewidth of the displayed grids. """ _type_name = "grids"
[docs] def __init__(self, alpha=0.7, min_pix=1, min_pix_ids=20, draw_ids=False, periodic=True, min_level=None, max_level=None, cmap='B-W LINEAR_r', edgecolors=None, linewidth=1.0): PlotCallback.__init__(self) self.alpha = alpha self.min_pix = min_pix self.min_pix_ids = min_pix_ids self.draw_ids = draw_ids # put grid numbers in the corner. self.periodic = periodic self.min_level = min_level self.max_level = max_level self.linewidth = linewidth self.cmap = cmap self.edgecolors = edgecolors
def __call__(self, plot): x0, x1 = plot.xlim y0, y1 = plot.ylim xx0, xx1 = plot._axes.get_xlim() yy0, yy1 = plot._axes.get_ylim() (dx, dy) = self.pixel_scale(plot) (xpix, ypix) = plot.image._A.shape ax = plot.data.axis px_index = plot.data.ds.coordinates.x_axis[ax] py_index = plot.data.ds.coordinates.y_axis[ax] DW = plot.data.ds.domain_width if self.periodic: pxs, pys = np.mgrid[-1:1:3j,-1:1:3j] else: pxs, pys = np.mgrid[0:0:1j,0:0:1j] GLE, GRE, levels = [], [], [] for block, mask in plot.data.blocks: GLE.append(block.LeftEdge.in_units("code_length")) GRE.append(block.RightEdge.in_units("code_length")) levels.append(block.Level) if len(GLE) == 0: return # Retain both units and registry GLE = YTArray(GLE, input_units = GLE[0].units) GRE = YTArray(GRE, input_units = GRE[0].units) levels = np.array(levels) min_level = self.min_level or 0 max_level = self.max_level or levels.max() # sorts the three arrays in order of ascending level - this makes images look nicer new_indices = np.argsort(levels) levels = levels[new_indices] GLE = GLE[new_indices] GRE = GRE[new_indices] for px_off, py_off in zip(pxs.ravel(), pys.ravel()): pxo = px_off * DW[px_index] pyo = py_off * DW[py_index] left_edge_x = np.array((GLE[:,px_index]+pxo-x0)*dx) + xx0 left_edge_y = np.array((GLE[:,py_index]+pyo-y0)*dy) + yy0 right_edge_x = np.array((GRE[:,px_index]+pxo-x0)*dx) + xx0 right_edge_y = np.array((GRE[:,py_index]+pyo-y0)*dy) + yy0 xwidth = xpix * (right_edge_x - left_edge_x) / (xx1 - xx0) ywidth = ypix * (right_edge_y - left_edge_y) / (yy1 - yy0) visible = np.logical_and( np.logical_and(xwidth > self.min_pix, ywidth > self.min_pix), np.logical_and(levels >= min_level, levels <= max_level)) # Grids can either be set by edgecolors OR a colormap. if self.edgecolors is not None: edgecolors = colorConverter.to_rgba( self.edgecolors, alpha=self.alpha) else: # use colormap if not explicity overridden by edgecolors if self.cmap is not None: color_bounds = [0,plot.data.ds.index.max_level] edgecolors = apply_colormap( levels[visible]*1.0, color_bounds=color_bounds, cmap_name=self.cmap)[0,:,:]*1.0/255. edgecolors[:,3] = self.alpha else: edgecolors = (0.0,0.0,0.0,self.alpha) if visible.nonzero()[0].size == 0: continue verts = np.array( [(left_edge_x, left_edge_x, right_edge_x, right_edge_x), (left_edge_y, right_edge_y, right_edge_y, left_edge_y)]) verts=verts.transpose()[visible,:,:] grid_collection = matplotlib.collections.PolyCollection( verts, facecolors="none", edgecolors=edgecolors, linewidth=self.linewidth) plot._axes.hold(True) plot._axes.add_collection(grid_collection) if self.draw_ids: visible_ids = np.logical_and( np.logical_and(xwidth > self.min_pix_ids, ywidth > self.min_pix_ids), np.logical_and(levels >= min_level, levels <= max_level)) active_ids = np.unique(plot.data['grid_indices']) for i in np.where(visible_ids)[0]: plot._axes.text( left_edge_x[i] + (2 * (xx1 - xx0) / xpix), left_edge_y[i] + (2 * (yy1 - yy0) / ypix), "%d" % active_ids[i], clip_on=True) plot._axes.hold(False)
[docs]class StreamlineCallback(PlotCallback): """ annotate_streamlines(field_x, field_y, factor=16, density=1, plot_args=None): Add streamlines to any plot, using the *field_x* and *field_y* from the associated data, skipping every *factor* datapoints like 'quiver'. *density* is the index of the amount of the streamlines. """ _type_name = "streamlines"
[docs] def __init__(self, field_x, field_y, factor = 16, density = 1, plot_args=None): PlotCallback.__init__(self) def_plot_args = {} self.field_x = field_x self.field_y = field_y self.factor = factor self.dens = density if plot_args is None: plot_args = def_plot_args self.plot_args = plot_args
def __call__(self, plot): x0, x1 = plot.xlim y0, y1 = plot.ylim xx0, xx1 = plot._axes.get_xlim() yy0, yy1 = plot._axes.get_ylim() plot._axes.hold(True) nx = plot.image._A.shape[0] / self.factor ny = plot.image._A.shape[1] / self.factor pixX = _MPL.Pixelize(plot.data['px'], plot.data['py'], plot.data['pdx'], plot.data['pdy'], plot.data[self.field_x], int(nx), int(ny), (x0, x1, y0, y1),).transpose() pixY = _MPL.Pixelize(plot.data['px'], plot.data['py'], plot.data['pdx'], plot.data['pdy'], plot.data[self.field_y], int(nx), int(ny), (x0, x1, y0, y1),).transpose() X,Y = (np.linspace(xx0,xx1,nx,endpoint=True), np.linspace(yy0,yy1,ny,endpoint=True)) streamplot_args = {'x': X, 'y': Y, 'u':pixX, 'v': pixY, 'density': self.dens} streamplot_args.update(self.plot_args) plot._axes.streamplot(**streamplot_args) plot._axes.set_xlim(xx0,xx1) plot._axes.set_ylim(yy0,yy1) plot._axes.hold(False)
[docs]class LinePlotCallback(PlotCallback): """ annotate_line(p1, p2, coord_system="data", plot_args=None): Overplot a line with endpoints at p1 and p2. p1 and p2 should be 2D or 3D coordinates consistent with the coordinate system denoted in the "coord_system" keyword. Parameters ---------- p1, p2 : 2- or 3-element tuples, lists, or arrays These are the coordinates of the endpoints of the line. coord_system : string, optional This string defines the coordinate system of the coordinates p1 and p2. Valid coordinates are: "data" -- the 3D dataset coordinates "plot" -- the 2D coordinates defined by the actual plot limits "axis" -- the MPL axis coordinates: (0,0) is lower left; (1,1) is upper right "figure" -- the MPL figure coordinates: (0,0) is lower left, (1,1) is upper right plot_args : dictionary, optional This dictionary is passed to the MPL plot function for generating the line. By default, it is: {'color':'white', 'linewidth':2} Examples -------- >>> # Overplot a diagonal white line from the lower left corner to upper >>> # right corner >>> import yt >>> ds = yt.load('IsolatedGalaxy/galaxy0030/galaxy0030') >>> s = yt.SlicePlot(ds, 'z', 'density') >>> s.annotate_line([0,0], [1,1], coord_system='axis') >>> s.save() >>> # Overplot a red dashed line from data coordinate (0.1, 0.2, 0.3) to >>> # (0.5, 0.6, 0.7) >>> import yt >>> ds = yt.load('IsolatedGalaxy/galaxy0030/galaxy0030') >>> s = yt.SlicePlot(ds, 'z', 'density') >>> s.annotate_line([0.1, 0.2, 0.3], [0.5, 0.6, 0.7], coord_system='data', plot_args={'color':'red', 'lineStyles':'--'}) >>> s.save() """ _type_name = "line"
[docs] def __init__(self, p1, p2, data_coords=False, coord_system="data", plot_args=None): PlotCallback.__init__(self) def_plot_args = {'color':'white', 'linewidth':2} self.p1 = p1 self.p2 = p2 if plot_args is None: plot_args = def_plot_args self.plot_args = plot_args if data_coords: coord_system = "data" warnings.warn("The data_coords keyword is deprecated. Please set " "the keyword coord_system='data' instead.") self.coord_system = coord_system self.transform = None
def __call__(self, plot): p1 = self.sanitize_coord_system(plot, self.p1, coord_system=self.coord_system) p2 = self.sanitize_coord_system(plot, self.p2, coord_system=self.coord_system) xx0, xx1 = plot._axes.get_xlim() yy0, yy1 = plot._axes.get_ylim() plot._axes.hold(True) plot._axes.plot([p1[0], p2[0]], [p1[1], p2[1]], transform=self.transform, **self.plot_args) plot._axes.set_xlim(xx0,xx1) plot._axes.set_ylim(yy0,yy1) plot._axes.hold(False)
[docs]class ImageLineCallback(LinePlotCallback): """ annotate_image_line(p1, p2, coord_system="axis", plot_args=None): This callback is deprecated, as it is simply a wrapper around the LinePlotCallback (ie annotate_image()). The only difference is that it uses coord_system="axis" by default. Please see LinePlotCallback for more information. """ _type_name = "image_line"
[docs] def __init__(self, p1, p2, data_coords=False, coord_system='axis', plot_args=None): super(ImageLineCallback, self).__init__(p1, p2, data_coords, coord_system, plot_args) warnings.warn("The ImageLineCallback (annotate_image_line()) is " "deprecated. Please use the LinePlotCallback " "(annotate_line()) instead.")
def __call__(self, plot): super(ImageLineCallback, self).__call__(plot)
[docs]class CuttingQuiverCallback(PlotCallback): """ annotate_cquiver(field_x, field_y, factor) Get a quiver plot on top of a cutting plane, using *field_x* and *field_y*, skipping every *factor* datapoint in the discretization. """ _type_name = "cquiver"
[docs] def __init__(self, field_x, field_y, factor): PlotCallback.__init__(self) self.field_x = field_x self.field_y = field_y self.factor = factor
def __call__(self, plot): x0, x1 = plot.xlim y0, y1 = plot.ylim xx0, xx1 = plot._axes.get_xlim() yy0, yy1 = plot._axes.get_ylim() plot._axes.hold(True) nx = plot.image._A.shape[0] / self.factor ny = plot.image._A.shape[1] / self.factor indices = np.argsort(plot.data['dx'])[::-1] pixX = _MPL.CPixelize( plot.data['x'], plot.data['y'], plot.data['z'], plot.data['px'], plot.data['py'], plot.data['pdx'], plot.data['pdy'], plot.data['pdz'], plot.data.center, plot.data._inv_mat, indices, plot.data[self.field_x], int(nx), int(ny), (x0, x1, y0, y1),).transpose() pixY = _MPL.CPixelize( plot.data['x'], plot.data['y'], plot.data['z'], plot.data['px'], plot.data['py'], plot.data['pdx'], plot.data['pdy'], plot.data['pdz'], plot.data.center, plot.data._inv_mat, indices, plot.data[self.field_y], int(nx), int(ny), (x0, x1, y0, y1),).transpose() X,Y = np.meshgrid(np.linspace(xx0,xx1,nx,endpoint=True), np.linspace(yy0,yy1,ny,endpoint=True)) plot._axes.quiver(X,Y, pixX, pixY) plot._axes.set_xlim(xx0,xx1) plot._axes.set_ylim(yy0,yy1) plot._axes.hold(False)
[docs]class ClumpContourCallback(PlotCallback): """ annotate_clumps(clumps, plot_args=None) Take a list of *clumps* and plot them as a set of contours. """ _type_name = "clumps"
[docs] def __init__(self, clumps, plot_args=None): self.clumps = clumps if plot_args is None: plot_args = {} self.plot_args = plot_args
def __call__(self, plot): x0, x1 = plot.xlim y0, y1 = plot.ylim xx0, xx1 = plot._axes.get_xlim() yy0, yy1 = plot._axes.get_ylim() extent = [xx0,xx1,yy0,yy1] plot._axes.hold(True) ax = plot.data.axis px_index = plot.data.ds.coordinates.x_axis[ax] py_index = plot.data.ds.coordinates.y_axis[ax] xf = plot.data.ds.coordinates.axis_name[px_index] yf = plot.data.ds.coordinates.axis_name[py_index] dxf = "d%s" % xf dyf = "d%s" % yf nx, ny = plot.image._A.shape buff = np.zeros((nx,ny),dtype='float64') for i,clump in enumerate(reversed(self.clumps)): mylog.info("Pixelizing contour %s", i) xf_copy = clump[xf].copy().in_units("code_length") yf_copy = clump[yf].copy().in_units("code_length") temp = _MPL.Pixelize(xf_copy, yf_copy, clump[dxf].in_units("code_length")/2.0, clump[dyf].in_units("code_length")/2.0, clump[dxf].d*0.0+i+1, # inits inside Pixelize int(nx), int(ny), (x0, x1, y0, y1), 0).transpose() buff = np.maximum(temp, buff) self.rv = plot._axes.contour(buff, np.unique(buff), extent=extent, **self.plot_args) plot._axes.hold(False)
[docs]class ArrowCallback(PlotCallback): """ annotate_arrow(pos, length=0.03, coord_system='data', plot_args=None): Overplot an arrow pointing at a position for highlighting a specific feature. Arrow points from lower left to the designated position with arrow length "length". Parameters ---------- pos : 2- or 3-element tuple, list, or array These are the coordinates to which the arrow is pointing length : float, optional The length, in axis units, of the arrow. coord_system : string, optional This string defines the coordinate system of the coordinates of pos Valid coordinates are: "data" -- the 3D dataset coordinates "plot" -- the 2D coordinates defined by the actual plot limits "axis" -- the MPL axis coordinates: (0,0) is lower left; (1,1) is upper right "figure" -- the MPL figure coordinates: (0,0) is lower left, (1,1) is upper right plot_args : dictionary, optional This dictionary is passed to the MPL arrow function for generating the arrow. By default, it is: {'color':'white', 'linewidth':2} Examples -------- >>> # Overplot an arrow pointing to feature at data coord: (0.2, 0.3, 0.4) >>> import yt >>> ds = yt.load('IsolatedGalaxy/galaxy0030/galaxy0030') >>> s = yt.SlicePlot(ds, 'z', 'density') >>> s.annotate_arrow([0.2,0.3,0.4]) >>> s.save() >>> # Overplot a red arrow with longer length pointing to plot coordinate >>> # (0.1, -0.1) >>> import yt >>> ds = yt.load('IsolatedGalaxy/galaxy0030/galaxy0030') >>> s = yt.SlicePlot(ds, 'z', 'density') >>> s.annotate_arrow([0.1, -0.1, length=0.06, coord_system='plot', ... plot_args={'color':'red'}) >>> s.save() """ _type_name = "arrow"
[docs] def __init__(self, pos, code_size=None, length=0.03, coord_system='data', plot_args=None): def_plot_args = {'color':'white', 'linewidth':2} self.pos = pos self.code_size = code_size self.length = length self.coord_system = coord_system self.transform = None if plot_args is None: plot_args = def_plot_args self.plot_args = plot_args
def __call__(self, plot): x,y = self.sanitize_coord_system(plot, self.pos, coord_system=self.coord_system) xx0, xx1 = plot._axes.get_xlim() yy0, yy1 = plot._axes.get_ylim() if self.code_size is not None: warnings.warn("The code_size keyword is deprecated. Please use " "the length keyword in 'axis' units instead. " "Setting code_size overrides length value.") if iterable(self.code_size): self.code_size = plot.data.ds.quan(self.code_size[0], self.code_size[1]) self.code_size = np.float64(self.code_size.in_units(plot.xlim[0].units)) self.code_size = self.code_size * self.pixel_scale(plot)[0] dx = dy = self.code_size else: dx = (xx1-xx0) * self.length dy = (yy1-yy0) * self.length plot._axes.hold(True) from matplotlib.patches import Arrow arrow = Arrow(x-dx, y-dy, dx, dy, width=dx, transform=self.transform, **self.plot_args) plot._axes.add_patch(arrow) plot._axes.set_xlim(xx0,xx1) plot._axes.set_ylim(yy0,yy1) plot._axes.hold(False)
[docs]class MarkerAnnotateCallback(PlotCallback): """ annotate_marker(pos, marker='x', coord_system="data", plot_args=None): Overplot a marker on a position for highlighting specific features. Parameters ---------- pos : 2- or 3-element tuple, list, or array These are the coordinates where the marker will be overplotted marker : string, optional The shape of the marker to be passed to the MPL scatter function. By default, it is 'x', but other acceptable values are: '.', 'o', 'v', '^', 's', 'p' '*', etc. See matplotlib.markers for more information. coord_system : string, optional This string defines the coordinate system of the coordinates of pos Valid coordinates are: "data" -- the 3D dataset coordinates "plot" -- the 2D coordinates defined by the actual plot limits "axis" -- the MPL axis coordinates: (0,0) is lower left; (1,1) is upper right "figure" -- the MPL figure coordinates: (0,0) is lower left, (1,1) is upper right plot_args : dictionary, optional This dictionary is passed to the MPL scatter function for generating the marker. By default, it is: {'color':'white', 's':50} Examples -------- >>> # Overplot a white X on a feature at data location (0.5, 0.5, 0.5) >>> import yt >>> ds = yt.load('IsolatedGalaxy/galaxy0030/galaxy0030') >>> s = yt.SlicePlot(ds, 'z', 'density') >>> s.annotate_marker([0.4, 0.5, 0.6]) >>> s.save() >>> # Overplot a big yellow circle at axis location (0.1, 0.2) >>> import yt >>> ds = yt.load('IsolatedGalaxy/galaxy0030/galaxy0030') >>> s = yt.SlicePlot(ds, 'z', 'density') >>> s.annotate_marker([0.1, 0.2], marker='o', coord_system='axis', ... plot_args={'color':'yellow', 's':200}) >>> s.save() """ _type_name = "marker"
[docs] def __init__(self, pos, marker='x', coord_system="data", plot_args=None): def_plot_args = {'color':'w', 's':50} self.pos = pos self.marker = marker if plot_args is None: plot_args = def_plot_args self.plot_args = plot_args self.coord_system = coord_system self.transform = None
def __call__(self, plot): x,y = self.sanitize_coord_system(plot, self.pos, coord_system=self.coord_system) xx0, xx1 = plot._axes.get_xlim() yy0, yy1 = plot._axes.get_ylim() plot._axes.hold(True) plot._axes.scatter(x, y, marker = self.marker, transform=self.transform, **self.plot_args) plot._axes.set_xlim(xx0,xx1) plot._axes.set_ylim(yy0,yy1) plot._axes.hold(False)
[docs]class SphereCallback(PlotCallback): """ annotate_sphere(center, radius, circle_args=None, coord_system='data', text=None, text_args=None): Overplot a circle with designated center and radius with optional text. Parameters ---------- center : 2- or 3-element tuple, list, or array These are the coordinates where the circle will be overplotted radius : YTArray, float, or (1, ('kpc')) style tuple The radius of the circle in code coordinates circle_args : dict, optional This dictionary is passed to the MPL circle object. By default, {'color':'white'} coord_system : string, optional This string defines the coordinate system of the coordinates of pos Valid coordinates are: "data" -- the 3D dataset coordinates "plot" -- the 2D coordinates defined by the actual plot limits "axis" -- the MPL axis coordinates: (0,0) is lower left; (1,1) is upper right "figure" -- the MPL figure coordinates: (0,0) is lower left, (1,1) is upper right text : string, optional Optional text to include next to the circle. text_args : dictionary, optional This dictionary is passed to the MPL text function. By default, it is: {'color':'white'} Examples -------- >>> # Overplot a white circle of radius 100 kpc over the central galaxy >>> import yt >>> ds = yt.load('IsolatedGalaxy/galaxy0030/galaxy0030') >>> s = yt.SlicePlot(ds, 'z', 'density') >>> s.annotate_sphere([0.5, 0.5, 0.5], radius=(100, 'kpc')) >>> s.save() """ _type_name = "sphere"
[docs] def __init__(self, center, radius, circle_args=None, text=None, coord_system='data', text_args=None): def_text_args = {'color':'white'} def_circle_args = {'color':'white'} self.center = center self.radius = radius if circle_args is None: circle_args = def_circle_args if 'fill' not in circle_args: circle_args['fill'] = False self.circle_args = circle_args self.text = text if text_args is None: text_args = def_text_args self.text_args = text_args self.coord_system = coord_system self.transform = None
def __call__(self, plot): from matplotlib.patches import Circle if iterable(self.radius): self.radius = plot.data.ds.quan(self.radius[0], self.radius[1]) self.radius = np.float64(self.radius.in_units(plot.xlim[0].units)) # This assures the radius has the appropriate size in # the different coordinate systems, since one cannot simply # apply a different transform for a length in the same way # you can for a coordinate. if self.coord_system == 'data' or self.coord_system == 'plot': self.radius = self.radius * self.pixel_scale(plot)[0] else: self.radius /= (plot.xlim[1]-plot.xlim[0]).v x,y = self.sanitize_coord_system(plot, self.center, coord_system=self.coord_system) cir = Circle((x, y), self.radius, transform=self.transform, **self.circle_args) xx0, xx1 = plot._axes.get_xlim() yy0, yy1 = plot._axes.get_ylim() plot._axes.hold(True) plot._axes.add_patch(cir) if self.text is not None: label = plot._axes.text(x, y, self.text, transform=self.transform, **self.text_args) self._set_font_properties(plot, [label], **self.text_args) plot._axes.set_xlim(xx0,xx1) plot._axes.set_ylim(yy0,yy1) plot._axes.hold(False)
[docs]class TextLabelCallback(PlotCallback): """ annotate_text(pos, text, coord_system='data', text_args=None, inset_box_args=None): Overplot text on the plot at a specified position. If you desire an inset box around your text, set one with the inset_box_args dictionary keyword. Parameters ---------- pos : 2- or 3-element tuple, list, or array These are the coordinates where the text will be overplotted text : string The text you wish to include coord_system : string, optional This string defines the coordinate system of the coordinates of pos Valid coordinates are: "data" -- the 3D dataset coordinates "plot" -- the 2D coordinates defined by the actual plot limits "axis" -- the MPL axis coordinates: (0,0) is lower left; (1,1) is upper right "figure" -- the MPL figure coordinates: (0,0) is lower left, (1,1) is upper right text_args : dictionary, optional This dictionary is passed to the MPL text function for generating the text. By default, it is: {'color':'white'} and uses the defaults for the other fonts in the image. inset_box_args : dictionary, optional A dictionary of any arbitrary parameters to be passed to the Matplotlib FancyBboxPatch object as the inset box around the text. Default: {} Examples -------- >>> # Overplot white text at data location [0.55, 0.7, 0.4] >>> import yt >>> ds = yt.load('IsolatedGalaxy/galaxy0030/galaxy0030') >>> s = yt.SlicePlot(ds, 'z', 'density') >>> s.annotate_text([0.55, 0.7, 0.4], "Here is a galaxy") >>> s.save() >>> # Overplot yellow text at axis location [0.2, 0.8] with >>> # a shaded inset box >>> import yt >>> ds = yt.load('IsolatedGalaxy/galaxy0030/galaxy0030') >>> s = yt.SlicePlot(ds, 'z', 'density') >>> s.annotate_text([0.2, 0.8], "Here is a galaxy", coord_system='axis', ... text_args={'color':'yellow'}, ... inset_box_args={'boxstyle':'square,pad=0.3', ... 'facecolor':'black', ... 'linewidth':3, ... 'edgecolor':'white', 'alpha':0.5}) >>> s.save() """ _type_name = "text"
[docs] def __init__(self, pos, text, data_coords=False, coord_system='data', text_args=None, inset_box_args=None): def_text_args = {'color':'white'} self.pos = pos self.text = text if data_coords: coord_system = 'data' warnings.warn("The data_coords keyword is deprecated. Please set " "the keyword coord_system='data' instead.") if text_args is None: text_args = def_text_args self.text_args = text_args if inset_box_args is None: inset_box_args = {} self.inset_box_args = inset_box_args self.coord_system = coord_system self.transform = None
def __call__(self, plot): kwargs = self.text_args.copy() x,y = self.sanitize_coord_system(plot, self.pos, coord_system=self.coord_system) # Set the font properties of text from this callback to be # consistent with other text labels in this figure xx0, xx1 = plot._axes.get_xlim() yy0, yy1 = plot._axes.get_ylim() plot._axes.hold(True) label = plot._axes.text(x, y, self.text, transform=self.transform, bbox=self.inset_box_args, **kwargs) self._set_font_properties(plot, [label], **kwargs) plot._axes.set_xlim(xx0,xx1) plot._axes.set_ylim(yy0,yy1) plot._axes.hold(False)
[docs]class PointAnnotateCallback(TextLabelCallback): """ annotate_point(pos, text, coord_system='data', text_args=None, inset_box_args=None) This callback is deprecated, as it is simply a wrapper around the TextLabelCallback (ie annotate_text()). Please see TextLabelCallback for more information. """ _type_name = "point"
[docs] def __init__(self, pos, text, data_coords=False, coord_system='data', text_args=None, inset_box_args=None): super(PointAnnotateCallback, self).__init__(pos, text, data_coords, coord_system, text_args, inset_box_args) warnings.warn("The PointAnnotateCallback (annotate_point()) is " "deprecated. Please use the TextLabelCallback " "(annotate_point()) instead.")
def __call__(self, plot): super(PointAnnotateCallback, self).__call__(plot)
[docs]class HaloCatalogCallback(PlotCallback): """ annotate_halos(halo_catalog, circle_args=None, width=None, annotate_field=None, text_args=None, factor=1.0) Plots circles at the locations of all the halos in a halo catalog with radii corresponding to the virial radius of each halo. circle_args: Contains the arguments controlling the appearance of the circles, supplied to the Matplotlib patch Circle. width: the width over which to select halos to plot, useful when overplotting to a slice plot. Accepts a tuple in the form (1.0, 'Mpc'). annotate_field: Accepts a field contained in the halo catalog to add text to the plot near the halo. Example: annotate_field = 'particle_mass' will write the halo mass next to each halo. text_args: Contains the arguments controlling the text appearance of the annotated field. factor: A number the virial radius is multiplied by for plotting the circles. Ex: factor = 2.0 will plot circles with twice the radius of each halo virial radius. """ _type_name = 'halos' region = None _descriptor = None
[docs] def __init__(self, halo_catalog, circle_args=None, circle_kwargs=None, width=None, annotate_field=None, text_args=None, font_kwargs=None, factor=1.0): PlotCallback.__init__(self) def_circle_args = {'edgecolor':'white', 'facecolor':'None'} def_text_args = {'color':'white'} self.halo_catalog = halo_catalog self.width = width self.annotate_field = annotate_field if circle_kwargs is not None: circle_args = circle_kwargs warnings.warn("The circle_kwargs keyword is deprecated. Please " "use the circle_args keyword instead.") if font_kwargs is not None: text_args = font_kwargs warnings.warn("The font_kwargs keyword is deprecated. Please use " "the text_args keyword instead.") if circle_args is None: circle_args = def_circle_args self.circle_args = circle_args if text_args is None: text_args = def_text_args self.text_args = text_args self.factor = factor
def __call__(self, plot): data = plot.data x0, x1 = plot.xlim y0, y1 = plot.ylim xx0, xx1 = plot._axes.get_xlim() yy0, yy1 = plot._axes.get_ylim() halo_data= self.halo_catalog.halos_ds.all_data() axis_names = plot.data.ds.coordinates.axis_name xax = plot.data.ds.coordinates.x_axis[data.axis] yax = plot.data.ds.coordinates.y_axis[data.axis] field_x = "particle_position_%s" % axis_names[xax] field_y = "particle_position_%s" % axis_names[yax] field_z = "particle_position_%s" % axis_names[data.axis] plot._axes.hold(True) # Set up scales for pixel size and original data pixel_scale = self.pixel_scale(plot)[0] data_scale = data.ds.length_unit units = data_scale.units # Convert halo positions to code units of the plotted data # and then to units of the plotted window px = halo_data[field_x][:].in_units(units) / data_scale py = halo_data[field_y][:].in_units(units) / data_scale px, py = self.convert_to_plot(plot,[px,py]) # Convert halo radii to a radius in pixels radius = halo_data['virial_radius'][:].in_units(units) radius = np.array(radius*pixel_scale*self.factor/data_scale) if self.width: pz = halo_data[field_z][:].in_units(units)/data_scale pz = data.ds.arr(pz, 'code_length') c = data.center[data.axis] # I should catch an error here if width isn't in this form # but I dont really want to reimplement get_sanitized_width... width = data.ds.arr(self.width[0], self.width[1]).in_units('code_length') indices = np.where((pz > c-width) & (pz < c+width)) px = px[indices] py = py[indices] radius = radius[indices] for x,y,r in zip(px, py, radius): plot._axes.add_artist(Circle(xy=(x,y), radius = r, **self.circle_args)) plot._axes.set_xlim(xx0,xx1) plot._axes.set_ylim(yy0,yy1) plot._axes.hold(False) if self.annotate_field: annotate_dat = halo_data[self.annotate_field] texts = ['{:g}'.format(float(dat))for dat in annotate_dat] labels = [] for pos_x, pos_y, t in zip(px, py, texts): labels.append(plot._axes.text(pos_x, pos_y, t, **self.text_args)) # Set the font properties of text from this callback to be # consistent with other text labels in this figure self._set_font_properties(plot, labels, **self.text_args)
[docs]class ParticleCallback(PlotCallback): """ annotate_particles(width, p_size=1.0, col='k', marker='o', stride=1.0, ptype=None, minimum_mass=None, alpha=1.0) Adds particle positions, based on a thick slab along *axis* with a *width* along the line of sight. *p_size* controls the number of pixels per particle, and *col* governs the color. *ptype* will restrict plotted particles to only those that are of a given type. Particles with masses below *minimum_mass* will not be plotted. *alpha* determines the opacity of the marker symbol used in the scatter plot. """ _type_name = "particles" region = None _descriptor = None
[docs] def __init__(self, width, p_size=1.0, col='k', marker='o', stride=1.0, ptype='all', minimum_mass=None, alpha=1.0): PlotCallback.__init__(self) self.width = width self.p_size = p_size self.color = col self.marker = marker self.stride = stride self.ptype = ptype self.minimum_mass = minimum_mass self.alpha = alpha
def __call__(self, plot): data = plot.data if iterable(self.width): w = plot.data.ds.quan(self.width[0], self.width[1]).in_units("code_length") self.width = np.float64(w) elif isinstance(self.width, YTQuantity): self.width = np.float64(plot.data.ds.quan(self.width).in_units("code_length")) # we construct a rectangular prism x0, x1 = plot.xlim y0, y1 = plot.ylim xx0, xx1 = plot._axes.get_xlim() yy0, yy1 = plot._axes.get_ylim() reg = self._get_region((x0,x1), (y0,y1), plot.data.axis, data) ax = data.axis xax = plot.data.ds.coordinates.x_axis[ax] yax = plot.data.ds.coordinates.y_axis[ax] axis_names = plot.data.ds.coordinates.axis_name field_x = "particle_position_%s" % axis_names[xax] field_y = "particle_position_%s" % axis_names[yax] pt = self.ptype gg = ( ( reg[pt, field_x] >= x0 ) & ( reg[pt, field_x] <= x1 ) & ( reg[pt, field_y] >= y0 ) & ( reg[pt, field_y] <= y1 ) ) if self.minimum_mass is not None: gg &= (reg[pt, "particle_mass"] >= self.minimum_mass) if gg.sum() == 0: return plot._axes.hold(True) px, py = self.convert_to_plot(plot, [np.array(reg[pt, field_x][gg][::self.stride]), np.array(reg[pt, field_y][gg][::self.stride])]) plot._axes.scatter(px, py, edgecolors='None', marker=self.marker, s=self.p_size, c=self.color,alpha=self.alpha) plot._axes.set_xlim(xx0,xx1) plot._axes.set_ylim(yy0,yy1) plot._axes.hold(False) def _get_region(self, xlim, ylim, axis, data): LE, RE = [None]*3, [None]*3 ds = data.ds xax = ds.coordinates.x_axis[axis] yax = ds.coordinates.y_axis[axis] zax = axis LE[xax], RE[xax] = xlim LE[yax], RE[yax] = ylim LE[zax] = data.center[zax].ndarray_view() - self.width*0.5 RE[zax] = data.center[zax].ndarray_view() + self.width*0.5 if self.region is not None \ and np.all(self.region.left_edge <= LE) \ and np.all(self.region.right_edge >= RE): return self.region self.region = data.ds.region(data.center, LE, RE) return self.region
[docs]class TitleCallback(PlotCallback): """ annotate_title(title) Accepts a *title* and adds it to the plot """ _type_name = "title"
[docs] def __init__(self, title): PlotCallback.__init__(self) self.title = title
def __call__(self,plot): plot._axes.set_title(self.title) # Set the font properties of text from this callback to be # consistent with other text labels in this figure label = plot._axes.title self._set_font_properties(plot, [label])
[docs]class TriangleFacetsCallback(PlotCallback): """ annotate_triangle_facets(triangle_vertices, plot_args=None ) Intended for representing a slice of a triangular faceted geometry in a slice plot. Uses a set of *triangle_vertices* to find all trangles the plane of a SlicePlot intersects with. The lines between the intersection points of the triangles are then added to the plot to create an outline of the geometry represented by the triangles. """ _type_name = "triangle_facets"
[docs] def __init__(self, triangle_vertices, plot_args=None): super(TriangleFacetsCallback, self).__init__() self.plot_args = {} if plot_args is None else plot_args self.vertices = triangle_vertices
def __call__(self, plot): plot._axes.hold(True) ax = plot.data.axis xax = plot.data.ds.coordinates.x_axis[ax] yax = plot.data.ds.coordinates.y_axis[ax] if not hasattr(self.vertices, "in_units"): vertices = plot.data.pf.arr(self.vertices, "code_length") else: vertices = self.vertices l_cy = triangle_plane_intersect(plot.data.axis, plot.data.coord, vertices)[:,:,(xax, yax)] # reformat for conversion to plot coordinates l_cy = np.rollaxis(l_cy,0,3) # convert all line starting points l_cy[0] = self.convert_to_plot(plot,l_cy[0]) l_cy[1] = self.convert_to_plot(plot,l_cy[1]) # convert all line ending points l_cy = np.rollaxis(l_cy,2,0) # create line collection and add it to the plot lc = matplotlib.collections.LineCollection(l_cy, **self.plot_args) plot._axes.add_collection(lc) plot._axes.hold(False)
[docs]class TimestampCallback(PlotCallback): """ annotate_timestamp(x_pos=None, y_pos=None, corner='lower_left', time=True, redshift=False, time_format="t = {time:.0f} {units}", time_unit=None, redshift_format="z = {redshift:.2f}", draw_inset_box=False, coord_system='axis', text_args=None, inset_box_args=None) Annotates the timestamp and/or redshift of the data output at a specified location in the image (either in a present corner, or by specifying (x,y) image coordinates with the x_pos, y_pos arguments. If no time_units are specified, it will automatically choose appropriate units. It allows for custom formatting of the time and redshift information, as well as the specification of an inset box around the text. Parameters ---------- x_pos, y_pos : floats, optional The image location of the timestamp in the coord system defined by the coord_system kwarg. Setting x_pos and y_pos overrides the corner parameter. corner : string, optional Corner sets up one of 4 predeterimined locations for the timestamp to be displayed in the image: 'upper_left', 'upper_right', 'lower_left', 'lower_right' (also allows None). This value will be overridden by the optional x_pos and y_pos keywords. time : boolean, optional Whether or not to show the ds.current_time of the data output. Can be used solo or in conjunction with redshift parameter. redshift : boolean, optional Whether or not to show the ds.current_time of the data output. Can be used solo or in conjunction with the time parameter. time_format : string, optional This specifies the format of the time output assuming "time" is the number of time and "unit" is units of the time (e.g. 's', 'Myr', etc.) The time can be specified to arbitrary precision according to printf formatting codes (defaults to .1f -- a float with 1 digits after decimal). Example: "Age = {time:.2f} {units}". time_unit : string, optional time_unit must be a valid yt time unit (e.g. 's', 'min', 'hr', 'yr', 'Myr', etc.) redshift_format : string, optional This specifies the format of the redshift output. The redshift can be specified to arbitrary precision according to printf formatting codes (defaults to 0.2f -- a float with 2 digits after decimal). Example: "REDSHIFT = {redshift:03.3g}", draw_inset_box : boolean, optional Whether or not an inset box should be included around the text If so, it uses the inset_box_args to set the matplotlib FancyBboxPatch object. coord_system : string, optional This string defines the coordinate system of the coordinates of pos Valid coordinates are: "data" -- the 3D dataset coordinates "plot" -- the 2D coordinates defined by the actual plot limits "axis" -- the MPL axis coordinates: (0,0) is lower left; (1,1) is upper right "figure" -- the MPL figure coordinates: (0,0) is lower left, (1,1) is upper right text_args : dictionary, optional A dictionary of any arbitrary parameters to be passed to the Matplotlib text object. Defaults: {'color':'white', 'horizontalalignment':'center', 'verticalalignment':'top'}. inset_box_args : dictionary, optional A dictionary of any arbitrary parameters to be passed to the Matplotlib FancyBboxPatch object as the inset box around the text. Defaults: {'boxstyle':'square,pad=0.3', 'facecolor':'black', 'linewidth':3, 'edgecolor':'white', 'alpha':0.5} Example ------- >>> import yt >>> ds = yt.load('Enzo_64/DD0020/data0020') >>> s = yt.SlicePlot(ds, 'z', 'density') >>> s.annotate_timestamp() """ _type_name = "timestamp"
[docs] def __init__(self, x_pos=None, y_pos=None, corner='lower_left', time=True, redshift=False, time_format="t = {time:.1f} {units}", time_unit=None, redshift_format="z = {redshift:.2f}", draw_inset_box=False, coord_system='axis', text_args=None, inset_box_args=None): def_text_args = {'color':'white', 'horizontalalignment':'center', 'verticalalignment':'top'} def_inset_box_args = {'boxstyle':'square,pad=0.3', 'facecolor':'black', 'linewidth':3, 'edgecolor':'white', 'alpha':0.5} # Set position based on corner argument. self.pos = (x_pos, y_pos) self.corner = corner self.time = time self.redshift = redshift self.time_format = time_format self.redshift_format = redshift_format self.time_unit = time_unit self.coord_system = coord_system if text_args is None: text_args = def_text_args self.text_args = text_args if inset_box_args is None: inset_box_args = def_inset_box_args self.inset_box_args = inset_box_args # if inset box is not desired, set inset_box_args to {} if not draw_inset_box: self.inset_box_args = {}
def __call__(self, plot): # Setting pos overrides corner argument if self.pos[0] is None or self.pos[1] is None: if self.corner == 'upper_left': self.pos = (0.03, 0.96) self.text_args['horizontalalignment'] = 'left' self.text_args['verticalalignment'] = 'top' elif self.corner == 'upper_right': self.pos = (0.97, 0.96) self.text_args['horizontalalignment'] = 'right' self.text_args['verticalalignment'] = 'top' elif self.corner == 'lower_left': self.pos = (0.03, 0.03) self.text_args['horizontalalignment'] = 'left' self.text_args['verticalalignment'] = 'bottom' elif self.corner == 'lower_right': self.pos = (0.97, 0.03) self.text_args['horizontalalignment'] = 'right' self.text_args['verticalalignment'] = 'bottom' elif self.corner is None: self.pos = (0.5, 0.5) self.text_args['horizontalalignment'] = 'center' self.text_args['verticalalignment'] = 'center' else: raise SyntaxError("Argument 'corner' must be set to " "'upper_left', 'upper_right', 'lower_left', " "'lower_right', or None") self.text = "" # If we're annotating the time, put it in the correct format if self.time: # If no time_units are set, then identify a best fit time unit if self.time_unit is None: self.time_unit = plot.ds.get_smallest_appropriate_unit( \ plot.ds.current_time, quantity='time') t = plot.ds.current_time.in_units(self.time_unit) self.text += self.time_format.format(time=float(t), units=self.time_unit) # If time and redshift both shown, do one on top of the other if self.time and self.redshift: self.text += "\n" # If we're annotating the redshift, put it in the correct format if self.redshift: try: z = np.abs(plot.data.ds.current_redshift) except AttributeError: raise AttributeError("Dataset does not have current_redshift. " "Set redshift=False.") self.text += self.redshift_format.format(redshift=float(z)) # This is just a fancy wrapper around the TextLabelCallback tcb = TextLabelCallback(self.pos, self.text, coord_system=self.coord_system, text_args=self.text_args, inset_box_args=self.inset_box_args) return tcb(plot)
class ScaleCallback(PlotCallback): """ annotate_scale(corner='lower_right', coeff=None, unit=None, pos=None, max_frac=0.16, min_frac=0.015, coord_system='axis', size_bar_args=None, draw_inset_box=False, inset_box_args=None) Annotates the scale of the plot at a specified location in the image (either in a preset corner, or by specifying (x,y) image coordinates with the pos argument. Coeff and units (e.g. 1 Mpc or 100 kpc) refer to the distance scale you desire to show on the plot. If no coeff and units are specified, an appropriate pair will be determined such that your scale bar is never smaller than min_frac or greater than max_frac of your plottable axis length. Additional customization of the scale bar is possible by adjusting the size_bar_args dictionary. This accepts keyword arguments for the AnchoredSizeBar class in matplotlib's axes_grid toolkit. Parameters ---------- corner : string, optional Corner sets up one of 4 predeterimined locations for the scale bar to be displayed in the image: 'upper_left', 'upper_right', 'lower_left', 'lower_right' (also allows None). This value will be overridden by the optional 'pos' keyword. coeff : float, optional The coefficient of the unit defining the distance scale (e.g. 10 kpc or 100 Mpc) for overplotting. If set to None along with unit keyword, coeff will be automatically determined to be a power of 10 relative to the best-fit unit. unit : string, optional unit must be a valid yt distance unit (e.g. 'm', 'km', 'AU', 'pc', 'kpc', etc.) or set to None. If set to None, will be automatically determined to be the best-fit to the data. pos : 2- or 3-element tuples, lists, or arrays, optional The image location of the scale bar in the plot coordinate system. Setting pos overrides the corner parameter. min_frac, max_frac: float, optional The minimum/maximum fraction of the axis width for the scale bar to extend. A value of 1 would allow the scale bar to extend across the entire axis width. Only used for automatically calculating best-fit coeff and unit when neither is specified, otherwise disregarded. coord_system : string, optional This string defines the coordinate system of the coordinates of pos Valid coordinates are: "data" -- the 3D dataset coordinates "plot" -- the 2D coordinates defined by the actual plot limits "axis" -- the MPL axis coordinates: (0,0) is lower left; (1,1) is upper right "figure" -- the MPL figure coordinates: (0,0) is lower left, (1,1) is upper right size_bar_args : dictionary, optional A dictionary of parameters to be passed to the Matplotlib AnchoredSizeBar initializer. Defaults: {'pad': 0.25, 'sep': 5, 'borderpad': 1, 'color': 'w'} draw_inset_box : boolean, optional Whether or not an inset box should be included around the scale bar. inset_box_args : dictionary, optional A dictionary of keyword arguments to be passed to the matplotlib Patch object that represents the inset box. Defaults: {'facecolor': 'black', 'linewidth': 3, 'edgecolor', 'white', 'alpha': 0.5, 'boxstyle': 'square'} Example ------- >>> import yt >>> ds = yt.load('Enzo_64/DD0020/data0020') >>> s = yt.SlicePlot(ds, 'z', 'density') >>> s.annotate_scale() """ _type_name = "scale" def __init__(self, corner='lower_right', coeff=None, unit=None, pos=None, max_frac=0.16, min_frac=0.015, coord_system='axis', size_bar_args=None, draw_inset_box=False, inset_box_args=None): def_size_bar_args = { 'pad': 0.05, 'sep': 5, 'borderpad': 1, 'color': 'w' } def_inset_box_args = { 'facecolor': 'black', 'linewidth': 3, 'edgecolor': 'white', 'alpha': 0.5, 'boxstyle': 'square', } # Set position based on corner argument. self.corner = corner self.coeff = coeff self.unit = unit self.pos = pos self.max_frac = max_frac self.min_frac = min_frac self.coord_system = coord_system if size_bar_args is None: self.size_bar_args = def_size_bar_args else: self.size_bar_args = size_bar_args if inset_box_args is None: self.inset_box_args = def_inset_box_args else: self.inset_box_args = inset_box_args self.draw_inset_box = draw_inset_box def __call__(self, plot): # Callback only works for plots with axis ratios of 1 xsize = plot.xlim[1] - plot.xlim[0] if plot.aspect != 1.0: raise NotImplementedError( "Scale callback has only been implemented for plots with no " "aspect ratio scaling. (aspect = {%s})".format(plot._aspect)) # Setting pos overrides corner argument if self.pos is None: if self.corner == 'upper_left': self.pos = (0.11, 0.952) elif self.corner == 'upper_right': self.pos = (0.89, 0.952) elif self.corner == 'lower_left': self.pos = (0.11, 0.052) elif self.corner == 'lower_right': self.pos = (0.89, 0.052) elif self.corner is None: self.pos = (0.5, 0.5) else: raise SyntaxError("Argument 'corner' must be set to " "'upper_left', 'upper_right', 'lower_left', " "'lower_right', or None") # When identifying a best fit distance unit, do not allow scale marker # to be greater than max_frac fraction of xaxis or under min_frac # fraction of xaxis max_scale = self.max_frac * xsize min_scale = self.min_frac * xsize if self.coeff is None: self.coeff = 1. # If no units are set, then identify a best fit distance unit if self.unit is None: min_scale = plot.ds.get_smallest_appropriate_unit( min_scale, return_quantity=True) max_scale = plot.ds.get_smallest_appropriate_unit( max_scale, return_quantity=True) self.coeff = max_scale.v self.unit = max_scale.units self.scale = YTQuantity(self.coeff, self.unit) text = "{scale} {units}".format(scale=int(self.coeff), units=self.unit) image_scale = (plot.frb.convert_distance_x(self.scale) / plot.frb.convert_distance_x(xsize)).v size_vertical = self.size_bar_args.pop('size_vertical', .005) fontproperties = self.size_bar_args.pop( 'fontproperties', plot.font_properties) frameon = self.size_bar_args.pop('frameon', self.draw_inset_box) # this "anchors" the size bar to a box centered on self.pos in axis # coordinates self.size_bar_args['bbox_to_anchor'] = self.pos self.size_bar_args['bbox_transform'] = plot._axes.transAxes bar = AnchoredSizeBar(plot._axes.transAxes, image_scale, text, 10, size_vertical=size_vertical, fontproperties=fontproperties, frameon=frameon, **self.size_bar_args) bar.patch.set(**self.inset_box_args) plot._axes.add_artist(bar) return plot class RayCallback(PlotCallback): """ annotate_ray(ray, plot_args=None) Adds a line representing the projected path of a ray across the plot. The ray can be either a YTOrthoRayBase, YTRayBase, or a LightRay object. annotate_ray() will properly account for periodic rays across the volume. Parameters ---------- ray : YTOrthoRayBase, YTRayBase, or LightRay Ray is the object that we want to include. We overplot the projected trajectory of the ray. If the object is a analysis_modules.cosmological_observation.light_ray.light_ray.LightRay object, it will only plot the segment of the LightRay that intersects the dataset currently displayed. plot_args : dictionary, optional A dictionary of any arbitrary parameters to be passed to the Matplotlib line object. Defaults: {'color':'white', 'linewidth':2}. Examples -------- >>> # Overplot a ray and an ortho_ray object on a projection >>> import yt >>> ds = yt.load('IsolatedGalaxy/galaxy0030/galaxy0030') >>> oray = ds.ortho_ray(1, (0.3, 0.4)) # orthoray down the y axis >>> ray = ds.ray((0.1, 0.2, 0.3), (0.6, 0.7, 0.8)) # arbitrary ray >>> p = yt.ProjectionPlot(ds, 'z', 'density') >>> p.annotate_ray(oray) >>> p.annotate_ray(ray) >>> p.save() >>> # Overplot a LightRay object on a projection >>> import yt >>> from yt.analysis_modules.cosmological_observation.api import LightRay >>> ds = yt.load('enzo_cosmology_plus/RD0004/RD0004') >>> lr = LightRay("enzo_cosmology_plus/AMRCosmology.enzo", ... 'Enzo', 0.0, 0.1, time_data=False) >>> lray = lr.make_light_ray(seed=1) >>> p = yt.ProjectionPlot(ds, 'z', 'density') >>> p.annotate_ray(lr) >>> p.save() """ _type_name = "ray" def __init__(self, ray, plot_args=None): PlotCallback.__init__(self) def_plot_args = {'color':'white', 'linewidth':2} self.ray = ray if plot_args is None: plot_args = def_plot_args self.plot_args = plot_args def _process_ray(self): """ Get the start_coord and end_coord of a ray object """ return (self.ray.start_point, self.ray.end_point) def _process_ortho_ray(self): """ Get the start_coord and end_coord of an ortho_ray object """ start_coord = self.ray.ds.domain_left_edge.copy() end_coord = self.ray.ds.domain_right_edge.copy() xax = self.ray.ds.coordinates.x_axis[self.ray.axis] yax = self.ray.ds.coordinates.y_axis[self.ray.axis] start_coord[xax] = end_coord[xax] = self.ray.coords[0] start_coord[yax] = end_coord[yax] = self.ray.coords[1] return (start_coord, end_coord) def _process_light_ray(self, plot): """ Get the start_coord and end_coord of a LightRay object. Identify which of the sections of the LightRay is in the dataset that is currently being plotted. If there is one, return the start and end of the corresponding ray segment """ for ray_ds in self.ray.light_ray_solution: if ray_ds['unique_identifier'] == plot.ds.unique_identifier: start_coord = ray_ds['start'] end_coord = ray_ds['end'] return (start_coord, end_coord) # if no intersection between the plotted dataset and the LightRay # return a false tuple to pass to start_coord return ((False, False), (False, False)) def __call__(self, plot): type_name = getattr(self.ray, "_type_name", None) if type_name == "ray": start_coord, end_coord = self._process_ray() elif type_name == "ortho_ray": start_coord, end_coord = self._process_ortho_ray() elif hasattr(self.ray, "light_ray_solution"): start_coord, end_coord = self._process_light_ray(plot) else: raise SyntaxError("ray must be a YTRayBase, YTOrthoRayBase, or " "LightRay object.") # if start_coord and end_coord are all False, it means no intersecting # ray segment with this plot. if not all(start_coord) and not all(end_coord): return plot # if possible, break periodic ray into non-periodic # segments and add each of them individually if any(plot.ds.periodicity): segments = periodic_ray(start_coord, end_coord, left=plot.ds.domain_left_edge, right=plot.ds.domain_right_edge) else: segments = [[start_coord, end_coord]] for segment in segments: lcb = LinePlotCallback(segment[0], segment[1], coord_system='data', plot_args=self.plot_args) lcb(plot) return plot