diff --git a/esbo_etc/classes/SpectralQty.py b/esbo_etc/classes/SpectralQty.py index 055ab1a..4100395 100644 --- a/esbo_etc/classes/SpectralQty.py +++ b/esbo_etc/classes/SpectralQty.py @@ -4,6 +4,8 @@ import astropy.units as u import math from typing import Union import logging +from astropy.io import ascii +import re class SpectralQty: @@ -11,7 +13,7 @@ class SpectralQty: A class to hold and work with spectral quantities """ - def __init__(self, wl: u.Quantity, qty: u.Quantity, extrapolate: bool = False): + def __init__(self, wl: u.Quantity, qty: u.Quantity, extrapolate: bool = False) -> "SpectralQty": """ Initialize a new spectral quantity @@ -24,6 +26,10 @@ class SpectralQty: they are assumed to be dimensionless. extrapolate : bool Whether extrapolation should be allowed. If disabled, the spectrum will be truncated and a warning given. + Returns + ------- + sqty : SpectralQty + The created spectral quantity. """ # Check if both lengths are equal if len(wl) == len(qty): @@ -40,6 +46,49 @@ class SpectralQty: error("Lengths not matching") self._extrapolate = extrapolate + @classmethod + def fromFile(cls, file: str, wl_unit_default: u.Quantity = None, qty_unit_default: u.Quantity = None, + extrapolate: bool = False) -> "SpectralQty": + """ + Initialize a new spectral quantity and read the values from a file + + Parameters + ---------- + file : str + Path to the file to read the values from. The file needs to provide two columns: wavelength + and the corresponding spectral quantity. The format of the file will be guessed by + `astropy.io.ascii.read(). If the file doesn't provide units via astropy's enhanced CSV format, the units + will be read from the column headers or otherwise assumed to be *wl_unit_default* and *qty_unit_default*. + wl_unit_default : Quantity + Default unit to be used for the wavelength column if no units are provided by the file. + qty_unit_default : Quantity + Default unit to be used for the quantity column if no units are provided by the file. + extrapolate : bool + Whether extrapolation should be allowed. If disabled, the spectrum will be truncated and a warning given. + Returns + ------- + sqty : SpectralQty + The created spectral quantity. + """ + # Read the file + data = ascii.read(file) + # Check if units are given + if data[data.colnames[0]].unit is None: + # Convert values to float + data[data.colnames[0]] = list(map(float, data[data.colnames[0]])) + data[data.colnames[1]] = list(map(float, data[data.colnames[1]])) + # Check if units are given in column headers + if all([re.search("\\[.+\\]", x) for x in data.colnames]): + # Extract units from headers and apply them on the columns + units = [u.Unit(re.findall("(?<=\\[).+(?=\\])", x)[0]) for x in data.colnames] + data[data.colnames[0]].unit = units[0] + data[data.colnames[1]].unit = units[1] + # Use default units + elif wl_unit_default is not None and qty_unit_default is not None: + data[data.colnames[0]].unit = wl_unit_default + data[data.colnames[1]].unit = qty_unit_default + return cls(data[data.colnames[0]].quantity, data[data.colnames[1]].quantity, extrapolate=extrapolate) + def __eq__(self, other) -> bool: """ Check if this object is equal to another object @@ -59,7 +108,7 @@ class SpectralQty: all([math.isclose(x, y, rel_tol=1e-5) for x, y in zip(self.wl.value, other.wl.value)]) and \ all([math.isclose(x, y, rel_tol=1e-5) for x, y in zip(self.qty.value, other.qty.value)]) - def __add__(self, other: Union[int, float, u.Quantity, "SpectralQty"]): + def __add__(self, other: Union[int, float, u.Quantity, "SpectralQty"]) -> "SpectralQty": """ Calculate the sum with another object @@ -103,7 +152,7 @@ class SpectralQty: __radd__ = __add__ - def __sub__(self, other: Union[int, float, u.Quantity, "SpectralQty"]): + def __sub__(self, other: Union[int, float, u.Quantity, "SpectralQty"]) -> "SpectralQty": """ Calculate the difference to another object @@ -145,7 +194,7 @@ class SpectralQty: else: error("Units are not matching for substraction.") - def __mul__(self, other: Union[int, float, u.Quantity, "SpectralQty"]): + def __mul__(self, other: Union[int, float, u.Quantity, "SpectralQty"]) -> "SpectralQty": """ Calculate the product with another object diff --git a/tests/test_SpectralQty.py b/tests/test_SpectralQty.py index 4e10c49..d30890e 100644 --- a/tests/test_SpectralQty.py +++ b/tests/test_SpectralQty.py @@ -97,3 +97,14 @@ class TestSpectralQty(TestCase): sqty_res = SpectralQty(wl_new[:2], [1.15, 1.35] << u.W / (u.m ** 2 * u.nm)) sqty_rebin = self.sqty.rebin(wl_new) self.assertEqual(sqty_rebin, sqty_res) + + def test_fromFile(self): + sqty = SpectralQty.fromFile("data/target/target_demo_1.csv", u.nm, u.W / (u.m ** 2 * u.nm)) + print(sqty.qty) + print(sqty.wl) + res = SpectralQty(np.arange(1.1, 2.1, 0.1) * 1e-15 << u.W / (u.m**2 * u.nm), + np.arange(200, 210, 1) << u.nm) + print(res.qty) + print(res.wl) + self.assertEqual(sqty, SpectralQty(np.arange(200, 210, 1) << u.nm, + np.arange(1.1, 2.1, 0.1) * 1e-15 << u.W / (u.m**2 * u.nm)))