"""
Derived field base class.
"""
#-----------------------------------------------------------------------------
# Copyright (c) 2013, yt Development Team.
#
# Distributed under the terms of the Modified BSD License.
#
# The full license is in the file COPYING.txt, distributed with this software.
#-----------------------------------------------------------------------------
import contextlib
import inspect
from yt.funcs import \
ensure_list
from yt.units.yt_array import \
YTArray
from .field_exceptions import \
ValidationException, \
NeedsGridType, \
NeedsOriginalGrid, \
NeedsDataField, \
NeedsProperty, \
NeedsParameter, \
FieldUnitsError
from .field_detector import \
FieldDetector
from yt.units.unit_object import \
Unit
def derived_field(**kwargs):
def inner_decorator(function):
if 'name' not in kwargs:
kwargs['name'] = function.__name__
kwargs['function'] = function
add_field(**kwargs)
return function
return inner_decorator
def TranslationFunc(field_name):
def _TranslationFunc(field, data):
# We do a bunch of in-place modifications, so we will copy this.
return data[field_name].copy()
return _TranslationFunc
def NullFunc(field, data):
raise YTFieldNotFound(field.name)
[docs]class DerivedField(object):
"""
This is the base class used to describe a cell-by-cell derived field.
Parameters
----------
name : str
is the name of the field.
function : callable
A function handle that defines the field. Should accept
arguments (field, data)
units : str
A plain text string encoding the unit. Powers must be in
python syntax (** instead of ^). If set to "auto" the units will be
inferred from the units of the return value of the field function.
take_log : bool
Describes whether the field should be logged
validators : list
A list of :class:`FieldValidator` objects
particle_type : bool
Is this a particle (1D) field?
vector_field : bool
Describes the dimensionality of the field. Currently unused.
display_field : bool
Governs its appearance in the dropdowns in Reason
not_in_all : bool
Used for baryon fields from the data that are not in all the grids
display_name : str
A name used in the plots
output_units : str
For fields that exist on disk, which we may want to convert to other
fields or that get aliased to themselves, we can specify a different
desired output unit than the unit found on disk.
"""
[docs] def __init__(self, name, function, units=None,
take_log=True, validators=None,
particle_type=False, vector_field=False, display_field=True,
not_in_all=False, display_name=None, output_units = None):
self.name = name
self.take_log = take_log
self.display_name = display_name
self.not_in_all = not_in_all
self.display_field = display_field
self.particle_type = particle_type
self.vector_field = vector_field
if output_units is None: output_units = units
self.output_units = output_units
self._function = function
if validators:
self.validators = ensure_list(validators)
else:
self.validators = []
# handle units
if units is None:
self.units = ''
elif isinstance(units, str):
if units.lower() == 'auto':
self.units = None
else:
self.units = units
elif isinstance(units, Unit):
self.units = str(units)
else:
raise FieldUnitsError("Cannot handle units '%s' (type %s)." \
"Please provide a string or Unit " \
"object." % (units, type(units)) )
def _copy_def(self):
dd = {}
dd['name'] = self.name
dd['units'] = self.units
dd['take_log'] = self.take_log
dd['validators'] = list(self.validators)
dd['particle_type'] = self.particle_type
dd['vector_field'] = self.vector_field
dd['display_field'] = True
dd['not_in_all'] = self.not_in_all
dd['display_name'] = self.display_name
return dd
[docs] def get_units(self):
u = Unit(self.units)
return u.latex_representation()
[docs] def get_projected_units(self):
u = Unit(self.units)*Unit('cm')
return u.latex_representation()
[docs] def check_available(self, data):
"""
This raises an exception of the appropriate type if the set of
validation mechanisms are not met, and otherwise returns True.
"""
for validator in self.validators:
validator(data)
# If we don't get an exception, we're good to go
return True
[docs] def get_dependencies(self, *args, **kwargs):
"""
This returns a list of names of fields that this field depends on.
"""
e = FieldDetector(*args, **kwargs)
if self._function.__name__ == '<lambda>':
e.requested.append(self.name)
else:
e[self.name]
return e
_unit_registry = None
@contextlib.contextmanager
[docs] def unit_registry(self, data):
old_registry = self._unit_registry
if hasattr(data, 'unit_registry'):
ur = data.unit_registry
elif hasattr(data, 'ds'):
ur = data.ds.unit_registry
else:
ur = None
self._unit_registry = ur
yield
self._unit_registry = old_registry
def __call__(self, data):
""" Return the value of the field in a given *data* object. """
ii = self.check_available(data)
original_fields = data.keys() # Copy
if self._function is NullFunc:
raise RuntimeError(
"Something has gone terribly wrong, _function is NullFunc " +
"for %s" % (self.name,))
with self.unit_registry(data):
dd = self._function(self, data)
for field_name in data.keys():
if field_name not in original_fields:
del data[field_name]
return dd
[docs] def get_source(self):
"""
Return a string containing the source of the function (if possible.)
"""
return inspect.getsource(self._function)
[docs] def get_label(self, projected=False):
"""
Return a data label for the given field, including units.
"""
name = self.name[1]
if self.display_name is not None:
name = self.display_name
# Start with the field name
data_label = r"$\rm{%s}" % name
# Grab the correct units
if projected:
raise NotImplementedError
else:
units = Unit(self.units)
# Add unit label
if not units.is_dimensionless:
data_label += r"\ \ (%s)" % (units)
data_label += r"$"
return data_label
class FieldValidator(object):
pass
[docs]class ValidateParameter(FieldValidator):
[docs] def __init__(self, parameters):
"""
This validator ensures that the dataset has a given parameter.
"""
FieldValidator.__init__(self)
self.parameters = ensure_list(parameters)
def __call__(self, data):
doesnt_have = []
for p in self.parameters:
if not data.has_field_parameter(p):
doesnt_have.append(p)
if len(doesnt_have) > 0:
raise NeedsParameter(doesnt_have)
return True
[docs]class ValidateDataField(FieldValidator):
[docs] def __init__(self, field):
"""
This validator ensures that the output file has a given data field stored
in it.
"""
FieldValidator.__init__(self)
self.fields = ensure_list(field)
def __call__(self, data):
doesnt_have = []
if isinstance(data, FieldDetector): return True
for f in self.fields:
if f not in data.index.field_list:
doesnt_have.append(f)
if len(doesnt_have) > 0:
raise NeedsDataField(doesnt_have)
return True
[docs]class ValidateProperty(FieldValidator):
[docs] def __init__(self, prop):
"""
This validator ensures that the data object has a given python attribute.
"""
FieldValidator.__init__(self)
self.prop = ensure_list(prop)
def __call__(self, data):
doesnt_have = []
for p in self.prop:
if not hasattr(data,p):
doesnt_have.append(p)
if len(doesnt_have) > 0:
raise NeedsProperty(doesnt_have)
return True
[docs]class ValidateSpatial(FieldValidator):
[docs] def __init__(self, ghost_zones = 0, fields=None):
"""
This validator ensures that the data handed to the field is of spatial
nature -- that is to say, 3-D.
"""
FieldValidator.__init__(self)
self.ghost_zones = ghost_zones
self.fields = fields
def __call__(self, data):
# When we say spatial information, we really mean
# that it has a three-dimensional data structure
#if isinstance(data, FieldDetector): return True
if not getattr(data, '_spatial', False):
raise NeedsGridType(self.ghost_zones,self.fields)
if self.ghost_zones <= data._num_ghost_zones:
return True
raise NeedsGridType(self.ghost_zones,self.fields)
[docs]class ValidateGridType(FieldValidator):
[docs] def __init__(self):
"""
This validator ensures that the data handed to the field is an actual
grid patch, not a covering grid of any kind.
"""
FieldValidator.__init__(self)
def __call__(self, data):
# We need to make sure that it's an actual AMR grid
if isinstance(data, FieldDetector): return True
if getattr(data, "_type_name", None) == 'grid': return True
raise NeedsOriginalGrid()