Skip to content

Commit

Permalink
Fix handling of telescope parameters from cli for string values, fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
maxnoe committed Dec 3, 2020
1 parent dd1b5ee commit f1cd0bb
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 7 deletions.
27 changes: 27 additions & 0 deletions ctapipe/core/tests/test_traits.py
Original file line number Diff line number Diff line change
Expand Up @@ -433,6 +433,33 @@ class SomeComponent(TelescopeComponent):
assert config["SomeComponent"]["tel_param1"] == [("type", "*", 6.0)]


def test_telescope_parameter_from_cli(mock_subarray):
"""
Test we can pass single default for telescope components via cli
see #1559
"""

from ctapipe.core import Tool, run_tool

class SomeComponent(TelescopeComponent):
path = TelescopeParameter(Path(), default_value=None).tag(config=True)
val = TelescopeParameter(Float(), default_value=1.0).tag(config=True)

class TelescopeTool(Tool):
def setup(self):
self.comp = SomeComponent(subarray=mock_subarray, parent=self)

tool = TelescopeTool()
run_tool(tool)
assert tool.comp.path == [("type", "*", None)]
assert tool.comp.val == [("type", "*", 1.0)]

tool = TelescopeTool()
run_tool(tool, ["--SomeComponent.path", "test.h5", "--SomeComponent.val", "2.0"])
assert tool.comp.path == [("type", "*", pathlib.Path("test.h5").absolute())]
assert tool.comp.val == [("type", "*", 2.0)]


def test_datetimes():
from astropy import time as t

Expand Down
19 changes: 12 additions & 7 deletions ctapipe/core/traits.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,20 +357,25 @@ class TelescopeParameter(List):
"""

klass = TelescopePatternList
_valid_defaults = (object,) # allow everything, we validate the default ourselves

def __init__(self, trait, default_value=Undefined, **kwargs):

if not isinstance(trait, TraitType):
raise TypeError("trait must be a TraitType instance")

if (
not isinstance(default_value, (UserList, list, List))
and default_value is not Undefined
):
default_value = trait.validate(self, default_value)
default_value = [("type", "*", default_value)]
self._trait = trait
if default_value != Undefined:
default_value = self.validate(self, default_value)

super().__init__(default_value=default_value, **kwargs)
self._trait = trait

def from_string(self, s):
val = super().from_string(s)
# for strings, parsing fails and traitlets returns None
if val == [("type", "*", None)] and s != "None":
val = [("type", "*", self._trait.from_string(s))]
return val

def validate(self, obj, value):
# Support a single value for all (check and convert into a default value)
Expand Down

0 comments on commit f1cd0bb

Please sign in to comment.