Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

retnet traning config #64

Open
hanlinxuy opened this issue Sep 3, 2023 · 6 comments
Open

retnet traning config #64

hanlinxuy opened this issue Sep 3, 2023 · 6 comments
Assignees

Comments

@hanlinxuy
Copy link

hanlinxuy commented Sep 3, 2023

Hello,

I have followed the training configuration introduced here (#52) with retnet_medium architecture. I have some questions that I would appreciate if anyone could answer them.

The first is about the initialization. From the RETNET paper https://arxiv.org/abs/2307.08621, I saw that parameters were initialized following deepnet. So I am wondering why in the RetNetConfig it is set to False, and where should I set it True? (https://github.com/microsoft/torchscale/blob/main/torchscale/architecture/config.py#L239)

If I simply add "--deepnorm" in command line, this will be activated together with subln (https://github.com/microsoft/torchscale/blob/main/torchscale/architecture/config.py#L240), then I found the output of each layers getting larger and larger with the layer id increasing.

The second is about the vocabulary. I am newer to fairseq so I am not sure how to deal with a large dataset via fairseq_preprocess. I am trying to use MINIPILE while the dict.txt has 32309612 lines. It seems too large so I am wondering if there is some official recommendation for this part.

The third is about --share-decoder-input-output-embed, Is it recommended? I am sorry if I missed in paper.

Thank you guys in advance:)

@donglixp donglixp self-assigned this Sep 12, 2023
@simran-arora
Copy link

Hi, Is there any resolution to this question for the initialization and recommended training configs to reproduce the paper results? I am also seeing some instability with the default configs.
Thanks so much!

@sunyt32
Copy link
Contributor

sunyt32 commented Oct 11, 2023

  1. --share-decoder-input-output-embed saves model parameters especially when the model size is small. The performance is almost the same. We activate it in our experiment.
  2. Don't activate --subln or --deepnorm. The current initialization is good enough.
  3. The training instability comes from Linear bias and eps in LayerNorm. In our experiment, we set bias=False and eps=1e-5. Besides, RMSNorm is helpful for stability so we make a modification.

@donglixp
Copy link
Contributor

Hi, Is there any resolution to this question for the initialization and recommended training configs to reproduce the paper results? I am also seeing some instability with the default configs. Thanks so much!

@simran-arora @hanlinxuy

  • The LN eps was modified from 1e-6 to 1e-5 as in the commit d1fefe9

  • The RMSNorm is also used in the commit 5c89ffb , so that the effects of LN_eps can be eliminated

  • For the RetNet implementation, the initialization principle proposed in DeepNet has been integrated. So the arguments --subln or --deepnorm should not be added.

  • Removing bias also improves training stability.

The latest released code has considered the above points.

@simran-arora
Copy link

Thanks so much! I had used layer norm and did not set the bias=False. Will try switching these.

Adding the explicit deepnorm initialization also improved stability for my downstream runs, but I will try using the recommended techniques instead.

@sunyt32
Copy link
Contributor

sunyt32 commented Oct 12, 2023

@simran-arora It's better to set bias=False both in layer norm and nn.Linear.

Besides, would you mind sharing the training details with us? e.g. corpus, model size, and hyper-parameters. We'd like to see the instability setting.

@hanlinxuy
Copy link
Author

Hi, Is there any resolution to this question for the initialization and recommended training configs to reproduce the paper results? I am also seeing some instability with the default configs. Thanks so much!

@simran-arora @hanlinxuy

  • The LN eps was modified from 1e-6 to 1e-5 as in the commit d1fefe9
  • The RMSNorm is also used in the commit 5c89ffb , so that the effects of LN_eps can be eliminated
  • For the RetNet implementation, the initialization principle proposed in DeepNet has been integrated. So the arguments --subln or --deepnorm should not be added.
  • Removing bias also improves training stability.

The latest released code has considered the above points.

Thank you very much! Will try later with those new information!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants