Skip to content

Commit

Permalink
Add a get_inputs_kernel_bias method to the DenseGeneral layer.
Browse files Browse the repository at this point in the history
This method returns the inputs, kernel, bias, dot_dimension_nums, and out_shape_sequence_int. This is useful for users who want to use the DenseGeneral layer with custom kernel containing the dot_general step.

PiperOrigin-RevId: 629615743
  • Loading branch information
Flax Team committed May 1, 2024
1 parent ae5d66d commit 7ee6ba6
Showing 1 changed file with 44 additions and 8 deletions.
52 changes: 44 additions & 8 deletions flax/linen/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,15 +114,28 @@ class DenseGeneral(Module):
dot_general: Optional[DotGeneralT] = None
dot_general_cls: Any = None

_dot_dimension_numbers = tuple[
tuple[Sequence[int], Sequence[int]], tuple[Sequence[int], Sequence[int]]
]

@compact
def __call__(self, inputs: Array) -> Array:
"""Applies a linear transformation to the inputs along multiple dimensions.
def get_inputs_kernel_bias(
self, inputs: Array
) -> tuple[Array, Array, Array, Any, _dot_dimension_numbers, Sequence[int]]:
"""Gets the inputs, kernel and bias arrays in user-defined dtypes.
Args:
inputs: The nd-array to be transformed.
Returns:
The transformed input.
The tuple contains [inputs, kernel, bias, dot_general, dot_dimension_nums,
out_shape_sequence_int].
dot_dimension_numbers: The dot dimension number for dot_general -- a tuple
of tuples of sequences of ints of the form ``((lhs_contracting_dims,
rhs_contracting_dims), (lhs_batch_dims, rhs_batch_dims))``
output_shape: a sequence of ints representing the output shape.
"""
features = _canonicalize_tuple(self.features)
axis = _canonicalize_tuple(self.axis)
Expand Down Expand Up @@ -194,16 +207,39 @@ def bias_init_wrap(rng, shape, dtype=jnp.float32):
dot_general = self.dot_general
else:
dot_general = lax.dot_general

return (
inputs,
kernel,
bias,
dot_general,
((axis, contract_ind), (batch_dims, batch_ind)),
expanded_batch_shape + features,
)

def __call__(self, inputs: Array) -> Array:
"""Applies a linear transformation to the inputs along multiple dimensions.
Args:
inputs: The nd-array to be transformed.
Returns:
The transformed input.
"""
inputs, kernel, bias, dot_general, dot_dimension_nums, out_shape = (
self.get_inputs_kernel_bias(inputs)
)

out = dot_general(
inputs,
kernel,
((axis, contract_ind), (batch_dims, batch_ind)),
precision=self.precision,
inputs,
kernel,
dot_dimension_nums,
precision=self.precision,
)
# dot_general output has shape [batch_dims/group_dims] + [feature_dims]
if self.use_bias:
# expand bias shape to broadcast bias over batch dims.
bias = jnp.reshape(bias, expanded_batch_shape + features)
bias = jnp.reshape(bias, out_shape)
out += bias
return out

Expand Down

0 comments on commit 7ee6ba6

Please sign in to comment.