Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Issue fien tuning Falcon #3996

Open
vinven7 opened this issue Apr 22, 2024 · 2 comments
Open

Issue fien tuning Falcon #3996

vinven7 opened this issue Apr 22, 2024 · 2 comments
Assignees

Comments

@vinven7
Copy link

vinven7 commented Apr 22, 2024

Describe the bug
I am trying to finetune tiiuae/falcon-7b-instruct and I am getting this error.

TypeError: where(): argument 'condition' (position 1) must be Tensor, not bool

To Reproduce Steps to reproduce the behavior:

qlora_fine_tuning_config = yaml.safe_load(
"""
model_type: llm
base_model: tiiuae/falcon-7b-instruct

input_features:
  - name: Prompt
    type: text
    preprocessing:
      max_sequence_length: 256

output_features:
  - name: Responses
    type: text
    preprocessing:
      max_sequence_length: 256

prompt:
   template: >-

     ### Prompt: {Prompt}

     ### responses : 

generation:
  temperature: 0.1
  max_new_tokens: 256

adapter:
  type: lora

quantization:
  bits: 4

preprocessing:
  split:
     probabilities:
      - 1.0
      - 0.0
      - 0.0

trainer:
  type: finetune
  # epochs: 5
  # epochs: 3
  train_steps: 5
  batch_size: 1
  eval_batch_size: 2
  gradient_accumulation_steps: 16  # effective batch size = batch size * gradient_accumulation_steps
  learning_rate: 2.0e-4
  enable_gradient_checkpointing: true
  learning_rate_scheduler:
    decay: cosine
    warmup_fraction: 0.03
    reduce_on_plateau: 0
"""
)

new_model = LudwigModel(config=qlora_fine_tuning_config, logging_level=logging.INFO)
results = new_model.train(dataset=train_df)`

Here is the full trace:

-------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[9], line 65
      1 qlora_fine_tuning_config = yaml.safe_load(
      2 """
      3     model_type: llm
   (...)
     61     """
     62   )
     64 new_model = LudwigModel(config=qlora_fine_tuning_config, logging_level=logging.INFO)
---> 65 results = new_model.train(dataset=train_df)

File ~/anaconda3/envs/vineeth_10/lib/python3.10/site-packages/ludwig/api.py:682, in LudwigModel.train(self, dataset, training_set, validation_set, test_set, training_set_metadata, data_format, experiment_name, model_name, model_resume_path, skip_save_training_description, skip_save_training_statistics, skip_save_model, skip_save_progress, skip_save_log, skip_save_processed_input, output_directory, random_seed, **kwargs)
    675     callback.on_train_start(
    676         model=self.model,
    677         config=self.config_obj.to_dict(),
    678         config_fp=self.config_fp,
    679     )
    681 try:
--> 682     train_stats = trainer.train(
    683         training_set,
    684         validation_set=validation_set,
    685         test_set=test_set,
    686         save_path=model_dir,
    687     )
    688     (self.model, train_trainset_stats, train_valiset_stats, train_testset_stats) = train_stats
    690     # Calibrates output feature probabilities on validation set if calibration is enabled.
    691     # Must be done after training, and before final model parameters are saved.

File ~/anaconda3/envs/vineeth_10/lib/python3.10/site-packages/ludwig/trainers/trainer.py:1050, in Trainer.train(self, training_set, validation_set, test_set, save_path, return_state_dict, **kwargs)
   1047 self.callback(lambda c: c.on_epoch_start(self, progress_tracker, save_path))
   1049 # Trains over a full epoch of data or up to the last training step, whichever is sooner.
-> 1050 should_break, has_nan_or_inf_tensors = self._train_loop(
   1051     batcher,
   1052     progress_tracker,
   1053     save_path,
   1054     train_summary_writer,
   1055     progress_bar,
   1056     training_set,
   1057     validation_set,
   1058     test_set,
   1059     start_time,
   1060     validation_summary_writer,
   1061     test_summary_writer,
   1062     model_hyperparameters_path,
   1063     output_features,
   1064     metrics_names,
   1065     checkpoint_manager,
   1066     final_steps_per_checkpoint,
   1067     early_stopping_steps,
   1068     profiler,
   1069 )
   1070 if self.is_coordinator():
   1071     # ========== Save training progress ==========
   1072     logger.debug(
   1073         f"Epoch {progress_tracker.epoch} took: "
   1074         f"{time_utils.strdelta((time.time() - start_time) * 1000.0)}."
   1075     )

File ~/anaconda3/envs/vineeth_10/lib/python3.10/site-packages/ludwig/trainers/trainer.py:1247, in Trainer._train_loop(self, batcher, progress_tracker, save_path, train_summary_writer, progress_bar, training_set, validation_set, test_set, start_time, validation_summary_writer, test_summary_writer, model_hyperparameters_path, output_features, metrics_names, checkpoint_manager, final_steps_per_checkpoint, early_stopping_steps, profiler)
   1238 inputs = {
   1239     i_feat.feature_name: torch.from_numpy(np.array(batch[i_feat.proc_column], copy=True)).to(self.device)
   1240     for i_feat in self.model.input_features.values()
   1241 }
   1242 targets = {
   1243     o_feat.feature_name: torch.from_numpy(np.array(batch[o_feat.proc_column], copy=True)).to(self.device)
   1244     for o_feat in self.model.output_features.values()
   1245 }
-> 1247 loss, all_losses, used_tokens = self.train_step(inputs, targets, should_step=should_step, profiler=profiler)
   1249 # Update LR schduler here instead of train loop to avoid updating during batch size tuning, etc.
   1250 self.scheduler.step()

File ~/anaconda3/envs/vineeth_10/lib/python3.10/site-packages/ludwig/trainers/trainer.py:340, in Trainer.train_step(self, inputs, targets, should_step, profiler)
    337 with torch.cuda.amp.autocast() if self.use_amp else contextlib.nullcontext():
    338     with self.distributed.prepare_model_update(self.dist_model, should_step=should_step):
    339         # Obtain model predictions and loss
--> 340         model_outputs = self.dist_model((inputs, targets))
    341         loss, all_losses = self.model.train_loss(
    342             targets, model_outputs, self.regularization_type, self.regularization_lambda
    343         )
    344         loss = loss / self.gradient_accumulation_steps

File ~/anaconda3/envs/vineeth_10/lib/python3.10/site-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, **kwargs)
   1496 # If we don't have any hooks, we want to skip the rest of the logic in
   1497 # this function, and just call forward.
   1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1499         or _global_backward_pre_hooks or _global_backward_hooks
   1500         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501     return forward_call(*args, **kwargs)
   1502 # Do not call functions when jit is used
   1503 full_backward_hooks, non_full_backward_hooks = [], []

File ~/anaconda3/envs/vineeth_10/lib/python3.10/site-packages/ludwig/models/llm.py:265, in LLM.forward(self, inputs, mask)
    261 input_ids, target_ids = self._unpack_inputs(inputs)
    263 # Generate merged input_id, target_id pairs for the model, and create corresponding attention masks
    264 # We save them as class variables so that we can use them when realigning target and prediction tensors
--> 265 self.model_inputs, self.attention_masks = generate_merged_ids(
    266     input_ids, target_ids, self.tokenizer, self.global_max_sequence_length
    267 )
    269 # Wrap with flash attention backend for faster generation
    270 with (
    271     torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False)
    272     if (torch.cuda.is_available() and self.curr_device.type == "cuda")
    273     else contextlib.nullcontext()
    274 ):
    275     # TODO (jeffkinnison): Determine why the 8-bit `SCB` and `CB` matrices are deleted in the forward pass

File ~/anaconda3/envs/vineeth_10/lib/python3.10/site-packages/ludwig/utils/llm_utils.py:479, in generate_merged_ids(input_ids, target_ids, tokenizer, max_sequence_length)
    476 # Merge input_ids and target_ids by concatenating them together.
    477 # We remove the left padding from both input_ids and target_ids before concatenating them.
    478 for input_id_sample, target_id_sample in zip(input_ids, target_ids):
--> 479     input_id_sample_no_padding = remove_left_padding(input_id_sample, tokenizer)[0]
    480     target_id_sample_no_padding = remove_left_padding(target_id_sample, tokenizer)[0]
    481     target_id_sample_no_padding = torch.cat((target_id_sample_no_padding, eos_tensor), dim=-1)

File ~/anaconda3/envs/vineeth_10/lib/python3.10/site-packages/ludwig/utils/llm_utils.py:291, in remove_left_padding(input_ids_sample, tokenizer)
    288     input_ids_no_padding = input_ids_sample[pad_idx + 1 :]
    290 # Start from the first BOS token
--> 291 bos_idxs = torch.where(input_ids_no_padding == tokenizer.bos_token_id)[0]  # all BOS token locations
    292 if len(bos_idxs) != 0:
    293     bos_idx = bos_idxs[0]  # get first BOS token location

TypeError: where(): argument 'condition' (position 1) must be Tensor, not bool`

Environment (please complete the following information):

OS: Linux, Jupyter Notebook
Version [e.g. 22]
3.0
0.10.3

@arnavgarg1 arnavgarg1 self-assigned this Apr 22, 2024
@arnavgarg1
Copy link
Contributor

arnavgarg1 commented Apr 22, 2024

Hi @vinven7, I have this potential fix here: #3997
Are you able to pull down my branch and give it a try for me to let me know if it works?

@vinven7
Copy link
Author

vinven7 commented May 1, 2024

HI @arnavgarg1 I tried this and got the same error:


│ Experiment name  │ api_experiment                                                                                            │
├──────────────────┼───────────────────────────────────────────────────────────────────────────────────────────────────────────┤
│ Model name       │ run                                                                                                       │
├──────────────────┼───────────────────────────────────────────────────────────────────────────────────────────────────────────┤
│ Output directory │ /home/jupyter/Vineeth/MatKG_RLHF/results/api_experiment_run_26                                            │
├──────────────────┼───────────────────────────────────────────────────────────────────────────────────────────────────────────┤
│ ludwig_version   │ '0.10.3.dev'                                                                                              │
├──────────────────┼───────────────────────────────────────────────────────────────────────────────────────────────────────────┤
│ command          │ ('/home/synthesisproject/anaconda3/envs/vineeth_10.1/lib/python3.10/site-packages/ipykernel_launcher.py ' │
│                  │  '-f '                                                                                                    │
│                  │  '/home/synthesisproject/.local/share/jupyter/runtime/kernel-3ad33b30-7b97-4f7a-9f68-070d36c1d4b3.json')  │
├──────────────────┼───────────────────────────────────────────────────────────────────────────────────────────────────────────┤
│ commit_hash      │ '46e384352263'                                                                                            │
├──────────────────┼───────────────────────────────────────────────────────────────────────────────────────────────────────────┤
│ random_seed      │ 42                                                                                                        │
├──────────────────┼───────────────────────────────────────────────────────────────────────────────────────────────────────────┤
│ data_format      │ "<class 'pandas.core.frame.DataFrame'>"                                                                   │
├──────────────────┼───────────────────────────────────────────────────────────────────────────────────────────────────────────┤
│ torch_version    │ '2.0.1'                                                                                                   │
├──────────────────┼───────────────────────────────────────────────────────────────────────────────────────────────────────────┤
│ compute          │ {   'arch_list': [   'sm_37',                                                                             │
│                  │                      'sm_50',                                                                             │
│                  │                      'sm_60',                                                                             │
│                  │                      'sm_61',                                                                             │
│                  │                      'sm_70',                                                                             │
│                  │                      'sm_75',                                                                             │
│                  │                      'sm_80',                                                                             │
│                  │                      'sm_86',                                                                             │
│                  │                      'sm_90',                                                                             │
│                  │                      'compute_37'],                                                                       │
│                  │     'devices': {   0: {   'device_capability': (8, 6),                                                    │
│                  │                           'device_properties': "_CudaDeviceProperties(name='NVIDIA "                      │
│                  │                                                "RTX A5000', major=8, minor=6, "                           │
│                  │                                                'total_memory=24256MB, '                                   │
│                  │                                                'multi_processor_count=64)',                               │
│                  │                           'gpu_type': 'NVIDIA RTX A5000'}},                                               │
│                  │     'gencode_flags': '-gencode compute=compute_37,code=sm_37 -gencode '                                   │
│                  │                      'compute=compute_50,code=sm_50 -gencode '                                            │
│                  │                      'compute=compute_60,code=sm_60 -gencode '                                            │
│                  │                      'compute=compute_61,code=sm_61 -gencode '                                            │
│                  │                      'compute=compute_70,code=sm_70 -gencode '                                            │
│                  │                      'compute=compute_75,code=sm_75 -gencode '                                            │
│                  │                      'compute=compute_80,code=sm_80 -gencode '                                            │
│                  │                      'compute=compute_86,code=sm_86 -gencode '                                            │
│                  │                      'compute=compute_90,code=sm_90 -gencode '                                            │
│                  │                      'compute=compute_37,code=compute_37',                                                │
│                  │     'gpus_per_node': 1,                                                                                   │
│                  │     'num_nodes': 1}                                                                                       │
╘══════════════════╧═══════════════════════════════════════════════════════════════════════════════════════════════════════════╛

╒═══════════════╕
│ LUDWIG CONFIG │
╘═══════════════╛

User-specified config (with upgrades):

{   'adapter': {'type': 'lora'},
    'base_model': 'tiiuae/falcon-7b-instruct',
    'generation': {'max_new_tokens': 256, 'temperature': 0.1},
    'input_features': [   {   'name': 'Prompt',
                              'preprocessing': {'max_sequence_length': 256},
                              'type': 'text'}],
    'ludwig_version': '0.10.3.dev',
    'model_type': 'llm',
    'output_features': [   {   'name': 'Responses',
                               'preprocessing': {'max_sequence_length': 256},
                               'type': 'text'}],
    'preprocessing': {'split': {'probabilities': [1.0, 0.0, 0.0]}},
    'prompt': {'template': '\n### Prompt: {Prompt}\n### responses : '},
    'quantization': {'bits': 8},
    'trainer': {   'batch_size': 1,
                   'enable_gradient_checkpointing': True,
                   'epochs': 10,
                   'eval_batch_size': 1,
                   'gradient_accumulation_steps': 16,
                   'learning_rate': 1e-05,
                   'learning_rate_scheduler': {   'reduce_on_plateau': 0,
                                                  'warmup_fraction': 0.03},
                   'optimizer': {   'params': {   'betas': [0.9, 0.999],
                                                  'eps': 1e-08,
                                                  'weight_decay': 0},
                                    'type': 'paged_adam'},
                   'type': 'finetune'}}

Full config saved to:
/home/jupyter/Vineeth/MatKG_RLHF/results/api_experiment_run_26/api_experiment/model/model_hyperparameters.json

╒═══════════════╕
│ PREPROCESSING │
╘═══════════════╛

No cached dataset found at /home/jupyter/Vineeth/MatKG_RLHF/828481b807b711ef8a2b3cecef329410.training.hdf5. Preprocessing the dataset.
Using full dataframe
Building dataset (it may take a while)
Loaded HuggingFace implementation of tiiuae/falcon-7b-instruct tokenizer
No padding token id found. Using eos_token as pad_token.
Max length of feature 'None': 144 (without start and stop symbols)
Max sequence length is 144 for feature 'None'
Loaded HuggingFace implementation of tiiuae/falcon-7b-instruct tokenizer
No padding token id found. Using eos_token as pad_token.
Max length of feature 'Responses': 49 (without start and stop symbols)
Max sequence length is 49 for feature 'Responses'
Loaded HuggingFace implementation of tiiuae/falcon-7b-instruct tokenizer
No padding token id found. Using eos_token as pad_token.
Loaded HuggingFace implementation of tiiuae/falcon-7b-instruct tokenizer
No padding token id found. Using eos_token as pad_token.
Building dataset: DONE
Writing preprocessed training set cache to /home/jupyter/Vineeth/MatKG_RLHF/828481b807b711ef8a2b3cecef329410.training.hdf5
Writing preprocessed validation set cache to /home/jupyter/Vineeth/MatKG_RLHF/828481b807b711ef8a2b3cecef329410.validation.hdf5
Writing preprocessed test set cache to /home/jupyter/Vineeth/MatKG_RLHF/828481b807b711ef8a2b3cecef329410.test.hdf5
Writing train set metadata to /home/jupyter/Vineeth/MatKG_RLHF/828481b807b711ef8a2b3cecef329410.meta.json
Validation set empty. If this is unintentional, please check the preprocessing configuration.
Test set empty. If this is unintentional, please check the preprocessing configuration.

Dataset Statistics
╒═══════════╤═══════════════╤════════════════════╕
│ Dataset   │   Size (Rows) │ Size (In Memory)   │
╞═══════════╪═══════════════╪════════════════════╡
│ Training  │          9418 │ 2.16 Mb            │
╘═══════════╧═══════════════╧════════════════════╛

╒═══════╕
│ MODEL │
╘═══════╛

Warnings and other logs:
Loading large language model...
We will use 90% of the memory on device 0 for storing the model, and 10% for the buffer to avoid OOM. You can set `max_memory` in to a higher value to use more memory (at your own risk).
/home/synthesisproject/anaconda3/envs/vineeth_10.1/lib/python3.10/site-packages/torch/_utils.py:776: UserWarning: TypedStorage is deprecated. It will be removed in the future and UntypedStorage will be the only storage class. This should only matter to you if you are using storages directly.  To access UntypedStorage directly, use tensor.untyped_storage() instead of tensor.storage()
  return self.fget.__get__(instance, owner)()
Done.
No padding token id found. Using eos_token as pad_token.
Loaded HuggingFace implementation of tiiuae/falcon-7b-instruct tokenizer
No padding token id found. Using eos_token as pad_token.
==================================================
Trainable Parameter Summary For Fine-Tuning
Fine-tuning with adapter: lora
trainable params: 2,359,296 || all params: 6,924,080,000 || trainable%: 0.03407378308742822
==================================================
Gradient checkpointing enabled for training.

╒══════════╕
│ TRAINING │
╘══════════╛

Creating fresh model training run.
Training for 94180 step(s), approximately 10 epoch(s).
Early stopping policy: 5 round(s) of evaluation, or 47090 step(s), approximately 5 epoch(s).

Starting with step 0, epoch: 0
Training:   0%|          | 0/94180 [00:00<?, ?it/s]
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[10], line 65
      1 qlora_fine_tuning_config = yaml.safe_load(
      2 """
      3     model_type: llm
   (...)
     61     """
     62   )
     64 new_model = LudwigModel(config=qlora_fine_tuning_config, logging_level=logging.INFO)
---> 65 results = new_model.train(dataset=train_df)

File ~/anaconda3/envs/vineeth_10.1/lib/python3.10/site-packages/ludwig/api.py:682, in LudwigModel.train(self, dataset, training_set, validation_set, test_set, training_set_metadata, data_format, experiment_name, model_name, model_resume_path, skip_save_training_description, skip_save_training_statistics, skip_save_model, skip_save_progress, skip_save_log, skip_save_processed_input, output_directory, random_seed, **kwargs)
    675     callback.on_train_start(
    676         model=self.model,
    677         config=self.config_obj.to_dict(),
    678         config_fp=self.config_fp,
    679     )
    681 try:
--> 682     train_stats = trainer.train(
    683         training_set,
    684         validation_set=validation_set,
    685         test_set=test_set,
    686         save_path=model_dir,
    687     )
    688     (self.model, train_trainset_stats, train_valiset_stats, train_testset_stats) = train_stats
    690     # Calibrates output feature probabilities on validation set if calibration is enabled.
    691     # Must be done after training, and before final model parameters are saved.

File ~/anaconda3/envs/vineeth_10.1/lib/python3.10/site-packages/ludwig/trainers/trainer.py:1050, in Trainer.train(self, training_set, validation_set, test_set, save_path, return_state_dict, **kwargs)
   1047 self.callback(lambda c: c.on_epoch_start(self, progress_tracker, save_path))
   1049 # Trains over a full epoch of data or up to the last training step, whichever is sooner.
-> 1050 should_break, has_nan_or_inf_tensors = self._train_loop(
   1051     batcher,
   1052     progress_tracker,
   1053     save_path,
   1054     train_summary_writer,
   1055     progress_bar,
   1056     training_set,
   1057     validation_set,
   1058     test_set,
   1059     start_time,
   1060     validation_summary_writer,
   1061     test_summary_writer,
   1062     model_hyperparameters_path,
   1063     output_features,
   1064     metrics_names,
   1065     checkpoint_manager,
   1066     final_steps_per_checkpoint,
   1067     early_stopping_steps,
   1068     profiler,
   1069 )
   1070 if self.is_coordinator():
   1071     # ========== Save training progress ==========
   1072     logger.debug(
   1073         f"Epoch {progress_tracker.epoch} took: "
   1074         f"{time_utils.strdelta((time.time() - start_time) * 1000.0)}."
   1075     )

File ~/anaconda3/envs/vineeth_10.1/lib/python3.10/site-packages/ludwig/trainers/trainer.py:1247, in Trainer._train_loop(self, batcher, progress_tracker, save_path, train_summary_writer, progress_bar, training_set, validation_set, test_set, start_time, validation_summary_writer, test_summary_writer, model_hyperparameters_path, output_features, metrics_names, checkpoint_manager, final_steps_per_checkpoint, early_stopping_steps, profiler)
   1238 inputs = {
   1239     i_feat.feature_name: torch.from_numpy(np.array(batch[i_feat.proc_column], copy=True)).to(self.device)
   1240     for i_feat in self.model.input_features.values()
   1241 }
   1242 targets = {
   1243     o_feat.feature_name: torch.from_numpy(np.array(batch[o_feat.proc_column], copy=True)).to(self.device)
   1244     for o_feat in self.model.output_features.values()
   1245 }
-> 1247 loss, all_losses, used_tokens = self.train_step(inputs, targets, should_step=should_step, profiler=profiler)
   1249 # Update LR schduler here instead of train loop to avoid updating during batch size tuning, etc.
   1250 self.scheduler.step()

File ~/anaconda3/envs/vineeth_10.1/lib/python3.10/site-packages/ludwig/trainers/trainer.py:340, in Trainer.train_step(self, inputs, targets, should_step, profiler)
    337 with torch.cuda.amp.autocast() if self.use_amp else contextlib.nullcontext():
    338     with self.distributed.prepare_model_update(self.dist_model, should_step=should_step):
    339         # Obtain model predictions and loss
--> 340         model_outputs = self.dist_model((inputs, targets))
    341         loss, all_losses = self.model.train_loss(
    342             targets, model_outputs, self.regularization_type, self.regularization_lambda
    343         )
    344         loss = loss / self.gradient_accumulation_steps

File ~/anaconda3/envs/vineeth_10.1/lib/python3.10/site-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, **kwargs)
   1496 # If we don't have any hooks, we want to skip the rest of the logic in
   1497 # this function, and just call forward.
   1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1499         or _global_backward_pre_hooks or _global_backward_hooks
   1500         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501     return forward_call(*args, **kwargs)
   1502 # Do not call functions when jit is used
   1503 full_backward_hooks, non_full_backward_hooks = [], []

File ~/anaconda3/envs/vineeth_10.1/lib/python3.10/site-packages/ludwig/models/llm.py:265, in LLM.forward(self, inputs, mask)
    261 input_ids, target_ids = self._unpack_inputs(inputs)
    263 # Generate merged input_id, target_id pairs for the model, and create corresponding attention masks
    264 # We save them as class variables so that we can use them when realigning target and prediction tensors
--> 265 self.model_inputs, self.attention_masks = generate_merged_ids(
    266     input_ids, target_ids, self.tokenizer, self.global_max_sequence_length
    267 )
    269 # Wrap with flash attention backend for faster generation
    270 with (
    271     torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False)
    272     if (torch.cuda.is_available() and self.curr_device.type == "cuda")
    273     else contextlib.nullcontext()
    274 ):
    275     # TODO (jeffkinnison): Determine why the 8-bit `SCB` and `CB` matrices are deleted in the forward pass

File ~/anaconda3/envs/vineeth_10.1/lib/python3.10/site-packages/ludwig/utils/llm_utils.py:479, in generate_merged_ids(input_ids, target_ids, tokenizer, max_sequence_length)
    476 # Merge input_ids and target_ids by concatenating them together.
    477 # We remove the left padding from both input_ids and target_ids before concatenating them.
    478 for input_id_sample, target_id_sample in zip(input_ids, target_ids):
--> 479     input_id_sample_no_padding = remove_left_padding(input_id_sample, tokenizer)[0]
    480     target_id_sample_no_padding = remove_left_padding(target_id_sample, tokenizer)[0]
    481     target_id_sample_no_padding = torch.cat((target_id_sample_no_padding, eos_tensor), dim=-1)

File ~/anaconda3/envs/vineeth_10.1/lib/python3.10/site-packages/ludwig/utils/llm_utils.py:291, in remove_left_padding(input_ids_sample, tokenizer)
    288     input_ids_no_padding = input_ids_sample[pad_idx + 1 :]
    290 # Start from the first BOS token
--> 291 bos_idxs = torch.where(input_ids_no_padding == tokenizer.bos_token_id)[0]  # all BOS token locations
    292 if len(bos_idxs) != 0:
    293     bos_idx = bos_idxs[0]  # get first BOS token location

TypeError: where(): argument 'condition' (position 1) must be Tensor, not bool

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants