From f496c569009d2e9d05e853c3bb814c6db1808dde Mon Sep 17 00:00:00 2001 From: Sanju C Sudhakaran Date: Thu, 13 Jun 2024 09:04:31 +0530 Subject: [PATCH] Import TE module conditionally --- .../accelerate/utils/transformer_engine.py | 23 +++++++++++++------ 1 file changed, 16 insertions(+), 7 deletions(-) diff --git a/optimum/habana/accelerate/utils/transformer_engine.py b/optimum/habana/accelerate/utils/transformer_engine.py index 500cdf900..823da61d5 100755 --- a/optimum/habana/accelerate/utils/transformer_engine.py +++ b/optimum/habana/accelerate/utils/transformer_engine.py @@ -18,16 +18,23 @@ import torch -try: - import habana_frameworks.torch.hpex.experimental.transformer_engine as te - from habana_frameworks.torch.hpex.experimental.transformer_engine.distributed import activation_checkpointing +has_transformer_engine = False - has_transformer_engine = True -except ImportError: - has_transformer_engine = False + +def import_te(): + global te, has_transformer_engine + try: + import habana_frameworks.torch.hpex.experimental.transformer_engine as te + + has_transformer_engine = True + + except ImportError: + has_transformer_engine = False def is_fp8_available(): + if not has_transformer_engine: + import_te() return has_transformer_engine @@ -101,6 +108,8 @@ def get_fp8_recipe(fp8_recipe_handler): Creates transformer engine FP8 recipe object. Adapted from: https://github.com/huggingface/accelerate/blob/v0.27.2/src/accelerate/accelerator.py#L1309 """ + if not is_fp8_available(): + raise ImportError("Using `get_fp8_recipe` requires transformer_engine to be installed.") kwargs = fp8_recipe_handler.to_dict() if fp8_recipe_handler is not None else {} if "fp8_format" in kwargs: kwargs["fp8_format"] = getattr(te.recipe.Format, kwargs["fp8_format"]) @@ -137,7 +146,7 @@ def _gradient_checkpointing_wrap(func, *args, **kwargs): below wraps this first argument with `transformer_engine`'s `activation_checkpointing` context. """ _args = list(args) - _args[0] = activation_checkpointing()(_args[0]) + _args[0] = te.distributed.activation_checkpointing()(_args[0]) args = tuple(_args) return func(*args, **kwargs)