Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unable to load saved model #7

Open
deewhy26 opened this issue Sep 24, 2023 · 2 comments
Open

Unable to load saved model #7

deewhy26 opened this issue Sep 24, 2023 · 2 comments

Comments

@deewhy26
Copy link

deewhy26 commented Sep 24, 2023

Hello, sorry im quite new to writing issues.
I trained a joint token classification and sequence classification model. To save it i used this:
trainer.save_model("multi_task/")
However trying to load the same model, i faced this issue
model_2 = tn.load_pipeline("/kaggle/working/multi_task","intent_classification", adapt_task_embedding=True)

╭─────────────────────────────── Traceback (most recent call last) ────────────────────────────────╮
│ in <module>:1                                                                                    │
│                                                                                                  │
│ ❱ 1 model_2 =  tn.load_pipeline("/kaggle/working/multi_task","intent_classification", adapt_     │
│   2                                                                                              │
│                                                                                                  │
│ /opt/conda/lib/python3.10/site-packages/tasknet/utils.py:198 in load_pipeline                    │
│                                                                                                  │
│   195 │   │   import tasksource                                                                  │
│   196except:                                                                                │
│   197 │   │   raise ImportError("Requires tasksource.\n pip install tasksource")                 │
│ ❱ 198task = tasksource.load_task(task_name, multilingual=multilingual)                      │
│   199 │                                                                                          │
│   200model = AutoModelForSequenceClassification.from_pretrained(                            │
│   201 │   │   model_name, ignore_mismatched_sizes=True                                           │
│                                                                                                  │
│ /opt/conda/lib/python3.10/site-packages/tasksource/access.py:102 in load_task                    │
│                                                                                                  │
│    99query = dict_of(id, dataset_name, config_name, task_name,preprocessing_name)           │
│   100query = {k:v for k,v in query.items() if v}                                            │
│   101_tasks = (lmtasks if multilingual else tasks)                                          │
│ ❱ 102preprocessing = load_preprocessing(_tasks, **query)                                    │
│   103dataset = load_dataset(preprocessing.dataset_name, preprocessing.config_name, **load   │
│   104dataset= preprocessing(dataset,max_rows, max_rows_eval)                                │
│   105dataset.task_type = preprocessing.__class__.__name__                                   │
│                                                                                                  │
│ /opt/conda/lib/python3.10/site-packages/tasksource/access.py:90 in load_preprocessing            │
│                                                                                                  │
│    87                                                                                            │
│    88 def load_preprocessing(tasks=tasks, **kwargs):                                             │
│    89_tasks_df = list_tasks(multilingual=tasks==lmtasks)                                    │
│ ❱  90y = _tasks_df.copy().query(dict_to_query(**kwargs)).iloc[0]                            │
│    91preprocessing= copy.copy(getattr(tasks, y.preprocessing_name))                         │
│    92for c in 'dataset_name','config_name':                                                 │
│    93 │   │   if not isinstance(getattr(preprocessing,c), str):                                  │
│                                                                                                  │
│ /opt/conda/lib/python3.10/site-packages/pandas/core/indexing.py:1073 in __getitem__              │
│                                                                                                  │
│   1070 │   │   │   axis = self.axis or 0                                                         │
│   1071 │   │   │                                                                                 │
│   1072 │   │   │   maybe_callable = com.apply_if_callable(key, self.obj)                         │
│ ❱ 1073 │   │   │   return self._getitem_axis(maybe_callable, axis=axis)                          │
│   1074 │                                                                                         │
│   1075def _is_scalar_access(self, key: tuple):                                              │
│   1076 │   │   raise NotImplementedError()                                                       │
│                                                                                                  │
│ /opt/conda/lib/python3.10/site-packages/pandas/core/indexing.py:1625 in _getitem_axis            │
│                                                                                                  │
│   1622 │   │   │   │   raise TypeError("Cannot index by location index with a non-integer key")  │
│   1623 │   │   │                                                                                 │
│   1624 │   │   │   # validate the location                                                       │
│ ❱ 1625 │   │   │   self._validate_integer(key, axis)                                             │
│   1626 │   │   │                                                                                 │
│   1627 │   │   │   return self.obj._ixs(key, axis=axis)                                          │
│   1628                                                                                           │
│                                                                                                  │
│ /opt/conda/lib/python3.10/site-packages/pandas/core/indexing.py:1557 in _validate_integer        │
│                                                                                                  │
│   1554 │   │   """                                                                               │
│   1555 │   │   len_axis = len(self.obj._get_axis(axis))                                          │
│   1556 │   │   if key >= len_axis or key < -len_axis:                                            │
│ ❱ 1557 │   │   │   raise IndexError("single positional indexer is out-of-bounds")                │
│   1558 │                                                                                         │
│   1559# -------------------------------------------------------------------                 │1560                                                                                           │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
IndexError: single positional indexer is out-of-bounds

I intend to load the model from this checkpoint for domain adaption
How do i address this issue? Also if you could redirect me to some resources to understand adapters that would be pretty helpful (i found out from the other issue posted here).
PS: Thank you for this incredible library- it saved me a lot of time

@sileod
Copy link
Owner

sileod commented Sep 25, 2023

Hi ,thank you for your issue, would it be easy to reproduce the error in a colab ?

@deewhy26
Copy link
Author

Hi ,thank you for your issue, would it be easy to reproduce the error in a colab ?

Sorry i dont understand, do you want me to try this on collab? This error currently was rendered on kaggle

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants