Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TE FP8 integration #1096

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 40 additions & 0 deletions examples/language-modeling/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -574,6 +574,46 @@ python3 ../gaudi_spawn.py --use_deepspeed --world_size 8 run_lora_clm.py \
--flash_attention_causal_mask True
```

- Multi-card finetuning of Llama2-70B with DeepSpeed ZeRO-3 optimization, LoRA and FP8 precision:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we can just modify existing command for FP8 precision (instead of adding a new one).


> The following command requires Habana DeepSpeed 1.13.0 or later.
```bash
PT_HPU_MAX_COMPOUND_OP_SIZE=10 \
python3 ../gaudi_spawn.py --use_deepspeed --world_size 8 run_lora_clm.py \
--model_name_or_path meta-llama/Llama-2-70b-hf \
--deepspeed llama2_ds_zero3_config.json \
--dataset_name tatsu-lab/alpaca \
--bf16 True \
--output_dir ./lora_out \
--num_train_epochs 2 \
--max_seq_len 2048 \
--per_device_train_batch_size 10 \
--per_device_eval_batch_size 1 \
--gradient_checkpointing \
--evaluation_strategy epoch \
--eval_delay 2 \
--save_strategy no \
--learning_rate 0.0018 \
--warmup_ratio 0.03 \
--lr_scheduler_type "cosine" \
--logging_steps 1 \
--dataset_concatenation \
--attn_softmax_bf16 True \
--do_train \
--do_eval \
--use_habana \
--use_lazy_mode \
--pipelining_fwd_bwd \
--throughput_warmup_steps 3 \
--lora_rank 4 \
--lora_target_modules "q_proj" "v_proj" "k_proj" "o_proj" \
--validation_split_percentage 4 \
--use_flash_attention True \
--flash_attention_causal_mask True \
--fp8 True
```

- Multi-card finetuning of Llama2-70B with FSDP and LoRA:

```bash
Expand Down
55 changes: 16 additions & 39 deletions optimum/habana/accelerate/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@
DeepSpeedPlugin,
DistributedDataParallelKwargs,
DistributedType,
FP8RecipeKwargs,
GradientAccumulationPlugin,
GradScalerKwargs,
InitProcessGroupKwargs,
Expand Down Expand Up @@ -73,12 +72,11 @@
from .utils import (
GaudiDistributedType,
GaudiDynamoBackend,
GaudiFP8RecipeKwargs,
GaudiFullyShardedDataParallelPlugin,
GaudiTorchDynamoPlugin,
te_forward_convert,
te_setup_fp8_recipe_handler,
te_wrap_fp8,
te_wrap_fp8_forward_convert,
convert_model,
get_fp8_recipe,
)


Expand Down Expand Up @@ -113,7 +111,6 @@ def __init__(
dynamo_backend: GaudiDynamoBackend | str | None = None,
distribution_strategy: str = None,
force_autocast: bool = False,
fp8_recipe_format: str = None,
):
self.trackers = []
if project_config is not None:
Expand Down Expand Up @@ -181,7 +178,6 @@ def __init__(
self.scaler_handler = None
self.init_handler = None
self.fp8_recipe_handler = None
self.fp8_recipe_format = None
self.autocast_handler = None
if kwargs_handlers is not None:
for handler in kwargs_handlers:
Expand All @@ -203,9 +199,9 @@ def __init__(
raise ValueError("You can only pass one `InitProcessGroupKwargs` in `kwargs_handler`.")
else:
self.init_handler = handler
elif isinstance(handler, FP8RecipeKwargs):
elif isinstance(handler, GaudiFP8RecipeKwargs):
if self.fp8_recipe_handler is not None:
raise ValueError("You can only pass one `FP8RecipeKwargs` in `kwargs_handler`.")
raise ValueError("You can only pass one `GaudiFP8RecipeKwargs` in `kwargs_handler`.")
else:
self.fp8_recipe_handler = handler
elif isinstance(handler, AutocastKwargs):
Expand All @@ -225,8 +221,14 @@ def __init__(
_from_accelerator=True,
**kwargs,
)
if self.fp8_recipe_handler is None and self.state.is_fp8_enabled:
self.fp8_recipe_handler = te_setup_fp8_recipe_handler(self.fp8_recipe_format)

if self.state.is_fp8_enabled:
if self.fp8_recipe_handler is None:
self.fp8_recipe_handler = GaudiFP8RecipeKwargs()
# Handling FP8 recipe creation in init since both `prepare_model` and `_prepare_deepspeed` require it.
# (Base accelerator handles this in `prepare_model` function)
self.fp8_recipe_handler = get_fp8_recipe(self.fp8_recipe_handler)

trackers = filter_trackers(log_with, self.logging_dir)
if len(trackers) < 1 and log_with is not None:
warnings.warn(f"`log_with={log_with}` was passed but no supported trackers are currently installed.")
Expand Down Expand Up @@ -349,31 +351,8 @@ def prepare_model(self, model: torch.nn.Module, device_placement: bool = None, e
model.forward = MethodType(convert_outputs_to_fp32(model.forward.__func__), model)
else:
model.forward = convert_outputs_to_fp32(new_forward)
elif self.state.is_fp8_enabled:
model = te_wrap_fp8_forward_convert(model, self.fp8_recipe_handler)
# FP8 is not supported on Gaudi2 yet
# elif self.mixed_precision == "fp8":
# if not has_transformer_engine_layers(model):
# with torch.no_grad():
# convert_model(model)
# model._converted_to_transformer_engine = True
# model._original_forward = model.forward

# kwargs = self.fp8_recipe_handler.to_kwargs() if self.fp8_recipe_handler is not None else {}
# if "fp8_format" in kwargs:
# kwargs["fp8_format"] = getattr(te_recipe.Format, kwargs["fp8_format"])
# fp8_recipe = te_recipe.DelayedScaling(**kwargs)
# cuda_device_capacity = torch.cuda.get_device_capability()
# fp8_enabled = cuda_device_capacity[0] >= 9 or (
# cuda_device_capacity[0] == 8 and cuda_device_capacity[1] >= 9
# )
# if not fp8_enabled:
# logger.warn(
# f"The current device has compute capability of {cuda_device_capacity} which is "
# "insufficient for FP8 mixed precision training (requires a GPU Hopper/Ada Lovelace "
# "or higher, compute capability of 8.9 or higher). Will use FP16 instead."
# )
# model.forward = fp8_autocast(enabled=fp8_enabled, fp8_recipe=fp8_recipe)(model.forward)
if self.state.is_fp8_enabled:
model = convert_model(model)

if (getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False)) and getattr(
model, "hf_device_map", False
Expand Down Expand Up @@ -469,7 +448,7 @@ def _prepare_deepspeed(self, *args):
result = [
self._prepare_one(obj, first_pass=True)
if isinstance(obj, torch.utils.data.DataLoader)
else te_wrap_fp8(obj)
else convert_model(obj)
if isinstance(obj, torch.nn.Module) and self.state.is_fp8_enabled
else obj
for obj in args
Expand Down Expand Up @@ -685,8 +664,6 @@ def _prepare_deepspeed(self, *args):
result[i] = scheduler
# pointing for deepspeed_engine_wrapped.backward()
self.deepspeed_engine_wrapped = DeepSpeedEngineWrapper(engine)
if self.state.is_fp8_enabled:
model = te_forward_convert(engine, self.fp8_recipe_handler)
self._models.append(engine)
if optimizer is not None:
self._optimizers.append(optimizer)
Expand Down
8 changes: 4 additions & 4 deletions optimum/habana/accelerate/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
from .dataclasses import (
GaudiDistributedType,
GaudiDynamoBackend,
GaudiFP8RecipeKwargs,
GaudiFullyShardedDataParallelPlugin,
GaudiTorchDynamoPlugin,
)
from .transformer_engine import (
te_forward_convert,
te_setup_fp8_recipe_handler,
te_wrap_fp8,
te_wrap_fp8_forward_convert,
FP8ContextWrapper,
convert_model,
get_fp8_recipe,
)
46 changes: 45 additions & 1 deletion optimum/habana/accelerate/utils/dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import torch
from accelerate.utils import FullyShardedDataParallelPlugin
from accelerate.utils.constants import FSDP_BACKWARD_PREFETCH
from accelerate.utils.dataclasses import BaseEnum, TorchDynamoPlugin
from accelerate.utils.dataclasses import BaseEnum, KwargsHandler, TorchDynamoPlugin
from accelerate.utils.environment import str_to_bool


Expand Down Expand Up @@ -144,3 +144,47 @@ def __post_init__(self):
if self.sync_module_states:
device = torch.device("hpu")
self.param_init_fn = lambda x: x.to_empty(device=device, recurse=False)


@dataclass
class GaudiFP8RecipeKwargs(KwargsHandler):
"""
Use this object in your [`Accelerator`] to customize the initialization of the recipe for FP8 mixed precision training with `transformer-engine`.

Adapted from: https://github.com/huggingface/accelerate/blob/v0.27.2/src/accelerate/utils/dataclasses.py#L180

Args:
margin (`int`, *optional*, defaults to 0):
The margin to use for the scaling factor computation.
interval (`int`, *optional*, defaults to 16):
The interval to use for how often the scaling factor is recomputed.
fp8_format (`str`, *optional*, defaults to "HYBRID"):
The format to use for the FP8 recipe. Must be one of `E5M2` or `HYBRID`.
amax_history_len (`int`, *optional*, defaults to 1):
The length of the history to use for the scaling factor computation
amax_compute_algo (`str`, *optional*, defaults to "most_recent"):
The algorithm to use for the scaling factor computation. Must be one of `max` or `most_recent`.
reduce_amax (`bool`, *optional*, defaults to "False"):
By default, if `torch.distributed` is initialized, the `amax` value for FP8
tensors is reduced across the `fp8_group` (specified in the `fp8_autocast`
call). This keeps the amaxes and scaling factors synced across the given
distributed group. If set to `False`, this reduction is skipped and every
HPU maintains local amaxes and scaling factors. To ensure results are
numerically identical across checkpointing boundaries in this case, all
ranks must checkpoint in order to store the local tensors.
"""

margin: int = 0
interval: int = 16
fp8_format: str = "HYBRID"
amax_compute_algo: str = "most_recent"
amax_history_len: int = 1
reduce_amax: bool = False

def __post_init__(self):
self.fp8_format = self.fp8_format.upper()
assert self.fp8_format in ("E5M2", "HYBRID"), "Only E5M2 and HYBRID FP8 formats are currently supported."
assert self.amax_compute_algo in (
"max",
"most_recent",
), "Only max and most_recent `amax_compute_algo` modes are currently supported."
Loading
Loading