Function to read a CSV file with units added

This commit is contained in:
Lukas Klass 2020-05-18 10:36:21 +02:00
parent 2bedde4bf7
commit d57add3b9c

View File

@ -2,6 +2,10 @@ import logging
import sys import sys
import traceback import traceback
import numpy as np import numpy as np
from astropy.io import ascii
from astropy.table import Table
import astropy.units as u
import re
def error(msg: str, exit_: bool = True): def error(msg: str, exit_: bool = True):
@ -87,3 +91,42 @@ def rasterizeCircle(grid: np.ndarray, radius: float, xc: float, yc: float):
# ax.add_artist(circle) # ax.add_artist(circle)
# plt.show() # plt.show()
return grid return grid
def readCSV(file: str, units: list = None, format_: str = None) -> Table:
"""
Read a CSV file and parse the units in the header
Parameters
----------
file : str
The path to the file to read.
units : list
A list of the default units for the columns.
format_ : str
The format to be used for reading (see also astropy table formats).
Returns
-------
data : Table
The read table as astropy Table object.
"""
# Read the file
data = ascii.read(file, format=format_)
# Check if units are given
if data[data.colnames[0]].unit is None:
# Convert values to float
for i in range(len(data.columns)):
data[data.colnames[i]] = list(map(float, data[data.colnames[i]]))
# Check if units are given in column headers
if all([re.search("\\[.+\\]", x) for x in data.colnames]):
# Extract units from headers and apply them on the columns
# noinspection PyArgumentList
units = [u.Unit(re.findall("(?<=\\[).+(?=\\])", x)[0]) for x in data.colnames]
for i in range(len(data.columns)):
data[data.colnames[i]].unit = units[i]
# Use default units
elif units is not None and len(units) == len(data.columns):
for i in range(len(data.columns)):
data[data.colnames[i]].unit = units[i]
return data