Skip to content

Commit

Permalink
Merge pull request #104 from laekov/faster-doc
Browse files Browse the repository at this point in the history
Documents for FasterMoE
  • Loading branch information
laekov committed Apr 2, 2022
2 parents a6a8c4a + 33895a0 commit 59bcec8
Show file tree
Hide file tree
Showing 8 changed files with 167 additions and 1 deletion.
33 changes: 33 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -99,8 +99,20 @@ FastMoE's model parallel requires sophiscated parallel strategies that neither P
Megatron-LM provides. The `fmoe.DistributedGroupedDataParallel` module is
introduced to replace PyTorch's DDP module.

#### Faster Performance Features

From a PPoPP'22 paper, _FasterMoE: modeling and optimizing training of
large-scale dynamic pre-trained models_, we have adopted techniques to make
FastMoE's model parallel much more efficient.

These optimizations are named as **Faster Performance Features**, and can be
enabled via several environment variables. Their usage and constraints are
detailed in [a separate document](doc/fastermoe).

## Citation

For the core FastMoE system.

```
@article{he2021fastmoe,
title={FastMoE: A Fast Mixture-of-Expert Training System},
Expand All @@ -110,6 +122,27 @@ introduced to replace PyTorch's DDP module.
}
```

For the [faster performance features](doc/fastermoe).

```
@inproceedings{he2022fastermoe,
author = {He, Jiaao and Zhai, Jidong and Antunes, Tiago and Wang, Haojie and Luo, Fuwen and Shi, Shangfeng and Li, Qin},
title = {FasterMoE: Modeling and Optimizing Training of Large-Scale Dynamic Pre-Trained Models},
year = {2022},
isbn = {9781450392044},
publisher = {Association for Computing Machinery},
address = {New York, NY, USA},
url = {https://doi.org/10.1145/3503221.3508418},
doi = {10.1145/3503221.3508418},
booktitle = {Proceedings of the 27th ACM SIGPLAN Symposium on Principles and Practice of Parallel Programming},
pages = {120–134},
numpages = {15},
keywords = {parallelism, distributed deep learning, performance modeling},
location = {Seoul, Republic of Korea},
series = {PPoPP '22}
}
```

## Troubleshootings / Discussion

If you have any problem using FastMoE, or you are interested in getting involved in developing FastMoE, feel free to join [our slack channel](https://join.slack.com/t/fastmoe/shared_invite/zt-mz0ai6ol-ggov75D62YsgHfzShw8KYw).
98 changes: 98 additions & 0 deletions doc/fastermoe/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
Boost the Performance by FasterMoE
===

一个中文版见[这篇博客](https://laekov.com.cn/view/181401#howto)

There are three main optimizations in the PPoPP'22 paper _FasterMoE: Modeling
and Optimizing Training of Large-scale Dynamic Pre-trained Models_. Thanks to
the contributions of authors of the article, their optimizations are now
integrated into FastMoE, and can be enabled via switches of environment
variables. These optimizations can greatly increase the training efficiency of
FastMoE.

## Smart Scheduling

Recall that in an MoE layer, two `all-to-all`s are performed with the experts'
computation in-between. In FasterMoE, the `all-to-all`s are broken down using
a _group-wise exchange_ algorithm. And then, the expert can instantly start
its jobs as long as a part of input, e.g. tokens from one other worker, is
ready.

Its effectiveness is revealed in the following timeline. `S` and `R` stand for
the components of the `all-to-all`s, and `C` stands for computation of the
expert.

![](smartsch.png)

In FastMoE, to enable smart scheduling, set the environment variable `
FMOE_FASTER_SCHEDULE_ENABLE` to `1` or `ON`, and it is now by default off.

Please note that there are a few constraints for smart scheduling in the
current version of FastMoE. `num_expert` has to be `1`, which means only one
expert can reside on each worker. The input and output features have to be of
the same length for the experts. This is because the developers of FasterMoE
only implement this on their prototype, and they are looking for the
community's efforts to have other cases supported.

To fine-tune the performance of smart scheduling, the environment variable
`FMOE_FASTER_GROUP_SIZE` stands for the size of worker groups in the
_Group-wise Exchange_ algorithm. In other words, it is the granularity of the
schedule. It should be set to a proper value that balance between pipeline
bubbles and inefficient undersized computation granularity.

## Expert Shadowing

According to observations when training real models, when no limitation is
placed over expert selection, it follows a skew distribution, which means a few
experts are much more popular than others. This introduces significant
performance issue of load imbalance when using FastMoE's model parallel mode.

The authors of FasterMoE proposes the solution that for the hot experts, their
parameters are broadcast to all workers, namely shadows. With the shadows,
computation of the hot experts can be performed locally on all workers,
avoiding the bottleneck of sending so much workload to the workers containing
the hot experts. Besides, a performance predictor, together with a shadow
selection algorithm, is used to determine which experts to be shadowed before
each iteration.

In FastMoE, this feature is enabled by the environment variable
`FMOE_FASTER_SHADOW_ENABLE`. For simplicity, this feature is only available
when smart scheduling is enabled. Besides the constraints of smart scheduling,
this feature requires the experts to be identical in structure, so that
parameters can be copied between experts.

A default shadow selection policy is located at
`fmoe/fastermoe/shadow_policy.py`. If you want to alter the policy, please code
there and re-install FastMoE. For the default policy, we assume that the
experts are two-layer MLPs. A few parameters of the policy can be specified by
the following environment variables for better effectiveness of the shadowing
mechanism.

* `FMOE_FASTER_GLBPLC_NETBW` is the bandwidth of the interconnection between
workers, measured by `GBps`.
* `FMOE_FASTER_GLBPLC_GPUTP` is the GeMM throughput of the GPUs, measured by
`FLOPs`, e.g. `13e12` for NVIDIA V100 PCIe GPUs using fp32.
* `FMOE_FASTER_GLBPLC_ALPHA` is the fraction of the activation length in the
middle of the MLP to the input and output feature length, commonly seen to be
`2` or `4` in transformers.
* `FMOE_FASTER_GLBPLC_DMODEL` is the feature length of input and output of the
experts. This parameter can be set automatically by FastMoE.

## Topology-aware Gate

The two optimizations above do not change the behavior of the model, while this
one does. To reduce network congestion when training in distributed system
with hierarchical network topology, e.g. many GPUs in each of many nodes, the
number of samples transmitted through the slower upper-level network is
limited. The overfilling tokens select experts within the same lower-level
network to reduce the communication overhead.

The example topology-aware gate is implemented as `FasterGate` among FastMoE's
gates. However, note that it may influence the accuracy of the model. And for
different training hardware, different topology-aware gates shall be designed
according to the specific case.

The environment variable `FMOE_TOPO_GPUS_PER_NODE` represents number of GPUs in
each local network, e.g. each node. And `FMOE_TOPO_OUTGOING_FRACTION` controls
the fraction of tokens that are allowed to be sent across the upper-level
network.
Binary file added doc/fastermoe/smartsch.png
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
9 changes: 9 additions & 0 deletions doc/readme-cn.md
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,15 @@ FastMoE 的模型并行模式需要专门的并行策略, 而 PyTorch 和 Megatr
都不支持这样的策略. 因此, 需要使用 `fmoe.DistributedGroupedDataParallel`
模块来代替 PyTorch 的 DDP 模块.

### 如何训练得更快

在 PPoPP'22 会议上有一篇论文: _FasterMoE: modeling and optimizing training of
large-scale dynamic pre-trained models_. 我们将文中的技术集成到了 FastMoE 系统中,
从而提升其模型并行的效率.

这些新特性被命名为 **Faster Performance Features**, 并通过一些环境变量来控制是否
启用它们. 详见[这篇单独的文档](doc/fastermoe).

## 答疑 / 讨论

如果您在使用 FastMoE 的过程中有任何疑问, 或您有兴趣参与 FastMoE 的相关工作,
Expand Down
16 changes: 16 additions & 0 deletions doc/release-note.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,19 @@
## v1.0.0

### FasterMoE

* The new performance boosting features in the PPoPP'22 paper FasterMoE, detailed in the document.
* Expert Shadowing.
* Smart Scheduling.
* Topology-aware gate.

### Bug fixes

* Transformer-XL examples.
* Compatibility to PyTorch versions.
* Megatron-LM documents.
* GShardGate.

## v0.3.0

### FMoE core
Expand Down
2 changes: 2 additions & 0 deletions fmoe/fastermoe/schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,8 @@ def stash_fn(params, idx):
out = _local_gather(local_output_buf, pos_g, out_batch_size,
maybe_overlap=False)

# gib and local_input_buf are necessary, because ctx.gibs are created
# based on their memory
variables = (pos_s, pos_g, local_expert_count, global_expert_count,
stored_models, gib, local_input_buf)

Expand Down
3 changes: 2 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
'Tiago Antunes',
'Jinjun Peng',
'Qin Li',
'Mingshu Zhai'
]

is_rocm_pytorch = False
Expand All @@ -37,7 +38,7 @@
if __name__ == '__main__':
setuptools.setup(
name='fastmoe',
version='0.3.0',
version='1.0.0',
description='An efficient Mixture-of-Experts system for PyTorch',
author=', '.join(authors),
author_email='[email protected]',
Expand Down
7 changes: 7 additions & 0 deletions tests/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
FastMoE test
===

To run unit test, directly run `pytest` in this directory.

`test.sh` is a wrapper script to execute single tests without pytest for
debugging purpose.

0 comments on commit 59bcec8

Please sign in to comment.