Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

mlp_only_layers is more flexible than decoder_sparse_step #30552

Merged
merged 11 commits into from
May 10, 2024
6 changes: 6 additions & 0 deletions src/transformers/models/qwen2_moe/configuration_qwen2_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,10 @@ class Qwen2MoeConfig(PretrainedConfig):
allow the model to output the auxiliary loss, including load balancing loss and router z-loss.
router_aux_loss_coef (`float`, *optional*, defaults to 0.001):
The aux loss factor for the total loss.
mlp_only_layers (`List[int]`, *optional*, defaults to `[]`):
Indicate which layers use Qwen2MoeMLP rather than Qwen2MoeSparseMoeBlock
The list contains layer index, from 0 to num_layers-1 if we have num_layers layers
If `mlp_only_layers` is empty, `decoder_sparse_step` is used to determine the sparsity.

```python
>>> from transformers import Qwen2MoeModel, Qwen2MoeConfig
Expand Down Expand Up @@ -139,6 +143,7 @@ def __init__(
norm_topk_prob=False,
output_router_logits=False,
router_aux_loss_coef=0.001,
mlp_only_layers=[],
eigen2017 marked this conversation as resolved.
Show resolved Hide resolved
**kwargs,
):
self.vocab_size = vocab_size
Expand Down Expand Up @@ -168,6 +173,7 @@ def __init__(
self.norm_topk_prob = norm_topk_prob
self.output_router_logits = output_router_logits
self.router_aux_loss_coef = router_aux_loss_coef
self.mlp_only_layers = mlp_only_layers

super().__init__(
tie_word_embeddings=tie_word_embeddings,
Expand Down
7 changes: 5 additions & 2 deletions src/transformers/models/qwen2_moe/modeling_qwen2_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" PyTorch Qwen2MoE model."""
"""PyTorch Qwen2MoE model."""

import inspect
import math
import warnings
Expand Down Expand Up @@ -866,7 +867,9 @@ def __init__(self, config: Qwen2MoeConfig, layer_idx: int):

self.self_attn = QWEN2MOE_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx)

if config.num_experts > 0 and (layer_idx + 1) % config.decoder_sparse_step == 0:
if (layer_idx not in config.mlp_only_layers) and (
config.num_experts > 0 and (layer_idx + 1) % config.decoder_sparse_step == 0
):
self.mlp = Qwen2MoeSparseMoeBlock(config)
else:
self.mlp = Qwen2MoeMLP(config, intermediate_size=config.intermediate_size)
Expand Down