Skip to content

Commit

Permalink
Avoid deepcopying defaults for containers in stage1 (#1538)
Browse files Browse the repository at this point in the history
* Avoid deepcopying defaults for containers in stage1

* Add type argument to field

* Fix dl1 source container reading

* Add missing type
  • Loading branch information
maxnoe committed Dec 3, 2020
1 parent 8f17f02 commit dd1b5ee
Show file tree
Hide file tree
Showing 6 changed files with 81 additions and 35 deletions.
10 changes: 5 additions & 5 deletions ctapipe/containers.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,20 +238,20 @@ class DL1CameraContainer(Container):
peak_time = Field(
None,
"Numpy array containing position of the peak of the pulse as determined by "
"the extractor. Shape: (n_pixel)",
"the extractor. Shape: (n_pixel, )",
dtype=np.float32,
ndim=1,
)

image_mask = Field(
None,
"Boolean numpy array where True means the pixel has passed cleaning. Shape: ("
"n_pixel)",
"Boolean numpy array where True means the pixel has passed cleaning."
" Shape: (n_pixel, )",
dtype=np.bool,
ndim=1,
)

parameters = Field(ImageParametersContainer(), "Parameters derived from images")
parameters = Field(None, "Image parameters", type=ImageParametersContainer)


class DL1Container(Container):
Expand Down Expand Up @@ -404,7 +404,7 @@ class SimulatedCameraContainer(Container):
)

true_parameters = Field(
ImageParametersContainer(), "Parameters derived from the true_image"
None, "Parameters derived from the true_image", type=ImageParametersContainer
)


Expand Down
23 changes: 22 additions & 1 deletion ctapipe/core/container.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,11 @@
import numpy as np
from astropy.units import UnitConversionError, Quantity, Unit

import logging


log = logging.getLogger(__name__)


class FieldValidationError(ValueError):
pass
Expand All @@ -26,8 +31,11 @@ class Field:
unit to convert to when writing output, or None for no conversion
ucd: str
universal content descriptor (see Virtual Observatory standards)
type: type
expected type of value
dtype: str or np.dtype
expected data type of the value, None to ignore in validation.
Means value is expected to be a numpy array or astropy quantity
ndim: int or None
expected dimensionality of the data, for arrays, None to ignore
allow_none:
Expand All @@ -41,6 +49,7 @@ def __init__(
unit=None,
ucd=None,
dtype=None,
type=None,
ndim=None,
allow_none=True,
):
Expand All @@ -50,6 +59,7 @@ def __init__(
self.unit = Unit(unit) if unit is not None else None
self.ucd = ucd
self.dtype = np.dtype(dtype) if dtype is not None else None
self.type = type
self.ndim = ndim
self.allow_none = allow_none

Expand Down Expand Up @@ -84,6 +94,11 @@ def validate(self, value):

errorstr = f"the value '{value}' ({type(value)}) is invalid: "

if self.type is not None and not isinstance(value, self.type):
raise FieldValidationError(
f"{errorstr} Should be an instance of {self.type}"
)

if self.unit is not None:
if not isinstance(value, Quantity):
raise FieldValidationError(
Expand Down Expand Up @@ -221,7 +236,13 @@ def __init__(self, **fields):
self.prefix = self.container_prefix

for k in set(self.fields).difference(fields):
setattr(self, k, deepcopy(self.fields[k].default))

# deepcopy of None is surprisingly slow
default = self.fields[k].default
if default is not None:
default = deepcopy(default)

setattr(self, k, default)

for k, v in fields.items():
setattr(self, k, v)
Expand Down
5 changes: 5 additions & 0 deletions ctapipe/core/tests/test_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,11 @@ def test_field_validation():
with pytest.raises(FieldValidationError):
field_n2.validate(None)

field_type = Field(None, "foo", type=str)
field_type.validate("foo")
with pytest.raises(FieldValidationError):
field_type.validate(5)


def test_container_validation():
""" check that we can validate all fields in a container"""
Expand Down
11 changes: 8 additions & 3 deletions ctapipe/image/image_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,11 @@
)


DEFAULT_IMAGE_PARAMETERS = ImageParametersContainer()
DEFAULT_TIMING_PARAMETERS = TimingParametersContainer()
DEFAULT_PEAKTIME_STATISTICS = PeakTimeStatisticsContainer()


class ImageQualityQuery(QualityQuery):
""" for configuring image-wise data checks """

Expand Down Expand Up @@ -143,8 +148,8 @@ def _parameterize_image(
container_class=PeakTimeStatisticsContainer,
)
else:
timing = TimingParametersContainer()
peak_time_statistics = PeakTimeStatisticsContainer()
timing = DEFAULT_TIMING_PARAMETERS
peak_time_statistics = DEFAULT_PEAKTIME_STATISTICS

return ImageParametersContainer(
hillas=hillas,
Expand All @@ -158,7 +163,7 @@ def _parameterize_image(

# return the default container (containing nan values) for no
# parameterization
return ImageParametersContainer()
return DEFAULT_IMAGE_PARAMETERS

def _process_telescope_event(self, event):
"""
Expand Down
31 changes: 17 additions & 14 deletions ctapipe/io/dl1eventsource.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
PeakTimeStatisticsContainer,
TimingParametersContainer,
TriggerContainer,
ImageParametersContainer,
)
from ctapipe.utils import IndexFinder

Expand Down Expand Up @@ -352,13 +353,15 @@ def _generate_events(self):
# Best would probbaly be if we could directly read
# into the ImageParametersContainer
params = next(param_readers[f"tel_{tel:03d}"])
dl1.parameters.hillas = params[0]
dl1.parameters.timing = params[1]
dl1.parameters.leakage = params[2]
dl1.parameters.concentration = params[3]
dl1.parameters.morphology = params[4]
dl1.parameters.intensity_statistics = params[5]
dl1.parameters.peak_time_statistics = params[6]
dl1.parameters = ImageParametersContainer(
hillas=params[0],
timing=params[1],
leakage=params[2],
concentration=params[3],
morphology=params[4],
intensity_statistics=params[5],
peak_time_statistics=params[6],
)

if self.has_simulated_dl1:
if f"tel_{tel:03d}" not in param_readers.keys():
Expand All @@ -371,13 +374,13 @@ def _generate_events(self):
simulated_params = next(
simulated_param_readers[f"tel_{tel:03d}"]
)
simulated.true_parameters.hillas = simulated_params[0]
simulated.true_parameters.leakage = simulated_params[1]
simulated.true_parameters.concentration = simulated_params[2]
simulated.true_parameters.morphology = simulated_params[3]
simulated.true_parameters.intensity_statistics = simulated_params[
4
]
simulated.true_parameters = ImageParametersContainer(
hillas=simulated_params[0],
leakage=simulated_params[1],
concentration=simulated_params[2],
morphology=simulated_params[3],
intensity_statistics=simulated_params[4],
)

yield data

Expand Down
36 changes: 24 additions & 12 deletions ctapipe/io/simteleventsource.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
SimulationConfigContainer,
SimulatedCameraContainer,
SimulatedShowerContainer,
TelescopePointingContainer,
TelescopeTriggerContainer,
)
from ..coordinates import CameraFrame
from ..core.traits import Bool, CaselessStrEnum, create_class_enum_trait
Expand Down Expand Up @@ -403,8 +405,8 @@ def _generate_events(self):
true_image=true_image
)

self._fill_event_pointing(
data.pointing.tel[tel_id], tracking_positions[tel_id]
data.pointing.tel[tel_id] = self._fill_event_pointing(
tracking_positions[tel_id]
)

r0 = data.r0.tel[tel_id]
Expand All @@ -430,23 +432,25 @@ def _generate_events(self):
yield data

@staticmethod
def _fill_event_pointing(pointing, tracking_position):
def _fill_event_pointing(tracking_position):
azimuth_raw = tracking_position["azimuth_raw"]
altitude_raw = tracking_position["altitude_raw"]
azimuth_cor = tracking_position.get("azimuth_cor", np.nan)
altitude_cor = tracking_position.get("altitude_cor", np.nan)

# take pointing corrected position if available
if np.isnan(azimuth_cor):
pointing.azimuth = u.Quantity(azimuth_raw, u.rad)
azimuth = u.Quantity(azimuth_raw, u.rad, copy=False)
else:
pointing.azimuth = u.Quantity(azimuth_cor, u.rad)
azimuth = u.Quantity(azimuth_cor, u.rad, copy=False)

# take pointing corrected position if available
if np.isnan(altitude_cor):
pointing.altitude = u.Quantity(altitude_raw, u.rad)
altitude = u.Quantity(altitude_raw, u.rad, copy=False)
else:
pointing.altitude = u.Quantity(altitude_cor, u.rad)
altitude = u.Quantity(altitude_cor, u.rad, copy=False)

return TelescopePointingContainer(azimuth=azimuth, altitude=altitude)

@staticmethod
def _fill_trigger_info(data, array_event):
Expand All @@ -471,23 +475,31 @@ def _fill_trigger_info(data, array_event):
for tel_id, time in zip(
trigger["triggered_telescopes"], trigger["trigger_times"]
):
# time is relative to central trigger in nano seconds
trigger = data.trigger.tel[tel_id]
trigger.time = Time(
# telesocpe time is relative to central trigger in ns
time = Time(
central_time.jd1,
central_time.jd2 + time / NANOSECONDS_PER_DAY,
scale=central_time.scale,
format="jd",
)

# triggered pixel info
n_trigger_pixels = -1
trigger_pixels = None

tel_event = array_event["telescope_events"].get(tel_id)
if tel_event:
# code 0 = trigger pixels
pixel_list = tel_event["pixel_lists"].get(0)
if pixel_list:
trigger.n_trigger_pixels = pixel_list["pixels"]
trigger.trigger_pixels = pixel_list["pixel_list"]
n_trigger_pixels = pixel_list["pixels"]
trigger_pixels = pixel_list["pixel_list"]

trigger = data.trigger.tel[tel_id] = TelescopeTriggerContainer(
time=time,
n_trigger_pixels=n_trigger_pixels,
trigger_pixels=trigger_pixels,
)

def _fill_array_pointing(self, data):
if self.file_.header["tracking_mode"] == 0:
Expand Down

0 comments on commit dd1b5ee

Please sign in to comment.