From 15fb785b2d5f19d1642cd29cc8faeee3717e8a32 Mon Sep 17 00:00:00 2001 From: LukasK13 Date: Wed, 8 Apr 2020 09:45:59 +0200 Subject: [PATCH] Code clean up --- esbo_etc/classes/config.py | 131 +++++++++++++++++-------------------- 1 file changed, 61 insertions(+), 70 deletions(-) diff --git a/esbo_etc/classes/config.py b/esbo_etc/classes/config.py index b91ac18..f163983 100644 --- a/esbo_etc/classes/config.py +++ b/esbo_etc/classes/config.py @@ -1,54 +1,54 @@ import xml.etree.ElementTree as eT import numpy as np -import quantities as pq +import astropy.units as u import os import logging -import sys +from esbo_etc.lib.helpers import error class Entry(object): """ A class used to represent a configuration entry. + Copied from ExoSim (https://github.com/ExoSim/ExoSimPublic) """ - val = None - attrib = None - xml_entry = None - def __call__(self): - return self.val + return self.val if hasattr(self, "val") else None def parse(self, xml): """ - Parse a XML tree element + Parse attributes of a XML element - :param xml: XML element tree to parse + :param xml: XML element to parse the attributes from """ - self.attrib = xml.attrib - for attr in self.attrib.keys(): - setattr(self, attr, self.attrib[attr]) + # Copy the XML attributes to object attributes + for attrib in xml.attrib.keys(): + setattr(self, attrib, xml.attrib[attrib]) + # Convert to python datatype and apply the corresponding unit (if applicable) if hasattr(self, 'units'): try: - self.val = pq.Quantity(list(map(float, self.val.split(','))), - self.units).simplified - if self.units == 'deg': - self.val = [x * pq.rad for x in self.val] # workaround for qt unit conversion + self.val = u.Quantity(list(map(float, self.val.split(','))), + self.units) + # if self.units == 'deg': + # self.val = [val * pq.rad for val in self.val] # workaround for qt unit conversion if len(self.val) == 1: self.val = self.val[0] except (ValueError, LookupError): - logging.error('unable to convert units in entry [tag, units, value]: ', - xml.tag, self.units, self.val) + error("unable to convert units in entry '" + xml.tag + "': " + self.val + " " + self.units, exit_=False) + elif hasattr(self, "val") and self.val.lower() in ["false", "true"]: + self.val = (self.val.lower() == "true") class Configuration(object): """ A Class to parse the XML configuration file. + Adapted from ExoSim (https://github.com/ExoSim/ExoSimPublic) Attributes ---------- - conf : str - Parsed XML tree + conf : Entry + Parsed configuration file as Entry-tree """ conf = None @@ -63,83 +63,74 @@ class Configuration(object): default_path : str default path to use for relative paths """ - if not os.path.exists(filename): - logging.error("Configuration file '" + filename + "' doesn't exist.") - sys.exit(1) + # Check if configuration file exists + if not os.path.exists(filename): + error("Configuration file '" + filename + "' doesn't exist.") + + # Read configuration file logging.info("Reading configuration from file '" + filename + "'.") self.conf = self.parser(eT.parse(filename).getroot()) - if default_path: - setattr(self.conf, "__path__", default_path) - elif hasattr(self.conf.common, "ConfigPath"): - setattr(self.conf, "__path__", - os.path.expanduser(self.conf.common.ConfigPath().replace('__path__', os.getcwd()))) - else: - logging.error("Path to config files not defined") - - self.validate_options() self.calc_metaoptions() - def parser(self, root): + def parser(self, parent): """ - Parse a XML configuration file. + Parse a XML element tree to an Entry-tree Parameters ---------- - root : ElementTree - The XML tree to be parsed + parent : ElementTree + The parent XML tree to be parsed Returns ------- obj : Entry The parsed XML tree """ + + # Initialize empty Entry object obj = Entry() - for ch in root: - retval = self.parser(ch) - retval.parse(ch) + for child in parent: + # recursively parse children of child element + parsed_child = self.parser(child) + # parse attributes of child element + parsed_child.parse(child) - if hasattr(obj, ch.tag): - if isinstance(getattr(obj, ch.tag), list): - getattr(obj, ch.tag).append(retval) + # Add or append the parsed child to the prepared Entry object + if hasattr(obj, child.tag): + if isinstance(getattr(obj, child.tag), list): + getattr(obj, child.tag).append(parsed_child) else: - setattr(obj, ch.tag, [getattr(obj, ch.tag), retval]) + setattr(obj, child.tag, [getattr(obj, child.tag), parsed_child]) else: - setattr(obj, ch.tag, retval) + setattr(obj, child.tag, parsed_child) return obj - def validate_options(self): - self.validate_is_list() - self.validate_True_False_spelling() - - def validate_is_list(self): - if not isinstance(self.conf.common_optics.optical_component, list): - self.conf.common_optics.optical_component = [self.conf.common_optics.optical_component] - if not isinstance(self.conf.instrument, list): - self.conf.instrument = [self.conf.instrument] - - def validate_True_False_spelling(self): - accepted_values = ['True', 'False'] - test_cases = [ - 'noise/EnableJitter', - 'noise/EnableShotNoise', - 'noise/EnableReadoutNoise', - ] - for item in test_cases: - if hasattr(self.conf, item.split('/')[0]): - if not self.conf.__getattribute__(item.split('/')[0]).__dict__[item.split('/')[1]]() in accepted_values: - raise ValueError("Accepted values for [%s] are 'True' or 'False'" % item) - def calc_metaoptions(self): + """ + Calculate additional attributes e.g. the wavelength grid + Returns + ------- + + """ self.calc_metaoption_wl_delta() def calc_metaoption_wl_delta(self): - wl_delta = self.conf.common.wl_min() / self.conf.common.logbinres() - setattr(self.conf.common, 'common_wl', (np.arange(self.conf.common.wl_min(), - self.conf.common.wl_max(), - wl_delta) * wl_delta.units).rescale(pq.um)) + """ + Calculate the wavelength grid used for the calculations. + Returns + ------- + + """ + if hasattr(self.conf.common, "wl_delta"): + wl_delta = self.conf.common.wl_delta() + else: + wl_delta = self.conf.common.wl_min() / self.conf.common.res() + setattr(self.conf.common, 'wl_bins', np.arange(self.conf.common.wl_min().to(u.micron).value, + self.conf.common.wl_max().to(u.micron).value, + wl_delta.to(u.micron).value) * u.micron) if __name__ == "__main__":