Skip to content

Commit

Permalink
Import TE module conditionally
Browse files Browse the repository at this point in the history
  • Loading branch information
SanjuCSudhakaran committed Jun 25, 2024
1 parent 8b9af70 commit f496c56
Showing 1 changed file with 16 additions and 7 deletions.
23 changes: 16 additions & 7 deletions optimum/habana/accelerate/utils/transformer_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,23 @@
import torch


try:
import habana_frameworks.torch.hpex.experimental.transformer_engine as te
from habana_frameworks.torch.hpex.experimental.transformer_engine.distributed import activation_checkpointing
has_transformer_engine = False

has_transformer_engine = True
except ImportError:
has_transformer_engine = False

def import_te():
global te, has_transformer_engine
try:
import habana_frameworks.torch.hpex.experimental.transformer_engine as te

has_transformer_engine = True

except ImportError:
has_transformer_engine = False


def is_fp8_available():
if not has_transformer_engine:
import_te()
return has_transformer_engine


Expand Down Expand Up @@ -101,6 +108,8 @@ def get_fp8_recipe(fp8_recipe_handler):
Creates transformer engine FP8 recipe object.
Adapted from: https://github.com/huggingface/accelerate/blob/v0.27.2/src/accelerate/accelerator.py#L1309
"""
if not is_fp8_available():
raise ImportError("Using `get_fp8_recipe` requires transformer_engine to be installed.")
kwargs = fp8_recipe_handler.to_dict() if fp8_recipe_handler is not None else {}
if "fp8_format" in kwargs:
kwargs["fp8_format"] = getattr(te.recipe.Format, kwargs["fp8_format"])
Expand Down Expand Up @@ -137,7 +146,7 @@ def _gradient_checkpointing_wrap(func, *args, **kwargs):
below wraps this first argument with `transformer_engine`'s `activation_checkpointing` context.
"""
_args = list(args)
_args[0] = activation_checkpointing()(_args[0])
_args[0] = te.distributed.activation_checkpointing()(_args[0])
args = tuple(_args)

return func(*args, **kwargs)
Expand Down

0 comments on commit f496c56

Please sign in to comment.