Source code for yt.fields.derived_field

"""
Derived field base class.

"""

#-----------------------------------------------------------------------------
# Copyright (c) 2013, yt Development Team.
#
# Distributed under the terms of the Modified BSD License.
#
# The full license is in the file COPYING.txt, distributed with this software.
#-----------------------------------------------------------------------------

import contextlib
import inspect

from yt.funcs import \
    ensure_list
from yt.units.yt_array import \
    YTArray
from .field_exceptions import \
    ValidationException, \
    NeedsGridType, \
    NeedsOriginalGrid, \
    NeedsDataField, \
    NeedsProperty, \
    NeedsParameter, \
    FieldUnitsError
from .field_detector import \
    FieldDetector
from yt.units.unit_object import \
    Unit

def derived_field(**kwargs):
    def inner_decorator(function):
        if 'name' not in kwargs:
            kwargs['name'] = function.__name__
        kwargs['function'] = function
        add_field(**kwargs)
        return function
    return inner_decorator

def TranslationFunc(field_name):
    def _TranslationFunc(field, data):
        # We do a bunch of in-place modifications, so we will copy this.
        return data[field_name].copy()
    return _TranslationFunc

def NullFunc(field, data):
    raise YTFieldNotFound(field.name)
 
[docs]class DerivedField(object): """ This is the base class used to describe a cell-by-cell derived field. Parameters ---------- name : str is the name of the field. function : callable A function handle that defines the field. Should accept arguments (field, data) units : str A plain text string encoding the unit. Powers must be in python syntax (** instead of ^). If set to "auto" the units will be inferred from the units of the return value of the field function. take_log : bool Describes whether the field should be logged validators : list A list of :class:`FieldValidator` objects particle_type : bool Is this a particle (1D) field? vector_field : bool Describes the dimensionality of the field. Currently unused. display_field : bool Governs its appearance in the dropdowns in Reason not_in_all : bool Used for baryon fields from the data that are not in all the grids display_name : str A name used in the plots output_units : str For fields that exist on disk, which we may want to convert to other fields or that get aliased to themselves, we can specify a different desired output unit than the unit found on disk. """
[docs] def __init__(self, name, function, units=None, take_log=True, validators=None, particle_type=False, vector_field=False, display_field=True, not_in_all=False, display_name=None, output_units = None): self.name = name self.take_log = take_log self.display_name = display_name self.not_in_all = not_in_all self.display_field = display_field self.particle_type = particle_type self.vector_field = vector_field if output_units is None: output_units = units self.output_units = output_units self._function = function if validators: self.validators = ensure_list(validators) else: self.validators = [] # handle units if units is None: self.units = '' elif isinstance(units, str): if units.lower() == 'auto': self.units = None else: self.units = units elif isinstance(units, Unit): self.units = str(units) else: raise FieldUnitsError("Cannot handle units '%s' (type %s)." \ "Please provide a string or Unit " \ "object." % (units, type(units)) )
def _copy_def(self): dd = {} dd['name'] = self.name dd['units'] = self.units dd['take_log'] = self.take_log dd['validators'] = list(self.validators) dd['particle_type'] = self.particle_type dd['vector_field'] = self.vector_field dd['display_field'] = True dd['not_in_all'] = self.not_in_all dd['display_name'] = self.display_name return dd
[docs] def get_units(self): u = Unit(self.units) return u.latex_representation()
[docs] def get_projected_units(self): u = Unit(self.units)*Unit('cm') return u.latex_representation()
[docs] def check_available(self, data): """ This raises an exception of the appropriate type if the set of validation mechanisms are not met, and otherwise returns True. """ for validator in self.validators: validator(data) # If we don't get an exception, we're good to go return True
[docs] def get_dependencies(self, *args, **kwargs): """ This returns a list of names of fields that this field depends on. """ e = FieldDetector(*args, **kwargs) if self._function.__name__ == '<lambda>': e.requested.append(self.name) else: e[self.name] return e
_unit_registry = None @contextlib.contextmanager
[docs] def unit_registry(self, data): old_registry = self._unit_registry if hasattr(data, 'unit_registry'): ur = data.unit_registry elif hasattr(data, 'ds'): ur = data.ds.unit_registry else: ur = None self._unit_registry = ur yield self._unit_registry = old_registry
def __call__(self, data): """ Return the value of the field in a given *data* object. """ ii = self.check_available(data) original_fields = data.keys() # Copy if self._function is NullFunc: raise RuntimeError( "Something has gone terribly wrong, _function is NullFunc " + "for %s" % (self.name,)) with self.unit_registry(data): dd = self._function(self, data) for field_name in data.keys(): if field_name not in original_fields: del data[field_name] return dd
[docs] def get_source(self): """ Return a string containing the source of the function (if possible.) """ return inspect.getsource(self._function)
[docs] def get_label(self, projected=False): """ Return a data label for the given field, including units. """ name = self.name[1] if self.display_name is not None: name = self.display_name # Start with the field name data_label = r"$\rm{%s}" % name # Grab the correct units if projected: raise NotImplementedError else: units = Unit(self.units) # Add unit label if not units.is_dimensionless: data_label += r"\ \ (%s)" % (units) data_label += r"$" return data_label
class FieldValidator(object): pass
[docs]class ValidateParameter(FieldValidator):
[docs] def __init__(self, parameters): """ This validator ensures that the dataset has a given parameter. """ FieldValidator.__init__(self) self.parameters = ensure_list(parameters)
def __call__(self, data): doesnt_have = [] for p in self.parameters: if not data.has_field_parameter(p): doesnt_have.append(p) if len(doesnt_have) > 0: raise NeedsParameter(doesnt_have) return True
[docs]class ValidateDataField(FieldValidator):
[docs] def __init__(self, field): """ This validator ensures that the output file has a given data field stored in it. """ FieldValidator.__init__(self) self.fields = ensure_list(field)
def __call__(self, data): doesnt_have = [] if isinstance(data, FieldDetector): return True for f in self.fields: if f not in data.index.field_list: doesnt_have.append(f) if len(doesnt_have) > 0: raise NeedsDataField(doesnt_have) return True
[docs]class ValidateProperty(FieldValidator):
[docs] def __init__(self, prop): """ This validator ensures that the data object has a given python attribute. """ FieldValidator.__init__(self) self.prop = ensure_list(prop)
def __call__(self, data): doesnt_have = [] for p in self.prop: if not hasattr(data,p): doesnt_have.append(p) if len(doesnt_have) > 0: raise NeedsProperty(doesnt_have) return True
[docs]class ValidateSpatial(FieldValidator):
[docs] def __init__(self, ghost_zones = 0, fields=None): """ This validator ensures that the data handed to the field is of spatial nature -- that is to say, 3-D. """ FieldValidator.__init__(self) self.ghost_zones = ghost_zones self.fields = fields
def __call__(self, data): # When we say spatial information, we really mean # that it has a three-dimensional data structure #if isinstance(data, FieldDetector): return True if not getattr(data, '_spatial', False): raise NeedsGridType(self.ghost_zones,self.fields) if self.ghost_zones <= data._num_ghost_zones: return True raise NeedsGridType(self.ghost_zones,self.fields)
[docs]class ValidateGridType(FieldValidator):
[docs] def __init__(self): """ This validator ensures that the data handed to the field is an actual grid patch, not a covering grid of any kind. """ FieldValidator.__init__(self)
def __call__(self, data): # We need to make sure that it's an actual AMR grid if isinstance(data, FieldDetector): return True if getattr(data, "_type_name", None) == 'grid': return True raise NeedsOriginalGrid()