Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Issues lowering attention module to edge #3672

Open
ismaeelbashir03 opened this issue May 19, 2024 · 0 comments
Open

Issues lowering attention module to edge #3672

ismaeelbashir03 opened this issue May 19, 2024 · 0 comments
Labels
module: exir Issues related to Export IR

Comments

@ismaeelbashir03
Copy link

ismaeelbashir03 commented May 19, 2024

I am trying to lower an attention module below:

class Attention(torch.nn.Module):
    def __init__(self):
        super(Attention, self).__init__()
    
    def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
        return torch.nn.functional.scaled_dot_product_attention(q, k, v)
    
    def get_eager_model(self) -> torch.nn.Module:
        return self

    def get_example_inputs(self):
        return (torch.randn(1, 10, 5), torch.randn(1, 10, 5), torch.randn(1, 10, 5))
    
    def get_dynamic_shapes(self):
        dim1_q = Dim("Attention_dim1_q", min=MIN_DIM, max=MAX_DIM)
        dim2_q = Dim("Attention_dim2_q", min=MIN_DIM, max=MAX_DIM)

        return {"q": {1: dim1_q, 2: dim2_q}, 
                "k": {1: dim1_q, 2: dim2_q}, 
                "v": {1: dim1_q, 2: dim2_q}}

However, I am receiving the following error when exporting to edge:

raise SpecViolationError(
torch._export.verifier.SpecViolationError: Operator '<function sym_float at 0x10f9540d0>' is not an allowed operator type: (<class 'torch._ops.OpOverload'>, <class 'torch._ops.HigherOrderOperator'>)
Valid builtin ops: [<built-in function getitem>, <built-in function add>, <built-in function mul>, <built-in function sub>, <built-in function truediv>, <built-in function ge>, <built-in function le>, <built-in function gt>, <built-in function lt>, <built-in function eq>, <built-in function ne>, <built-in function floordiv>, <built-in function mod>, <built-in function and_>, <built-in function or_>, <built-in function not_>, <built-in function pow>, <built-in function neg>, <built-in function abs>, <built-in function ceil>, <built-in function floor>]Valid torch functions: (<class 'torch.autograd.grad_mode.set_grad_enabled'>, <function sym_int at 0x10f954160>, <function sym_ite at 0x10f954940>, <function sym_max at 0x10f9541f0>, <function sym_min at 0x10f954280>, <function sym_not at 0x1038c7130>, <function _sym_sqrt at 0x10f9543a0>, <built-in function _set_grad_enabled>)

I am using the following code to export to edge:

def _to_core_aten(
    model: Union[torch.fx.GraphModule, torch.nn.Module],
    example_inputs: Tuple[Value, ...],
    dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any]]] = None,
    verbose=True,
) -> ExportedProgram:
    # post autograd export. eventually this will become .to_core_aten
    if not isinstance(model, torch.fx.GraphModule) and not isinstance(
        model, torch.nn.Module
    ):
        raise ValueError(
            f"Expected passed in model to be an instance of fx.GraphModule, got {type(model)}"
        )
    core_aten_ep = export(model, example_inputs, dynamic_shapes=dynamic_shapes)
    if verbose:
        logging.info(f"Core ATen graph:\n{core_aten_ep.graph}")
    return core_aten_ep


def _core_aten_to_edge(
    core_aten_exir_ep: ExportedProgram,
    edge_constant_methods: Optional[Dict[str, Any]] = None,
    edge_compile_config=None,
    verbose=True,
) -> EdgeProgramManager:
    if not edge_compile_config:
        edge_compile_config = exir.EdgeCompileConfig(
            _check_ir_validity=False,  # quant ops currently break ir verification
        )
    edge_manager: EdgeProgramManager = to_edge(
        core_aten_exir_ep,
        constant_methods=edge_constant_methods,
        compile_config=edge_compile_config,
    )
    if verbose:
        logging.info(f"Exported graph:\n{edge_manager.exported_program().graph}")
    return edge_manager


def export_to_edge(
    model: Union[torch.fx.GraphModule, torch.nn.Module],
    example_inputs: Tuple[Value, ...],
    dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any]]] = None,
    edge_constant_methods: Optional[Dict[str, Any]] = None,
    edge_compile_config=_EDGE_COMPILE_CONFIG,
    verbose=True,
) -> EdgeProgramManager:
    core_aten_ep = _to_core_aten(model, example_inputs, dynamic_shapes, verbose=verbose)
    return _core_aten_to_edge(
        core_aten_ep, edge_constant_methods, edge_compile_config, verbose=verbose
    )

model = model.eval()
model = torch._export.capture_pre_autograd_graph(model, example_inputs, dynamic_shapes=dynamic_shapes)


edge = export_to_edge(
    model,
    example_inputs,
    dynamic_shapes=dynamic_shapes,
    edge_compile_config=EdgeCompileConfig(
        _check_ir_validity=False if args.quantize else True,
    ),
)
@ismaeelbashir03 ismaeelbashir03 changed the title Lowering attention module Issues lowering attention module to edge May 19, 2024
@mergennachin mergennachin added the module: exir Issues related to Export IR label May 20, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: exir Issues related to Export IR
Projects
None yet
Development

No branches or pull requests

2 participants