Skip to content

Commit

Permalink
Merge pull request #2288 from cta-observatory/fix_load_obsinfo
Browse files Browse the repository at this point in the history
Only remove sort index from table if input did not contain it
  • Loading branch information
maxnoe committed Mar 15, 2023
2 parents c24f38a + 92df57c commit c169e25
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 6 deletions.
15 changes: 10 additions & 5 deletions ctapipe/io/tableloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,19 +310,23 @@ def _get_sort_index(self, start=None, stop=None):
return table

@staticmethod
def _sort_to_original_order(table, include_tel_id=False):
def _sort_to_original_order(table, include_tel_id=False, keep_index=False):
if len(table) == 0:
return
if include_tel_id:
table.sort(("__index__", "tel_id"))
else:
table.sort("__index__")
table.remove_column("__index__")

if not keep_index:
table.remove_column("__index__")

@staticmethod
def _add_index_if_needed(table):
if "__index__" not in table.colnames:
table["__index__"] = np.arange(len(table))
return True
return False

def read_simulation_configuration(self):
"""
Expand Down Expand Up @@ -350,7 +354,7 @@ def _join_observation_info(self, table):
obs_table["obs_id"] = obs_table["obs_id"].astype(table["obs_id"].dtype)

# to be able to sort to original table order
self._add_index_if_needed(table)
index_added = self._add_index_if_needed(table)

joint = join_allow_empty(
table,
Expand All @@ -360,8 +364,9 @@ def _join_observation_info(self, table):
)

# sort back to original order and remove index col
self._sort_to_original_order(joint)
del table["__index__"]
self._sort_to_original_order(joint, keep_index=not index_added)
if index_added:
table.remove_column("__index__")
return joint

def read_subarray_events(self, start=None, stop=None, keep_order=True):
Expand Down
6 changes: 5 additions & 1 deletion ctapipe/io/tests/test_table_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,13 +415,17 @@ def test_order_merged():

path = get_dataset_path("gamma_diffuse_dl2_train_small.dl2.h5")

trigger = read_table(path, "/dl1/event/subarray/trigger")
tel_trigger = read_table(path, "/dl1/event/telescope/trigger")
with TableLoader(
path,
load_dl1_parameters=False,
load_dl1_parameters=True,
load_dl2=True,
load_observation_info=True,
) as loader:
events = loader.read_subarray_events()
check_equal_array_event_order(events, trigger)

tables = loader.read_telescope_events_by_id()

for tel_id, table in tables.items():
Expand Down
2 changes: 2 additions & 0 deletions docs/changes/2288.bugfix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Fix ``TableLoader.read_subarray_events`` raising an exception when
``load_observation_info=True``.

0 comments on commit c169e25

Please sign in to comment.