Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

kto error when assign dataset to device #1620

Closed
mostafamdy opened this issue May 4, 2024 · 3 comments
Closed

kto error when assign dataset to device #1620

mostafamdy opened this issue May 4, 2024 · 3 comments

Comments

@mostafamdy
Copy link

Hi
when I set device to dataset and train with KTO I get this error
when removing this line it works

dataset.set_format("torch", device="cuda:1")

code

peft_config = LoraConfig(
    task_type="CAUSAL_LM",#TaskType.SEQ_2_SEQ_LM,
    inference_mode=False,
    r=8,
    lora_alpha=32,
    lora_dropout=0.1
)
kto_config=KTOConfig(
    truncation_mode="keep_start",
    output_dir="/kaggle/working/kto1",
    max_completion_length=1000,
    per_device_train_batch_size=1,
)

# Initialize the KTO trainer
kto_trainer = KTOTrainer(
    model,
    args=kto_config,
    train_dataset=dataset,
    tokenizer=tokenizer,
    peft_config=peft_config,
)

# Train and push the model to the Hub
kto_trainer.train()

error

Tokenizing train dataset: 100%
1319/1319 [00:05<00:00, 242.58 examples/s]
Extracting KL train dataset: 34%
444/1319 [00:00<00:01, 860.87 examples/s]
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Cell In[38], line 16
      8 kto_config=KTOConfig(
      9     truncation_mode="keep_start",
     10     output_dir="/kaggle/working/kto1",
     11     max_completion_length=1000,
     12     per_device_train_batch_size=1,
     13 )
     15 # Initialize the KTO trainer
---> 16 kto_trainer = KTOTrainer(
     17     model,
     18 #     model_ref,
     19     args=kto_config,
     20     train_dataset=formatted_dataset,
     21 #     eval_dataset=formatted_dataset["test"],
     22     tokenizer=tokenizer,
     23     peft_config=peft_config,
     24 )
     26 # Train and push the model to the Hub
     27 kto_trainer.train()

File /opt/conda/lib/python3.10/site-packages/trl/trainer/kto_trainer.py:517, in KTOTrainer.__init__(self, model, ref_model, args, train_dataset, eval_dataset, tokenizer, data_collator, model_init, callbacks, optimizers, preprocess_logits_for_metrics, peft_config, compute_metrics, model_adapter_name, ref_adapter_name)
    512     raise ValueError(
    513         "Batch size is 1 (too small). KTO will not work properly because the KL term will be equivalent to the implied reward."
    514     )
    515 # create pairs for estimating the KL term by flipping the matched pairs in each batch of size total_batch_size
    516 # i.e., (x_1, y_1), ..., (x_n, y_n) --> (x_1, y_n), ..., (x_n, y_1) = (x'_1, y'_1), ..., (x'_n, y'_n)
--> 517 train_kl_dataset = train_dataset.map(
    518     _get_kl_dataset, batched=True, batch_size=total_batch_size, desc="Extracting KL train dataset"
    519 )
    520 # Prepare the datasets
    521 fn_kwargs = {
    522     "prefix": "",
    523     "is_encoder_decoder": self.is_encoder_decoder,
   (...)
    528     "max_prompt_length": self.max_prompt_length,
    529 }

File /opt/conda/lib/python3.10/site-packages/datasets/arrow_dataset.py:593, in transmit_tasks.<locals>.wrapper(*args, **kwargs)
    591     self: "Dataset" = kwargs.pop("self")
    592 # apply actual function
--> 593 out: Union["Dataset", "DatasetDict"] = func(self, *args, **kwargs)
    594 datasets: List["Dataset"] = list(out.values()) if isinstance(out, dict) else [out]
    595 for dataset in datasets:
    596     # Remove task templates if a column mapping of the template is no longer valid

File /opt/conda/lib/python3.10/site-packages/datasets/arrow_dataset.py:558, in transmit_format.<locals>.wrapper(*args, **kwargs)
    551 self_format = {
    552     "type": self._format_type,
    553     "format_kwargs": self._format_kwargs,
    554     "columns": self._format_columns,
    555     "output_all_columns": self._output_all_columns,
    556 }
    557 # apply actual function
--> 558 out: Union["Dataset", "DatasetDict"] = func(self, *args, **kwargs)
    559 datasets: List["Dataset"] = list(out.values()) if isinstance(out, dict) else [out]
    560 # re-apply format to the output

File /opt/conda/lib/python3.10/site-packages/datasets/arrow_dataset.py:3105, in Dataset.map(self, function, with_indices, with_rank, input_columns, batched, batch_size, drop_last_batch, remove_columns, keep_in_memory, load_from_cache_file, cache_file_name, writer_batch_size, features, disable_nullable, fn_kwargs, num_proc, suffix_template, new_fingerprint, desc)
   3099 if transformed_dataset is None:
   3100     with hf_tqdm(
   3101         unit=" examples",
   3102         total=pbar_total,
   3103         desc=desc or "Map",
   3104     ) as pbar:
-> 3105         for rank, done, content in Dataset._map_single(**dataset_kwargs):
   3106             if done:
   3107                 shards_done += 1

File /opt/conda/lib/python3.10/site-packages/datasets/arrow_dataset.py:3482, in Dataset._map_single(shard, function, with_indices, with_rank, input_columns, batched, batch_size, drop_last_batch, remove_columns, keep_in_memory, cache_file_name, writer_batch_size, features, disable_nullable, fn_kwargs, new_fingerprint, rank, offset)
   3478 indices = list(
   3479     range(*(slice(i, i + batch_size).indices(shard.num_rows)))
   3480 )  # Something simpler?
   3481 try:
-> 3482     batch = apply_function_on_filtered_inputs(
   3483         batch,
   3484         indices,
   3485         check_same_num_examples=len(shard.list_indexes()) > 0,
   3486         offset=offset,
   3487     )
   3488 except NumExamplesMismatchError:
   3489     raise DatasetTransformationNotAllowedError(
   3490         "Using `.map` in batched mode on a dataset with attached indexes is allowed only if it doesn't create or remove existing examples. You can first run `.drop_index() to remove your index and then re-add it."
   3491     ) from None

File /opt/conda/lib/python3.10/site-packages/datasets/arrow_dataset.py:3361, in Dataset._map_single.<locals>.apply_function_on_filtered_inputs(pa_inputs, indices, check_same_num_examples, offset)
   3359 if with_rank:
   3360     additional_args += (rank,)
-> 3361 processed_inputs = function(*fn_args, *additional_args, **fn_kwargs)
   3362 if isinstance(processed_inputs, LazyDict):
   3363     processed_inputs = {
   3364         k: v for k, v in processed_inputs.data.items() if k not in processed_inputs.keys_to_format
   3365     }

File /opt/conda/lib/python3.10/site-packages/trl/trainer/kto_trainer.py:72, in _get_kl_dataset(batch)
     70 def _get_kl_dataset(batch: Dict[str, List[Any]]) -> Dict[str, List[Any]]:
     71     """Creates mismatched pairs of prompts and completions for the KL dataset by reversing the order of completions."""
---> 72     batch["answer_input_ids"] = batch["answer_input_ids"][::-1]
     73     batch["answer_attention_mask"] = batch["answer_attention_mask"][::-1]
     74     return batch

ValueError: step must be greater than zero
@mostafamdy
Copy link
Author

mostafamdy commented May 4, 2024

I try to set device because I got this error
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:1 and cuda:0! (when checking argument for argument tensors in method wrapper_CUDA_cat)

@kawine
Copy link
Contributor

kawine commented May 10, 2024

@mostafamdy i think this should have been fixed in the repo but maybe hasn't made its way to the release yet.

can you try installing trl directly from the cloned repo and seeing if it works then?

Copy link

github-actions bot commented Jun 4, 2024

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants