From 8b9af70792a40c093287303ddacae1f1153aab8d Mon Sep 17 00:00:00 2001 From: Sanju C Sudhakaran Date: Tue, 25 Jun 2024 09:08:18 +0300 Subject: [PATCH 1/3] Integrate TE-FP8 to optimum-habana Signed-off-by: Sanju C Sudhakaran --- examples/language-modeling/README.md | 40 +++++ optimum/habana/accelerate/accelerator.py | 55 ++---- optimum/habana/accelerate/utils/__init__.py | 8 +- .../habana/accelerate/utils/dataclasses.py | 46 ++++- .../accelerate/utils/transformer_engine.py | 162 ++++++++++-------- optimum/habana/transformers/trainer.py | 24 ++- optimum/habana/transformers/training_args.py | 8 - 7 files changed, 216 insertions(+), 127 deletions(-) diff --git a/examples/language-modeling/README.md b/examples/language-modeling/README.md index 8c77f0e81..ef7f1d249 100644 --- a/examples/language-modeling/README.md +++ b/examples/language-modeling/README.md @@ -574,6 +574,46 @@ python3 ../gaudi_spawn.py --use_deepspeed --world_size 8 run_lora_clm.py \ --flash_attention_causal_mask True ``` +- Multi-card finetuning of Llama2-70B with DeepSpeed ZeRO-3 optimization, LoRA and FP8 precision: + + > The following command requires Habana DeepSpeed 1.13.0 or later. + +```bash +PT_HPU_MAX_COMPOUND_OP_SIZE=10 \ +python3 ../gaudi_spawn.py --use_deepspeed --world_size 8 run_lora_clm.py \ + --model_name_or_path meta-llama/Llama-2-70b-hf \ + --deepspeed llama2_ds_zero3_config.json \ + --dataset_name tatsu-lab/alpaca \ + --bf16 True \ + --output_dir ./lora_out \ + --num_train_epochs 2 \ + --max_seq_len 2048 \ + --per_device_train_batch_size 10 \ + --per_device_eval_batch_size 10 \ + --gradient_checkpointing \ + --evaluation_strategy epoch \ + --eval_delay 2 \ + --save_strategy no \ + --learning_rate 0.0018 \ + --warmup_ratio 0.03 \ + --lr_scheduler_type "cosine" \ + --logging_steps 1 \ + --dataset_concatenation \ + --attn_softmax_bf16 True \ + --do_train \ + --do_eval \ + --use_habana \ + --use_lazy_mode \ + --pipelining_fwd_bwd \ + --throughput_warmup_steps 3 \ + --lora_rank 4 \ + --lora_target_modules "q_proj" "v_proj" "k_proj" "o_proj" \ + --validation_split_percentage 10 \ + --use_flash_attention True \ + --flash_attention_causal_mask True \ + --fp8 True +``` + - Multi-card finetuning of Llama2-70B with FSDP and LoRA: ```bash diff --git a/optimum/habana/accelerate/accelerator.py b/optimum/habana/accelerate/accelerator.py index fcfd47c31..d826908ee 100644 --- a/optimum/habana/accelerate/accelerator.py +++ b/optimum/habana/accelerate/accelerator.py @@ -37,7 +37,6 @@ DeepSpeedPlugin, DistributedDataParallelKwargs, DistributedType, - FP8RecipeKwargs, GradientAccumulationPlugin, GradScalerKwargs, InitProcessGroupKwargs, @@ -73,12 +72,11 @@ from .utils import ( GaudiDistributedType, GaudiDynamoBackend, + GaudiFP8RecipeKwargs, GaudiFullyShardedDataParallelPlugin, GaudiTorchDynamoPlugin, - te_forward_convert, - te_setup_fp8_recipe_handler, - te_wrap_fp8, - te_wrap_fp8_forward_convert, + convert_model, + get_fp8_recipe, ) @@ -113,7 +111,6 @@ def __init__( dynamo_backend: GaudiDynamoBackend | str | None = None, distribution_strategy: str = None, force_autocast: bool = False, - fp8_recipe_format: str = None, ): self.trackers = [] if project_config is not None: @@ -181,7 +178,6 @@ def __init__( self.scaler_handler = None self.init_handler = None self.fp8_recipe_handler = None - self.fp8_recipe_format = None self.autocast_handler = None if kwargs_handlers is not None: for handler in kwargs_handlers: @@ -203,9 +199,9 @@ def __init__( raise ValueError("You can only pass one `InitProcessGroupKwargs` in `kwargs_handler`.") else: self.init_handler = handler - elif isinstance(handler, FP8RecipeKwargs): + elif isinstance(handler, GaudiFP8RecipeKwargs): if self.fp8_recipe_handler is not None: - raise ValueError("You can only pass one `FP8RecipeKwargs` in `kwargs_handler`.") + raise ValueError("You can only pass one `GaudiFP8RecipeKwargs` in `kwargs_handler`.") else: self.fp8_recipe_handler = handler elif isinstance(handler, AutocastKwargs): @@ -225,8 +221,14 @@ def __init__( _from_accelerator=True, **kwargs, ) - if self.fp8_recipe_handler is None and self.state.is_fp8_enabled: - self.fp8_recipe_handler = te_setup_fp8_recipe_handler(self.fp8_recipe_format) + + if self.state.is_fp8_enabled: + if self.fp8_recipe_handler is None: + self.fp8_recipe_handler = GaudiFP8RecipeKwargs() + # Handling FP8 recipe creation in init since both `prepare_model` and `_prepare_deepspeed` require it. + # (Base accelerator handles this in `prepare_model` function) + self.fp8_recipe_handler = get_fp8_recipe(self.fp8_recipe_handler) + trackers = filter_trackers(log_with, self.logging_dir) if len(trackers) < 1 and log_with is not None: warnings.warn(f"`log_with={log_with}` was passed but no supported trackers are currently installed.") @@ -349,31 +351,8 @@ def prepare_model(self, model: torch.nn.Module, device_placement: bool = None, e model.forward = MethodType(convert_outputs_to_fp32(model.forward.__func__), model) else: model.forward = convert_outputs_to_fp32(new_forward) - elif self.state.is_fp8_enabled: - model = te_wrap_fp8_forward_convert(model, self.fp8_recipe_handler) - # FP8 is not supported on Gaudi2 yet - # elif self.mixed_precision == "fp8": - # if not has_transformer_engine_layers(model): - # with torch.no_grad(): - # convert_model(model) - # model._converted_to_transformer_engine = True - # model._original_forward = model.forward - - # kwargs = self.fp8_recipe_handler.to_kwargs() if self.fp8_recipe_handler is not None else {} - # if "fp8_format" in kwargs: - # kwargs["fp8_format"] = getattr(te_recipe.Format, kwargs["fp8_format"]) - # fp8_recipe = te_recipe.DelayedScaling(**kwargs) - # cuda_device_capacity = torch.cuda.get_device_capability() - # fp8_enabled = cuda_device_capacity[0] >= 9 or ( - # cuda_device_capacity[0] == 8 and cuda_device_capacity[1] >= 9 - # ) - # if not fp8_enabled: - # logger.warn( - # f"The current device has compute capability of {cuda_device_capacity} which is " - # "insufficient for FP8 mixed precision training (requires a GPU Hopper/Ada Lovelace " - # "or higher, compute capability of 8.9 or higher). Will use FP16 instead." - # ) - # model.forward = fp8_autocast(enabled=fp8_enabled, fp8_recipe=fp8_recipe)(model.forward) + if self.state.is_fp8_enabled: + model = convert_model(model) if (getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False)) and getattr( model, "hf_device_map", False @@ -469,7 +448,7 @@ def _prepare_deepspeed(self, *args): result = [ self._prepare_one(obj, first_pass=True) if isinstance(obj, torch.utils.data.DataLoader) - else te_wrap_fp8(obj) + else convert_model(obj) if isinstance(obj, torch.nn.Module) and self.state.is_fp8_enabled else obj for obj in args @@ -685,8 +664,6 @@ def _prepare_deepspeed(self, *args): result[i] = scheduler # pointing for deepspeed_engine_wrapped.backward() self.deepspeed_engine_wrapped = DeepSpeedEngineWrapper(engine) - if self.state.is_fp8_enabled: - model = te_forward_convert(engine, self.fp8_recipe_handler) self._models.append(engine) if optimizer is not None: self._optimizers.append(optimizer) diff --git a/optimum/habana/accelerate/utils/__init__.py b/optimum/habana/accelerate/utils/__init__.py index 37181aee9..ee25954b9 100755 --- a/optimum/habana/accelerate/utils/__init__.py +++ b/optimum/habana/accelerate/utils/__init__.py @@ -1,12 +1,12 @@ from .dataclasses import ( GaudiDistributedType, GaudiDynamoBackend, + GaudiFP8RecipeKwargs, GaudiFullyShardedDataParallelPlugin, GaudiTorchDynamoPlugin, ) from .transformer_engine import ( - te_forward_convert, - te_setup_fp8_recipe_handler, - te_wrap_fp8, - te_wrap_fp8_forward_convert, + FP8ContextWrapper, + convert_model, + get_fp8_recipe, ) diff --git a/optimum/habana/accelerate/utils/dataclasses.py b/optimum/habana/accelerate/utils/dataclasses.py index eaf5f0915..fce2c06c8 100644 --- a/optimum/habana/accelerate/utils/dataclasses.py +++ b/optimum/habana/accelerate/utils/dataclasses.py @@ -20,7 +20,7 @@ import torch from accelerate.utils import FullyShardedDataParallelPlugin from accelerate.utils.constants import FSDP_BACKWARD_PREFETCH -from accelerate.utils.dataclasses import BaseEnum, TorchDynamoPlugin +from accelerate.utils.dataclasses import BaseEnum, KwargsHandler, TorchDynamoPlugin from accelerate.utils.environment import str_to_bool @@ -144,3 +144,47 @@ def __post_init__(self): if self.sync_module_states: device = torch.device("hpu") self.param_init_fn = lambda x: x.to_empty(device=device, recurse=False) + + +@dataclass +class GaudiFP8RecipeKwargs(KwargsHandler): + """ + Use this object in your [`Accelerator`] to customize the initialization of the recipe for FP8 mixed precision training with `transformer-engine`. + + Adapted from: https://github.com/huggingface/accelerate/blob/v0.27.2/src/accelerate/utils/dataclasses.py#L180 + + Args: + margin (`int`, *optional*, defaults to 0): + The margin to use for the scaling factor computation. + interval (`int`, *optional*, defaults to 16): + The interval to use for how often the scaling factor is recomputed. + fp8_format (`str`, *optional*, defaults to "HYBRID"): + The format to use for the FP8 recipe. Must be one of `E5M2` or `HYBRID`. + amax_history_len (`int`, *optional*, defaults to 1): + The length of the history to use for the scaling factor computation + amax_compute_algo (`str`, *optional*, defaults to "most_recent"): + The algorithm to use for the scaling factor computation. Must be one of `max` or `most_recent`. + reduce_amax (`bool`, *optional*, defaults to "False"): + By default, if `torch.distributed` is initialized, the `amax` value for FP8 + tensors is reduced across the `fp8_group` (specified in the `fp8_autocast` + call). This keeps the amaxes and scaling factors synced across the given + distributed group. If set to `False`, this reduction is skipped and every + HPU maintains local amaxes and scaling factors. To ensure results are + numerically identical across checkpointing boundaries in this case, all + ranks must checkpoint in order to store the local tensors. + """ + + margin: int = 0 + interval: int = 16 + fp8_format: str = "HYBRID" + amax_compute_algo: str = "most_recent" + amax_history_len: int = 1 + reduce_amax: bool = False + + def __post_init__(self): + self.fp8_format = self.fp8_format.upper() + assert self.fp8_format in ("E5M2", "HYBRID"), "Only E5M2 and HYBRID FP8 formats are currently supported." + assert self.amax_compute_algo in ( + "max", + "most_recent", + ), "Only max and most_recent `amax_compute_algo` modes are currently supported." diff --git a/optimum/habana/accelerate/utils/transformer_engine.py b/optimum/habana/accelerate/utils/transformer_engine.py index 07aa71aa6..500cdf900 100755 --- a/optimum/habana/accelerate/utils/transformer_engine.py +++ b/optimum/habana/accelerate/utils/transformer_engine.py @@ -13,61 +13,34 @@ # See the License for the specific language governing permissions and # limitations under the License. -import torch - - -te = None +import functools +import torch -class SwitchableForwardMaker: - def __init__(self, module, fp8_recipe_handler): - self.original_forward = module.forward - self.fp8_forward = te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe_handler)(module.forward) - self.module = module - module.forward = self.forward - - def forward(self, *args, **kwargs): - if self.module.training: - return self.fp8_forward(*args, **kwargs) - else: - return self.original_forward(*args, **kwargs) - @staticmethod - def convert(module, fp8_recipe_handler): - SwitchableForwardMaker(module, fp8_recipe_handler) +try: + import habana_frameworks.torch.hpex.experimental.transformer_engine as te + from habana_frameworks.torch.hpex.experimental.transformer_engine.distributed import activation_checkpointing + has_transformer_engine = True +except ImportError: + has_transformer_engine = False -def get_te(): - global te - if te is None: - try: - import habana_frameworks.torch.hpex.experimental.transformer_engine as te - te = te - except ImportError: - te = None +def is_fp8_available(): + return has_transformer_engine -def convert_model(model, to_transformer_engine=True, _convert_linear=True): +def _convert_model(model, to_transformer_engine=True, _convert_linear=True): """ - Recursively converts the linear and layernorm layers of a model to their `transformers_engine` counterpart. + Recursively converts the linear layer of a model to their `transformers_engine` counterpart. """ - if te is None: + if not is_fp8_available(): raise ImportError("Using `convert_model` requires transformer_engine to be installed.") - from peft.tuners.lora.layer import Linear as PEFTLinear - - from optimum.habana.peft.layer import LoRALinear - for name, module in model.named_children(): - if type(module) == PEFTLinear and to_transformer_engine and _convert_linear: - LoRALinear.replace_forward(module) - if ( - isinstance(module, torch.nn.Linear) - and not type(module) == PEFTLinear - and to_transformer_engine - and _convert_linear - ): + if isinstance(module, torch.nn.Linear) and to_transformer_engine and _convert_linear: has_bias = module.bias is not None + # Initializing TE linear without weights and biases and shallow copying them from the original module. te_module = te.Linear( module.in_features, module.out_features, @@ -81,11 +54,14 @@ def convert_model(model, to_transformer_engine=True, _convert_linear=True): te_module.bias = module.bias setattr(model, name, te_module) - elif isinstance(module, te.Linear) and not to_transformer_engine and _convert_linear: has_bias = module.bias is not None new_module = torch.nn.Linear( - module.in_features, module.out_features, bias=has_bias, params_dtype=module.weight.dtype + module.in_features, + module.out_features, + bias=has_bias, + dtype=module.weight.dtype, + device=module.weight.device, ) new_module.weight.copy_(module.weight) if has_bias: @@ -93,14 +69,14 @@ def convert_model(model, to_transformer_engine=True, _convert_linear=True): setattr(model, name, new_module) else: - convert_model(module, to_transformer_engine=to_transformer_engine, _convert_linear=_convert_linear) + _convert_model(module, to_transformer_engine=to_transformer_engine, _convert_linear=_convert_linear) def has_transformer_engine_layers(model): """ Returns whether a given model has some `transformer_engine` layer or not. """ - if te is None: + if not is_fp8_available(): raise ImportError("Using `has_transformer_engine_layers` requires transformer_engine to be installed.") for m in model.modules(): if isinstance(m, (te.Linear)): @@ -108,38 +84,78 @@ def has_transformer_engine_layers(model): return False -def te_setup_fp8_recipe_handler(fp8_recipe_format): - get_te() - fp8_format = te.recipe.Format.E5M2 - if fp8_recipe_format == "E4M3": - fp8_format = te.recipe.Format.E4M3 - elif fp8_recipe_format == "HYBRID": - fp8_format = te.recipe.Format.HYBRID - fp8_recipe_handler = te.recipe.DelayedScaling( - fp8_format=fp8_format, - margin=0, - interval=16, - amax_history_len=1, - amax_compute_algo="most_recent", - reduce_amax=False, - ) - fp8_recipe_handler.backend = "TE" - return fp8_recipe_handler - - -def te_wrap_fp8(model): +def convert_model(model): + """ + Converts torch.nn.Linear modules to `transformers_engine` Linear modules. + Adapted from: https://github.com/huggingface/accelerate/blob/v0.27.2/src/accelerate/accelerator.py#L1303 + """ if not has_transformer_engine_layers(model): with torch.no_grad(): - convert_model(model) + _convert_model(model) model._converted_to_transformer_engine = True return model -def te_wrap_fp8_forward_convert(model, fp8_recipe_handler): - model = te_wrap_fp8(model) - SwitchableForwardMaker.convert(model, fp8_recipe_handler) - return model +def get_fp8_recipe(fp8_recipe_handler): + """ + Creates transformer engine FP8 recipe object. + Adapted from: https://github.com/huggingface/accelerate/blob/v0.27.2/src/accelerate/accelerator.py#L1309 + """ + kwargs = fp8_recipe_handler.to_dict() if fp8_recipe_handler is not None else {} + if "fp8_format" in kwargs: + kwargs["fp8_format"] = getattr(te.recipe.Format, kwargs["fp8_format"]) + fp8_recipe_handler = te.recipe.DelayedScaling(**kwargs) + fp8_recipe_handler.backend = "TE" + return fp8_recipe_handler + + +class FP8ContextWrapper: + """ + Helper class for FP8 context related operations. + """ + + def __init__(self, ctx, fp8_recipe): + self.ctx = ctx + self.fp8_ctx = self.create_fp8_context(fp8_recipe) + def __enter__(self): + self.ctx.__enter__() + self.fp8_ctx.__enter__() + + def __exit__(self, exc_type, exc_value, exc_traceback): + self.fp8_ctx.__exit__(exc_type, exc_value, exc_traceback) + self.ctx.__exit__(exc_type, exc_value, exc_traceback) + + @staticmethod + def create_fp8_context(fp8_recipe): + return te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe) + + @staticmethod + def _gradient_checkpointing_wrap(func, *args, **kwargs): + """ + `_gradient_checkpointing_func` always takes the function to be recomputed as the first argument. The function + below wraps this first argument with `transformer_engine`'s `activation_checkpointing` context. + """ + _args = list(args) + _args[0] = activation_checkpointing()(_args[0]) + args = tuple(_args) + + return func(*args, **kwargs) + + @staticmethod + def gradient_checkpointing_wrap(model): + """ + Wrap `_gradient_checkpointing_func` in the model with `transformer_engine`'s `activation_checkpointing` context. + This context is used to signal the `transformer_engine` modules whether they have been called with activation checkpointing enabled or not. + """ + if hasattr(model, "gradient_checkpointing") and model.gradient_checkpointing: + model._gradient_checkpointing_func = functools.partial( + FP8ContextWrapper._gradient_checkpointing_wrap, model._gradient_checkpointing_func + ) + return -def te_forward_convert(model, fp8_recipe_handler): - SwitchableForwardMaker.convert(model, fp8_recipe_handler) + for module in model.modules(): + if hasattr(module, "gradient_checkpointing") and module.gradient_checkpointing: + module._gradient_checkpointing_func = functools.partial( + FP8ContextWrapper._gradient_checkpointing_wrap, module._gradient_checkpointing_func + ) diff --git a/optimum/habana/transformers/trainer.py b/optimum/habana/transformers/trainer.py index 5fa006873..5966790d4 100644 --- a/optimum/habana/transformers/trainer.py +++ b/optimum/habana/transformers/trainer.py @@ -96,7 +96,7 @@ from optimum.utils import logging from ..accelerate import GaudiAccelerator -from ..accelerate.utils import GaudiDistributedType +from ..accelerate.utils import FP8ContextWrapper, GaudiDistributedType from ..utils import ( HabanaProfile, get_hpu_memory_stats, @@ -692,6 +692,10 @@ def hpu_deepspeed_checkpointing(function, *checkpoint_args): transformers.modeling_utils.checkpoint = lazy_mode_checkpointing self.model.gradient_checkpointing_enable(gradient_checkpointing_kwargs=gradient_checkpointing_kwargs) + + # Wrap `_gradient_checkpointing_func` in the model with `transformer_engine` `activation_checkpointing` context. + if self.accelerator.state.is_fp8_enabled: + FP8ContextWrapper.gradient_checkpointing_wrap(self.model) else: # Hack because `RegressionModel` in test_trainer.py doesn't have `gradient_checkpointing_disable` if hasattr(self.model, "gradient_checkpointing_disable"): @@ -1518,6 +1522,11 @@ def autocast_smart_context_manager(self, cache_enabled: Optional[bool] = True): else: ctx_manager = contextlib.nullcontext() + # Merge autocast context and `fp8_autocast` context if FP8 is enabled. + # Currently FP8 is enabled only for training. + if self.accelerator.state.is_fp8_enabled and self.model.training: + ctx_manager = FP8ContextWrapper(ctx_manager, self.accelerator.fp8_recipe_handler) + return ctx_manager def training_step(self, model: torch.nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]]) -> torch.Tensor: @@ -1551,6 +1560,9 @@ def training_step(self, model: torch.nn.Module, inputs: Dict[str, Union[torch.Te self.htcore.mark_step() if _is_peft_model(self.model) and self.model.peft_type == PeftType.ADALORA: + assert not ( + self.accelerator.state.is_fp8_enabled and self.args.gradient_checkpointing + ), "FP8 precision with gradient_checkpointing is currently not supported with PeftType.ADALORA" if self.is_deepspeed_enabled and not is_deepspeed_zero3_enabled(): self.accelerator.deepspeed_engine_wrapped.engine.backward(loss) self.model.base_model.update_and_allocate(self.state.global_step) @@ -1559,7 +1571,15 @@ def training_step(self, model: torch.nn.Module, inputs: Dict[str, Union[torch.Te self.accelerator.backward(loss) self.model.base_model.update_and_allocate(self.state.global_step) else: - self.accelerator.backward(loss) + if self.accelerator.state.is_fp8_enabled and self.args.gradient_checkpointing: + # The precision used in backward pass should be same as the one used in forward pass. + # However when training with gradient_checkpointing and FP8 precision, recompute forward + # in backward does not automatically run with FP8 precision. In order to handle this, + # the backward is run in `fp8_autocast` context + with FP8ContextWrapper.create_fp8_context(self.accelerator.fp8_recipe_handler): + self.accelerator.backward(loss) + else: + self.accelerator.backward(loss) return loss.detach() / self.args.gradient_accumulation_steps def save_model(self, output_dir: Optional[str] = None, _internal_call: bool = False): diff --git a/optimum/habana/transformers/training_args.py b/optimum/habana/transformers/training_args.py index eff3c1ede..c280e4688 100644 --- a/optimum/habana/transformers/training_args.py +++ b/optimum/habana/transformers/training_args.py @@ -291,14 +291,6 @@ class GaudiTrainingArguments(TrainingArguments): metadata={"help": "Whether to use fp8 for training."}, ) - fp8_recipe_format: Optional[str] = field( - default="E5M2", - metadata={ - "help": "Which fp8 format to use for fp8 training.", - "choices": ["E5M2", "E4M3", "HYBRID"], - }, - ) - def __post_init__(self): if self.use_hpu_graphs: warnings.warn( From f496c569009d2e9d05e853c3bb814c6db1808dde Mon Sep 17 00:00:00 2001 From: Sanju C Sudhakaran Date: Thu, 13 Jun 2024 09:04:31 +0530 Subject: [PATCH 2/3] Import TE module conditionally --- .../accelerate/utils/transformer_engine.py | 23 +++++++++++++------ 1 file changed, 16 insertions(+), 7 deletions(-) diff --git a/optimum/habana/accelerate/utils/transformer_engine.py b/optimum/habana/accelerate/utils/transformer_engine.py index 500cdf900..823da61d5 100755 --- a/optimum/habana/accelerate/utils/transformer_engine.py +++ b/optimum/habana/accelerate/utils/transformer_engine.py @@ -18,16 +18,23 @@ import torch -try: - import habana_frameworks.torch.hpex.experimental.transformer_engine as te - from habana_frameworks.torch.hpex.experimental.transformer_engine.distributed import activation_checkpointing +has_transformer_engine = False - has_transformer_engine = True -except ImportError: - has_transformer_engine = False + +def import_te(): + global te, has_transformer_engine + try: + import habana_frameworks.torch.hpex.experimental.transformer_engine as te + + has_transformer_engine = True + + except ImportError: + has_transformer_engine = False def is_fp8_available(): + if not has_transformer_engine: + import_te() return has_transformer_engine @@ -101,6 +108,8 @@ def get_fp8_recipe(fp8_recipe_handler): Creates transformer engine FP8 recipe object. Adapted from: https://github.com/huggingface/accelerate/blob/v0.27.2/src/accelerate/accelerator.py#L1309 """ + if not is_fp8_available(): + raise ImportError("Using `get_fp8_recipe` requires transformer_engine to be installed.") kwargs = fp8_recipe_handler.to_dict() if fp8_recipe_handler is not None else {} if "fp8_format" in kwargs: kwargs["fp8_format"] = getattr(te.recipe.Format, kwargs["fp8_format"]) @@ -137,7 +146,7 @@ def _gradient_checkpointing_wrap(func, *args, **kwargs): below wraps this first argument with `transformer_engine`'s `activation_checkpointing` context. """ _args = list(args) - _args[0] = activation_checkpointing()(_args[0]) + _args[0] = te.distributed.activation_checkpointing()(_args[0]) args = tuple(_args) return func(*args, **kwargs) From c9345ef957f41556987927b98acf826b3058d009 Mon Sep 17 00:00:00 2001 From: Sanju C Sudhakaran Date: Wed, 26 Jun 2024 08:25:30 +0300 Subject: [PATCH 3/3] Remove the special handling of LoRA layers in TE layer conversion --- examples/language-modeling/README.md | 4 +-- optimum/habana/peft/layer.py | 54 ---------------------------- 2 files changed, 2 insertions(+), 56 deletions(-) diff --git a/examples/language-modeling/README.md b/examples/language-modeling/README.md index ef7f1d249..a6fa898c0 100644 --- a/examples/language-modeling/README.md +++ b/examples/language-modeling/README.md @@ -589,7 +589,7 @@ python3 ../gaudi_spawn.py --use_deepspeed --world_size 8 run_lora_clm.py \ --num_train_epochs 2 \ --max_seq_len 2048 \ --per_device_train_batch_size 10 \ - --per_device_eval_batch_size 10 \ + --per_device_eval_batch_size 1 \ --gradient_checkpointing \ --evaluation_strategy epoch \ --eval_delay 2 \ @@ -608,7 +608,7 @@ python3 ../gaudi_spawn.py --use_deepspeed --world_size 8 run_lora_clm.py \ --throughput_warmup_steps 3 \ --lora_rank 4 \ --lora_target_modules "q_proj" "v_proj" "k_proj" "o_proj" \ - --validation_split_percentage 10 \ + --validation_split_percentage 4 \ --use_flash_attention True \ --flash_attention_causal_mask True \ --fp8 True diff --git a/optimum/habana/peft/layer.py b/optimum/habana/peft/layer.py index dacafb115..b61eebada 100755 --- a/optimum/habana/peft/layer.py +++ b/optimum/habana/peft/layer.py @@ -1,7 +1,6 @@ from typing import Any import torch -from peft.utils.other import transpose def GaudiAdaloraLayerSVDLinearForward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor: @@ -32,56 +31,3 @@ def GaudiAdaloraLayerSVDLinearForward(self, x: torch.Tensor, *args: Any, **kwarg result += (dropout(x) @ (lora_A * lora_E).T @ lora_B.T) * (scaling / ranknum) return result - - -class LoRALinear: - def __init__(self, module): - has_bias = module.bias is not None - self.module = module - import habana_frameworks.torch.hpex.experimental.transformer_engine as te - - self.module.te_linear = te.Linear( - module.in_features, - module.out_features, - bias=has_bias, - params_dtype=module.weight.dtype, - skip_weight_param_allocation=True, - ) - - def _linear(self, input: torch.Tensor) -> torch.Tensor: - # TODO: to check if bias is removed from lora linear - if hasattr(self.module, "bias"): - return self.module.te_linear( - input, transpose(self.module.weight, self.module.fan_in_fan_out), bias=self.module.bias - ) - else: - return self.module.te_linear(input, transpose(self.module.weight, self.module.fan_in_fan_out)) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - previous_dtype = x.dtype - - if self.module.disable_adapters: - if self.module.merged: - self.module.unmerge() - result = self._linear(x) - elif self.module.merged: - result = self._linear(x) - else: - result = self._linear(x) - for active_adapter in self.module.active_adapters: - if active_adapter not in self.module.lora_A.keys(): - continue - lora_A = self.module.lora_A[active_adapter] - lora_B = self.module.lora_B[active_adapter] - dropout = self.module.lora_dropout[active_adapter] - scaling = self.module.scaling[active_adapter] - x = x.to(lora_A.weight.dtype) - result = result.clone() + lora_B(lora_A(dropout(x))) * scaling - - result = result.to(previous_dtype) - return result - - @staticmethod - def replace_forward(module): - lora_linear = LoRALinear(module) - module.forward = lora_linear.forward