Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extend multimodal/speech_llm with lhotse, t5 and bestow supports #9169

Merged
merged 470 commits into from
Jun 7, 2024
Merged
Show file tree
Hide file tree
Changes from 250 commits
Commits
Show all changes
470 commits
Select commit Hold shift + click to select a range
0c7b399
Fixes
pzelasko Dec 8, 2023
3b282aa
Docs fix
pzelasko Dec 8, 2023
5034d77
Add support for custom NeMo fields in Lhotse-NeMo adapters (attach to…
pzelasko Dec 11, 2023
31b1973
Add support for custom NeMo fields in Lhotse-NeMo adapters (attach to…
pzelasko Dec 11, 2023
0880d44
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 11, 2023
30ce202
Merge branch 'feature/lhotse-integration' of https://github.com/pzela…
pzelasko Dec 11, 2023
5f11fdb
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 11, 2023
02f0f0a
support distributed_fused_adam
zhehuaichen Dec 11, 2023
378af7c
support distributed_fused_adam
zhehuaichen Dec 13, 2023
35412fb
Add support for sharded NeMo manifest files
pzelasko Dec 13, 2023
1f2acde
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 13, 2023
5b58e69
support megatron_amp_O2
zhehuaichen Dec 13, 2023
37cabcc
Support heterogeneous sampling rates in non tarred NeMo manifests
pzelasko Dec 13, 2023
1270609
migrate to PTL2.0
stevehuang52 Dec 14, 2023
6df13f1
clean up
stevehuang52 Dec 14, 2023
fa0493a
update manifest util
stevehuang52 Dec 14, 2023
22e3bff
Support multiple tokenizer/parser types, aggregate tokenizers, and cu…
pzelasko Dec 15, 2023
60cdea6
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 15, 2023
96020e6
fix
pzelasko Dec 15, 2023
949fbbc
Merge branch 'feature/lhotse-integration' of https://github.com/pzela…
pzelasko Dec 15, 2023
5630ad4
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 15, 2023
fc13c42
fix
pzelasko Dec 15, 2023
cee170f
Merge branch 'feature/lhotse-integration' of https://github.com/pzela…
pzelasko Dec 15, 2023
6eb16fa
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 15, 2023
fa73e72
agg and normal tokenizers actually work
pzelasko Dec 15, 2023
0fe901c
Merge branch 'feature/lhotse-integration' of https://github.com/pzela…
pzelasko Dec 15, 2023
c014c85
Support weights for NeMo tarred manifests
pzelasko Dec 15, 2023
034f55f
Temporarily hardcoded pnc stripping/lowercasing
pzelasko Dec 15, 2023
1526dcb
fix
pzelasko Dec 15, 2023
94fcb1f
Merge branch 'feature/lhotse-integration' into canary
pzelasko Dec 15, 2023
0e589a5
make pnc hack configurable from the config and disabled by default
pzelasko Dec 15, 2023
31e5bf7
fix the hack
pzelasko Dec 15, 2023
2eb320a
migrate to ptl2.1 to support multiple dataloaders
stevehuang52 Dec 15, 2023
8437bdd
support encoder overwrite
zhehuaichen Dec 18, 2023
3fc0db6
update misc
stevehuang52 Dec 19, 2023
4f947ce
fix eval and clean up
stevehuang52 Dec 20, 2023
4b70343
Merge branch 'heh/modular_speechlm_tmp' into modular_speechllm_clean_…
zhehuaichen Dec 20, 2023
63131d0
support add_sep for perception model
zhehuaichen Dec 22, 2023
318f784
fix https://github.com/Lightning-AI/pytorch-lightning/issues/18803
zhehuaichen Dec 22, 2023
59c0d4d
add_bos
zhehuaichen Dec 25, 2023
72cbc94
Transformer decoder with conditioning for canary (#8091)
krishnacpuvvada Dec 27, 2023
dc45efc
Option to limit the number of open streams (#8095)
pzelasko Dec 29, 2023
401efed
audio signal support in multi
zhehuaichen Dec 30, 2023
12487c2
update asr evaluator
stevehuang52 Jan 2, 2024
0535fdd
fix from
zhehuaichen Jan 2, 2024
03255be
transcribe fn for Canary models (#8110)
krishnacpuvvada Jan 2, 2024
6148198
update for evaluation
stevehuang52 Jan 2, 2024
6a237c9
update for eval
stevehuang52 Jan 2, 2024
15d162b
update for evaluation
stevehuang52 Jan 2, 2024
faeaac3
fix bleu
stevehuang52 Jan 2, 2024
9fc3ae5
fix typo
stevehuang52 Jan 3, 2024
1e7cfd6
Add missing audio_filepath validation for Canary (#8119)
pzelasko Jan 3, 2024
8b04025
add default concat_sampling_probabilities
zhehuaichen Jan 3, 2024
abaa3b0
Merge branch 'modular_speechllm_clean_cross_ptn2.1' into canary_speec…
zhehuaichen Jan 3, 2024
0ddedd4
support lhotse dataset in speechllm
zhehuaichen Jan 4, 2024
467fb24
bypass get_iterator_k_split
zhehuaichen Jan 5, 2024
1b169fa
tmp fix
zhehuaichen Jan 5, 2024
5dc8660
try to use fixed batch with megatron
zhehuaichen Jan 5, 2024
c0f5f0c
add batch logging
zhehuaichen Jan 6, 2024
6d6be8a
support unfrozen llm
zhehuaichen Jan 11, 2024
209f752
Create README.md
stevehuang52 Jan 12, 2024
01dd0d6
Update README.md
stevehuang52 Jan 12, 2024
528d1bf
Update README.md
stevehuang52 Jan 12, 2024
d94f9dd
update
stevehuang52 Jan 12, 2024
dbad4ac
rename
stevehuang52 Jan 12, 2024
02e91d3
add llama prompt template
zhehuaichen Jan 12, 2024
73736ad
update and refactor
stevehuang52 Jan 15, 2024
0d5c6da
support sample alpha
zhehuaichen Jan 16, 2024
18b27d1
support lhotse validation set and canary pretrained ckpt with pseudo …
zhehuaichen Jan 17, 2024
c12044a
make sure backward compatibility
zhehuaichen Jan 17, 2024
76be5ce
remove pad
zhehuaichen Jan 18, 2024
efa862a
make sure asr_model is frozen
zhehuaichen Jan 18, 2024
3ca65df
support greedy decoding
zhehuaichen Jan 18, 2024
4e17ced
valid on lhotse
zhehuaichen Jan 18, 2024
a7aeddc
fix multi dataloader in val case for lhotse SALM; add default data
zhehuaichen Jan 19, 2024
4ccc271
remove the bruteforce _keep_special_tokens implementation
zhehuaichen Jan 19, 2024
6817833
decoding_ratio and convert_canary_prompt_to_text support
zhehuaichen Jan 19, 2024
f437770
canary_tokens_augment_ratio
zhehuaichen Jan 19, 2024
bf51ad4
debug
zhehuaichen Jan 19, 2024
1aa6fe8
bug fix
zhehuaichen Jan 19, 2024
a38488d
fix lhotse based eval of llama canary model
zhehuaichen Jan 22, 2024
430c5bf
support some overwrite for eval
zhehuaichen Jan 23, 2024
a7dcafe
support zero shot prompt in training
zhehuaichen Jan 23, 2024
62a0cf0
support cross attention based SALM
zhehuaichen Jan 24, 2024
f17e8fa
Merge branch 'canary_speechllm1' of github.com:zhehuaichen/NeMo into …
zhehuaichen Jan 24, 2024
42d74bf
support cross attention based SALM
zhehuaichen Jan 24, 2024
50571f2
fix for batch train/valid of cross
zhehuaichen Jan 24, 2024
7160027
Merge branch 'canary_speechllm1_cross' of github.com:zhehuaichen/NeMo…
zhehuaichen Jan 24, 2024
2fe48aa
support learnable gate and plotting
zhehuaichen Jan 24, 2024
39545b7
support using pseudo label in prompt rather than cross att
zhehuaichen Jan 25, 2024
2cec2f5
bug fix for perception cfg and context tokens shift
zhehuaichen Jan 25, 2024
fa57fb2
DentityConnectorsAdd
zhehuaichen Jan 25, 2024
666aa44
fix ckpt saving
zhehuaichen Jan 25, 2024
fa8e00e
Support RnnGatedCrossAttention
zhehuaichen Jan 26, 2024
175b66e
add include_ffw and fix _optimizer_param_groups for all unfrozen run
zhehuaichen Jan 29, 2024
dcb5084
support grad acc when using bucket
zhehuaichen Feb 1, 2024
6c0a798
support TransformerCrossAttention
zhehuaichen Feb 2, 2024
e9935dc
support ProjectTransformerCrossAttention
zhehuaichen Feb 4, 2024
45932a5
support ++model.use_am_tokenizer ++model.override_vocab_size ++model.…
zhehuaichen Feb 7, 2024
d7cc642
support question set on val without canary
zhehuaichen Feb 12, 2024
eeaad1f
support load_audio_encoder and wip in optim_param_groups
zhehuaichen Feb 13, 2024
2504a0a
minor fix for audio pretrain model init
zhehuaichen Feb 16, 2024
a9478ef
simplify canary_tokens_augment
zhehuaichen Feb 16, 2024
b754b61
use question in the manifest if it exists
zhehuaichen Feb 21, 2024
c94a632
support dataset weighting for non tar
zhehuaichen Feb 21, 2024
94bd346
Update SpeechLLM code (#8475)
stevehuang52 Feb 21, 2024
8afd277
Update README.md
stevehuang52 Feb 21, 2024
78c1e8e
update speechllm (#8486)
stevehuang52 Feb 22, 2024
2e74cd1
clean up
stevehuang52 Feb 22, 2024
95ee03c
for now bypass asr_model init in perception since that causes issues …
zhehuaichen Feb 23, 2024
5ff28a1
update doc and infer
stevehuang52 Feb 23, 2024
80f7439
https://github.com/NVIDIA/NeMo/pull/8464/files
zhehuaichen Feb 23, 2024
e1e825f
update doc
stevehuang52 Feb 23, 2024
99fb448
update doc
stevehuang52 Feb 23, 2024
446c6d9
update doc
stevehuang52 Feb 23, 2024
3d78dd7
update doc
stevehuang52 Feb 23, 2024
70ed539
add a debug script
zhehuaichen Feb 26, 2024
13f03a2
support text-only training and speech and text joint training
zhehuaichen Feb 29, 2024
c0260c6
always require text only data has question field in the data and use it
zhehuaichen Feb 29, 2024
85ba4f6
support prepend_to_exist_question
zhehuaichen Mar 4, 2024
fc185d5
support random_context_prob
zhehuaichen Mar 5, 2024
3247583
apply random_context_prob for w/ and w/o canary
zhehuaichen Mar 5, 2024
268bb70
guard random context
zhehuaichen Mar 5, 2024
9d69f2e
protect the case where answer is empty
zhehuaichen Mar 6, 2024
b812226
fix for ++model.pretrained_canary_model=$ASR_MODEL
zhehuaichen Mar 16, 2024
ae9adf0
support unfreeze_emb
zhehuaichen Mar 17, 2024
0916850
minor update
stevehuang52 Mar 18, 2024
db542b4
fix import
stevehuang52 Mar 18, 2024
fe7214b
clean up
stevehuang52 Mar 18, 2024
b133332
support t5 + lhotse
zhehuaichen Mar 20, 2024
3f5fd1b
add xattn
zhehuaichen Mar 20, 2024
2291706
CrossAttendModularizedAudioT5Model is WIP and replaced by audio_promp…
zhehuaichen Mar 20, 2024
e6cdebf
support distributed adam
zhehuaichen Mar 20, 2024
916324e
clean up
stevehuang52 Mar 20, 2024
98f86b5
fix pretrained info
stevehuang52 Mar 20, 2024
2876d41
support with_distributed_adam
zhehuaichen Mar 21, 2024
cf6deb2
fix distributed adam
zhehuaichen Mar 21, 2024
b8fc008
add local_batch_size
zhehuaichen Mar 22, 2024
b6ef6a5
support mt5
zhehuaichen Mar 22, 2024
555a007
Merge remote-tracking branch 'origin/main' into heh/modular_speechllm_pr
stevehuang52 Mar 22, 2024
8f524e3
Merge remote-tracking branch 'origin/main' into heh/modular_speechllm_pr
stevehuang52 Mar 22, 2024
619d75d
update dockerfile
stevehuang52 Mar 22, 2024
c0b9d0c
support mt5 and bypass bos_id=-1
zhehuaichen Mar 22, 2024
9a4861b
support configurating legacy_tokenizer for mt5 models
zhehuaichen Mar 23, 2024
c3ca938
update for merging main
stevehuang52 Mar 25, 2024
76db149
fix for merge main
stevehuang52 Mar 25, 2024
f7afea1
Merge remote-tracking branch 'origin/main' into heh/modular_speechllm_pr
stevehuang52 Mar 25, 2024
c99ad43
clean up docs
stevehuang52 Mar 25, 2024
7c9ded7
clean up
stevehuang52 Mar 25, 2024
4c4ac20
clean up
stevehuang52 Mar 25, 2024
afbc212
clean up
stevehuang52 Mar 25, 2024
6bce450
refactor
stevehuang52 Mar 25, 2024
b3f6156
clean up
stevehuang52 Mar 25, 2024
f63b8b8
update
stevehuang52 Mar 25, 2024
9dd72b6
clean up
stevehuang52 Mar 26, 2024
11facc7
fix speechlm test
stevehuang52 Mar 26, 2024
3da8282
update doc
stevehuang52 Mar 26, 2024
179fafd
Merge branch 'main' into heh/modular_speechllm_pr
stevehuang52 Mar 26, 2024
14c1334
refactor
stevehuang52 Mar 26, 2024
98a0143
refactor
stevehuang52 Mar 27, 2024
7dbe84d
refactor
stevehuang52 Mar 27, 2024
3a039f5
fix multi-layer feat
stevehuang52 Mar 27, 2024
55c9e04
Merge remote-tracking branch 'origin/main' into heh/modular_speechllm_pr
stevehuang52 Mar 27, 2024
073212b
update for webdataset
stevehuang52 Mar 27, 2024
edcf401
support setting dropout and label smoothing
zhehuaichen Mar 28, 2024
d3a04e0
make sure the updated cfg is passed to frozen_model
zhehuaichen Mar 28, 2024
3762632
mv model paths
zhehuaichen Mar 29, 2024
ba86fb9
refactor
stevehuang52 Apr 3, 2024
fdfe7b5
force str to avoid bugs with implicit conversion of str to bool type
stevehuang52 Apr 4, 2024
18b2921
Update examples/multimodal/speech_llm/README.md
stevehuang52 Apr 5, 2024
fef24dc
Update examples/multimodal/speech_llm/README.md
stevehuang52 Apr 5, 2024
c532150
refactor
stevehuang52 Apr 5, 2024
21d4261
Merge branch 'heh/modular_speechllm_pr' of https://github.com/NVIDIA/…
stevehuang52 Apr 5, 2024
c2f6b78
refactor
stevehuang52 Apr 5, 2024
7744144
Merge branch 'canary_speechllm1_cross_t5_pr' into canary_speechllm1_c…
zhehuaichen Apr 5, 2024
647e184
update for saving nemo
stevehuang52 Apr 5, 2024
36df825
update eval and ngc ckpt
stevehuang52 Apr 5, 2024
f6a90d1
Update nemo/collections/multimodal/speech_llm/data/audio_text_qa_data…
stevehuang52 Apr 8, 2024
d73a684
Update nemo/collections/multimodal/speech_llm/modules/common/audio_te…
stevehuang52 Apr 8, 2024
3dea3ce
Update tests/collections/multimodal/test_speechllm_models.py
stevehuang52 Apr 8, 2024
aa4f85b
refactor and remove nlp adapter mixin assert
stevehuang52 Apr 8, 2024
9e10694
Merge branch 'heh/modular_speechllm_pr' of https://github.com/NVIDIA/…
stevehuang52 Apr 8, 2024
360acd4
remove random context augmentation
stevehuang52 Apr 8, 2024
6449924
fix docstring
stevehuang52 Apr 8, 2024
52617f9
add docstring
stevehuang52 Apr 8, 2024
7c78165
minor refactor
stevehuang52 Apr 11, 2024
ed29843
refactor
stevehuang52 Apr 11, 2024
19b3d9f
fixes to be compatible with 24.01
zhehuaichen Apr 12, 2024
5a4be92
refactor and fix missing import
stevehuang52 Apr 12, 2024
03b9e60
fix for unfreeze llm
zhehuaichen Apr 13, 2024
35f0b03
for unfreeze am
zhehuaichen Apr 13, 2024
c991e5b
Merge branch 'main' into heh/modular_speechllm_pr
pablo-garay Apr 13, 2024
79156fc
major refactor on input format and minor update
stevehuang52 Apr 16, 2024
0268898
Merge branch 'heh/modular_speechllm_pr' of https://github.com/NVIDIA/…
stevehuang52 Apr 16, 2024
b6cac3d
fix codeQL
stevehuang52 Apr 17, 2024
8b19dc5
Merge branch 'main' into heh/modular_speechllm_pr
stevehuang52 Apr 17, 2024
960f958
clean up
stevehuang52 Apr 17, 2024
fac3a4e
fix for canary prompt
zhehuaichen Apr 22, 2024
89f0a42
fix for canary prompt and support t5
zhehuaichen Apr 22, 2024
790359d
Merge branch 'canary_speechllm1_cross_t5_pr2' of github.com:zhehuaich…
zhehuaichen Apr 23, 2024
2e18366
Merge remote-tracking branch 'origin/main' into heh/modular_speechllm_pr
stevehuang52 Apr 24, 2024
2bf9b07
configurable random_context_positive_percent
zhehuaichen Apr 24, 2024
566ee5a
update default random_context_num to 8 to reduce seq len
zhehuaichen Apr 24, 2024
f5e4af3
inference support
zhehuaichen Apr 25, 2024
a63e35d
support TP>1
zhehuaichen May 1, 2024
e0f5189
fix for salm decode
zhehuaichen May 5, 2024
8043262
Merge branch 'main' into heh/modular_speechllm_pr
stevehuang52 May 6, 2024
55f8231
update for NGC ckpt and refactor
stevehuang52 May 6, 2024
d9e2788
clean up
stevehuang52 May 6, 2024
60843db
support output metainfo with audio_filepath
zhehuaichen May 7, 2024
3cd12e9
Merge branch 'main' into heh/modular_speechllm_pr
ericharper May 7, 2024
30a583a
Merge branch 'main' into heh/modular_speechllm_pr
ericharper May 7, 2024
b9fc1bc
Merge remote-tracking branch 'upstream/heh/modular_speechllm_pr' into…
zhehuaichen May 8, 2024
e4cad0c
revert unrelated changes
zhehuaichen May 8, 2024
226c605
revert unrelated changes
zhehuaichen May 8, 2024
f9e2f94
some fixes for t5
zhehuaichen May 8, 2024
d4a6fd8
clean up and test inference
zhehuaichen May 8, 2024
90887cf
move dataset code to one place
zhehuaichen May 8, 2024
d5265bd
verify train and inference for bestow+gpt and salm+t5
zhehuaichen May 8, 2024
55b270b
Merge branch 'main' into heh/modular_speechllm_pr
nithinraok May 8, 2024
1c4cbd7
Merge branch 'main' into heh/modular_speechllm_pr
stevehuang52 May 9, 2024
3e88457
skip speechlm test until data moved to CI machines
stevehuang52 May 9, 2024
0700cdb
use pad_id for pad and add eos_id when enabled
zhehuaichen May 9, 2024
17ab55b
Merge branch 'main' into heh/modular_speechllm_pr
stevehuang52 May 10, 2024
6cae145
Merge branch 'main' into heh/modular_speechllm_pr
stevehuang52 May 10, 2024
4cfaa30
refactor and update to avoid changing nlp_adapter_mixin
stevehuang52 May 10, 2024
27e33ee
Merge branch 'heh/modular_speechllm_pr' of https://github.com/NVIDIA/…
stevehuang52 May 10, 2024
67ecaa1
Merge branch 'main' into heh/modular_speechllm_pr
stevehuang52 May 10, 2024
89926fa
Apply isort and black reformatting
stevehuang52 May 10, 2024
fb8914d
Merge remote-tracking branch 'upstream/heh/modular_speechllm_pr' into…
zhehuaichen May 11, 2024
9499f2e
minor edit
zhehuaichen May 11, 2024
e601135
Merge remote-tracking branch 'upstream/main' into canary_speechllm1_c…
zhehuaichen May 11, 2024
3cc0432
Apply isort and black reformatting
zhehuaichen May 11, 2024
11407b2
fixes per Piotr and Steve's comments
zhehuaichen May 28, 2024
db2166e
WIP in getting rid of canary specific things in dataset
zhehuaichen Jun 3, 2024
d6c23a5
Merge remote-tracking branch 'upstream/main' into canary_speechllm1_c…
zhehuaichen Jun 3, 2024
9c2c4af
remove canary specific design; bugfix for asr/models/aed_multitask_mo…
zhehuaichen Jun 3, 2024
24c0f9f
remove random_context and submit it later by rewriting with augmenter
zhehuaichen Jun 3, 2024
1999298
remove canary specific stuffs in dataloading; use input_cfg in lhotse…
zhehuaichen Jun 4, 2024
1f10bd7
fix for https://github.com/NVIDIA/NeMo/pull/9169/#pullrequestreview-2…
zhehuaichen Jun 4, 2024
e96da9b
Merge remote-tracking branch 'upstream/main' into canary_speechllm1_c…
zhehuaichen Jun 4, 2024
0aa4179
minor fix
zhehuaichen Jun 5, 2024
24daa2e
make sure NGC inference and fix CodeQL https://github.com/NVIDIA/NeMo…
zhehuaichen Jun 5, 2024
1239b35
Merge branch 'main' into canary_speechllm1_cross_t5_pr3
zhehuaichen Jun 5, 2024
eb7e00d
add back the assert in nlp collection and add a enforce_divisible_bat…
zhehuaichen Jun 6, 2024
e4e7802
nit
zhehuaichen Jun 6, 2024
0b4451b
fixes per Som s comments https://github.com/NVIDIA/NeMo/pull/9169#pul…
zhehuaichen Jun 6, 2024
9d362e5
Merge branch 'main' into canary_speechllm1_cross_t5_pr3
zhehuaichen Jun 6, 2024
f12fa74
nit
zhehuaichen Jun 6, 2024
18bdbe7
fix split_list
zhehuaichen Jun 6, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view

Large diffs are not rendered by default.

Large diffs are not rendered by default.

338 changes: 338 additions & 0 deletions examples/multimodal/speech_llm/conf/modular_audio_t5_config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,338 @@
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

name: megatron_audio_t5_salm_lhotse
zhehuaichen marked this conversation as resolved.
Show resolved Hide resolved

trainer:
devices: 1
accelerator: gpu
num_nodes: 1
precision: bf16
logger: False # logger provided by exp_manager
enable_checkpointing: False
use_distributed_sampler: False
max_epochs: 9999
max_steps: 1000000 # consumed_samples = global_step * micro_batch_size * data_parallel_size * accumulate_grad_batches
limit_train_batches : 1000
log_every_n_steps: 10 # frequency with which training steps are logged
val_check_interval: 1.0 # If is an int n > 1, will run val every n training steps, if a float 0.0 - 1.0 will run val every epoch fraction, e.g. 0.25 will run val every quarter epoch
gradient_clip_val: 1.0
accumulate_grad_batches: 1

model_target: nemo.collections.multimodal.speech_llm.models.modular_models_t5.ModularizedAudioT5Model
exp_manager:
# explicit_log_dir: null
exp_dir: null
name: ${name}
create_wandb_logger: False
wandb_logger_kwargs:
project: null
name: null
resume_if_exists: True
resume_ignore_no_checkpoint: True
create_checkpoint_callback: True
checkpoint_callback_params:
monitor: validation_${model.data.validation_ds.metric.name}
save_top_k: 1
mode: min
save_nemo_on_train_end: True
filename: '${name}--{${exp_manager.checkpoint_callback_params.monitor}:.3f}-{step}-{epoch}'
model_parallel_size: ${model.tensor_model_parallel_size}
always_save_nemo: False
save_best_model: True
create_early_stopping_callback: False
early_stopping_callback_params:
monitor: "val_loss"
mode: "min"
min_delta: 0.001
patience: 10
verbose: True
strict: False # Should be False to avoid a runtime error where EarlyStopping says monitor is unavailable, which sometimes happens with resumed training.


model:
virtual_prompt_style: 'no-prompts' # make cls happy
seed: 1234
tensor_model_parallel_size: 1 # intra-layer model parallelism
pipeline_model_parallel_size: 1 # inter-layer model parallelism

pretrained_audio_model: stt_en_fastconformer_transducer_large
freeze_llm: True
freeze_audio_encoder: False
freeze_modality_adapter: False
load_audio_encoder: True

global_batch_size: 128
micro_batch_size: 4
language_model_path: ??? # Path to an existing .nemo model you wish to add new tasks to or run inference with
resume_from_checkpoint: null # The path to a checkpoint file to continue the training, restores the whole state including the epoch, step, LR schedulers, apex, etc.
save_nemo_on_validation_end: False # Saves an inference ready .nemo file every time a checkpoint is saved during training.
sync_batch_comm: False
megatron_amp_O2: False

## Sequence Parallelism
# Makes tensor parallelism more memory efficient for LLMs (20B+) by parallelizing layer norms and dropout sequentially
# See Reducing Activation Recomputation in Large Transformer Models: https://arxiv.org/abs/2205.05198 for more details.
sequence_parallel: False

## Activation Checkpoint
activations_checkpoint_granularity: null # 'selective' or 'full'
activations_checkpoint_method: null # 'uniform', 'block', not used with 'selective'
# 'uniform' divides the total number of transformer layers and checkpoints the input activation
# of each chunk at the specified granularity
# 'block' checkpoints the specified number of layers per pipeline stage at the specified granularity
activations_checkpoint_num_layers: null # not used with 'selective'
activations_checkpoint_layers_per_pipeline: null
answer_only_loss: True
gradient_as_bucket_view: False

hidden_dropout: 0.0
attention_dropout: 0.0
ffn_dropout: 0.0

# use_am_tokenizer: True
# override_vocab_size: 1024

lora_tuning:
kqv_adapter_dim: 128
kv_adapter_dim: 64
q_adapter_dim: 32
adapter_dropout: 0.0
column_init_method: 'xavier' # IGNORED if linear_adapter is used, options: xavier, zero or normal
row_init_method: 'zero' # IGNORED if linear_adapter is used, options: xavier, zero or normal

peft:
peft_scheme: "adapter" # can be either adapter,ia3, or ptuning
restore_from_path: null

# Used for adapter peft training
adapter_tuning:
type: 'parallel_adapter' # this should be either 'parallel_adapter' or 'linear_adapter'
adapter_dim: 32
adapter_dropout: 0.0
norm_position: 'pre' # This can be set to 'pre' or 'post', 'pre' is normally what is used.
column_init_method: 'xavier' # IGNORED if linear_adapter is used, options: xavier, zero or normal
row_init_method: 'zero' # IGNORED if linear_adapter is used, options: xavier, zero or normal
norm_type: 'mixedfusedlayernorm' # IGNORED if layer_adapter is used, options are ['layernorm', 'mixedfusedlayernorm']

# Used for p-tuning peft training
p_tuning:
virtual_tokens: 10 # The number of virtual tokens the prompt encoder should add at the start of the sequence
bottleneck_dim: 1024 # the size of the prompt encoder mlp bottleneck
embedding_dim: 1024 # the size of the prompt encoder embeddings
init_std: 0.023

perception:
target: nemo.collections.multimodal.speech_llm.modules.perception_modules.AudioPerceptionModule
use_multi_layer_feat: false

modality_adapter:
_target_: nemo.collections.asr.modules.ConformerEncoder
feat_in: 1024
feat_out: -1 # you may set it if you need different output size other than the default d_model
n_layers: 2
d_model: 512

# Sub-sampling parameters
subsampling: dw_striding # vggnet, striding, stacking or stacking_norm, dw_striding
subsampling_factor: 8 # must be power of 2 for striding and vggnet
subsampling_conv_channels: 256 # set to -1 to make it equal to the d_model
causal_downsampling: false

# Reduction parameters: Can be used to add another subsampling layer at a given position.
# Having a 2x reduction will speedup the training and inference speech while keeping similar WER.
# Adding it at the end will give the best WER while adding it at the beginning will give the best speedup.
reduction: null # pooling, striding, or null
reduction_position: null # Encoder block index or -1 for subsampling at the end of encoder
reduction_factor: 1

# Feed forward module's params
ff_expansion_factor: 4

# Multi-headed Attention Module's params
self_attention_model: rel_pos # rel_pos or abs_pos
n_heads: 8 # may need to be lower for smaller d_models
# [left, right] specifies the number of steps to be seen from left and right of each step in self-attention
att_context_size: [-1, -1] # -1 means unlimited context
att_context_style: regular # regular or chunked_limited
xscaling: true # scales up the input embeddings by sqrt(d_model)
untie_biases: true # unties the biases of the TransformerXL layers
pos_emb_max_len: 5000

# Convolution module's params
conv_kernel_size: 9
conv_norm_type: 'batch_norm' # batch_norm or layer_norm or groupnormN (N specifies the number of groups)
# conv_context_size can be"causal" or a list of two integers while conv_context_size[0]+conv_context_size[1]+1==conv_kernel_size
# null means [(kernel_size-1)//2, (kernel_size-1)//2], and 'causal' means [(kernel_size-1), 0]
conv_context_size: null

### regularization
dropout: 0.1 # The dropout used in most of the Conformer Modules
dropout_pre_encoder: 0.1 # The dropout used before the encoder
dropout_emb: 0.0 # The dropout used for embeddings
dropout_att: 0.1 # The dropout for multi-headed attention modules

# set to non-zero to enable stochastic depth
stochastic_depth_drop_prob: 0.0
stochastic_depth_mode: linear # linear or uniform
stochastic_depth_start_layer: 1

spec_augment:
_target_: nemo.collections.asr.modules.SpectrogramAugmentation
freq_masks: 2 # set to zero to disable it
time_masks: 10 # set to zero to disable it
freq_width: 27
time_width: 0.05

# the following are read from the pretrained AM:
# output_dim: null
# encoder: null
# preprocessor: null

data:
train_ds:
# Example of how to specify paths to multiple datasets
# manifest_filepath:
# - /path/to/squad.jsonl
# - /path/to/mnli.jsonl
# - /path/to/boolq.jsonl
# Example of how each dataset is formatted
# {'audio_filepath': 'audio1.wav', 'offset': 0.0, 'duration': 12.3, 'question': 'transcribe this audio', 'answer': 'I have a dream...'}
# the 'answer' field can also be 'text', and a default 'question' field is added if missing in manigests, so as to work with ASR manifests
manifest_filepath: ??? # Path to a list of JSONL files corresponding to the source data.
global_batch_size: ${model.global_batch_size}
micro_batch_size: ${model.micro_batch_size}
shuffle: True
num_workers: 0
pin_memory: True
max_seq_length: 2048
min_seq_length: 1
drop_last: True
# Notably, the data weights are controlled by either bucketing_weights
# or concat_sampling_probabilities depending on the dataset type (tar and
# non-tar).
# See audio_text_qa_dataset.py for details.
concat_sampling_probabilities: null # When providing a list of datasets, this arg defines the sampling probabilities from each dataset when strategy='random'
context_key: 'input'
label_key: 'output'
add_eos: True
# add_eos: False
add_sep: True
add_bos: False
separate_prompt_and_response_with_newline: False
truncation_field: "context" # Options: ['context', 'answer']
index_mapping_dir: null # Path to a directory to write index mapping files.
prompt_template: "Q: {input}\nA: {output}" # fstring to use for assistant prompt. Example: "Q: {input}\nA: {output}"
# ASR configs
sample_rate: 16000 #${model.audio_encoder.preprocessor.sample_rate}
max_duration: 24 # it is set for LibriSpeech, you may need to update it for your dataset
min_duration: 0.1
# tarred datasets
is_tarred: false
tarred_audio_filepaths: null
shuffle_n: 2048
# bucketing params
bucketing_strategy: "fully_randomized"
bucketing_batch_size: null
# sample_alpha: 0.1
use_lhotse: True
duration_bins : [2,4,6,8,10,12,14,16,18]
lhotse:
text_field : "text"
batch_duration : 80 # 0
quadratic_duration : 30
max_open_streams: 50
num_buckets : 30
buffer_size : 10000
shuffle_buffer_size : 10000
duration_bins: [2.92,3.474,3.924,4.335,4.728,5.11,5.487,5.872,6.288,6.696,7.128,7.62,8.208,8.934,9.883,10.56,11.22,11.88,12.51,13.05,13.59,14.13,14.64,15.17875,15.81,16.54,17.37,18.241,19.18]
# sample_alpha: 0.1

validation_ds:
manifest_filepath: ??? # Path to a list of JSONL files corresponding to the source data. Data format is identical to train_ds.
global_batch_size: ${model.global_batch_size}
micro_batch_size: ${model.micro_batch_size}
shuffle: False
num_workers: 0
pin_memory: True
max_seq_length: 2048
min_seq_length: 1
drop_last: False
context_key: ${model.data.train_ds.context_key}
label_key: ${model.data.train_ds.label_key}
add_eos: ${model.data.train_ds.add_eos}
add_sep: ${model.data.train_ds.add_sep}
add_bos: ${model.data.train_ds.add_bos}
separate_prompt_and_response_with_newline: ${model.data.train_ds.separate_prompt_and_response_with_newline}
write_predictions_to_file: False
output_file_path_prefix: null # Prefix of the file to write predictions to.
truncation_field: "context" # Options: ['context', 'answer']
index_mapping_dir: null # Path to a directory to write index mapping files.
prompt_template: ${model.data.train_ds.prompt_template} # fstring to use for assistant prompt. Example: "Q: {input}\nA: {output}"
tokens_to_generate: 128
# ASR configs
sample_rate: 16000 #${model.audio_encoder.preprocessor.sample_rate}

log_every_n_steps: 1
metric:
name: "wer" # Name of the evaluation metric to use. Options: ['exact_string_match', 'loss']
average: null # Average the metric over the dataset. Options: ['macro', 'micro']. Works only for 'F1', 'accuracy' etc. Refer to torchmetrics for metrics where this is supported.
num_classes: null

# make model init happy
num_workers: 0
# test_ds:
# manifest_filepath: null # Path to a list of JSONL files corresponding to the source data. Data format is identical to train_ds.
# names: null # Names of the corresponding datasets used to log metrics.
# global_batch_size: ${model.global_batch_size}
# micro_batch_size: ${model.micro_batch_size}
# shuffle: False
# num_workers: 4
# pin_memory: True
# max_seq_length: 2048
# min_seq_length: 1
# drop_last: False
# context_key: 'input'
# label_key: 'output'
# add_eos: ${model.data.train_ds.add_eos}
# add_sep: ${model.data.train_ds.add_sep}
# add_bos: ${model.data.train_ds.add_bos}
# separate_prompt_and_response_with_newline: ${model.data.train_ds.separate_prompt_and_response_with_newline}
# write_predictions_to_file: False
# output_file_path_prefix: null # Prefix of the file to write predictions to.
# truncation_field: "context" # Options: ['context', 'answer']
# index_mapping_dir: null # Path to a directory to write index mapping files.
# prompt_template: ${model.data.train_ds.prompt_template}
# # ASR configs
# sample_rate: 16000 #${model.audio_encoder.preprocessor.sample_rate}

# metric:
# name: "loss" # Name of the evaluation metric to use. Options: ['exact_string_match', 'loss']
# average: null # Average the metric over the dataset. Options: ['macro', 'micro']. Works only for 'F1', 'accuracy' etc. Refer to torchmetrics for metrics where this is supported.
# num_classes: null

optim:
name: fused_adam
lr: 1e-4
weight_decay: 0.01
betas:
- 0.9
- 0.98
sched:
name: CosineAnnealing
warmup_steps: 50
min_lr: 0.0 # min_lr must be 0.0 for prompt learning when pipeline parallel > 1
constant_steps: 0 # Constant steps should also be 0 when min_lr=0
monitor: val_loss
reduce_on_plateau: false
8 changes: 6 additions & 2 deletions examples/multimodal/speech_llm/modular_audio_gpt_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from nemo.collections.multimodal.speech_llm.models.modular_models import ModularAudioGPTModel
from nemo.collections.nlp.parts.megatron_trainer_builder import MegatronLMPPTrainerBuilder
from nemo.core.config import hydra_runner
from nemo.utils import logging
from nemo.utils import logging, model_utils
from nemo.utils.exp_manager import exp_manager

mp.set_start_method("spawn", force=True)
Expand Down Expand Up @@ -61,7 +61,11 @@ def main(cfg) -> None:
# update resume from checkpoint found by exp_manager
logging.info(f'Resuming training from checkpoint: {trainer.ckpt_path}')

model = ModularAudioGPTModel.restore_from_pretrained_models(cfg, trainer=trainer)
if hasattr(cfg, 'model_target'):
imported_cls = model_utils.import_class_by_path(cfg.model_target)
else:
imported_cls = ModularAudioGPTModel
model = imported_cls.restore_from_pretrained_models(cfg, trainer=trainer)

trainer.fit(model)

Expand Down