Skip to content

Commit

Permalink
Merge branch 'main' into llama_adapter
Browse files Browse the repository at this point in the history
  • Loading branch information
sywangyi committed May 22, 2024
2 parents a022e12 + 6be3e99 commit 820d120
Show file tree
Hide file tree
Showing 55 changed files with 1,046 additions and 589 deletions.
3 changes: 3 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,9 @@ slow_tests_trl: test_installs
python -m pip install peft==0.7.0
python -m pytest tests/test_trl.py -v -s -k "test_calculate_loss"

slow_tests_object_segmentation: test_installs
python -m pytest tests/test_object_segmentation.py

# Check if examples are up to date with the Transformers library
example_diff_tests: test_installs
python -m pytest tests/test_examples_match_transformers.py
Expand Down
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,7 @@ The following model architectures, tasks and device distributions have been vali
| ESMFold | | <div style="text-align:left"><li>Single card</li></div> | <li>[protein folding](https://github.com/huggingface/optimum-habana/tree/main/examples/protein-folding)</li> |
| Blip | | <div style="text-align:left"><li>Single card</li></div> | <li>[visual question answering](https://github.com/huggingface/optimum-habana/tree/main/examples/visual-question-answering)</li><li>[image to text](https://github.com/huggingface/optimum-habana/tree/main/examples/image-to-text)</li> |
| OWLViT | | <div style="text-align:left"><li>Single card</li></div> | <li>[zero shot object detection](https://github.com/huggingface/optimum-habana/tree/main/examples/zero-shot-object-detection)</li> |
| ClipSeg | | <div style="text-align:left"><li>Single card</li></div> | <li>[object segmentation](https://github.com/huggingface/optimum-habana/tree/main/examples/object-segementation)</li> |

</div>

Expand Down Expand Up @@ -247,4 +248,4 @@ Please refer to the Intel Gaudi AI Accelerator official [installation guide](htt

## Development

Check the [contributor guide](https://github.com/huggingface/optimum/blob/main/CONTRIBUTING.md) for instructions.
Check the [contributor guide](https://github.com/huggingface/optimum/blob/main/CONTRIBUTING.md) for instructions.
4 changes: 3 additions & 1 deletion docs/source/index.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,8 @@ In the tables below, ✅ means single-card, multi-card and DeepSpeed have all be
| ESMFold | | <div style="text-align:left"><li>Single card</li></div> | <li>[protein folding](https://github.com/huggingface/optimum-habana/tree/main/examples/protein-folding)</li> |
| Blip | | <div style="text-align:left"><li>Single card</li></div> | <li>[visual question answering](https://github.com/huggingface/optimum-habana/tree/main/examples/visual-question-answering)</li><li>[image to text](https://github.com/huggingface/optimum-habana/tree/main/examples/image-to-text)</li> |
| OWLViT | | <div style="text-align:left"><li>Single card</li></div> | <li>[zero shot object detection](https://github.com/huggingface/optimum-habana/tree/main/examples/zero-shot-object-detection)</li> |
| ClipSeg | | <div style="text-align:left"><li>Single card</li></div> | <li>[object segmentation](https://github.com/huggingface/optimum-habana/tree/main/examples/object-segementation)</li> |


- Diffusers

Expand Down Expand Up @@ -113,4 +115,4 @@ Besides, [this page](https://github.com/huggingface/optimum-habana/tree/main/exa
<p class="text-gray-700">Technical descriptions of how the Habana classes and methods of 🤗 Optimum Habana work.</p>
</a>
</div>
</div>
</div>
2 changes: 1 addition & 1 deletion docs/source/package_reference/gaudi_config.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ gaudi_config = GaudiConfig.from_pretrained(
gaudi_config_name,
cache_dir=model_args.cache_dir,
revision=model_args.model_revision,
use_auth_token=True if model_args.use_auth_token else None,
token=model_args.token,
)
```
and pass it to the trainer with the `gaudi_config` argument.
Expand Down
18 changes: 1 addition & 17 deletions examples/audio-classification/run_audio_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
import logging
import os
import sys
import warnings
from dataclasses import dataclass, field
from random import randint
from typing import Optional
Expand Down Expand Up @@ -164,12 +163,6 @@ class ModelArguments:
)
},
)
use_auth_token: bool = field(
default=None,
metadata={
"help": "The `use_auth_token` argument is deprecated and will be removed in v4.34. Please use `token` instead."
},
)
trust_remote_code: bool = field(
default=False,
metadata={
Expand Down Expand Up @@ -199,15 +192,6 @@ def main():
else:
model_args, data_args, training_args = parser.parse_args_into_dataclasses()

if model_args.use_auth_token is not None:
warnings.warn(
"The `use_auth_token` argument is deprecated and will be removed in v4.34. Please use `token` instead.",
FutureWarning,
)
if model_args.token is not None:
raise ValueError("`token` and `use_auth_token` are both specified. Please set only the argument `token`.")
model_args.token = model_args.use_auth_token

# Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
# information sent is the one passed as arguments along with your Python/PyTorch versions.
send_example_telemetry("run_audio_classification", model_args, data_args)
Expand All @@ -233,7 +217,7 @@ def main():
training_args.gaudi_config_name,
cache_dir=model_args.cache_dir,
revision=model_args.model_revision,
use_auth_token=True if model_args.use_auth_token else None,
token=model_args.token,
)

# Log on each process the small summary:
Expand Down
18 changes: 1 addition & 17 deletions examples/contrastive-image-text/run_bridgetower.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
import logging
import os
import sys
import warnings
from dataclasses import dataclass, field
from typing import Optional

Expand Down Expand Up @@ -99,12 +98,6 @@ class ModelArguments:
)
},
)
use_auth_token: bool = field(
default=None,
metadata={
"help": "The `use_auth_token` argument is deprecated and will be removed in v4.34. Please use `token` instead."
},
)
trust_remote_code: bool = field(
default=False,
metadata={
Expand Down Expand Up @@ -265,15 +258,6 @@ def main():
else:
model_args, data_args, training_args = parser.parse_args_into_dataclasses()

if model_args.use_auth_token is not None:
warnings.warn(
"The `use_auth_token` argument is deprecated and will be removed in v4.34. Please use `token` instead.",
FutureWarning,
)
if model_args.token is not None:
raise ValueError("`token` and `use_auth_token` are both specified. Please set only the argument `token`.")
model_args.token = model_args.use_auth_token

# Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
# information sent is the one passed as arguments along with your Python/PyTorch versions.
send_example_telemetry("run_bridgetower", model_args, data_args)
Expand All @@ -299,7 +283,7 @@ def main():
training_args.gaudi_config_name,
cache_dir=model_args.cache_dir,
revision=model_args.model_revision,
use_auth_token=True if model_args.use_auth_token else None,
token=model_args.token,
)

# Log on each process the small summary:
Expand Down
18 changes: 1 addition & 17 deletions examples/contrastive-image-text/run_clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
import logging
import os
import sys
import warnings
from dataclasses import dataclass, field
from typing import Optional

Expand Down Expand Up @@ -104,12 +103,6 @@ class ModelArguments:
)
},
)
use_auth_token: bool = field(
default=None,
metadata={
"help": "The `use_auth_token` argument is deprecated and will be removed in v4.34. Please use `token` instead."
},
)
trust_remote_code: bool = field(
default=False,
metadata={
Expand Down Expand Up @@ -263,15 +256,6 @@ def main():
else:
model_args, data_args, training_args = parser.parse_args_into_dataclasses()

if model_args.use_auth_token is not None:
warnings.warn(
"The `use_auth_token` argument is deprecated and will be removed in v4.34. Please use `token` instead.",
FutureWarning,
)
if model_args.token is not None:
raise ValueError("`token` and `use_auth_token` are both specified. Please set only the argument `token`.")
model_args.token = model_args.use_auth_token

# Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
# information sent is the one passed as arguments along with your Python/PyTorch versions.
send_example_telemetry("run_clip", model_args, data_args)
Expand All @@ -297,7 +281,7 @@ def main():
training_args.gaudi_config_name,
cache_dir=model_args.cache_dir,
revision=model_args.model_revision,
use_auth_token=True if model_args.use_auth_token else None,
token=model_args.token,
)

# Log on each process the small summary:
Expand Down
18 changes: 1 addition & 17 deletions examples/image-classification/run_image_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
import logging
import os
import sys
import warnings
from dataclasses import dataclass, field
from typing import Optional

Expand Down Expand Up @@ -169,12 +168,6 @@ class ModelArguments:
)
},
)
use_auth_token: bool = field(
default=None,
metadata={
"help": "The `use_auth_token` argument is deprecated and will be removed in v4.34. Please use `token` instead."
},
)
trust_remote_code: bool = field(
default=False,
metadata={
Expand Down Expand Up @@ -204,15 +197,6 @@ def main():
else:
model_args, data_args, training_args = parser.parse_args_into_dataclasses()

if model_args.use_auth_token is not None:
warnings.warn(
"The `use_auth_token` argument is deprecated and will be removed in v4.34. Please use `token` instead.",
FutureWarning,
)
if model_args.token is not None:
raise ValueError("`token` and `use_auth_token` are both specified. Please set only the argument `token`.")
model_args.token = model_args.use_auth_token

# Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
# information sent is the one passed as arguments along with your Python/PyTorch versions.
send_example_telemetry("run_image_classification", model_args, data_args)
Expand All @@ -238,7 +222,7 @@ def main():
training_args.gaudi_config_name,
cache_dir=model_args.cache_dir,
revision=model_args.model_revision,
use_auth_token=True if model_args.use_auth_token else None,
token=model_args.token,
)

# Log on each process the small summary:
Expand Down
31 changes: 30 additions & 1 deletion examples/language-modeling/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -386,7 +386,36 @@ python3 run_lora_clm.py \
--validation_split_percentage 4 \
--adam_epsilon 1e-08
```

- Single-card finetuning of Mistral-7B-Instruct-v0.2 with fp8:
```bash
python3 run_lora_clm.py \
--model_name_or_path mistralai/Mistral-7B-Instruct-v0.2\
--dataset_name tatsu-lab/alpaca \
--fp8 True \
--output_dir ./model_lora_mistral \
--num_train_epochs 3 \
--per_device_train_batch_size 8 \
--evaluation_strategy "no" \
--save_strategy "no" \
--learning_rate 4e-4 \
--warmup_ratio 0.03 \
--lr_scheduler_type "constant" \
--max_grad_norm 0.3 \
--logging_steps 1 \
--do_train \
--use_habana \
--use_lazy_mode \
--throughput_warmup_steps 5 \
--lora_rank=8 \
--lora_target_modules "v_proj" "q_proj" \
--lora_alpha=16 \
--lora_dropout=0.05 \
--dataset_concatenation \
--max_seq_length 512 \
--low_cpu_mem_usage True \
--validation_split_percentage 4 \
--adam_epsilon 1e-08
```
- Single-card finetuning of Falcon-40B:
```bash
LOWER_LIST=ops_bf16.txt python3 run_lora_clm.py \
Expand Down
24 changes: 7 additions & 17 deletions examples/language-modeling/run_clm.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
import math
import os
import sys
import warnings
from dataclasses import dataclass, field
from itertools import chain
from typing import Optional
Expand Down Expand Up @@ -128,12 +127,6 @@ class ModelArguments:
)
},
)
use_auth_token: bool = field(
default=None,
metadata={
"help": "The `use_auth_token` argument is deprecated and will be removed in v4.34. Please use `token` instead."
},
)
trust_remote_code: bool = field(
default=False,
metadata={
Expand Down Expand Up @@ -275,15 +268,6 @@ def main():
else:
model_args, data_args, training_args = parser.parse_args_into_dataclasses()

if model_args.use_auth_token is not None:
warnings.warn(
"The `use_auth_token` argument is deprecated and will be removed in v4.34. Please use `token` instead.",
FutureWarning,
)
if model_args.token is not None:
raise ValueError("`token` and `use_auth_token` are both specified. Please set only the argument `token`.")
model_args.token = model_args.use_auth_token

# Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
# information sent is the one passed as arguments along with your Python/PyTorch versions.
send_example_telemetry("run_clm", model_args, data_args)
Expand All @@ -310,7 +294,7 @@ def main():
training_args.gaudi_config_name,
cache_dir=model_args.cache_dir,
revision=model_args.model_revision,
use_auth_token=True if model_args.use_auth_token else None,
token=model_args.token,
)

# Log on each process the small summary:
Expand Down Expand Up @@ -592,9 +576,15 @@ def group_texts(examples):
)

if training_args.do_train:

def tensor_mapper(x):
return {i: torch.tensor(x[i], dtype=torch.int32) for i in x}

if "train" not in tokenized_datasets:
raise ValueError("--do_train requires a train dataset")
train_dataset = lm_datasets["train"]
if training_args.resume_from_checkpoint is not None and training_args.resume_from_checkpoint != "":
train_dataset = train_dataset.map(tensor_mapper)
if data_args.max_train_samples is not None:
max_train_samples = min(len(train_dataset), data_args.max_train_samples)
train_dataset = train_dataset.select(range(max_train_samples))
Expand Down
Loading

0 comments on commit 820d120

Please sign in to comment.