-
Notifications
You must be signed in to change notification settings - Fork 0
/
e2e_asr_conformer.py
74 lines (62 loc) · 2.53 KB
/
e2e_asr_conformer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
# Copyright 2020 Johns Hopkins University (Shinji Watanabe)
# Northwestern Polytechnical University (Pengcheng Guo)
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
"""
Conformer speech recognition model (pytorch).
It is a fusion of `e2e_asr_transformer.py`
Refer to: https://arxiv.org/abs/2005.08100
"""
from conformer.encoder import Encoder
from transformer.e2e_asr_transformer import E2E as E2ETransformer
from conformer.argument import (
add_arguments_conformer_common, # noqa: H301
)
class E2E(E2ETransformer):
"""E2E module.
:param int idim: dimension of inputs
:param int odim: dimension of outputs
:param Namespace args: argument Namespace containing options
"""
@staticmethod
def add_arguments(parser):
"""Add arguments."""
E2ETransformer.add_arguments(parser)
E2E.add_conformer_arguments(parser)
return parser
@staticmethod
def add_conformer_arguments(parser):
"""Add arguments for conformer model."""
group = parser.add_argument_group("conformer model specific setting")
group = add_arguments_conformer_common(group)
return parser
def __init__(self, args, ignore_id=-1):
"""Construct an E2E object.
:param int idim: dimension of inputs
:param int odim: dimension of outputs
:param Namespace args: argument Namespace containing options
"""
idim = args.fbank_dim
odim = args.odim
self.idim = args.fbank_dim
self.odim = args.odim
super().__init__(idim, odim, args, ignore_id)
if args.transformer_attn_dropout_rate is None:
args.transformer_attn_dropout_rate = args.dropout_rate
self.encoder = Encoder(
idim=self.idim,
attention_dim=args.adim,
attention_heads=args.aheads,
linear_units=args.eunits,
num_blocks=args.elayers,
input_layer=args.transformer_input_layer,
dropout_rate=args.dropout_rate,
positional_dropout_rate=args.dropout_rate,
attention_dropout_rate=args.transformer_attn_dropout_rate,
pos_enc_layer_type=args.transformer_encoder_pos_enc_layer_type,
selfattention_layer_type=args.transformer_encoder_selfattn_layer_type,
activation_type=args.transformer_encoder_activation_type,
macaron_style=args.macaron_style,
use_cnn_module=args.use_cnn_module,
cnn_module_kernel=args.cnn_module_kernel,
)
self.reset_parameters(args)