diff --git a/ctapipe/core/tests/test_traits.py b/ctapipe/core/tests/test_traits.py index 43a90ffa4f6..205475f9e03 100644 --- a/ctapipe/core/tests/test_traits.py +++ b/ctapipe/core/tests/test_traits.py @@ -433,6 +433,33 @@ class SomeComponent(TelescopeComponent): assert config["SomeComponent"]["tel_param1"] == [("type", "*", 6.0)] +def test_telescope_parameter_from_cli(mock_subarray): + """ + Test we can pass single default for telescope components via cli + see #1559 + """ + + from ctapipe.core import Tool, run_tool + + class SomeComponent(TelescopeComponent): + path = TelescopeParameter(Path(), default_value=None).tag(config=True) + val = TelescopeParameter(Float(), default_value=1.0).tag(config=True) + + class TelescopeTool(Tool): + def setup(self): + self.comp = SomeComponent(subarray=mock_subarray, parent=self) + + tool = TelescopeTool() + run_tool(tool) + assert tool.comp.path == [("type", "*", None)] + assert tool.comp.val == [("type", "*", 1.0)] + + tool = TelescopeTool() + run_tool(tool, ["--SomeComponent.path", "test.h5", "--SomeComponent.val", "2.0"]) + assert tool.comp.path == [("type", "*", pathlib.Path("test.h5").absolute())] + assert tool.comp.val == [("type", "*", 2.0)] + + def test_datetimes(): from astropy import time as t diff --git a/ctapipe/core/traits.py b/ctapipe/core/traits.py index 0aab8fbf848..2f2d5b0f00d 100644 --- a/ctapipe/core/traits.py +++ b/ctapipe/core/traits.py @@ -357,20 +357,25 @@ class TelescopeParameter(List): """ klass = TelescopePatternList + _valid_defaults = (object,) # allow everything, we validate the default ourselves def __init__(self, trait, default_value=Undefined, **kwargs): + if not isinstance(trait, TraitType): raise TypeError("trait must be a TraitType instance") - if ( - not isinstance(default_value, (UserList, list, List)) - and default_value is not Undefined - ): - default_value = trait.validate(self, default_value) - default_value = [("type", "*", default_value)] + self._trait = trait + if default_value != Undefined: + default_value = self.validate(self, default_value) super().__init__(default_value=default_value, **kwargs) - self._trait = trait + + def from_string(self, s): + val = super().from_string(s) + # for strings, parsing fails and traitlets returns None + if val == [("type", "*", None)] and s != "None": + val = [("type", "*", self._trait.from_string(s))] + return val def validate(self, obj, value): # Support a single value for all (check and convert into a default value)