Skip to content

junhsss/einorm

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

14 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

einorm

Test PyPI Version

An einops-style generalized normalization layer.

Installation

You need torch >= 1.13 or functorch to be installed:

pip install einorm

Usage

While nn.LayerNorm is great, it only normalizes tensors over the last few dimensions. Although this design is necessary for performance, it often requires dimension rearrangement beforehand. This is where einops semantics come in handy.

The Einorm module can be used as a drop-in replacement for nn.LayerNorm:

from einorm import Einorm

# Equivalent to nn.LayerNorm(1024)
Einorm("b n d", "d", d=1024)

Of course, You can normalize over any dimensions you want:

Einorm("a b c d e", "b d", b=12, d=34)

Caveats: Einorm internally depends on nn.functional.layer_norm anyway. Therefore, if you are not normalizing over the last dimensions, permute and contiguous calls will happen, which may incur slight performance degradation. If you are normalizing over the last dimensions, Einorm skips permute call, so the performance will be identical to nn.LayerNorm.

Grouped Layer Normalization

According to Scaling Vision Transformers to 22 Billion Parameters, normalizing query and key in a head-wise fashion can help stabilize the training dynamics. However, since nn.LayerNorm only calculates the mean and standard-deviation over the last few dimensions and normalizes over those few dimensions using the same statistics, it can be tricky to implement these behaviors.

This can be achieved by providing additional grouping arguments to Einorm:

Einorm("b h n d", "d", "h", h=16, d=64)  # num_heads=16, head_dim=64

Here, Einorm normalizes over the last dimension using per-head statistics and parameters.

It differs from the following, where Einorm normalizes over h and d dimensions using the same statistics and parameters:

Einorm("b h n d", "h d", h=16, d=64)

Einorm leverages functorch.vmap to support this behavior with optimal performance. Therefore, you will need functorch or torch >=1.13, which natively supports vmap.

Without bias

For some reason, nn.LayerNorm does not have an option for disabling bias. You can safely omit bias using bias option.

Einorm("b c h w", "h w", h=256, w=256, bias=False)