From bcb610febe85f3a46f073dbc9b69fc0b55f86a61 Mon Sep 17 00:00:00 2001 From: LukasK13 Date: Thu, 9 Apr 2020 16:53:56 +0200 Subject: [PATCH] Addition, subtraction and multiplication added --- esbo_etc/classes/SpectralQty.py | 167 ++++++++++++++++++++++++-------- tests/test_SpectralQty.py | 70 ++++++++++--- 2 files changed, 184 insertions(+), 53 deletions(-) diff --git a/esbo_etc/classes/SpectralQty.py b/esbo_etc/classes/SpectralQty.py index 89eb6ee..6da746c 100644 --- a/esbo_etc/classes/SpectralQty.py +++ b/esbo_etc/classes/SpectralQty.py @@ -1,9 +1,8 @@ from esbo_etc.lib.helpers import error -# import numpy as np -# from scipy.integrate import cumtrapz from scipy.interpolate import interp1d import astropy.units as u import math +from typing import Union class SpectralQty: @@ -20,10 +19,22 @@ class SpectralQty: wl : Quantity The binned wavelengths qty : Quantity - The quantity values corresponding to the binned wavelengths + The quantity values corresponding to the binned wavelengths. If the values are supplied without a unit, + they are assumed to be dimensionless. """ - self.wl = wl - self.qty = qty + # Check if both lengths are equal + if len(wl) == len(qty): + # check if units are given. If not, add a dimensionless unit + if hasattr(wl, "unit"): + self.wl = wl + else: + self.wl = wl * u.dimensionless_unscaled + if hasattr(qty, "unit"): + self.qty = qty + else: + self.qty = qty * u.dimensionless_unscaled + else: + error("Lengths not matching") def __eq__(self, other) -> bool: """ @@ -40,16 +51,117 @@ class SpectralQty: Result of the comparison """ return self.wl.unit == other.wl.unit and self.qty.unit == other.qty.unit and \ + len(self.wl) == len(other.wl) and len(self.qty) == len(other.qty) and \ all([math.isclose(x, y, rel_tol=1e-5) for x, y in zip(self.wl.value, other.wl.value)]) and \ all([math.isclose(x, y, rel_tol=1e-5) for x, y in zip(self.qty.value, other.qty.value)]) - def add(self, sqty: "SpectralQty"): - pass + def __add__(self, other: Union[int, float, u.Quantity, "SpectralQty"]): + """ + Calculate the sum with another object - def multiply(self, sqty: "SpectralQty"): - pass + Parameters + ---------- + other : Union[int, float, u.Quantity, "SpectralQty"] + Addend to be added to this object. If the binning of the object on the right hand side differs + from the binning of the left object, the object on the right hand side will be rebinned. - def rebin(self, wl: u.Quantity): + Returns + ------- + sum : SpectralQty + The sum of both objects + """ + # Summand is of type int or float, use same unit + if isinstance(other, int) or isinstance(other, float): + return SpectralQty(self.wl, self.qty + other * self.qty.unit) + # Summand is of type Quantity + elif isinstance(other, u.Quantity): + if other.unit == self.qty.unit: + return SpectralQty(self.wl, self.qty + other) + else: + raise TypeError("Units are not matching for addition.") + # Summand is of type SpectralQty + else: + if other.wl.unit.is_equivalent(self.wl.unit) and other.qty.unit.is_equivalent(self.qty.unit): + # Wavelengths are matching, just add the quantities + if len(self.wl) == len(other.wl) and all(self.wl == other.wl): + return SpectralQty(self.wl, self.qty + other.qty) + # Wavelengths are not matching, rebinning needed + else: + return SpectralQty(self.wl, self.qty + other.rebin(self.wl).qty) + else: + error("Units are not matching for addition.") + + __radd__ = __add__ + + def __sub__(self, other: Union[int, float, u.Quantity, "SpectralQty"]): + """ + Calculate the difference to another object + + Parameters + ---------- + other : Union[int, float, u.Quantity, "SpectralQty"] + Subtrahend to be subtracted from this object. If the binning of the object on the right hand side differs + from the binning of the left object, the object on the right hand side will be rebinned. + + Returns + ------- + sum : SpectralQty + The difference of both objects + """ + # Subtrahend is of type int or float, use same unit + if isinstance(other, int) or isinstance(other, float): + return SpectralQty(self.wl, self.qty - other * self.qty.unit) + # Subtrahend is of type Quantity + elif isinstance(other, u.Quantity): + if other.unit == self.qty.unit: + return SpectralQty(self.wl, self.qty - other) + else: + raise TypeError('Units are not matching for subtraction.') + # Subtrahend is of type SpectralQty + else: + if other.wl.unit.is_equivalent(self.wl.unit) and other.qty.unit.is_equivalent(self.qty.unit): + # Wavelengths are matching, just subtract the quantities + if len(self.wl) == len(other.wl) and all(self.wl == other.wl): + return SpectralQty(self.wl, self.qty - other.qty) + # Wavelengths are not matching, rebinning needed + else: + return SpectralQty(self.wl, self.qty - other.rebin(self.wl).qty) + else: + error("Units are not matching for substraction.") + + def __mul__(self, other: Union[int, float, u.Quantity, "SpectralQty"]): + """ + Calculate the product with another object + + Parameters + ---------- + other : Union[int, float, u.Quantity, "SpectralQty"] + Factor to be multiplied with this object. If the binning of the object on the right hand side differs + from the binning of the left object, the object on the right hand side will be rebinned. + + Returns + ------- + sum : SpectralQty + The product of both objects + """ + # Factor is of type int, float or Quantity, just multiply + if isinstance(other, int) or isinstance(other, float) or isinstance(other, u.Quantity): + return SpectralQty(self.wl, self.qty * other) + # Factor is of type SpectralQty + else: + if other.wl.unit.is_equivalent(self.wl.unit): + # Wavelengths are matching, just multiply the quantities + if len(self.wl) == len(other.wl) and all(self.wl == other.wl): + return SpectralQty(self.wl, self.qty * other.qty) + # Wavelengths are not matching, rebinning needed + else: + return SpectralQty(self.wl, self.qty * other.rebin(self.wl).qty) + else: + error("Units are not matching for multiplication.") + + __rmul__ = __mul__ + + def rebin(self, wl: u.Quantity) -> "SpectralQty": """ Resample the spectral quantity sqty(wl) over the new grid wl, rebinning if necessary, otherwise interpolates. Copied from ExoSim (https://github.com/ExoSim/ExoSimPublic). @@ -61,40 +173,11 @@ class SpectralQty: Returns ------- + sqty : SpectralQty + The rebinned spectral quantity """ if wl.unit != self.wl.unit: error("Mismatching units for rebinning: " + wl.unit + ", " + self.wl.unit) - - # idx = np.where(np.logical_and(self.wl > 0.9 * wl.min(), self.wl < 1.1 * wl.max()))[0] - # wl_old = self.wl[idx] - # qty_old = self.qty[idx] - # - # if np.diff(wl_old).min() < np.diff(wl).min(): - # # Binning - # c = cumtrapz(qty_old, x=wl_old) * qty_old.unit * wl_old.unit - # print(c) - # xpc = wl_old[1:] - # - # delta = np.gradient(wl) - # new_c_1 = np.interp(wl - 0.5 * delta, xpc, c, left=0.0, right=0.0) * c.unit - # new_c_2 = np.interp(wl + 0.5 * delta, xpc, c, left=0.0, right=0.0) * c.unit - # qty = (new_c_2 - new_c_1) / delta - # else: - # # Interpolation - # qty = np.interp(wl, wl_old, qty_old, left=0.0, right=0.0) - f = interp1d(self.wl, self.qty, fill_value="extrapolate") - qty = f(wl) * self.qty.unit - - self.wl = wl - self.qty = qty - - # import matplotlib.pyplot as plt - # plt.plot(wl_old, qty_old, '-') - # plt.plot(wl, qty, '.-') - # plt.show() - # # check - # print(np.trapz(qty, wl)) - # idx = np.where(np.logical_and(wl_old >= wl.min(), wl_old <= wl.max())) - # print(np.trapz(qty_old[idx], wl_old[idx])) + return SpectralQty(wl, f(wl) * self.qty.unit) diff --git a/tests/test_SpectralQty.py b/tests/test_SpectralQty.py index d2e7c69..542c857 100644 --- a/tests/test_SpectralQty.py +++ b/tests/test_SpectralQty.py @@ -5,27 +5,75 @@ import numpy as np class TestSpectralQty(TestCase): - qty = np.arange(1.1e-15, 2.0e-15, 1e-16) << u.W / (u.m ** 2 * u.nm) - wl = np.arange(200, 210, 1) << u.nm + qty = np.arange(1.1, 1.5, 0.1) << u.W / (u.m ** 2 * u.nm) + wl = np.arange(200, 204, 1) << u.nm def setUp(self): self.sqty = SpectralQty(self.wl, self.qty) - def test_equality(self): + def test___eq__(self): sqty_2 = SpectralQty(self.wl, self.qty) - self.assertTrue(self.sqty.__eq__(sqty_2)) + self.assertEqual(self.sqty, sqty_2) + + def test___mul__(self): + # Integer + self.assertEqual(self.sqty * 2, SpectralQty(np.arange(200, 204, 1) << u.nm, + np.arange(2.2, 3.0, 2e-1) << u.W / (u.m ** 2 * u.nm))) + self.assertEqual(2 * self.sqty, SpectralQty(np.arange(200, 204, 1) << u.nm, + np.arange(2.2, 3.0, 2e-1) << u.W / (u.m ** 2 * u.nm))) + # Float + self.assertEqual(self.sqty * 2., SpectralQty(np.arange(200, 204, 1) << u.nm, + np.arange(2.2, 3.0, 2e-1) << u.W / (u.m ** 2 * u.nm))) + self.assertEqual(2. * self.sqty, SpectralQty(np.arange(200, 204, 1) << u.nm, + np.arange(2.2, 3.0, 2e-1) << u.W / (u.m ** 2 * u.nm))) + # SpectralQty + self.assertEqual(self.sqty * SpectralQty(self.wl, np.arange(1, 5, 1) << u.m), + SpectralQty(self.wl, [1.1, 2.4, 3.9, 5.6] << u.W / (u.m * u.nm))) + self.assertEqual(SpectralQty(self.wl, np.arange(1, 5, 1) << u.m) * self.sqty, + SpectralQty(self.wl, [1.1, 2.4, 3.9, 5.6] << u.W / (u.m * u.nm))) + # rebin + self.assertEqual(self.sqty * SpectralQty(np.arange(200.5, 204.5, 1) << u.nm, np.arange(1, 5, 1) << u.m), + SpectralQty(self.wl, [0.55, 1.8, 3.25, 4.9] << u.W / (u.m * u.nm))) + + def test___sub__(self): + # Quantity + self.assertEqual(self.sqty - 0.1 * u.W / (u.m ** 2 * u.nm), + SpectralQty(np.arange(200, 204, 1) << u.nm, + np.arange(1.0, 1.4, 0.1) << u.W / (u.m ** 2 * u.nm))) + # SpectralQty + self.assertEqual( + self.sqty - SpectralQty(np.arange(200, 204, 1) << u.nm, np.arange(1, 5, 1) << u.W / (u.m ** 2 * u.nm)), + SpectralQty(self.wl, [0.1, -0.8, -1.7, -2.6] * u.W / (u.m ** 2 * u.nm))) + # rebin + self.assertEqual( + self.sqty - SpectralQty(np.arange(200.5, 204.5, 1) << u.nm, np.arange(1, 5, 1) << u.W / (u.m ** 2 * u.nm)), + SpectralQty(self.wl, [0.6, -0.3, -1.2, -2.1] * u.W / (u.m ** 2 * u.nm))) + + def test___add__(self): + # Quantity + self.assertEqual(self.sqty + 1.0 * u.W / (u.m ** 2 * u.nm), + SpectralQty(np.arange(200, 204, 1) << u.nm, + np.arange(2.1, 2.5, 0.1) << u.W / (u.m ** 2 * u.nm))) + # SpectralQty + self.assertEqual( + self.sqty + SpectralQty(np.arange(200, 204, 1) << u.nm, np.arange(1, 5, 1) << u.W / (u.m ** 2 * u.nm)), + SpectralQty(self.wl, [2.1, 3.2, 4.3, 5.4] * u.W / (u.m ** 2 * u.nm))) + # rebin + self.assertEqual( + self.sqty + SpectralQty(np.arange(200.5, 204.5, 1) << u.nm, np.arange(1, 5, 1) << u.W / (u.m ** 2 * u.nm)), + SpectralQty(self.wl, [1.6, 2.7, 3.8, 4.9] * u.W / (u.m ** 2 * u.nm))) def test_rebinning(self): # Test interpolation wl_new = np.arange(200.5, 210.5, 1) << u.nm - sqty_new = SpectralQty(wl_new, [1.15e-15, 1.25e-15, 1.35e-15, 1.45e-15, 1.55e-15, 1.65e-15, 1.75e-15, 1.85e-15, - 1.95e-15, 2.05e-15] << u.W / (u.m ** 2 * u.nm)) - self.sqty.rebin(wl_new) - self.assertTrue(self.sqty.__eq__(sqty_new)) + sqty_res = SpectralQty(wl_new, [1.15, 1.25, 1.35, 1.45, 1.55, 1.65, 1.75, 1.85, + 1.95, 2.05] << u.W / (u.m ** 2 * u.nm)) + sqty_rebin = self.sqty.rebin(wl_new) + self.assertEqual(sqty_rebin, sqty_res) # Test binning self.setUp() wl_new = np.arange(200.5, 210, 2) << u.nm - sqty_new = SpectralQty(wl_new, [1.15e-15, 1.35e-15, 1.55e-15, 1.75e-15, 1.95e-15] << u.W / (u.m ** 2 * u.nm)) - self.sqty.rebin(wl_new) - self.assertTrue(self.sqty.__eq__(sqty_new)) + sqty_res = SpectralQty(wl_new, [1.15, 1.35, 1.55, 1.75, 1.95] << u.W / (u.m ** 2 * u.nm)) + sqty_rebin = self.sqty.rebin(wl_new) + self.assertEqual(sqty_rebin, sqty_res)