Source code for yt.units.yt_array

"""
YTArray class.



"""
from __future__ import print_function
#-----------------------------------------------------------------------------
# Copyright (c) 2013, yt Development Team.
#
# Distributed under the terms of the Modified BSD License.
#
# The full license is in the file COPYING.txt, distributed with this software.
#-----------------------------------------------------------------------------

import copy
import numpy as np

from functools import wraps
from numpy import \
    add, subtract, multiply, divide, logaddexp, logaddexp2, true_divide, \
    floor_divide, negative, power, remainder, mod, fmod, absolute, rint, \
    sign, conj, exp, exp2, log, log2, log10, expm1, log1p, sqrt, square, \
    reciprocal, ones_like, sin, cos, tan, arcsin, arccos, arctan, arctan2, \
    hypot, sinh, cosh, tanh, arcsinh, arccosh, arctanh, deg2rad, rad2deg, \
    greater, greater_equal, less, less_equal, not_equal, equal, logical_and, \
    logical_or, logical_xor, logical_not, maximum, minimum, isreal, iscomplex, \
    isfinite, isinf, isnan, signbit, copysign, nextafter, modf, frexp, \
    floor, ceil, trunc, fmax, fmin

from yt.units.unit_object import Unit, UnitParseError
from yt.units.unit_registry import UnitRegistry
from yt.units.dimensions import dimensionless, current_mks, em_dimensions
from yt.utilities.exceptions import \
    YTUnitOperationError, YTUnitConversionError, \
    YTUfuncUnitError, YTIterableUnitCoercionError, \
    YTInvalidUnitEquivalence, YTEquivalentDimsError
from numbers import Number as numeric_type
from yt.utilities.on_demand_imports import _astropy
from sympy import Rational
from yt.units.unit_lookup_table import unit_prefixes, prefixable_units
from yt.units.equivalencies import equivalence_registry

NULL_UNIT = Unit()

# redefine this here to avoid a circular import from yt.funcs
def iterable(obj):
    try: len(obj)
    except: return False
    return True

def return_arr(func):
    @wraps(func)
    def wrapped(*args, **kwargs):
        ret, units = func(*args, **kwargs)
        if ret.shape == ():
            return YTQuantity(ret, units)
        else:
            # This could be a subclass, so don't call YTArray directly.
            return type(args[0])(ret, units)
    return wrapped

def sqrt_unit(unit):
    return unit**0.5

def multiply_units(unit1, unit2):
    return unit1 * unit2

def preserve_units(unit1, unit2):
    return unit1

def power_unit(unit, power):
    return unit**power

def square_unit(unit):
    return unit*unit

def divide_units(unit1, unit2):
    return unit1/unit2

def reciprocal_unit(unit):
    return unit**-1

def passthrough_unit(unit):
    return unit

def return_without_unit(unit):
    return None

def arctan2_unit(unit1, unit2):
    return NULL_UNIT

def comparison_unit(unit1, unit2):
    return None

def coerce_iterable_units(input_object):
    if isinstance(input_object, np.ndarray):
        return input_object
    if iterable(input_object):
        if any([isinstance(o, YTArray) for o in input_object]):
            ff = getattr(input_object[0], 'units', NULL_UNIT, )
            if any([ff != getattr(_, 'units', NULL_UNIT) for _ in input_object]):
                raise YTIterableUnitCoercionError(input_object)
            # This will create a copy of the data in the iterable.
            return YTArray(input_object)
        return input_object
    else:
        return input_object

def sanitize_units_mul(this_object, other_object):
    inp = coerce_iterable_units(this_object)
    ret = coerce_iterable_units(other_object)
    # If the other object is a YTArray and has the same dimensions as the object
    # under consideration, convert so we don't mix units with the same
    # dimensions.
    if isinstance(ret, YTArray):
        if inp.units.same_dimensions_as(ret.units):
            ret.in_units(inp.units)
    return ret

def sanitize_units_add(this_object, other_object, op_string):
    inp = coerce_iterable_units(this_object)
    ret = coerce_iterable_units(other_object)
    # Make sure the other object is a YTArray before we use the `units`
    # attribute.
    if isinstance(ret, YTArray):
        if not inp.units.same_dimensions_as(ret.units):
            raise YTUnitOperationError(op_string, inp.units, ret.units)
        ret = ret.in_units(inp.units)
    # If the other object is not a YTArray, the only valid case is adding
    # dimensionless things.
    else:
        if not inp.units.is_dimensionless:
            raise YTUnitOperationError(op_string, inp.units, dimensionless)
    return ret

unary_operators = (
    negative, absolute, rint, ones_like, sign, conj, exp, exp2, log, log2,
    log10, expm1, log1p, sqrt, square, reciprocal, sin, cos, tan, arcsin,
    arccos, arctan, sinh, cosh, tanh, arcsinh, arccosh, arctanh, deg2rad,
    rad2deg, logical_not, isreal, iscomplex, isfinite, isinf, isnan,
    signbit, floor, ceil, trunc, modf, frexp,
)

binary_operators = (
    add, subtract, multiply, divide, logaddexp, logaddexp2, true_divide, power,
    remainder, mod, arctan2, hypot, greater, greater_equal, less, less_equal,
    not_equal, equal, logical_and, logical_or, logical_xor, maximum, minimum,
    fmax, fmin, copysign, nextafter, fmod,
)

[docs]class YTArray(np.ndarray): """ An ndarray subclass that attaches a symbolic unit object to the array data. Parameters ---------- input_array : iterable A tuple, list, or array to attach units to input_units : String unit specification, unit symbol object, or astropy units The units of the array. Powers must be specified using python syntax (cm**3, not cm^3). registry : A UnitRegistry object The registry to create units from. If input_units is already associated with a unit registry and this is specified, this will be used instead of the registry associated with the unit object. dtype : string or NumPy dtype object The dtype of the array data. Examples -------- >>> from yt import YTArray >>> a = YTArray([1, 2, 3], 'cm') >>> b = YTArray([4, 5, 6], 'm') >>> a + b YTArray([ 401., 502., 603.]) cm >>> b + a YTArray([ 4.01, 5.02, 6.03]) m NumPy ufuncs will pass through units where appropriate. >>> import numpy as np >>> a = YTArray(np.arange(8), 'g/cm**3') >>> np.ones_like(a) YTArray([1, 1, 1, 1, 1, 1, 1, 1]) g/cm**3 and strip them when it would be annoying to deal with them. >>> np.log10(a) array([ -inf, 0. , 0.30103 , 0.47712125, 0.60205999, 0.69897 , 0.77815125, 0.84509804]) YTArray is tightly integrated with yt datasets: >>> import yt >>> ds = yt.load('IsolatedGalaxy/galaxy0030/galaxy0030') >>> a = ds.arr(np.ones(5), 'code_length') >>> a.in_cgs() YTArray([ 3.08600000e+24, 3.08600000e+24, 3.08600000e+24, 3.08600000e+24, 3.08600000e+24]) cm This is equivalent to: >>> b = YTArray(np.ones(5), 'code_length', registry=ds.unit_registry) >>> np.all(a == b) True """ _ufunc_registry = { add: preserve_units, subtract: preserve_units, multiply: multiply_units, divide: divide_units, logaddexp: return_without_unit, logaddexp2: return_without_unit, true_divide: divide_units, floor_divide: divide_units, negative: passthrough_unit, power: power_unit, remainder: preserve_units, mod: preserve_units, fmod: preserve_units, absolute: passthrough_unit, rint: return_without_unit, sign: return_without_unit, conj: passthrough_unit, exp: return_without_unit, exp2: return_without_unit, log: return_without_unit, log2: return_without_unit, log10: return_without_unit, expm1: return_without_unit, log1p: return_without_unit, sqrt: sqrt_unit, square: square_unit, reciprocal: reciprocal_unit, ones_like: passthrough_unit, sin: return_without_unit, cos: return_without_unit, tan: return_without_unit, sinh: return_without_unit, cosh: return_without_unit, tanh: return_without_unit, arcsin: return_without_unit, arccos: return_without_unit, arctan: return_without_unit, arctan2: arctan2_unit, arcsinh: return_without_unit, arccosh: return_without_unit, arctanh: return_without_unit, hypot: preserve_units, deg2rad: return_without_unit, rad2deg: return_without_unit, greater: comparison_unit, greater_equal: comparison_unit, less: comparison_unit, less_equal: comparison_unit, not_equal: comparison_unit, equal: comparison_unit, logical_and: comparison_unit, logical_or: comparison_unit, logical_xor: comparison_unit, logical_not: return_without_unit, maximum: preserve_units, minimum: preserve_units, fmax: preserve_units, fmin: preserve_units, isreal: return_without_unit, iscomplex: return_without_unit, isfinite: return_without_unit, isinf: return_without_unit, isnan: return_without_unit, signbit: return_without_unit, copysign: passthrough_unit, nextafter: preserve_units, modf: passthrough_unit, frexp: return_without_unit, floor: passthrough_unit, ceil: passthrough_unit, trunc: passthrough_unit, } __array_priority__ = 2.0 def __new__(cls, input_array, input_units=None, registry=None, dtype=None): if dtype is None: dtype = getattr(input_array, 'dtype', np.float64) if input_array is NotImplemented: return input_array if registry is None and isinstance(input_units, (str, bytes)): if input_units.startswith('code_'): raise UnitParseError( "Code units used without referring to a dataset. \n" "Perhaps you meant to do something like this instead: \n" "ds.arr(%s, \"%s\")" % (input_array, input_units) ) if isinstance(input_array, YTArray): if input_units is None: if registry is None: pass else: input_array.units.registry = registry elif isinstance(input_units, Unit): input_array.units = input_units else: input_array.units = Unit(input_units, registry=registry) return input_array elif isinstance(input_array, np.ndarray): pass elif iterable(input_array) and input_array: if isinstance(input_array[0], YTArray): return YTArray(np.array(input_array, dtype=dtype), input_array[0].units) # Input array is an already formed ndarray instance # We first cast to be our class type obj = np.asarray(input_array, dtype=dtype).view(cls) # Check units type if input_units is None: # Nothing provided. Make dimensionless... units = Unit() elif isinstance(input_units, Unit): units = input_units else: # units kwarg set, but it's not a Unit object. # don't handle all the cases here, let the Unit class handle if # it's a str. units = Unit(input_units, registry=registry) # Attach the units obj.units = units return obj def __array_finalize__(self, obj): """ """ if obj is None and hasattr(self, 'units'): return self.units = getattr(obj, 'units', NULL_UNIT) def __repr__(self): """ """ return super(YTArray, self).__repr__()+' '+self.units.__repr__() def __str__(self): """ """ return super(YTArray, self).__str__()+' '+self.units.__str__() # # Start unit conversion methods # def _unit_repr_check_same(self, units): """ Takes a Unit object, or string of known unit symbol, and check that it is compatible with this quantity. Returns Unit object. """ # let Unit() handle units arg if it's not already a Unit obj. if not isinstance(units, Unit): units = Unit(units, registry=self.units.registry) equiv_dims = em_dimensions.get(self.units.dimensions,None) if equiv_dims == units.dimensions: if current_mks in equiv_dims.free_symbols: base = "SI" else: base = "CGS" raise YTEquivalentDimsError(self.units, units, base) if not self.units.same_dimensions_as(units): raise YTUnitConversionError( self.units, self.units.dimensions, units, units.dimensions) return units
[docs] def convert_to_units(self, units): """ Convert the array and units to the given units. Parameters ---------- units : Unit object or str The units you want to convert to. """ new_units = self._unit_repr_check_same(units) (conversion_factor, offset) = self.units.get_conversion_factor(new_units) self.units = new_units self *= conversion_factor if offset: np.subtract(self, offset*self.uq, self) return self
def convert_to_base(self): """ Convert the array and units to the equivalent base units. """ return self.convert_to_units(self.units.get_base_equivalent())
[docs] def convert_to_cgs(self): """ Convert the array and units to the equivalent cgs units. """ return self.convert_to_units(self.units.get_cgs_equivalent())
[docs] def convert_to_mks(self): """ Convert the array and units to the equivalent mks units. """ return self.convert_to_units(self.units.get_mks_equivalent())
[docs] def in_units(self, units): """ Creates a copy of this array with the data in the supplied units, and returns it. Parameters ---------- units : Unit object or string The units you want to get a new quantity in. Returns ------- YTArray """ new_units = self._unit_repr_check_same(units) (conversion_factor, offset) = self.units.get_conversion_factor(new_units) new_array = self * conversion_factor new_array.units = new_units if offset: np.subtract(new_array, offset*new_array.uq, new_array) return new_array
def in_base(self): """ Creates a copy of this array with the data in the equivalent base units, and returns it. Returns ------- Quantity object with data converted to cgs units. """ return self.in_units(self.units.get_base_equivalent())
[docs] def in_cgs(self): """ Creates a copy of this array with the data in the equivalent cgs units, and returns it. Returns ------- Quantity object with data converted to cgs units. """ return self.in_units(self.units.get_cgs_equivalent())
[docs] def in_mks(self): """ Creates a copy of this array with the data in the equivalent mks units, and returns it. Returns ------- Quantity object with data converted to mks units. """ return self.in_units(self.units.get_mks_equivalent())
def to_equivalent(self, unit, equiv, **kwargs): """ Convert a YTArray or YTQuantity to an equivalent, e.g., something that is related by only a constant factor but not in the same units. Parameters ---------- unit : string The unit that you wish to convert to. equiv : string The equivalence you wish to use. To see which equivalencies are supported for this unitful quantity, try the :meth:`list_equivalencies` method. Examples -------- >>> a = yt.YTArray(1.0e7,"K") >>> a.to_equivalent("keV", "thermal") """ unit_quan = YTQuantity(1.0, unit, registry=self.units.registry) this_equiv = equivalence_registry[equiv]() if self.has_equivalent(equiv) and (unit_quan.has_equivalent(equiv) or this_equiv._one_way): new_arr = this_equiv.convert(self, unit_quan.units.dimensions, **kwargs) if isinstance(new_arr, tuple): try: return YTArray(new_arr[0], new_arr[1]).in_units(unit) except YTUnitConversionError: raise YTInvalidUnitEquivalence(equiv, self.units, unit) else: return new_arr.in_units(unit) else: raise YTInvalidUnitEquivalence(equiv, self.units, unit) def list_equivalencies(self): """ Lists the possible equivalencies associated with this YTArray or YTQuantity. """ for k,v in equivalence_registry.items(): if self.has_equivalent(k): print(v()) def has_equivalent(self, equiv): """ Check to see if this YTArray or YTQuantity has an equivalent unit in *equiv*. """ try: this_equiv = equivalence_registry[equiv]() except KeyError: raise KeyError("No such equivalence \"%s\"." % equiv) old_dims = self.units.dimensions return old_dims in this_equiv.dims
[docs] def ndarray_view(self): """ Returns a view into the array, but as an ndarray rather than ytarray. Returns ------- View of this array's data. """ return self.view(np.ndarray)
[docs] def to_ndarray(self): """ Creates a copy of this array with the unit information stripped """ return np.array(self)
@classmethod
[docs] def from_astropy(cls, arr): """ Creates a new YTArray with the same unit information from an AstroPy quantity *arr*. """ # Converting from AstroPy Quantity u = arr.unit ap_units = [] for base, power in zip(u.bases, u.powers): unit_str = base.to_string() # we have to do this because AstroPy is silly and defines # hour as "h" if unit_str == "h": unit_str = "hr" ap_units.append("%s**(%s)" % (unit_str, Rational(power))) ap_units = "*".join(ap_units) if isinstance(arr.value, np.ndarray): return YTArray(arr.value, ap_units) else: return YTQuantity(arr.value, ap_units)
[docs] def to_astropy(self, **kwargs): """ Creates a new AstroPy quantity with the same unit information. """ if _astropy.units is None: raise ImportError("You don't have AstroPy installed, so you can't convert to " + "an AstroPy quantity.") return self.value*_astropy.units.Unit(str(self.units), **kwargs) # # End unit conversion methods #
[docs] def write_hdf5(self, filename, dataset_name=None, info=None): r"""Writes ImageArray to hdf5 file. Parameters ---------- filename: string The filename to create and write a dataset to dataset_name: string The name of the dataset to create in the file. info: dictionary A dictionary of supplementary info to write to append as attributes to the dataset. Examples -------- >>> a = YTArray([1,2,3], 'cm') >>> myinfo = {'field':'dinosaurs', 'type':'field_data'} >>> a.write_hdf5('test_array_data.h5', dataset_name='dinosaurs', ... info=myinfo) """ import h5py from yt.extern.six.moves import cPickle as pickle if info is None: info = {} info['units'] = str(self.units) info['unit_registry'] = np.void(pickle.dumps(self.units.registry.lut)) if dataset_name is None: dataset_name = 'array_data' f = h5py.File(filename) if dataset_name in f.keys(): d = f[dataset_name] # Overwrite without deleting if we can get away with it. if d.shape == self.shape and d.dtype == self.dtype: d[:] = self for k in d.attrs.keys(): del d.attrs[k] else: del f[dataset_name] d = f.create_dataset(dataset_name, data=self) else: d = f.create_dataset(dataset_name, data=self) for k, v in info.items(): d.attrs[k] = v f.close()
@classmethod
[docs] def from_hdf5(cls, filename, dataset_name=None): r"""Attempts read in and convert a dataset in an hdf5 file into a YTArray. Parameters ---------- filename: string The filename to of the hdf5 file. dataset_name: string The name of the dataset to read from. If the dataset has a units attribute, attempt to infer units as well. """ import h5py from yt.extern.six.moves import cPickle as pickle if dataset_name is None: dataset_name = 'array_data' f = h5py.File(filename) dataset = f[dataset_name] data = dataset[:] units = dataset.attrs.get('units', '') if 'unit_registry' in dataset.attrs.keys(): unit_lut = pickle.loads(dataset.attrs['unit_registry'].tostring()) else: unit_lut = None f.close() registry = UnitRegistry(lut=unit_lut, add_default_symbols=False) return cls(data, units, registry=registry) # # Start convenience methods #
@property def value(self): """Get a copy of the array data as a numpy ndarray""" return np.array(self) v = value @property def ndview(self): """Get a view of the array data.""" return self.ndarray_view() d = ndview @property def unit_quantity(self): """Get a YTQuantity with the same unit as this array and a value of 1.0""" return YTQuantity(1.0, self.units) uq = unit_quantity @property def unit_array(self): """Get a YTArray filled with ones with the same unit and shape as this array""" return np.ones_like(self) ua = unit_array # # Start operation methods # def __add__(self, right_object): """ Add this ytarray to the object on the right of the `+` operator. Must check for the correct (same dimension) units. """ ro = sanitize_units_add(self, right_object, "addition") return YTArray(super(YTArray, self).__add__(ro)) def __radd__(self, left_object): """ See __add__. """ lo = sanitize_units_add(self, left_object, "addition") return YTArray(super(YTArray, self).__radd__(lo)) def __iadd__(self, other): """ See __add__. """ oth = sanitize_units_add(self, other, "addition") np.add(self, oth, out=self) return self def __sub__(self, right_object): """ Subtract the object on the right of the `-` from this ytarray. Must check for the correct (same dimension) units. """ ro = sanitize_units_add(self, right_object, "subtraction") return YTArray(super(YTArray, self).__sub__(ro)) def __rsub__(self, left_object): """ See __sub__. """ lo = sanitize_units_add(self, left_object, "subtraction") return YTArray(super(YTArray, self).__rsub__(lo)) def __isub__(self, other): """ See __sub__. """ oth = sanitize_units_add(self, other, "subtraction") np.subtract(self, oth, out=self) return self def __neg__(self): """ Negate the data. """ return YTArray(super(YTArray, self).__neg__()) def __pos__(self): """ Posify the data. """ return YTArray(super(YTArray, self).__pos__(), self.units) def __mul__(self, right_object): """ Multiply this YTArray by the object on the right of the `*` operator. The unit objects handle being multiplied. """ ro = sanitize_units_mul(self, right_object) return YTArray(super(YTArray, self).__mul__(ro)) def __rmul__(self, left_object): """ See __mul__. """ lo = sanitize_units_mul(self, left_object) return YTArray(super(YTArray, self).__rmul__(lo)) def __imul__(self, other): """ See __mul__. """ oth = sanitize_units_mul(self, other) np.multiply(self, oth, out=self) return self def __div__(self, right_object): """ Divide this YTArray by the object on the right of the `/` operator. """ ro = sanitize_units_mul(self, right_object) return YTArray(super(YTArray, self).__div__(ro)) def __rdiv__(self, left_object): """ See __div__. """ lo = sanitize_units_mul(self, left_object) return YTArray(super(YTArray, self).__rdiv__(lo)) def __idiv__(self, other): """ See __div__. """ oth = sanitize_units_mul(self, other) np.divide(self, oth, out=self) return self def __truediv__(self, right_object): ro = sanitize_units_mul(self, right_object) return YTArray(super(YTArray, self).__truediv__(ro)) def __rtruediv__(self, left_object): """ See __div__. """ lo = sanitize_units_mul(self, left_object) return YTArray(super(YTArray, self).__rtruediv__(lo)) def __itruediv__(self, other): """ See __div__. """ oth = sanitize_units_mul(self, other) np.true_divide(self, oth, out=self) return self def __floordiv__(self, right_object): ro = sanitize_units_mul(self, right_object) return YTArray(super(YTArray, self).__floordiv__(ro)) def __rfloordiv__(self, left_object): """ See __div__. """ lo = sanitize_units_mul(self, left_object) return YTArray(super(YTArray, self).__rfloordiv__(lo)) def __ifloordiv__(self, other): """ See __div__. """ oth = sanitize_units_mul(self, other) np.floor_divide(self, oth, out=self) return self #Should these raise errors? I need to come back and check this. def __or__(self, right_object): return YTArray(super(YTArray, self).__or__(right_object)) def __ror__(self, left_object): return YTArray(super(YTArray, self).__ror__(left_object)) def __ior__(self, other): np.bitwise_or(self, other, out=self) return self def __xor__(self, right_object): return YTArray(super(YTArray, self).__xor__(right_object)) def __rxor__(self, left_object): return YTArray(super(YTArray, self).__rxor__(left_object)) def __ixor__(self, other): np.bitwise_xor(self, other, out=self) return self def __and__(self, right_object): return YTArray(super(YTArray, self).__and__(right_object)) def __rand__(self, left_object): return YTArray(super(YTArray, self).__rand__(left_object)) def __iand__(self, other): np.bitwise_and(self, other, out=self) return self def __pow__(self, power): """ Raise this YTArray to some power. Parameters ---------- power : float or dimensionless YTArray. The pow value. """ if isinstance(power, YTArray): if not power.units.is_dimensionless: raise YTUnitOperationError('power', power.unit) # Work around a sympy issue (I think?) # # If I don't do this, super(YTArray, self).__pow__ returns a YTArray # with a unit attribute set to the sympy expression 1/1 rather than a # dimensionless Unit object. if self.units.is_dimensionless and power == -1: ret = super(YTArray, self).__pow__(power) return YTArray(ret, input_units='') return YTArray(super(YTArray, self).__pow__(power)) def __abs__(self): """ Return a YTArray with the abs of the data. """ return YTArray(super(YTArray, self).__abs__())
[docs] def sqrt(self): """ Return sqrt of this YTArray. We take the sqrt for the array and use take the 1/2 power of the units. """ return YTArray(super(YTArray, self).sqrt(), input_units=self.units**0.5) # # Start comparison operators. # # @todo: outsource to a single method with an op argument.
def __lt__(self, other): """ Test if this is less than the object on the right. """ # Check that other is a YTArray. if isinstance(other, YTArray): if not self.units.same_dimensions_as(other.units): raise YTUnitOperationError('less than', self.units, other.units) return np.array(self).__lt__(np.array(other.in_units(self.units))) return np.array(self).__lt__(np.array(other)) def __le__(self, other): """ Test if this is less than or equal to the object on the right. """ # Check that other is a YTArray. if isinstance(other, YTArray): if not self.units.same_dimensions_as(other.units): raise YTUnitOperationError('less than or equal', self.units, other.units) return np.array(self).__le__(np.array(other.in_units(self.units))) return np.array(self).__le__(np.array(other)) def __eq__(self, other): """ Test if this is equal to the object on the right. """ # Check that other is a YTArray. if other is None: # self is a YTArray, so it can't be None. return False if isinstance(other, YTArray): if not self.units.same_dimensions_as(other.units): raise YTUnitOperationError("equal", self.units, other.units) return np.array(self).__eq__(np.array(other.in_units(self.units))) return np.array(self).__eq__(np.array(other)) def __ne__(self, other): """ Test if this is not equal to the object on the right. """ # Check that the other is a YTArray. if other is None: return True if isinstance(other, YTArray): if not self.units.same_dimensions_as(other.units): raise YTUnitOperationError("not equal", self.units, other.units) return np.array(self).__ne__(np.array(other.in_units(self.units))) return np.array(self).__ne__(np.array(other)) def __ge__(self, other): """ Test if this is greater than or equal to other. """ # Check that the other is a YTArray. if isinstance(other, YTArray): if not self.units.same_dimensions_as(other.units): raise YTUnitOperationError("greater than or equal", self.units, other.units) return np.array(self).__ge__(np.array(other.in_units(self.units))) return np.array(self).__ge__(np.array(other)) def __gt__(self, other): """ Test if this is greater than the object on the right. """ # Check that the other is a YTArray. if isinstance(other, YTArray): if not self.units.same_dimensions_as(other.units): raise YTUnitOperationError("greater than", self.units, other.units) return np.array(self).__gt__(np.array(other.in_units(self.units))) return np.array(self).__gt__(np.array(other)) # # End comparison operators # # # Begin reduction operators # @return_arr
[docs] def prod(self, axis=None, dtype=None, out=None): if axis is not None: units = self.units**self.shape[axis] else: units = self.units**self.size return super(YTArray, self).prod(axis, dtype, out), units
@return_arr
[docs] def mean(self, axis=None, dtype=None, out=None): return super(YTArray, self).mean(axis, dtype, out), self.units
@return_arr
[docs] def sum(self, axis=None, dtype=None, out=None): return super(YTArray, self).sum(axis, dtype, out), self.units
@return_arr
[docs] def dot(self, b, out=None): return super(YTArray, self).dot(b), self.units*b.units
@return_arr
[docs] def std(self, axis=None, dtype=None, out=None, ddof=0): return super(YTArray, self).std(axis, dtype, out, ddof), self.units
def __getitem__(self, item): ret = super(YTArray, self).__getitem__(item) if ret.shape == (): return YTQuantity(ret, self.units) else: return ret def __array_wrap__(self, out_arr, context=None): ret = super(YTArray, self).__array_wrap__(out_arr, context) if isinstance(ret, YTQuantity) and ret.shape != (): ret = ret.view(YTArray) if context is None: if ret.shape == (): return ret[()] else: return ret elif context[0] in unary_operators: u = getattr(context[1][0], 'units', None) if u is None: u = NULL_UNIT unit = self._ufunc_registry[context[0]](u) ret_class = type(self) elif context[0] in binary_operators: oper1 = coerce_iterable_units(context[1][0]) oper2 = coerce_iterable_units(context[1][1]) cls1 = type(oper1) cls2 = type(oper2) unit1 = getattr(oper1, 'units', None) unit2 = getattr(oper2, 'units', None) ret_class = get_binary_op_return_class(cls1, cls2) if unit1 is None: unit1 = Unit(registry=getattr(unit2, 'registry', None)) if unit2 is None and context[0] is not power: unit2 = Unit(registry=getattr(unit1, 'registry', None)) elif context[0] is power: unit2 = oper2 if isinstance(unit2, np.ndarray): if isinstance(unit2, YTArray): if unit2.units.is_dimensionless: pass else: raise YTUnitOperationError(context[0], unit1, unit2) unit2 = 1.0 unit_operator = self._ufunc_registry[context[0]] if unit_operator in (preserve_units, comparison_unit, arctan2_unit): if unit1 != unit2: if not unit1.same_dimensions_as(unit2): raise YTUnitOperationError(context[0], unit1, unit2) else: raise YTUfuncUnitError(context[0], unit1, unit2) unit = self._ufunc_registry[context[0]](unit1, unit2) if unit_operator in (multiply_units, divide_units): if unit.is_dimensionless and unit.base_value != 1.0: if not unit1.is_dimensionless: if unit1.dimensions == unit2.dimensions: np.multiply(out_arr.view(np.ndarray), unit.base_value, out=out_arr) unit = Unit(registry=unit.registry) else: raise RuntimeError("Operation is not defined.") if unit is None: out_arr = np.array(out_arr, copy=False) return out_arr out_arr.units = unit if out_arr.size == 1: return YTQuantity(np.array(out_arr), unit) else: if ret_class is YTQuantity: # This happens if you do ndarray * YTQuantity. Explicitly # casting to YTArray avoids creating a YTQuantity with size > 1 return YTArray(np.array(out_arr, unit)) return ret_class(np.array(out_arr, copy=False), unit) def __reduce__(self): """Pickle reduction method See the documentation for the standard library pickle module: http://docs.python.org/2/library/pickle.html Unit metadata is encoded in the zeroth element of third element of the returned tuple, itself a tuple used to restore the state of the ndarray. This is always defined for numpy arrays. """ np_ret = super(YTArray, self).__reduce__() obj_state = np_ret[2] unit_state = (((str(self.units), self.units.registry.lut),) + obj_state[:],) new_ret = np_ret[:2] + unit_state + np_ret[3:] return new_ret def __setstate__(self, state): """Pickle setstate method This is called inside pickle.read() and restores the unit data from the metadata extracted in __reduce__ and then serialized by pickle. """ super(YTArray, self).__setstate__(state[1:]) unit, lut = state[0] registry = UnitRegistry(lut=lut, add_default_symbols=False) self.units = Unit(unit, registry=registry) def __deepcopy__(self, memodict=None): """copy.deepcopy implementation This is necessary for stdlib deepcopy of arrays and quantities. """ if memodict is None: memodict = {} ret = super(YTArray, self).__deepcopy__(memodict) return type(self)(ret, copy.deepcopy(self.units))
[docs]class YTQuantity(YTArray): """ A scalar associated with a unit. Parameters ---------- input_scalar : an integer or floating point scalar The scalar to attach units to input_units : String unit specification, unit symbol object, or astropy units The units of the quantity. Powers must be specified using python syntax (cm**3, not cm^3). registry : A UnitRegistry object The registry to create units from. If input_units is already associated with a unit registry and this is specified, this will be used instead of the registry associated with the unit object. dtype : string or NumPy dtype object The dtype of the array data. Examples -------- >>> from yt import YTQuantity >>> a = YTQuantity(1, 'cm') >>> b = YTQuantity(2, 'm') >>> a + b 201.0 cm >>> b + a 2.01 m NumPy ufuncs will pass through units where appropriate. >>> import numpy as np >>> a = YTQuantity(12, 'g/cm**3') >>> np.ones_like(a) 1 g/cm**3 and strip them when it would be annoying to deal with them. >>> print(np.log10(a)) 1.07918124605 YTQuantity is tightly integrated with yt datasets: >>> import yt >>> ds = yt.load('IsolatedGalaxy/galaxy0030/galaxy0030') >>> a = ds.quan(5, 'code_length') >>> a.in_cgs() 1.543e+25 cm This is equivalent to: >>> b = YTQuantity(5, 'code_length', registry=ds.unit_registry) >>> np.all(a == b) True """ def __new__(cls, input_scalar, input_units=None, registry=None, dtype=np.float64): if not isinstance(input_scalar, (numeric_type, np.number, np.ndarray)): raise RuntimeError("YTQuantity values must be numeric") ret = YTArray.__new__(cls, input_scalar, input_units, registry, dtype=dtype) if ret.size > 1: raise RuntimeError("YTQuantity instances must be scalars") return ret def __repr__(self): return str(self)
def validate_numpy_wrapper_units(v, arrs): if not any(isinstance(a, YTArray) for a in arrs): return v if not all(isinstance(a, YTArray) for a in arrs): raise RuntimeError("Not all of your arrays are YTArrays.") a1 = arrs[0] if not all(a.units == a1.units for a in arrs[1:]): raise RuntimeError("Your arrays must have identical units.") v.units = a1.units return v def uconcatenate(arrs, axis=0): """Concatenate a sequence of arrays. This wrapper around numpy.concatenate preserves units. All input arrays must have the same units. See the documentation of numpy.concatenate for full details. Examples -------- >>> A = yt.YTArray([1, 2, 3], 'cm') >>> B = yt.YTArray([2, 3, 4], 'cm') >>> uconcatenate((A, B)) YTArray([ 1., 2., 3., 2., 3., 4.]) cm """ v = np.concatenate(arrs, axis=axis) v = validate_numpy_wrapper_units(v, arrs) return v def ucross(arr1,arr2, registry=None): """Applies the cross product to two YT arrays. This wrapper around numpy.cross preserves units. See the documentation of numpy.cross for full details. """ v = np.cross(arr1,arr2) units = arr1.units * arr2.units arr = YTArray(v,units, registry=registry) return arr def uintersect1d(arr1, arr2, assume_unique=False): """Find the sorted unique elements of the two input arrays. A wrapper around numpy.intersect1d that preserves units. All input arrays must have the same units. See the documentation of numpy.intersect1d for full details. Examples -------- >>> A = yt.YTArray([1, 2, 3], 'cm') >>> B = yt.YTArray([2, 3, 4], 'cm') >>> uintersect1d(A, B) YTArray([ 2., 3.]) cm """ v = np.intersect1d(arr1, arr2, assume_unique=assume_unique) v = validate_numpy_wrapper_units(v, [arr1, arr2]) return v def uunion1d(arr1, arr2): """Find the union of two arrays. A wrapper around numpy.intersect1d that preserves units. All input arrays must have the same units. See the documentation of numpy.intersect1d for full details. Examples -------- >>> A = yt.YTArray([1, 2, 3], 'cm') >>> B = yt.YTArray([2, 3, 4], 'cm') >>> uunion1d(A, B) YTArray([ 1., 2., 3., 4.]) cm """ v = np.union1d(arr1, arr2) v = validate_numpy_wrapper_units(v, [arr1, arr2]) return v def array_like_field(data, x, field): field = data._determine_fields(field)[0] if isinstance(field, tuple): units = data.ds._get_field_info(field[0],field[1]).units else: units = data.ds._get_field_info(field).units if isinstance(x, YTArray): arr = copy.deepcopy(x) arr.convert_to_units(units) return arr if isinstance(x, np.ndarray): return data.ds.arr(x, units) else: return data.ds.quan(x, units) def get_binary_op_return_class(cls1, cls2): if cls1 is cls2: return cls1 if cls1 is np.ndarray or issubclass(cls1, (numeric_type, np.number, list, tuple)): return cls2 if cls2 is np.ndarray or issubclass(cls2, (numeric_type, np.number, list, tuple)): return cls1 if issubclass(cls1, YTQuantity): return cls2 if issubclass(cls2, YTQuantity): return cls1 if issubclass(cls1, cls2): return cls1 if issubclass(cls2, cls1): return cls2 else: raise RuntimeError("Undefined operation for a YTArray subclass. " "Received operand types (%s) and (%s)" % (cls1, cls2))