Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TypeError while using PiecewiseLinearOutput as distribution output with torch version GluonTS 0.14.3 #3114

Open
zijiexia opened this issue Jan 31, 2024 · 0 comments
Labels
bug Something isn't working

Comments

@zijiexia
Copy link

Description

Hi team, I was trying to reimplement my DeepAR model with the GluonTS written in PyTorch. I try to use the PiecewiseLinearOutput from gluonts.torch.distribution.piecewise_linear as the distribution output of the model. However, I got the error message saying

TypeError: unsupported operand type(s) for +: 'NoneType' and 'Tensor'

The same model works with me by using the MXNet written GluonTS in version 0.9.4.

To Reproduce

import pandas as pd

from gluonts.dataset.pandas import PandasDataset
from gluonts.dataset.split import split
from gluonts.torch import DeepAREstimator

from gluonts.torch.distributions.piecewise_linear import PiecewiseLinearOutput

# Load data from a CSV file into a PandasDataset
df = pd.read_csv(
    "https://raw.githubusercontent.com/AileenNielsen/"
    "TimeSeriesAnalysisWithPython/master/data/AirPassengers.csv",
    index_col=0,
    parse_dates=True,
)
dataset = PandasDataset(df, target="#Passengers")

# Split the data for training and testing
training_data, test_gen = split(dataset, offset=-36)
test_data = test_gen.generate_instances(prediction_length=12, windows=3)

distr_output = PiecewiseLinearOutput(num_pieces=4)

# Train the model and make predictions
model = DeepAREstimator(
    prediction_length=12, freq="30D", distr_output=distr_output, trainer_kwargs={"max_epochs": 5}
).train(training_data)

Error message or code output

(Paste the complete error message, including stack trace, or the undesired output that the above snippet produces.)

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[9], line 25
     22 distr_output = PiecewiseLinearOutput(num_pieces=4)
     24 # Train the model and make predictions
---> 25 model = DeepAREstimator(
     26     prediction_length=12, freq="30D", distr_output=distr_output, trainer_kwargs={"max_epochs": 5}
     27 ).train(training_data)

File ~\anaconda3\envs\gluontstest\lib\site-packages\gluonts\torch\model\estimator.py:246, in PyTorchLightningEstimator.train(self, training_data, validation_data, shuffle_buffer_length, cache_data, ckpt_path, **kwargs)
    237 def train(
    238     self,
    239     training_data: Dataset,
   (...)
    244     **kwargs,
    245 ) -> PyTorchPredictor:
--> 246     return self.train_model(
    247         training_data,
    248         validation_data,
    249         shuffle_buffer_length=shuffle_buffer_length,
    250         cache_data=cache_data,
    251         ckpt_path=ckpt_path,
    252     ).predictor

File ~\anaconda3\envs\gluontstest\lib\site-packages\gluonts\torch\model\estimator.py:209, in PyTorchLightningEstimator.train_model(self, training_data, validation_data, from_predictor, shuffle_buffer_length, cache_data, ckpt_path, **kwargs)
    200 custom_callbacks = self.trainer_kwargs.pop("callbacks", [])
    201 trainer = pl.Trainer(
    202     **{
    203         "accelerator": "auto",
   (...)
    206     }
    207 )
--> 209 trainer.fit(
    210     model=training_network,
    211     train_dataloaders=training_data_loader,
    212     val_dataloaders=validation_data_loader,
    213     ckpt_path=ckpt_path,
    214 )
    216 if checkpoint.best_model_path != "":
    217     logger.info(
    218         f"Loading best model from {checkpoint.best_model_path}"
    219     )

File ~\anaconda3\envs\gluontstest\lib\site-packages\lightning\pytorch\trainer\trainer.py:544, in Trainer.fit(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)
    542 self.state.status = TrainerStatus.RUNNING
    543 self.training = True
--> 544 call._call_and_handle_interrupt(
    545     self, self._fit_impl, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path
    546 )

File ~\anaconda3\envs\gluontstest\lib\site-packages\lightning\pytorch\trainer\call.py:44, in _call_and_handle_interrupt(trainer, trainer_fn, *args, **kwargs)
     42     if trainer.strategy.launcher is not None:
     43         return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, **kwargs)
---> 44     return trainer_fn(*args, **kwargs)
     46 except _TunerExitException:
     47     _call_teardown_hook(trainer)

File ~\anaconda3\envs\gluontstest\lib\site-packages\lightning\pytorch\trainer\trainer.py:580, in Trainer._fit_impl(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)
    573 assert self.state.fn is not None
    574 ckpt_path = self._checkpoint_connector._select_ckpt_path(
    575     self.state.fn,
    576     ckpt_path,
    577     model_provided=True,
    578     model_connected=self.lightning_module is not None,
    579 )
--> 580 self._run(model, ckpt_path=ckpt_path)
    582 assert self.state.stopped
    583 self.training = False

File ~\anaconda3\envs\gluontstest\lib\site-packages\lightning\pytorch\trainer\trainer.py:969, in Trainer._run(self, model, ckpt_path)
    967 # hook
    968 if self.state.fn == TrainerFn.FITTING:
--> 969     call._call_callback_hooks(self, "on_fit_start")
    970     call._call_lightning_module_hook(self, "on_fit_start")
    972 _log_hyperparams(self)

File ~\anaconda3\envs\gluontstest\lib\site-packages\lightning\pytorch\trainer\call.py:208, in _call_callback_hooks(trainer, hook_name, monitoring_callbacks, *args, **kwargs)
    206     if callable(fn):
    207         with trainer.profiler.profile(f"[Callback]{callback.state_key}.{hook_name}"):
--> 208             fn(trainer, trainer.lightning_module, *args, **kwargs)
    210 if pl_module:
    211     # restore current_fx when nested context
    212     pl_module._current_fx_name = prev_fx_name

File ~\anaconda3\envs\gluontstest\lib\site-packages\lightning\pytorch\callbacks\model_summary.py:60, in ModelSummary.on_fit_start(self, trainer, pl_module)
     57 if not self._max_depth:
     58     return
---> 60 model_summary = self._summary(trainer, pl_module)
     61 summary_data = model_summary._get_summary_data()
     62 total_parameters = model_summary.total_parameters

File ~\anaconda3\envs\gluontstest\lib\site-packages\lightning\pytorch\callbacks\model_summary.py:74, in ModelSummary._summary(self, trainer, pl_module)
     72 if isinstance(trainer.strategy, DeepSpeedStrategy) and trainer.strategy.zero_stage_3:
     73     return DeepSpeedSummary(pl_module, max_depth=self._max_depth)
---> 74 return summarize(pl_module, max_depth=self._max_depth)

File ~\anaconda3\envs\gluontstest\lib\site-packages\lightning\pytorch\utilities\model_summary\model_summary.py:473, in summarize(lightning_module, max_depth)
    460 def summarize(lightning_module: "pl.LightningModule", max_depth: int = 1) -> ModelSummary:
    461     """Summarize the LightningModule specified by `lightning_module`.
    462 
    463     Args:
   (...)
    471 
    472     """
--> 473     return ModelSummary(lightning_module, max_depth=max_depth)

File ~\anaconda3\envs\gluontstest\lib\site-packages\lightning\pytorch\utilities\model_summary\model_summary.py:209, in ModelSummary.__init__(self, model, max_depth)
    206     raise ValueError(f"`max_depth` can be -1, 0 or > 0, got {max_depth}.")
    208 self._max_depth = max_depth
--> 209 self._layer_summary = self.summarize()
    210 # 1 byte -> 8 bits
    211 # TODO: how do we compute precision_megabytes in case of mixed precision?
    212 precision_to_bits = {"64": 64, "32": 32, "16": 16, "bf16": 16}

File ~\anaconda3\envs\gluontstest\lib\site-packages\lightning\pytorch\utilities\model_summary\model_summary.py:270, in ModelSummary.summarize(self)
    268 summary = OrderedDict((name, LayerSummary(module)) for name, module in self.named_modules)
    269 if self._model.example_input_array is not None:
--> 270     self._forward_example_input()
    271 for layer in summary.values():
    272     layer.detach_hook()

File ~\anaconda3\envs\gluontstest\lib\site-packages\lightning\pytorch\utilities\model_summary\model_summary.py:300, in ModelSummary._forward_example_input(self)
    298     model(*input_)
    299 elif isinstance(input_, dict):
--> 300     model(**input_)
    301 else:
    302     model(input_)

File ~\anaconda3\envs\gluontstest\lib\site-packages\torch\nn\modules\module.py:1511, in Module._wrapped_call_impl(self, *args, **kwargs)
   1509     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1510 else:
-> 1511     return self._call_impl(*args, **kwargs)

File ~\anaconda3\envs\gluontstest\lib\site-packages\torch\nn\modules\module.py:1520, in Module._call_impl(self, *args, **kwargs)
   1515 # If we don't have any hooks, we want to skip the rest of the logic in
   1516 # this function, and just call forward.
   1517 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1518         or _global_backward_pre_hooks or _global_backward_hooks
   1519         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1520     return forward_call(*args, **kwargs)
   1522 try:
   1523     result = None

File ~\anaconda3\envs\gluontstest\lib\site-packages\gluonts\torch\model\deepar\lightning_module.py:68, in DeepARLightningModule.forward(self, *args, **kwargs)
     67 def forward(self, *args, **kwargs):
---> 68     return self.model(*args, **kwargs)

File ~\anaconda3\envs\gluontstest\lib\site-packages\torch\nn\modules\module.py:1511, in Module._wrapped_call_impl(self, *args, **kwargs)
   1509     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1510 else:
-> 1511     return self._call_impl(*args, **kwargs)

File ~\anaconda3\envs\gluontstest\lib\site-packages\torch\nn\modules\module.py:1561, in Module._call_impl(self, *args, **kwargs)
   1558     bw_hook = hooks.BackwardHook(self, full_backward_hooks, backward_pre_hooks)
   1559     args = bw_hook.setup_input_hook(args)
-> 1561 result = forward_call(*args, **kwargs)
   1562 if _global_forward_hooks or self._forward_hooks:
   1563     for hook_id, hook in (
   1564         *_global_forward_hooks.items(),
   1565         *self._forward_hooks.items(),
   1566     ):
   1567         # mark that always called hook is run

File ~\anaconda3\envs\gluontstest\lib\site-packages\gluonts\torch\model\deepar\module.py:451, in DeepARModel.forward(self, feat_static_cat, feat_static_real, past_time_feat, past_target, past_observed_values, future_time_feat, num_parallel_samples)
    444 repeated_params = [
    445     s.repeat_interleave(repeats=num_parallel_samples, dim=0)
    446     for s in params
    447 ]
    448 distr = self.output_distribution(
    449     repeated_params, trailing_n=1, scale=repeated_scale
    450 )
--> 451 next_sample = distr.sample()
    452 future_samples = [next_sample]
    454 for k in range(1, self.prediction_length):

File ~\anaconda3\envs\gluontstest\lib\site-packages\torch\distributions\transformed_distribution.py:142, in TransformedDistribution.sample(self, sample_shape)
    140 x = self.base_dist.sample(sample_shape)
    141 for transform in self.transforms:
--> 142     x = transform(x)
    143 return x

File ~\anaconda3\envs\gluontstest\lib\site-packages\torch\distributions\transforms.py:156, in Transform.__call__(self, x)
    152 """
    153 Computes the transform `x => y`.
    154 """
    155 if self._cache_size == 0:
--> 156     return self._call(x)
    157 x_old, y_old = self._cached_x_y
    158 if x is x_old:

File ~\anaconda3\envs\gluontstest\lib\site-packages\torch\distributions\transforms.py:780, in AffineTransform._call(self, x)
    779 def _call(self, x):
--> 780     return self.loc + self.scale * x

TypeError: unsupported operand type(s) for +: 'NoneType' and 'Tensor'

Environment

I've tested in multiple environments as listed below

  • Operating system: Win/Mac OS/Linux
  • Python version: 3.8.0/3.9.0
  • GluonTS version: 0.14.3
  • PyTorch version: 2.2.0
  • NumPy version: 1.24.4/1.26.3
  • pandas version: 2.0.3/2.2.0
  • PyTorch Lightning version: 2.1.3
@zijiexia zijiexia added the bug Something isn't working label Jan 31, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

1 participant