Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RecursionError escapes during inference #9139

Closed
2 tasks done
DrShushen opened this issue Oct 11, 2023 · 2 comments · Fixed by pylint-dev/astroid#2432
Closed
2 tasks done

RecursionError escapes during inference #9139

DrShushen opened this issue Oct 11, 2023 · 2 comments · Fixed by pylint-dev/astroid#2432
Assignees
Labels
Crash 💥 A bug that makes pylint crash Needs astroid update Needs an astroid update (probably a release too) before being mergable Needs PR This issue is accepted, sufficiently specified and now needs an implementation
Milestone

Comments

@DrShushen
Copy link

Note Reporting using the default pre-filled template, so definitely not a minimal example of code.

Issue title:
Crash with astroid-error

Bug description

Note

  • Found that this issue does not occur with pylint 2.17.7 astroid 2.15.8 but does with pylint 3.0.1 astroid 3.0.0.

When parsing the following ts_model.py:

from typing import Any, Callable, Dict, List, Optional, Tuple, Union

import numpy as np
import pydantic
import torch
from sklearn.model_selection import train_test_split
from torch import nn
from torch.utils.data import DataLoader, TensorDataset, sampler
from tsai.models.InceptionTime import InceptionTime
from tsai.models.InceptionTimePlus import InceptionTimePlus
from tsai.models.OmniScaleCNN import OmniScaleCNN
from tsai.models.ResCNN import ResCNN
from tsai.models.RNN_FCN import MLSTM_FCN
from tsai.models.TCN import TCN
from tsai.models.TransformerModel import TransformerModel
from tsai.models.XceptionTime import XceptionTime
from tsai.models.XCM import XCM
from typing_extensions import Literal

from tempor.log import logger as log
from tempor.models import constants
from tempor.models.constants import DEVICE, ModelTaskType, Nonlin
from tempor.models.mlp import MLP, MultiActivationHead
from tempor.models.samplers import ImbalancedDatasetSampler
from tempor.models.utils import enable_reproducibility, get_nonlin

TSModelMode = Literal[
    "LSTM",
    "GRU",
    "RNN",
    "Transformer",
    "MLSTM_FCN",
    "TCN",
    "InceptionTime",
    "InceptionTimePlus",
    "XceptionTime",
    "ResCNN",
    "OmniScaleCNN",
    "XCM",
]


class TimeSeriesModel(nn.Module):
    @pydantic.validate_arguments(config=pydantic.ConfigDict(arbitrary_types_allowed=True))  # type: ignore [operator]
    def __init__(
        self,
        task_type: ModelTaskType,
        n_static_units_in: int,
        n_temporal_units_in: int,
        n_temporal_window: int,
        output_shape: List[int],
        n_static_units_hidden: int = 102,
        n_static_layers_hidden: int = 2,
        n_temporal_units_hidden: int = 102,
        n_temporal_layers_hidden: int = 2,
        n_iter: int = 500,
        mode: TSModelMode = "RNN",
        n_iter_print: int = 10,
        batch_size: int = 100,
        lr: float = 1e-3,
        weight_decay: float = 1e-3,
        window_size: int = 1,
        device: Any = DEVICE,
        dataloader_sampler: Optional[sampler.Sampler] = None,
        nonlin_out: Optional[List[Tuple[Nonlin, int]]] = None,
        loss: Optional[Callable] = None,
        dropout: float = 0.0,
        nonlin: Nonlin = "relu",
        random_state: int = 0,
        clipping_value: int = 1,
        patience: int = 20,
        train_ratio: float = 0.8,
        use_horizon_condition: bool = True,
    ) -> None:
        """Basic neural net for time series.

        Args:
            task_type (ModelTaskType):
                The type of the problem. Available options: :obj:`~tempor.models.constants.ModelTaskType`.
            n_static_units_in (int):
                Number of input units for the static data.
            n_temporal_units_in (int):
                Number of units for the temporal features.
            n_temporal_window (int):
                Number of temporal observations for each subject.
            output_shape (List[int]):
                Shape of the output tensor.
            n_static_units_hidden (int, optional):
                Number of hidden units for the static features. Defaults to ``102``.
            n_static_layers_hidden (int, optional):
                Number of hidden layers for the static features. Defaults to ``2``.
            n_temporal_units_hidden (int, optional):
                Number of hidden units for the temporal features. Defaults to ``102``.
            n_temporal_layers_hidden (int, optional):
                Number of hidden layers for the temporal features. Defaults to ``2``.
            n_iter (int, optional):
                Number of epochs. Defaults to ``500``.
            mode (TSModelMode, optional):
                Core neural net architecture. Available options: :obj:`~tempor.models.ts_model.TSModelMode`.
                Defaults to ``"RNN"``.
            n_iter_print (int, optional):
                Number of epochs to print the loss. Defaults to ``10``.
            batch_size (int, optional):
                Batch size. Defaults to ``100``.
            lr (float, optional):
                Learning rate. Defaults to ``1e-3``.
            weight_decay (float, optional):
                 l2 (ridge) penalty for the weights. Defaults to ``1e-3``.
            window_size (int, optional):
                How many hidden states to use for the outcome. Defaults to ``1``.
            device (Any, optional):
                PyTorch device to use. Defaults to :obj:`~tempor.models.constants.DEVICE`.
            dataloader_sampler (Optional[sampler.Sampler], optional):
                Custom data sampler for training. Defaults to None.
            nonlin_out (Optional[List[Tuple[Nonlin, int]]], optional):
                List of activations for the output. Example ``[("tanh", 1), ("softmax", 3)]`` - means the output layer
                will apply ``"tanh"`` for the first unit, and ``"softmax"`` for the following 3 units in the output.
                Defaults to `None`.
            loss (Optional[Callable], optional):
                Custom additional loss. Defaults to `None`.
            dropout (float, optional):
                Dropout value. Defaults to ``0.0``.
            nonlin (Nonlin, optional):
                Activation for hidden layers. Available options: :obj:`~tempor.models.constants.Nonlin`.
                Defaults to ``"relu"``.
            random_state (int, optional):
                Random seed. Defaults to ``0``.
            clipping_value (int, optional):
                Gradients clipping value. Zero disables the feature. Defaults to ``1``.
            patience (int, optional):
                How many ``epoch * n_iter_print`` to wait without loss improvement. Defaults to ``20``.
            train_ratio (float, optional):
                Train/test split ratio. Defaults to ``0.8``.
            use_horizon_condition (bool, optional):
                Whether to predict using the observation times (`True`) or just the covariates (`False`).
                Defaults to `True`.
        """
        super(TimeSeriesModel, self).__init__()

        enable_reproducibility(random_state)
        if len(output_shape) == 0:
            raise ValueError("Invalid output shape")

        self.task_type = task_type

        if loss is not None:
            self.loss = loss
        elif task_type == "regression":
            self.loss = nn.MSELoss()
        elif task_type == "classification":
            self.loss = nn.CrossEntropyLoss()
        else:  # Prevented by pydantic.  # pragma: no cover
            raise ValueError(f"Invalid task type {task_type}")

        self.n_iter = n_iter
        self.n_iter_print = n_iter_print
        self.batch_size = batch_size
        self.n_static_units_in = n_static_units_in
        self.n_temporal_units_in = n_temporal_units_in
        self.n_temporal_window = n_temporal_window
        self.n_static_units_hidden = n_static_units_hidden
        self.n_temporal_units_hidden = n_temporal_units_hidden
        self.n_static_layers_hidden = n_static_layers_hidden
        self.n_temporal_layers_hidden = n_temporal_layers_hidden
        self.device = device
        self.window_size = window_size
        self.dataloader_sampler = dataloader_sampler
        self.lr = lr
        self.output_shape = output_shape
        self.n_units_out = int(np.prod(self.output_shape))
        self.clipping_value = clipping_value
        self.use_horizon_condition = use_horizon_condition

        self.patience = patience
        self.train_ratio = train_ratio
        self.random_state = random_state

        self.temporal_layer = TimeSeriesLayer(
            n_static_units_in=n_static_units_in,
            n_temporal_units_in=n_temporal_units_in + int(use_horizon_condition),  # measurements + horizon
            n_temporal_window=n_temporal_window,
            n_units_out=self.n_units_out,
            n_static_units_hidden=n_static_units_hidden,
            n_static_layers_hidden=n_static_layers_hidden,
            n_temporal_units_hidden=n_temporal_units_hidden,
            n_temporal_layers_hidden=n_temporal_layers_hidden,
            mode=mode,
            window_size=window_size,
            device=device,
            dropout=dropout,
            nonlin=nonlin,
        )

        self.mode = mode

        self.out_activation: Optional[nn.Module] = None
        self.n_act_out: Optional[int] = None

        if nonlin_out is not None:
            self.n_act_out = 0
            activations = []
            for nonlin, nonlin_len in nonlin_out:
                self.n_act_out += nonlin_len
                activations.append((get_nonlin(nonlin), nonlin_len))

            if self.n_units_out % self.n_act_out != 0:
                raise RuntimeError(
                    f"Shape mismatch for the output layer. Expected length {self.n_units_out}, but got "
                    f"{nonlin_out} with length {self.n_act_out}"
                )
            self.out_activation = MultiActivationHead(activations, device=device)
        elif self.task_type == "classification":
            self.n_act_out = self.n_units_out
            self.out_activation = MultiActivationHead([(nn.Softmax(dim=-1), self.n_units_out)], device=device)

        self.optimizer = torch.optim.Adam(
            self.parameters(),
            lr=lr,
            weight_decay=weight_decay,
        )  # optimize all rnn parameters

    @pydantic.validate_arguments(config=pydantic.ConfigDict(arbitrary_types_allowed=True))  # type: ignore [operator]
    def forward(
        self,
        static_data: torch.Tensor,
        temporal_data: torch.Tensor,
        observation_times: torch.Tensor,
    ) -> torch.Tensor:
        # x shape (batch, time_step, input_size)
        # r_out shape (batch, time_step, output_size)

        if torch.isnan(static_data).sum() != 0:
            raise ValueError("NaNs detected in the static data")
        if torch.isnan(temporal_data).sum() != 0:
            raise ValueError("NaNs detected in the temporal data")
        if torch.isnan(observation_times).sum() != 0:
            raise ValueError("NaNs detected in the temporal horizons")

        if self.use_horizon_condition:
            temporal_data_merged = torch.cat([temporal_data, observation_times.unsqueeze(2)], dim=2)
        else:
            temporal_data_merged = temporal_data

        if torch.isnan(temporal_data_merged).sum() != 0:  # pragma: no cover
            raise ValueError("NaNs detected in the temporal merged data")

        pred = self.temporal_layer(static_data, temporal_data_merged)

        if self.out_activation is not None:
            pred = pred.reshape(-1, self.n_act_out)
            pred = self.out_activation(pred)

        pred = pred.reshape(-1, *self.output_shape)

        return pred

    @pydantic.validate_arguments(config=pydantic.ConfigDict(arbitrary_types_allowed=True))  # type: ignore [operator]
    def predict(
        self,
        static_data: Union[List, np.ndarray],
        temporal_data: Union[List, np.ndarray],
        observation_times: Union[List, np.ndarray],
    ) -> np.ndarray:
        self.eval()
        with torch.no_grad():
            (
                static_data_t,
                temporal_data_t,
                observation_times_t,
                _,
                window_batches,
            ) = self._prepare_input(static_data, temporal_data, observation_times)

            yt = torch.zeros(len(temporal_data), *self.output_shape).to(self.device)
            for widx in range(len(temporal_data_t)):
                window_size = len(observation_times_t[widx][0])
                local_yt = self(
                    static_data_t[widx],
                    temporal_data_t[widx],
                    observation_times_t[widx],
                )
                yt[window_batches[window_size]] = local_yt

            if self.task_type == "classification":
                return np.argmax(yt.cpu().numpy(), -1)
            else:
                return yt.cpu().numpy()

    @pydantic.validate_arguments(config=pydantic.ConfigDict(arbitrary_types_allowed=True))  # type: ignore [operator]
    def predict_proba(
        self,
        static_data: Union[List, np.ndarray],
        temporal_data: Union[List, np.ndarray],
        observation_times: Union[List, np.ndarray],
    ) -> np.ndarray:
        self.eval()
        if self.task_type != "classification":
            raise RuntimeError("Task valid only for classification")
        with torch.no_grad():
            (
                static_data_t,
                temporal_data_t,
                observation_times_t,
                _,
                window_batches,
            ) = self._prepare_input(static_data, temporal_data, observation_times)

            yt = torch.zeros(len(temporal_data), *self.output_shape).to(self.device)
            for widx in range(len(temporal_data_t)):
                window_size = len(observation_times_t[widx][0])
                local_yt = self(
                    static_data_t[widx],
                    temporal_data_t[widx],
                    observation_times_t[widx],
                )
                yt[window_batches[window_size]] = local_yt

            return yt.cpu().numpy()

    def score(
        self,
        static_data: Union[List, np.ndarray],
        temporal_data: Union[List, np.ndarray],
        observation_times: Union[List, np.ndarray],
        outcome: np.ndarray,
    ) -> float:
        y_pred = self.predict(static_data, temporal_data, observation_times)
        if self.task_type == "classification":
            return np.mean(y_pred.astype(int) == outcome.astype(int))
        else:
            return np.mean(np.inner(outcome - y_pred, outcome - y_pred) / 2.0)

    @pydantic.validate_arguments(config=pydantic.ConfigDict(arbitrary_types_allowed=True))  # type: ignore [operator]
    def fit(
        self,
        static_data: Union[List, np.ndarray],
        temporal_data: Union[List, np.ndarray],
        observation_times: Union[List, np.ndarray],
        outcome: Union[List, np.ndarray],
    ) -> Any:
        (
            static_data_t,
            temporal_data_t,
            observation_times_t,
            outcome_t,
            _,
        ) = self._prepare_input(static_data, temporal_data, observation_times, outcome)

        return self._train(static_data_t, temporal_data_t, observation_times_t, outcome_t)

    @pydantic.validate_arguments(config=pydantic.ConfigDict(arbitrary_types_allowed=True))  # type: ignore [operator]
    def _train(
        self,
        static_data: List[torch.Tensor],
        temporal_data: List[torch.Tensor],
        observation_times: List[torch.Tensor],
        outcome: List[torch.Tensor],
    ) -> Any:
        patience = 0
        prev_error = np.inf

        train_dataloaders = []
        test_dataloaders = []
        for widx in range(len(temporal_data)):
            train_dl, test_dl = self.dataloader(
                static_data[widx],
                temporal_data[widx],
                observation_times[widx],
                outcome[widx],
            )
            train_dataloaders.append(train_dl)
            test_dataloaders.append(test_dl)

        # training and testing
        for it in range(self.n_iter):
            train_loss = self._train_epoch(train_dataloaders)
            if it % self.n_iter_print == 0:
                val_loss = self._test_epoch(test_dataloaders)
                log.info(f"Epoch:{it}| train loss: {train_loss}, validation loss: {val_loss}")
                if val_loss < prev_error:
                    patience = 0
                    prev_error = val_loss
                else:
                    patience += 1
                if patience > self.patience:
                    break

        return self

    def _train_epoch(self, loaders: List[DataLoader]) -> float:
        self.train()

        losses = []
        for loader in loaders:
            for step, (static_mb, temporal_mb, horizons_mb, y_mb) in enumerate(  # pylint: disable=unused-variable
                loader
            ):
                self.optimizer.zero_grad()  # clear gradients for this training step

                pred = self(static_mb, temporal_mb, horizons_mb)  # rnn output

                loss = self.loss(pred.squeeze(), y_mb.squeeze())

                loss.backward()  # backpropagation, compute gradients
                if self.clipping_value > 0:
                    torch.nn.utils.clip_grad_norm_(self.parameters(), self.clipping_value)  # pyright: ignore
                self.optimizer.step()  # apply gradients

                losses.append(loss.detach().cpu())

        return float(np.mean(losses))

    def _test_epoch(self, loaders: List[DataLoader]) -> float:
        self.eval()

        losses = []
        for loader in loaders:
            for step, (static_mb, temporal_mb, horizons_mb, y_mb) in enumerate(  # pylint: disable=unused-variable
                loader
            ):
                pred = self(static_mb, temporal_mb, horizons_mb)  # rnn output
                loss = self.loss(pred.squeeze(), y_mb.squeeze())

                losses.append(loss.detach().cpu())

        return float(np.mean(losses))

    def dataloader(
        self,
        static_data: torch.Tensor,
        temporal_data: torch.Tensor,
        observation_times: torch.Tensor,
        outcome: torch.Tensor,
    ) -> Tuple[DataLoader, DataLoader]:
        stratify = None
        _, out_counts = torch.unique(outcome, return_counts=True)
        if out_counts.min() > 1:
            stratify = outcome.cpu()

        split: Tuple[torch.Tensor, ...] = train_test_split(
            static_data.cpu(),
            temporal_data.cpu(),
            observation_times.cpu(),
            outcome.cpu(),
            train_size=self.train_ratio,
            random_state=self.random_state,
            stratify=stratify,
        )
        (
            static_data_train,
            static_data_test,
            temporal_data_train,
            temporal_data_test,
            observation_times_train,
            observation_times_test,
            outcome_train,
            outcome_test,
        ) = split
        train_dataset = TensorDataset(
            static_data_train.to(self.device),
            temporal_data_train.to(self.device),
            observation_times_train.to(self.device),
            outcome_train.to(self.device),
        )
        test_dataset = TensorDataset(
            static_data_test.to(self.device),
            temporal_data_test.to(self.device),
            observation_times_test.to(self.device),
            outcome_test.to(self.device),
        )

        sampler_ = self.dataloader_sampler
        if sampler_ is None and self.task_type == "classification":
            sampler_ = ImbalancedDatasetSampler(outcome_train.squeeze().cpu().numpy().tolist())

        return (
            DataLoader(
                train_dataset,
                batch_size=self.batch_size,
                sampler=sampler_,
                pin_memory=False,
            ),
            DataLoader(
                test_dataset,
                batch_size=self.batch_size,
                pin_memory=False,
            ),
        )

    def _check_tensor(self, X: Union[torch.Tensor, np.ndarray]) -> torch.Tensor:
        if isinstance(X, torch.Tensor):
            return X.to(self.device)
        else:
            return torch.from_numpy(np.asarray(X)).to(self.device)

    def _prepare_input(
        self,
        static_data: Union[List, np.ndarray],
        temporal_data: Union[List, np.ndarray],
        observation_times: Union[List, np.ndarray],
        outcome: Optional[Union[List, np.ndarray]] = None,
    ) -> Tuple:
        static_data = np.asarray(static_data)
        temporal_data = np.asarray(temporal_data)
        observation_times = np.asarray(observation_times)
        if outcome is not None:
            outcome = np.asarray(outcome)

        window_batches: Dict[int, List[int]] = {}
        for idx, item in enumerate(observation_times):
            window_len = len(item)
            if window_len not in window_batches:
                window_batches[window_len] = []
            window_batches[window_len].append(idx)

        static_data_mb = []
        temporal_data_mb = []
        observation_times_mb = []
        outcome_mb = []

        for widx in window_batches:
            indices = window_batches[widx]

            static_data_t = self._check_tensor(static_data[indices]).float()

            local_temporal_data = np.array(temporal_data[indices].tolist()).astype(float)
            temporal_data_t = self._check_tensor(local_temporal_data).float()
            local_observation_times = np.array(observation_times[indices].tolist()).astype(float)
            observation_times_t = self._check_tensor(local_observation_times).float()

            static_data_mb.append(static_data_t)
            temporal_data_mb.append(temporal_data_t)
            observation_times_mb.append(observation_times_t)

            if outcome is not None:
                outcome_t = self._check_tensor(outcome[indices]).float()

                if self.task_type == "classification":
                    outcome_t = outcome_t.long()
                outcome_mb.append(outcome_t)

        return (
            static_data_mb,
            temporal_data_mb,
            observation_times_mb,
            outcome_mb,
            window_batches,
        )


class TimeSeriesLayer(nn.Module):
    def __init__(
        self,
        n_static_units_in: int,
        n_temporal_units_in: int,
        n_temporal_window: int,
        n_units_out: int,
        n_static_units_hidden: int = 100,
        n_static_layers_hidden: int = 2,
        n_temporal_units_hidden: int = 100,
        n_temporal_layers_hidden: int = 2,
        mode: str = "RNN",
        window_size: int = 1,
        device: Any = constants.DEVICE,
        dropout: float = 0,
        nonlin: Nonlin = "relu",
    ) -> None:
        super(TimeSeriesLayer, self).__init__()
        temporal_params = {
            "input_size": n_temporal_units_in,
            "hidden_size": n_temporal_units_hidden,
            "num_layers": n_temporal_layers_hidden,
            "dropout": 0 if n_temporal_layers_hidden == 1 else dropout,
            "batch_first": True,
        }
        temporal_models = {
            "RNN": nn.RNN,
            "LSTM": nn.LSTM,
            "GRU": nn.GRU,
        }

        if mode in ["RNN", "LSTM", "GRU"]:
            self.temporal_layer = temporal_models[mode](**temporal_params)
        elif mode == "MLSTM_FCN":
            self.temporal_layer = MLSTM_FCN(
                c_in=n_temporal_units_in,
                c_out=n_temporal_units_hidden,
                hidden_size=n_temporal_units_hidden,
                rnn_layers=n_temporal_layers_hidden,
                fc_dropout=dropout,
                seq_len=n_temporal_window,
                shuffle=False,
            )
        elif mode == "TCN":
            self.temporal_layer = TCN(
                c_in=n_temporal_units_in,
                c_out=n_temporal_units_hidden,
                fc_dropout=dropout,
            )
        elif mode == "InceptionTime":
            self.temporal_layer = InceptionTime(
                c_in=n_temporal_units_in,
                c_out=n_temporal_units_hidden,
                depth=n_temporal_layers_hidden,
                seq_len=n_temporal_window,
            )
        elif mode == "InceptionTimePlus":
            self.temporal_layer = InceptionTimePlus(
                c_in=n_temporal_units_in,
                c_out=n_temporal_units_hidden,
                depth=n_temporal_layers_hidden,
                seq_len=n_temporal_window,
            )
        elif mode == "XceptionTime":
            self.temporal_layer = XceptionTime(
                c_in=n_temporal_units_in,
                c_out=n_temporal_units_hidden,
            )
        elif mode == "ResCNN":
            self.temporal_layer = ResCNN(
                c_in=n_temporal_units_in,
                c_out=n_temporal_units_hidden,
            )
        elif mode == "OmniScaleCNN":
            self.temporal_layer = OmniScaleCNN(
                c_in=n_temporal_units_in,
                c_out=n_temporal_units_hidden,
                seq_len=max(n_temporal_window, 10),
            )
        elif mode == "XCM":
            self.temporal_layer = XCM(
                c_in=n_temporal_units_in,
                c_out=n_temporal_units_hidden,
                seq_len=n_temporal_window,
                fc_dropout=dropout,
            )
        elif mode == "Transformer":
            self.temporal_layer = TransformerModel(
                c_in=n_temporal_units_in,
                c_out=n_temporal_units_hidden,
                dropout=dropout,
                n_layers=n_temporal_layers_hidden,
            )
        else:
            raise RuntimeError(f"Unknown TS mode {mode}")

        self.device = device
        self.mode = mode

        if mode in ["RNN", "LSTM", "GRU"]:
            self.out = WindowLinearLayer(
                n_static_units_in=n_static_units_in,
                n_temporal_units_in=n_temporal_units_hidden,
                window_size=window_size,
                n_units_out=n_units_out,
                n_layers=n_static_layers_hidden,
                dropout=dropout,
                nonlin=nonlin,
                device=device,
            )
        else:
            self.out = MLP(
                task_type="regression",
                n_units_in=n_static_units_in + n_temporal_units_hidden,
                n_units_out=n_units_out,
                n_layers_hidden=n_static_layers_hidden,
                n_units_hidden=n_static_units_hidden,
                dropout=dropout,
                nonlin=nonlin,
                device=device,
            )

        self.temporal_layer.to(device)
        self.out.to(device)

    def forward(self, static_data: torch.Tensor, temporal_data: torch.Tensor) -> torch.Tensor:
        if self.mode in ["RNN", "LSTM", "GRU"]:
            X_interm, _ = self.temporal_layer(temporal_data)

            if torch.isnan(X_interm).sum() != 0:
                raise RuntimeError("NaNs detected in the temporal embeddings")

            return self.out(static_data, X_interm)
        else:
            X_interm = self.temporal_layer(torch.swapaxes(temporal_data, 1, 2))

            if torch.isnan(X_interm).sum() != 0:
                raise RuntimeError("NaNs detected in the temporal embeddings")

            return self.out(torch.cat([static_data, X_interm], dim=1))


class WindowLinearLayer(nn.Module):
    @pydantic.validate_arguments(config=pydantic.ConfigDict(arbitrary_types_allowed=True))  # type: ignore [operator]
    def __init__(
        self,
        n_static_units_in: int,
        n_temporal_units_in: int,
        window_size: int,
        n_units_out: int,
        n_units_hidden: int = 100,
        n_layers: int = 1,
        dropout: float = 0,
        nonlin: Nonlin = "relu",
        device: Any = constants.DEVICE,
    ) -> None:
        super(WindowLinearLayer, self).__init__()

        self.device = device
        self.window_size = window_size
        self.n_static_units_in = n_static_units_in
        self.model = MLP(
            task_type="regression",
            n_units_in=n_static_units_in + n_temporal_units_in * window_size,
            n_units_out=n_units_out,
            n_layers_hidden=n_layers,
            n_units_hidden=n_units_hidden,
            dropout=dropout,
            nonlin=nonlin,
            device=device,
        )

    @pydantic.validate_arguments(config=pydantic.ConfigDict(arbitrary_types_allowed=True))  # type: ignore [operator]
    def forward(self, static_data: torch.Tensor, temporal_data: torch.Tensor) -> torch.Tensor:
        if self.n_static_units_in > 0 and len(static_data) != len(temporal_data):
            raise ValueError("Length mismatch between static and temporal data")

        batch_size, seq_len, n_feats = temporal_data.shape
        temporal_batch = temporal_data[:, seq_len - self.window_size :, :].reshape(
            batch_size, n_feats * self.window_size
        )
        batch = torch.cat([static_data, temporal_batch], dim=1)

        return self.model(batch).to(self.device)

Command used

pylint a.py

Pylint output

pylint crashed with a ``AstroidError`` and with the following stacktrace:
Traceback (most recent call last):
  File "/home/essav/.vscode-insiders/extensions/ms-python.pylint-2023.9.12791029/bundled/libs/pylint/lint/pylinter.py", line 788, in _lint_file
    check_astroid_module(module)
  File "/home/essav/.vscode-insiders/extensions/ms-python.pylint-2023.9.12791029/bundled/libs/pylint/lint/pylinter.py", line 1017, in check_astroid_module
    retval = self._check_astroid_module(
  File "/home/essav/.vscode-insiders/extensions/ms-python.pylint-2023.9.12791029/bundled/libs/pylint/lint/pylinter.py", line 1069, in _check_astroid_module
    walker.walk(node)
  File "/home/essav/.vscode-insiders/extensions/ms-python.pylint-2023.9.12791029/bundled/libs/pylint/utils/ast_walker.py", line 94, in walk
    self.walk(child)
  File "/home/essav/.vscode-insiders/extensions/ms-python.pylint-2023.9.12791029/bundled/libs/pylint/utils/ast_walker.py", line 94, in walk
    self.walk(child)
  File "/home/essav/.vscode-insiders/extensions/ms-python.pylint-2023.9.12791029/bundled/libs/pylint/utils/ast_walker.py", line 91, in walk
    callback(astroid)
  File "/home/essav/.vscode-insiders/extensions/ms-python.pylint-2023.9.12791029/bundled/libs/pylint/checkers/classes/special_methods_checker.py", line 183, in visit_functiondef
    inferred = _safe_infer_call_result(node, node)
  File "/home/essav/.vscode-insiders/extensions/ms-python.pylint-2023.9.12791029/bundled/libs/pylint/checkers/classes/special_methods_checker.py", line 48, in _safe_infer_call_result
    next(inferit)
  File "/home/essav/.vscode-insiders/extensions/ms-python.pylint-2023.9.12791029/bundled/libs/astroid/nodes/scoped_nodes/scoped_nodes.py", line 1650, in infer_call_result
    yield from returnnode.value.infer(context)
  File "/home/essav/.vscode-insiders/extensions/ms-python.pylint-2023.9.12791029/bundled/libs/astroid/nodes/node_ng.py", line 169, in infer
    for i, result in enumerate(self._infer(context=context, **kwargs)):
  File "/home/essav/.vscode-insiders/extensions/ms-python.pylint-2023.9.12791029/bundled/libs/astroid/decorators.py", line 103, in inner
    yield from generator
  File "/home/essav/.vscode-insiders/extensions/ms-python.pylint-2023.9.12791029/bundled/libs/astroid/decorators.py", line 49, in wrapped
    for res in _func(node, context, **kwargs):
  File "/home/essav/.vscode-insiders/extensions/ms-python.pylint-2023.9.12791029/bundled/libs/astroid/bases.py", line 179, in _infer_stmts
    for inf in stmt.infer(context=context):
  File "/home/essav/.vscode-insiders/extensions/ms-python.pylint-2023.9.12791029/bundled/libs/astroid/nodes/node_ng.py", line 169, in infer
    for i, result in enumerate(self._infer(context=context, **kwargs)):
  File "/home/essav/.vscode-insiders/extensions/ms-python.pylint-2023.9.12791029/bundled/libs/astroid/decorators.py", line 103, in inner
    yield from generator
  File "/home/essav/.vscode-insiders/extensions/ms-python.pylint-2023.9.12791029/bundled/libs/astroid/decorators.py", line 49, in wrapped
    for res in _func(node, context, **kwargs):
  File "/home/essav/.vscode-insiders/extensions/ms-python.pylint-2023.9.12791029/bundled/libs/astroid/bases.py", line 179, in _infer_stmts
    for inf in stmt.infer(context=context):
  File "/home/essav/.vscode-insiders/extensions/ms-python.pylint-2023.9.12791029/bundled/libs/astroid/nodes/node_ng.py", line 169, in infer
    for i, result in enumerate(self._infer(context=context, **kwargs)):
  File "/home/essav/.vscode-insiders/extensions/ms-python.pylint-2023.9.12791029/bundled/libs/astroid/decorators.py", line 103, in inner
    yield from generator
  File "/home/essav/.vscode-insiders/extensions/ms-python.pylint-2023.9.12791029/bundled/libs/astroid/decorators.py", line 49, in wrapped
    for res in _func(node, context, **kwargs):
  File "/home/essav/.vscode-insiders/extensions/ms-python.pylint-2023.9.12791029/bundled/libs/astroid/nodes/node_classes.py", line 1756, in _infer
    for callee in self.func.infer(context):
  File "/home/essav/.vscode-insiders/extensions/ms-python.pylint-2023.9.12791029/bundled/libs/astroid/nodes/node_ng.py", line 169, in infer
    for i, result in enumerate(self._infer(context=context, **kwargs)):
  File "/home/essav/.vscode-insiders/extensions/ms-python.pylint-2023.9.12791029/bundled/libs/astroid/decorators.py", line 103, in inner
    yield from generator
  File "/home/essav/.vscode-insiders/extensions/ms-python.pylint-2023.9.12791029/bundled/libs/astroid/decorators.py", line 49, in wrapped
    for res in _func(node, context, **kwargs):
  File "/home/essav/.vscode-insiders/extensions/ms-python.pylint-2023.9.12791029/bundled/libs/astroid/nodes/node_classes.py", line 1090, in _infer_attribute
    for owner in node.expr.infer(context):
  File "/home/essav/.vscode-insiders/extensions/ms-python.pylint-2023.9.12791029/bundled/libs/astroid/nodes/node_ng.py", line 169, in infer
    for i, result in enumerate(self._infer(context=context, **kwargs)):
  File "/home/essav/.vscode-insiders/extensions/ms-python.pylint-2023.9.12791029/bundled/libs/astroid/decorators.py", line 103, in inner
    yield from generator
  File "/home/essav/.vscode-insiders/extensions/ms-python.pylint-2023.9.12791029/bundled/libs/astroid/decorators.py", line 49, in wrapped
    for res in _func(node, context, **kwargs):
  File "/home/essav/.vscode-insiders/extensions/ms-python.pylint-2023.9.12791029/bundled/libs/astroid/bases.py", line 179, in _infer_stmts
    for inf in stmt.infer(context=context):
  File "/home/essav/.vscode-insiders/extensions/ms-python.pylint-2023.9.12791029/bundled/libs/astroid/nodes/node_ng.py", line 169, in infer
    for i, result in enumerate(self._infer(context=context, **kwargs)):
  File "/home/essav/.vscode-insiders/extensions/ms-python.pylint-2023.9.12791029/bundled/libs/astroid/decorators.py", line 103, in inner
    yield from generator
  File "/home/essav/.vscode-insiders/extensions/ms-python.pylint-2023.9.12791029/bundled/libs/astroid/decorators.py", line 49, in wrapped
    for res in _func(node, context, **kwargs):
  File "/home/essav/.vscode-insiders/extensions/ms-python.pylint-2023.9.12791029/bundled/libs/astroid/bases.py", line 179, in _infer_stmts
    for inf in stmt.infer(context=context):
  File "/home/essav/.vscode-insiders/extensions/ms-python.pylint-2023.9.12791029/bundled/libs/astroid/nodes/node_ng.py", line 169, in infer
    for i, result in enumerate(self._infer(context=context, **kwargs)):
  File "/home/essav/.vscode-insiders/extensions/ms-python.pylint-2023.9.12791029/bundled/libs/astroid/decorators.py", line 103, in inner
    yield from generator
  File "/home/essav/.vscode-insiders/extensions/ms-python.pylint-2023.9.12791029/bundled/libs/astroid/decorators.py", line 49, in wrapped
    for res in _func(node, context, **kwargs):
  File "/home/essav/.vscode-insiders/extensions/ms-python.pylint-2023.9.12791029/bundled/libs/astroid/nodes/node_classes.py", line 1765, in _infer
    yield from callee.infer_call_result(
  File "/home/essav/.vscode-insiders/extensions/ms-python.pylint-2023.9.12791029/bundled/libs/astroid/bases.py", line 331, in infer_call_result
    for res in node.infer_call_result(caller, context):
  File "/home/essav/.vscode-insiders/extensions/ms-python.pylint-2023.9.12791029/bundled/libs/astroid/bases.py", line 331, in infer_call_result
    for res in node.infer_call_result(caller, context):
  File "/home/essav/.vscode-insiders/extensions/ms-python.pylint-2023.9.12791029/bundled/libs/astroid/bases.py", line 331, in infer_call_result
    for res in node.infer_call_result(caller, context):
  [Previous line repeated 925 more times]
  File "/home/essav/.vscode-insiders/extensions/ms-python.pylint-2023.9.12791029/bundled/libs/astroid/bases.py", line 323, in infer_call_result
    for res in self.igetattr(caller.func.attrname, context):
  File "/home/essav/.vscode-insiders/extensions/ms-python.pylint-2023.9.12791029/bundled/libs/astroid/bases.py", line 279, in igetattr
    yield from _infer_stmts(
  File "/home/essav/.vscode-insiders/extensions/ms-python.pylint-2023.9.12791029/bundled/libs/astroid/bases.py", line 179, in _infer_stmts
    for inf in stmt.infer(context=context):
RecursionError: maximum recursion depth exceeded

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/essav/.vscode-insiders/extensions/ms-python.pylint-2023.9.12791029/bundled/libs/pylint/lint/pylinter.py", line 752, in _lint_files
    self._lint_file(fileitem, module, check_astroid_module)
  File "/home/essav/.vscode-insiders/extensions/ms-python.pylint-2023.9.12791029/bundled/libs/pylint/lint/pylinter.py", line 790, in _lint_file
    raise astroid.AstroidError from e
astroid.exceptions.AstroidError

Expected behavior

No crash.

Pylint version

pylint 3.0.1
astroid 3.0.0
Python 3.8.18 (default, Sep 11 2023, 13:40:15) 
[GCC 11.2.0]

OS / Environment

linux (Linux)

Additional dependencies

@Pierre-Sassoulas Pierre-Sassoulas added the Crash 💥 A bug that makes pylint crash label Oct 11, 2023
@jacobtylerwalls
Copy link
Member

Thanks for the report, we always appreciate people providing the crash template.

We may in the future mark this as a duplicate of #8842.

@jacobtylerwalls
Copy link
Member

#8842 has to do with recursion errors during AST build/transform, which we should handle in astroid.

I'll let this issue be the top-level duplicate for catching RecursionError during inference, which we can handle in pylint.

@jacobtylerwalls jacobtylerwalls changed the title Crash with astroid-error RecursionError escapes during inference Feb 20, 2024
@jacobtylerwalls jacobtylerwalls added the Needs PR This issue is accepted, sufficiently specified and now needs an implementation label Feb 20, 2024
@jacobtylerwalls jacobtylerwalls added this to the 3.1.1 milestone Feb 26, 2024
@jacobtylerwalls jacobtylerwalls self-assigned this May 4, 2024
@jacobtylerwalls jacobtylerwalls modified the milestones: 3.1.1, 3.2.0 May 12, 2024
@jacobtylerwalls jacobtylerwalls added the Needs astroid update Needs an astroid update (probably a release too) before being mergable label May 12, 2024
@jacobtylerwalls jacobtylerwalls modified the milestones: 3.2.0, 3.2.1 May 13, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Crash 💥 A bug that makes pylint crash Needs astroid update Needs an astroid update (probably a release too) before being mergable Needs PR This issue is accepted, sufficiently specified and now needs an implementation
Projects
None yet
3 participants