"""
The basic field info container resides here. These classes, code specific and
universal, are the means by which we access fields across YT, both derived and
native.
"""
#-----------------------------------------------------------------------------
# Copyright (c) 2013, yt Development Team.
#
# Distributed under the terms of the Modified BSD License.
#
# The full license is in the file COPYING.txt, distributed with this software.
#-----------------------------------------------------------------------------
import numpy as np
from numbers import Number as numeric_type
from yt.funcs import mylog, only_on_root
from yt.units.unit_object import Unit
from .derived_field import \
DerivedField, \
NullFunc, \
TranslationFunc
from yt.utilities.exceptions import \
YTFieldNotFound
from .field_plugin_registry import \
field_plugins
from .particle_fields import \
particle_deposition_functions, \
particle_vector_functions, \
particle_scalar_functions, \
standard_particle_fields, \
add_volume_weighted_smoothed_field, \
sph_whitelist_fields
[docs]class FieldInfoContainer(dict):
"""
This is a generic field container. It contains a list of potential derived
fields, all of which know how to act on a data object and return a value.
This object handles converting units as well as validating the availability
of a given field.
"""
fallback = None
known_other_fields = ()
known_particle_fields = ()
[docs] def __init__(self, ds, field_list, slice_info = None):
self._show_field_errors = []
self.ds = ds
# Now we start setting things up.
self.field_list = field_list
self.slice_info = slice_info
self.field_aliases = {}
self.species_names = []
self.setup_fluid_aliases()
[docs] def setup_fluid_fields(self):
pass
[docs] def setup_particle_fields(self, ptype, ftype='gas', num_neighbors=64 ):
skip_output_units = ("code_length",)
for f, (units, aliases, dn) in sorted(self.known_particle_fields):
units = self.ds.field_units.get((ptype, f), units)
if (f in aliases or ptype not in self.ds.particle_types_raw) and \
units not in skip_output_units:
u = Unit(units, registry = self.ds.unit_registry)
output_units = str(u.get_cgs_equivalent())
else:
output_units = units
if (ptype, f) not in self.field_list:
continue
self.add_output_field((ptype, f),
units = units, particle_type = True,
display_name = dn, output_units = output_units)
for alias in aliases:
self.alias((ptype, alias), (ptype, f), units = output_units)
# We'll either have particle_position or particle_position_[xyz]
if (ptype, "particle_position") in self.field_list or \
(ptype, "particle_position") in self.field_aliases:
particle_scalar_functions(ptype,
"particle_position", "particle_velocity",
self)
else:
# We need to check to make sure that there's a "known field" that
# overlaps with one of the vector fields. For instance, if we are
# in the Stream frontend, and we have a set of scalar position
# fields, they will overlap with -- and be overridden by -- the
# "known" vector field that the frontend creates. So the easiest
# thing to do is to simply remove the on-disk field (which doesn't
# exist) and replace it with a derived field.
if (ptype, "particle_position") in self and \
self[ptype, "particle_position"]._function == NullFunc:
self.pop((ptype, "particle_position"))
particle_vector_functions(ptype,
["particle_position_%s" % ax for ax in 'xyz'],
["particle_velocity_%s" % ax for ax in 'xyz'],
self)
particle_deposition_functions(ptype, "particle_position",
"particle_mass", self)
standard_particle_fields(self, ptype)
# Now we check for any leftover particle fields
for field in sorted(self.field_list):
if field in self: continue
if not isinstance(field, tuple):
raise RuntimeError
if field[0] not in self.ds.particle_types:
continue
self.add_output_field(field,
units = self.ds.field_units.get(field, ""),
particle_type = True)
self.setup_smoothed_fields(ptype,
num_neighbors=num_neighbors,
ftype=ftype)
[docs] def setup_smoothed_fields(self, ptype, num_neighbors = 64, ftype = "gas"):
# We can in principle compute this, but it is not yet implemented.
if (ptype, "density") not in self:
return
if (ptype, "smoothing_length") in self:
sml_name = "smoothing_length"
else:
sml_name = None
new_aliases = []
for ptype2, alias_name in self.keys():
if ptype2 != ptype:
continue
if alias_name not in sph_whitelist_fields:
continue
fn = add_volume_weighted_smoothed_field(
ptype, "particle_position", "particle_mass",
sml_name, "density", alias_name, self,
num_neighbors)
if 'particle_' in alias_name:
alias_name = alias_name.replace('particle_', '')
new_aliases.append(((ftype, alias_name), fn[0]))
for alias, source in new_aliases:
#print "Aliasing %s => %s" % (alias, source)
self.alias(alias, source)
[docs] def setup_fluid_aliases(self):
known_other_fields = dict(self.known_other_fields)
for field in sorted(self.field_list):
if not isinstance(field, tuple):
raise RuntimeError
if field[0] in self.ds.particle_types:
continue
args = known_other_fields.get(
field[1], ("", [], None))
units, aliases, display_name = args
# We allow field_units to override this. First we check if the
# field *name* is in there, then the field *tuple*.
units = self.ds.field_units.get(field[1], units)
units = self.ds.field_units.get(field, units)
if not isinstance(units, str) and args[0] != "":
units = "((%s)*%s)" % (args[0], units)
if isinstance(units, (numeric_type, np.number, np.ndarray)) and \
args[0] == "" and units != 1.0:
mylog.warning("Cannot interpret units: %s * %s, " +
"setting to dimensionless.", units, args[0])
units = ""
elif units == 1.0:
units = ""
self.add_output_field(field, units = units,
display_name = display_name)
for alias in aliases:
self.alias(("gas", alias), field)
[docs] def add_field(self, name, function=None, **kwargs):
"""
Add a new field, along with supplemental metadata, to the list of
available fields. This respects a number of arguments, all of which
are passed on to the constructor for
:class:`~yt.data_objects.api.DerivedField`.
Parameters
----------
name : str
is the name of the field.
function : callable
A function handle that defines the field. Should accept
arguments (field, data)
units : str
A plain text string encoding the unit. Powers must be in
python syntax (** instead of ^). If set to "auto" the units
will be inferred from the return value of the field function.
take_log : bool
Describes whether the field should be logged
validators : list
A list of :class:`FieldValidator` objects
particle_type : bool
Is this a particle (1D) field?
vector_field : bool
Describes the dimensionality of the field. Currently unused.
display_name : str
A name used in the plots
"""
override = kwargs.pop("force_override", False)
# Handle the case where the field has already been added.
if not override and name in self:
# See below.
if function is None:
def create_function(f):
return f
return create_function
return
# add_field can be used in two different ways: it can be called
# directly, or used as a decorator. If called directly, the
# function will be passed in as an argument, and we simply create
# the derived field and exit. If used as a decorator, function will
# be None. In that case, we return a function that will be applied
# to the function that the decorator is applied to.
if function is None:
def create_function(f):
self[name] = DerivedField(name, f, **kwargs)
return f
return create_function
self[name] = DerivedField(name, function, **kwargs)
[docs] def load_all_plugins(self, ftype="gas"):
loaded = []
for n in sorted(field_plugins):
loaded += self.load_plugin(n, ftype)
only_on_root(mylog.info, "Loaded %s (%s new fields)",
n, len(loaded))
self.find_dependencies(loaded)
[docs] def load_plugin(self, plugin_name, ftype = "gas", skip_check = False):
if callable(plugin_name):
f = plugin_name
else:
f = field_plugins[plugin_name]
orig = set(self.items())
f(self, ftype, slice_info = self.slice_info)
loaded = [n for n, v in set(self.items()).difference(orig)]
return loaded
[docs] def find_dependencies(self, loaded):
deps, unavailable = self.check_derived_fields(loaded)
self.ds.field_dependencies.update(deps)
# Note we may have duplicated
dfl = set(self.ds.derived_field_list).union(deps.keys())
self.ds.derived_field_list = list(sorted(dfl))
return loaded, unavailable
[docs] def add_output_field(self, name, **kwargs):
self[name] = DerivedField(name, NullFunc, **kwargs)
[docs] def alias(self, alias_name, original_name, units = None):
if original_name not in self: return
if units is None:
# We default to CGS here, but in principle, this can be pluggable
# as well.
u = Unit(self[original_name].units,
registry = self.ds.unit_registry)
units = str(u.get_cgs_equivalent())
self.field_aliases[alias_name] = original_name
self.add_field(alias_name,
function = TranslationFunc(original_name),
particle_type = self[original_name].particle_type,
display_name = self[original_name].display_name,
units = units)
[docs] def has_key(self, key):
# This gets used a lot
if key in self: return True
if self.fallback is None: return False
return key in self.fallback
def __missing__(self, key):
if self.fallback is None:
raise KeyError("No field named %s" % (key,))
return self.fallback[key]
@classmethod
[docs] def create_with_fallback(cls, fallback, name = ""):
obj = cls()
obj.fallback = fallback
obj.name = name
return obj
def __contains__(self, key):
if dict.__contains__(self, key): return True
if self.fallback is None: return False
return key in self.fallback
def __iter__(self):
for f in dict.__iter__(self):
yield f
if self.fallback is not None:
for f in self.fallback: yield f
[docs] def keys(self):
keys = dict.keys(self)
if self.fallback:
keys += list(self.fallback.keys())
return keys
[docs] def check_derived_fields(self, fields_to_check = None):
deps = {}
unavailable = []
fields_to_check = fields_to_check or list(self.keys())
for field in fields_to_check:
mylog.debug("Checking %s", field)
if field not in self: raise RuntimeError
fi = self[field]
try:
fd = fi.get_dependencies(ds = self.ds)
except Exception as e:
if field in self._show_field_errors:
raise
if type(e) != YTFieldNotFound:
mylog.debug("Raises %s during field %s detection.",
str(type(e)), field)
self.pop(field)
continue
# This next bit checks that we can't somehow generate everything.
# We also manually update the 'requested' attribute
missing = not all(f in self.field_list for f in fd.requested)
if missing:
self.pop(field)
unavailable.append(field)
continue
fd.requested = set(fd.requested)
deps[field] = fd
mylog.debug("Succeeded with %s (needs %s)", field, fd.requested)
dfl = set(self.ds.derived_field_list).union(deps.keys())
self.ds.derived_field_list = list(sorted(dfl))
return deps, unavailable