diff --git a/esbo_etc/classes/SpectralQty.py b/esbo_etc/classes/SpectralQty.py index c5be7f7..43993f3 100644 --- a/esbo_etc/classes/SpectralQty.py +++ b/esbo_etc/classes/SpectralQty.py @@ -1,8 +1,8 @@ -from esbo_etc.lib.helpers import error +from esbo_etc.lib.helpers import error, isLambda from scipy.interpolate import interp1d import astropy.units as u import math -from typing import Union +from typing import Union, Callable import logging from astropy.io import ascii import re @@ -108,13 +108,13 @@ class SpectralQty: all([math.isclose(x, y, rel_tol=1e-5) for x, y in zip(self.wl.value, other.wl.value)]) and \ all([math.isclose(x, y, rel_tol=1e-5) for x, y in zip(self.qty.value, other.qty.value)]) - def __add__(self, other: Union[int, float, u.Quantity, "SpectralQty"]) -> "SpectralQty": + def __add__(self, other: Union[int, float, u.Quantity, "SpectralQty", Callable]) -> "SpectralQty": """ Calculate the sum with another object Parameters ---------- - other : Union[int, float, u.Quantity, "SpectralQty"] + other : Union[int, float, u.Quantity, "SpectralQty", Callable] Addend to be added to this object. If the binning of the object on the right hand side differs from the binning of the left object, the object on the right hand side will be rebinned. @@ -132,6 +132,9 @@ class SpectralQty: return SpectralQty(self.wl, self.qty + other) else: raise TypeError("Units are not matching for addition.") + # Summand is of type lambda + elif isLambda(other): + return SpectralQty(self.wl, self.qty + [other(wl).value for wl in self.wl] * other(self.wl[0]).unit) # Summand is of type SpectralQty else: if other.wl.unit.is_equivalent(self.wl.unit) and other.qty.unit.is_equivalent(self.qty.unit): diff --git a/tests/test_SpectralQty.py b/tests/test_SpectralQty.py index 717d50d..18703eb 100644 --- a/tests/test_SpectralQty.py +++ b/tests/test_SpectralQty.py @@ -75,6 +75,10 @@ class TestSpectralQty(TestCase): self.assertEqual( self.sqty + SpectralQty(np.arange(200.5, 204.5, 1) << u.nm, np.arange(1, 5, 1) << u.W / (u.m ** 2 * u.nm)), SpectralQty(range(201, 204) << u.nm, [2.7, 3.8, 4.9] << u.W / (u.m ** 2 * u.nm))) + # lambda + sqty_2 = lambda wl: 1 * u.W / (u.m ** 2 * u.nm ** 2) * wl + self.assertEqual(self.sqty + sqty_2, + SpectralQty(self.wl, [201.1, 202.2, 203.3, 204.4] << u.W / (u.m**2 * u.nm))) def test_rebinning(self): # Test interpolation