From 92fa74d7f2235c06581ed7a161964babacd95052 Mon Sep 17 00:00:00 2001 From: LukasK13 Date: Wed, 15 Apr 2020 15:38:05 +0200 Subject: [PATCH] Multiply and subtraction with lambdas, code clean up --- esbo_etc/classes/SpectralQty.py | 38 ++++++++++++++++++++------------- 1 file changed, 23 insertions(+), 15 deletions(-) diff --git a/esbo_etc/classes/SpectralQty.py b/esbo_etc/classes/SpectralQty.py index 43993f3..4f1f34b 100644 --- a/esbo_etc/classes/SpectralQty.py +++ b/esbo_etc/classes/SpectralQty.py @@ -1,4 +1,4 @@ -from esbo_etc.lib.helpers import error, isLambda +from ..lib.helpers import error, isLambda from scipy.interpolate import interp1d import astropy.units as u import math @@ -8,12 +8,13 @@ from astropy.io import ascii import re +# noinspection PyUnresolvedReferences class SpectralQty: """ A class to hold and work with spectral quantities """ - def __init__(self, wl: u.Quantity, qty: u.Quantity, extrapolate: bool = False) -> "SpectralQty": + def __init__(self, wl: u.Quantity, qty: u.Quantity, extrapolate: bool = False): """ Initialize a new spectral quantity @@ -80,6 +81,7 @@ class SpectralQty: # Check if units are given in column headers if all([re.search("\\[.+\\]", x) for x in data.colnames]): # Extract units from headers and apply them on the columns + # noinspection PyArgumentList units = [u.Unit(re.findall("(?<=\\[).+(?=\\])", x)[0]) for x in data.colnames] data[data.colnames[0]].unit = units[0] data[data.colnames[1]].unit = units[1] @@ -104,9 +106,9 @@ class SpectralQty: Result of the comparison """ return self.wl.unit == other.wl.unit and self.qty.unit == other.qty.unit and \ - len(self.wl) == len(other.wl) and len(self.qty) == len(other.qty) and \ - all([math.isclose(x, y, rel_tol=1e-5) for x, y in zip(self.wl.value, other.wl.value)]) and \ - all([math.isclose(x, y, rel_tol=1e-5) for x, y in zip(self.qty.value, other.qty.value)]) + len(self.wl) == len(other.wl) and len(self.qty) == len(other.qty) and \ + all([math.isclose(x, y, rel_tol=1e-5) for x, y in zip(self.wl.value, other.wl.value)]) and \ + all([math.isclose(x, y, rel_tol=1e-5) for x, y in zip(self.qty.value, other.qty.value)]) def __add__(self, other: Union[int, float, u.Quantity, "SpectralQty", Callable]) -> "SpectralQty": """ @@ -139,13 +141,13 @@ class SpectralQty: else: if other.wl.unit.is_equivalent(self.wl.unit) and other.qty.unit.is_equivalent(self.qty.unit): # Wavelengths are matching, just add the quantities - if len(self.wl) == len(other.wl) and all(self.wl == other.wl): + if len(self.wl) == len(other.wl) and (self.wl == other.wl).all(): return SpectralQty(self.wl, self.qty + other.qty) # Wavelengths are not matching, rebinning needed else: # Rebin addend other_rebinned = other.rebin(self.wl) - if len(self.wl) == len(other_rebinned.wl) and all(self.wl == other_rebinned.wl): + if len(self.wl) == len(other_rebinned.wl) and (self.wl == other_rebinned.wl).all(): return SpectralQty(self.wl, self.qty + other_rebinned.qty) else: # Wavelengths are still not matching as extrapolation is disabled, rebin this spectral quantity @@ -155,13 +157,13 @@ class SpectralQty: __radd__ = __add__ - def __sub__(self, other: Union[int, float, u.Quantity, "SpectralQty"]) -> "SpectralQty": + def __sub__(self, other: Union[int, float, u.Quantity, "SpectralQty", Callable]) -> "SpectralQty": """ Calculate the difference to another object Parameters ---------- - other : Union[int, float, u.Quantity, "SpectralQty"] + other : Union[int, float, u.Quantity, "SpectralQty", Callable] Subtrahend to be subtracted from this object. If the binning of the object on the right hand side differs from the binning of the left object, the object on the right hand side will be rebinned. @@ -179,17 +181,20 @@ class SpectralQty: return SpectralQty(self.wl, self.qty - other) else: raise TypeError('Units are not matching for subtraction.') + # Subtrahend is of type lambda + elif isLambda(other): + return SpectralQty(self.wl, self.qty - [other(wl).value for wl in self.wl] * other(self.wl[0]).unit) # Subtrahend is of type SpectralQty else: if other.wl.unit.is_equivalent(self.wl.unit) and other.qty.unit.is_equivalent(self.qty.unit): # Wavelengths are matching, just subtract the quantities - if len(self.wl) == len(other.wl) and all(self.wl == other.wl): + if len(self.wl) == len(other.wl) and (self.wl == other.wl).all(): return SpectralQty(self.wl, self.qty - other.qty) # Wavelengths are not matching, rebinning needed else: # Rebin subtrahend other_rebinned = other.rebin(self.wl) - if len(self.wl) == len(other_rebinned.wl) and all(self.wl == other_rebinned.wl): + if len(self.wl) == len(other_rebinned.wl) and (self.wl == other_rebinned.wl).all(): return SpectralQty(self.wl, self.qty - other_rebinned.qty) else: # Wavelengths are still not matching as extrapolation is disabled, rebin this spectral quantity @@ -197,13 +202,13 @@ class SpectralQty: else: error("Units are not matching for substraction.") - def __mul__(self, other: Union[int, float, u.Quantity, "SpectralQty"]) -> "SpectralQty": + def __mul__(self, other: Union[int, float, u.Quantity, "SpectralQty", Callable]) -> "SpectralQty": """ Calculate the product with another object Parameters ---------- - other : Union[int, float, u.Quantity, "SpectralQty"] + other : Union[int, float, u.Quantity, "SpectralQty", Callable] Factor to be multiplied with this object. If the binning of the object on the right hand side differs from the binning of the left object, the object on the right hand side will be rebinned. @@ -215,17 +220,20 @@ class SpectralQty: # Factor is of type int, float or Quantity, just multiply if isinstance(other, int) or isinstance(other, float) or isinstance(other, u.Quantity): return SpectralQty(self.wl, self.qty * other) + # Factor is of type lambda + elif isLambda(other): + return SpectralQty(self.wl, self.qty * [other(wl).value for wl in self.wl] * other(self.wl[0]).unit) # Factor is of type SpectralQty else: if other.wl.unit.is_equivalent(self.wl.unit): # Wavelengths are matching, just multiply the quantities - if len(self.wl) == len(other.wl) and all(self.wl == other.wl): + if len(self.wl) == len(other.wl) and (self.wl == other.wl).all(): return SpectralQty(self.wl, self.qty * other.qty) # Wavelengths are not matching, rebinning needed else: # Rebin factor other_rebinned = other.rebin(self.wl) - if len(self.wl) == len(other_rebinned.wl) and all(self.wl == other_rebinned.wl): + if len(self.wl) == len(other_rebinned.wl) and (self.wl == other_rebinned.wl).all(): return SpectralQty(self.wl, self.qty * other_rebinned.qty) else: # Wavelengths are still not matching as extrapolation is disabled, rebin this spectral quantity