Multiply and subtraction with lambdas, code clean up

This commit is contained in:
Lukas Klass 2020-04-15 15:38:05 +02:00
parent f28c785d64
commit 92fa74d7f2

View File

@ -1,4 +1,4 @@
from esbo_etc.lib.helpers import error, isLambda from ..lib.helpers import error, isLambda
from scipy.interpolate import interp1d from scipy.interpolate import interp1d
import astropy.units as u import astropy.units as u
import math import math
@ -8,12 +8,13 @@ from astropy.io import ascii
import re import re
# noinspection PyUnresolvedReferences
class SpectralQty: class SpectralQty:
""" """
A class to hold and work with spectral quantities A class to hold and work with spectral quantities
""" """
def __init__(self, wl: u.Quantity, qty: u.Quantity, extrapolate: bool = False) -> "SpectralQty": def __init__(self, wl: u.Quantity, qty: u.Quantity, extrapolate: bool = False):
""" """
Initialize a new spectral quantity Initialize a new spectral quantity
@ -80,6 +81,7 @@ class SpectralQty:
# Check if units are given in column headers # Check if units are given in column headers
if all([re.search("\\[.+\\]", x) for x in data.colnames]): if all([re.search("\\[.+\\]", x) for x in data.colnames]):
# Extract units from headers and apply them on the columns # Extract units from headers and apply them on the columns
# noinspection PyArgumentList
units = [u.Unit(re.findall("(?<=\\[).+(?=\\])", x)[0]) for x in data.colnames] units = [u.Unit(re.findall("(?<=\\[).+(?=\\])", x)[0]) for x in data.colnames]
data[data.colnames[0]].unit = units[0] data[data.colnames[0]].unit = units[0]
data[data.colnames[1]].unit = units[1] data[data.colnames[1]].unit = units[1]
@ -104,9 +106,9 @@ class SpectralQty:
Result of the comparison Result of the comparison
""" """
return self.wl.unit == other.wl.unit and self.qty.unit == other.qty.unit and \ return self.wl.unit == other.wl.unit and self.qty.unit == other.qty.unit and \
len(self.wl) == len(other.wl) and len(self.qty) == len(other.qty) and \ len(self.wl) == len(other.wl) and len(self.qty) == len(other.qty) and \
all([math.isclose(x, y, rel_tol=1e-5) for x, y in zip(self.wl.value, other.wl.value)]) and \ all([math.isclose(x, y, rel_tol=1e-5) for x, y in zip(self.wl.value, other.wl.value)]) and \
all([math.isclose(x, y, rel_tol=1e-5) for x, y in zip(self.qty.value, other.qty.value)]) all([math.isclose(x, y, rel_tol=1e-5) for x, y in zip(self.qty.value, other.qty.value)])
def __add__(self, other: Union[int, float, u.Quantity, "SpectralQty", Callable]) -> "SpectralQty": def __add__(self, other: Union[int, float, u.Quantity, "SpectralQty", Callable]) -> "SpectralQty":
""" """
@ -139,13 +141,13 @@ class SpectralQty:
else: else:
if other.wl.unit.is_equivalent(self.wl.unit) and other.qty.unit.is_equivalent(self.qty.unit): if other.wl.unit.is_equivalent(self.wl.unit) and other.qty.unit.is_equivalent(self.qty.unit):
# Wavelengths are matching, just add the quantities # Wavelengths are matching, just add the quantities
if len(self.wl) == len(other.wl) and all(self.wl == other.wl): if len(self.wl) == len(other.wl) and (self.wl == other.wl).all():
return SpectralQty(self.wl, self.qty + other.qty) return SpectralQty(self.wl, self.qty + other.qty)
# Wavelengths are not matching, rebinning needed # Wavelengths are not matching, rebinning needed
else: else:
# Rebin addend # Rebin addend
other_rebinned = other.rebin(self.wl) other_rebinned = other.rebin(self.wl)
if len(self.wl) == len(other_rebinned.wl) and all(self.wl == other_rebinned.wl): if len(self.wl) == len(other_rebinned.wl) and (self.wl == other_rebinned.wl).all():
return SpectralQty(self.wl, self.qty + other_rebinned.qty) return SpectralQty(self.wl, self.qty + other_rebinned.qty)
else: else:
# Wavelengths are still not matching as extrapolation is disabled, rebin this spectral quantity # Wavelengths are still not matching as extrapolation is disabled, rebin this spectral quantity
@ -155,13 +157,13 @@ class SpectralQty:
__radd__ = __add__ __radd__ = __add__
def __sub__(self, other: Union[int, float, u.Quantity, "SpectralQty"]) -> "SpectralQty": def __sub__(self, other: Union[int, float, u.Quantity, "SpectralQty", Callable]) -> "SpectralQty":
""" """
Calculate the difference to another object Calculate the difference to another object
Parameters Parameters
---------- ----------
other : Union[int, float, u.Quantity, "SpectralQty"] other : Union[int, float, u.Quantity, "SpectralQty", Callable]
Subtrahend to be subtracted from this object. If the binning of the object on the right hand side differs Subtrahend to be subtracted from this object. If the binning of the object on the right hand side differs
from the binning of the left object, the object on the right hand side will be rebinned. from the binning of the left object, the object on the right hand side will be rebinned.
@ -179,17 +181,20 @@ class SpectralQty:
return SpectralQty(self.wl, self.qty - other) return SpectralQty(self.wl, self.qty - other)
else: else:
raise TypeError('Units are not matching for subtraction.') raise TypeError('Units are not matching for subtraction.')
# Subtrahend is of type lambda
elif isLambda(other):
return SpectralQty(self.wl, self.qty - [other(wl).value for wl in self.wl] * other(self.wl[0]).unit)
# Subtrahend is of type SpectralQty # Subtrahend is of type SpectralQty
else: else:
if other.wl.unit.is_equivalent(self.wl.unit) and other.qty.unit.is_equivalent(self.qty.unit): if other.wl.unit.is_equivalent(self.wl.unit) and other.qty.unit.is_equivalent(self.qty.unit):
# Wavelengths are matching, just subtract the quantities # Wavelengths are matching, just subtract the quantities
if len(self.wl) == len(other.wl) and all(self.wl == other.wl): if len(self.wl) == len(other.wl) and (self.wl == other.wl).all():
return SpectralQty(self.wl, self.qty - other.qty) return SpectralQty(self.wl, self.qty - other.qty)
# Wavelengths are not matching, rebinning needed # Wavelengths are not matching, rebinning needed
else: else:
# Rebin subtrahend # Rebin subtrahend
other_rebinned = other.rebin(self.wl) other_rebinned = other.rebin(self.wl)
if len(self.wl) == len(other_rebinned.wl) and all(self.wl == other_rebinned.wl): if len(self.wl) == len(other_rebinned.wl) and (self.wl == other_rebinned.wl).all():
return SpectralQty(self.wl, self.qty - other_rebinned.qty) return SpectralQty(self.wl, self.qty - other_rebinned.qty)
else: else:
# Wavelengths are still not matching as extrapolation is disabled, rebin this spectral quantity # Wavelengths are still not matching as extrapolation is disabled, rebin this spectral quantity
@ -197,13 +202,13 @@ class SpectralQty:
else: else:
error("Units are not matching for substraction.") error("Units are not matching for substraction.")
def __mul__(self, other: Union[int, float, u.Quantity, "SpectralQty"]) -> "SpectralQty": def __mul__(self, other: Union[int, float, u.Quantity, "SpectralQty", Callable]) -> "SpectralQty":
""" """
Calculate the product with another object Calculate the product with another object
Parameters Parameters
---------- ----------
other : Union[int, float, u.Quantity, "SpectralQty"] other : Union[int, float, u.Quantity, "SpectralQty", Callable]
Factor to be multiplied with this object. If the binning of the object on the right hand side differs Factor to be multiplied with this object. If the binning of the object on the right hand side differs
from the binning of the left object, the object on the right hand side will be rebinned. from the binning of the left object, the object on the right hand side will be rebinned.
@ -215,17 +220,20 @@ class SpectralQty:
# Factor is of type int, float or Quantity, just multiply # Factor is of type int, float or Quantity, just multiply
if isinstance(other, int) or isinstance(other, float) or isinstance(other, u.Quantity): if isinstance(other, int) or isinstance(other, float) or isinstance(other, u.Quantity):
return SpectralQty(self.wl, self.qty * other) return SpectralQty(self.wl, self.qty * other)
# Factor is of type lambda
elif isLambda(other):
return SpectralQty(self.wl, self.qty * [other(wl).value for wl in self.wl] * other(self.wl[0]).unit)
# Factor is of type SpectralQty # Factor is of type SpectralQty
else: else:
if other.wl.unit.is_equivalent(self.wl.unit): if other.wl.unit.is_equivalent(self.wl.unit):
# Wavelengths are matching, just multiply the quantities # Wavelengths are matching, just multiply the quantities
if len(self.wl) == len(other.wl) and all(self.wl == other.wl): if len(self.wl) == len(other.wl) and (self.wl == other.wl).all():
return SpectralQty(self.wl, self.qty * other.qty) return SpectralQty(self.wl, self.qty * other.qty)
# Wavelengths are not matching, rebinning needed # Wavelengths are not matching, rebinning needed
else: else:
# Rebin factor # Rebin factor
other_rebinned = other.rebin(self.wl) other_rebinned = other.rebin(self.wl)
if len(self.wl) == len(other_rebinned.wl) and all(self.wl == other_rebinned.wl): if len(self.wl) == len(other_rebinned.wl) and (self.wl == other_rebinned.wl).all():
return SpectralQty(self.wl, self.qty * other_rebinned.qty) return SpectralQty(self.wl, self.qty * other_rebinned.qty)
else: else:
# Wavelengths are still not matching as extrapolation is disabled, rebin this spectral quantity # Wavelengths are still not matching as extrapolation is disabled, rebin this spectral quantity