Skip to content

Commit

Permalink
Fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
LysandreJik committed Apr 30, 2024
1 parent c3be141 commit 26360e6
Show file tree
Hide file tree
Showing 25 changed files with 156 additions and 88 deletions.
76 changes: 5 additions & 71 deletions src/transformers/__init__.py

Large diffs are not rendered by default.

8 changes: 7 additions & 1 deletion src/transformers/models/bart/modeling_tf_bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -1269,6 +1269,7 @@ def build(self, input_shape=None):
"The bare BART Model outputting raw hidden-states without any specific head on top.",
BART_START_DOCSTRING,
)
@register(backends=("tf",))
class TFBartModel(TFBartPretrainedModel):
_requires_load_weight_prefix = True

Expand Down Expand Up @@ -1382,6 +1383,7 @@ def call(self, x):
"The BART Model with a language modeling head. Can be used for summarization.",
BART_START_DOCSTRING,
)
@register(backends=("tf",))
class TFBartForConditionalGeneration(TFBartPretrainedModel, TFCausalLanguageModelingLoss):
_keys_to_ignore_on_load_missing = [r"final_logits_bias"]
_requires_load_weight_prefix = True
Expand Down Expand Up @@ -1582,6 +1584,7 @@ def build(self, input_shape=None):
""",
BART_START_DOCSTRING,
)
@register(backends=("tf",))
class TFBartForSequenceClassification(TFBartPretrainedModel, TFSequenceClassificationLoss):
def __init__(self, config: BartConfig, load_weight_prefix=None, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)
Expand Down Expand Up @@ -1714,5 +1717,8 @@ def build(self, input_shape=None):
self.classification_head.build(None)

__all__ = [
"TFBartPretrainedModel"
"TFBartPretrainedModel",
"TFBartModel",
"TFBartForConditionalGeneration",
"TFBartForSequenceClassification"
]
Original file line number Diff line number Diff line change
Expand Up @@ -375,6 +375,7 @@ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] =
return (vocab_file,)


@register()
class MecabTokenizer:
"""Runs basic tokenization with MeCab morphological parser."""

Expand Down Expand Up @@ -647,6 +648,7 @@ def tokenize(self, text, never_split=None, **kwargs):
return tokens


@register()
class CharacterTokenizer:
"""Runs Character tokenization."""

Expand Down Expand Up @@ -982,5 +984,7 @@ def tokenize(self, text):
return new_pieces

__all__ = [
"BertJapaneseTokenizer"
"BertJapaneseTokenizer",
"CharacterTokenizer",
"MecabTokenizer"
]
10 changes: 9 additions & 1 deletion src/transformers/models/dpr/modeling_dpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ class DPRQuestionEncoderOutput(ModelOutput):


@dataclass
@register(backends=("torch",))
class DPRReaderOutput(ModelOutput):
"""
Class for outputs of [`DPRQuestionEncoder`].
Expand Down Expand Up @@ -428,6 +429,7 @@ class DPRPretrainedReader(DPRPreTrainedModel):
"The bare DPRContextEncoder transformer outputting pooler outputs as context representations.",
DPR_START_DOCSTRING,
)
@register(backends=("torch",))
class DPRContextEncoder(DPRPretrainedContextEncoder):
def __init__(self, config: DPRConfig):
super().__init__(config)
Expand Down Expand Up @@ -509,6 +511,7 @@ def forward(
"The bare DPRQuestionEncoder transformer outputting pooler outputs as question representations.",
DPR_START_DOCSTRING,
)
@register(backends=("torch",))
class DPRQuestionEncoder(DPRPretrainedQuestionEncoder):
def __init__(self, config: DPRConfig):
super().__init__(config)
Expand Down Expand Up @@ -591,6 +594,7 @@ def forward(
"The bare DPRReader transformer outputting span predictions.",
DPR_START_DOCSTRING,
)
@register(backends=("torch",))
class DPRReader(DPRPretrainedReader):
def __init__(self, config: DPRConfig):
super().__init__(config)
Expand Down Expand Up @@ -666,5 +670,9 @@ def forward(
"DPRPreTrainedModel",
"DPRPretrainedContextEncoder",
"DPRPretrainedQuestionEncoder",
"DPRPretrainedReader"
"DPRPretrainedReader",
"DPRContextEncoder",
"DPRQuestionEncoder",
"DPRReader",
"DPRReaderOutput"
]
8 changes: 7 additions & 1 deletion src/transformers/models/dpr/modeling_tf_dpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -539,6 +539,7 @@ class TFDPRPretrainedReader(TFPreTrainedModel):
"The bare DPRContextEncoder transformer outputting pooler outputs as context representations.",
TF_DPR_START_DOCSTRING,
)
@register(backends=("tf",))
class TFDPRContextEncoder(TFDPRPretrainedContextEncoder):
def __init__(self, config: DPRConfig, *args, **kwargs):
super().__init__(config, *args, **kwargs)
Expand Down Expand Up @@ -628,6 +629,7 @@ def build(self, input_shape=None):
"The bare DPRQuestionEncoder transformer outputting pooler outputs as question representations.",
TF_DPR_START_DOCSTRING,
)
@register(backends=("tf",))
class TFDPRQuestionEncoder(TFDPRPretrainedQuestionEncoder):
def __init__(self, config: DPRConfig, *args, **kwargs):
super().__init__(config, *args, **kwargs)
Expand Down Expand Up @@ -716,6 +718,7 @@ def build(self, input_shape=None):
"The bare DPRReader transformer outputting span predictions.",
TF_DPR_START_DOCSTRING,
)
@register(backends=("tf",))
class TFDPRReader(TFDPRPretrainedReader):
def __init__(self, config: DPRConfig, *args, **kwargs):
super().__init__(config, *args, **kwargs)
Expand Down Expand Up @@ -796,5 +799,8 @@ def build(self, input_shape=None):
__all__ = [
"TFDPRPretrainedContextEncoder",
"TFDPRPretrainedQuestionEncoder",
"TFDPRPretrainedReader"
"TFDPRPretrainedReader",
"TFDPRContextEncoder",
"TFDPRQuestionEncoder",
"TFDPRReader"
]
1 change: 1 addition & 0 deletions src/transformers/models/efficientformer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
if TYPE_CHECKING:
from .configuration_efficientformer import *
from .image_processing_efficientformer import *
from .modeling_efficientformer import *
from .modeling_tf_efficientformer import *
else:
import sys
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
logging,
)
from .configuration_efficientformer import EfficientFormerConfig

from ...utils.import_utils import register

logger = logging.get_logger(__name__)

Expand Down Expand Up @@ -496,6 +496,7 @@ def forward(
)


@register(backends=('torch',))
class EfficientFormerPreTrainedModel(PreTrainedModel):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
Expand Down Expand Up @@ -548,6 +549,7 @@ def _init_weights(self, module: nn.Module):
"The bare EfficientFormer Model transformer outputting raw hidden-states without any specific head on top.",
EFFICIENTFORMER_START_DOCSTRING,
)
@register(backends=('torch',))
class EfficientFormerModel(EfficientFormerPreTrainedModel):
def __init__(self, config: EfficientFormerConfig):
super().__init__(config)
Expand Down Expand Up @@ -611,6 +613,7 @@ def forward(
""",
EFFICIENTFORMER_START_DOCSTRING,
)
@register(backends=('torch',))
class EfficientFormerForImageClassification(EfficientFormerPreTrainedModel):
def __init__(self, config: EfficientFormerConfig):
super().__init__(config)
Expand Down Expand Up @@ -741,6 +744,7 @@ class token).
""",
EFFICIENTFORMER_START_DOCSTRING,
)
@register(backends=('torch',))
class EfficientFormerForImageClassificationWithTeacher(EfficientFormerPreTrainedModel):
def __init__(self, config: EfficientFormerConfig):
super().__init__(config)
Expand Down Expand Up @@ -799,3 +803,11 @@ def forward(
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)


__all__ = [
"EfficientFormerPreTrainedModel",
"EfficientFormerModel",
"EfficientFormerForImageClassification",
"EfficientFormerForImageClassificationWithTeacher"
]
1 change: 1 addition & 0 deletions src/transformers/models/fastspeech2_conformer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
if TYPE_CHECKING:
from .configuration_fastspeech2_conformer import *
from .modeling_fastspeech2_conformer import *
from .tokenization_fastspeech2_conformer import *
else:
import sys

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,14 @@

from ...tokenization_utils import PreTrainedTokenizer
from ...utils import logging, requires_backends

from ...utils.import_utils import register

logger = logging.get_logger(__name__)

VOCAB_FILES_NAMES = {"vocab_file": "vocab.json"}


@register()
class FastSpeech2ConformerTokenizer(PreTrainedTokenizer):
"""
Construct a FastSpeech2Conformer tokenizer.
Expand Down Expand Up @@ -182,3 +183,8 @@ def __setstate__(self, d):
"You need to install g2p-en to use FastSpeech2ConformerTokenizer. "
"See https://pypi.org/project/g2p-en/ for installation."
)


__all__ = [
"FastSpeech2ConformerTokenizer"
]
6 changes: 5 additions & 1 deletion src/transformers/models/fsmt/modeling_fsmt.py
Original file line number Diff line number Diff line change
Expand Up @@ -1031,6 +1031,7 @@ def _get_shape(t):
"The bare FSMT Model outputting raw hidden-states without any specific head on top.",
FSMT_START_DOCSTRING,
)
@register(backends=("torch",))
class FSMTModel(PretrainedFSMTModel):
_tied_weights_keys = ["decoder.embed_tokens.weight", "decoder.output_projection.weight"]

Expand Down Expand Up @@ -1172,6 +1173,7 @@ def set_output_embeddings(self, value):
@add_start_docstrings(
"The FSMT Model with a language modeling head. Can be used for summarization.", FSMT_START_DOCSTRING
)
@register(backends=("torch",))
class FSMTForConditionalGeneration(PretrainedFSMTModel):
base_model_prefix = "model"
_tied_weights_keys = ["decoder.embed_tokens.weight", "decoder.output_projection.weight"]
Expand Down Expand Up @@ -1388,5 +1390,7 @@ def forward(
return super().forward(positions)

__all__ = [
"PretrainedFSMTModel"
"PretrainedFSMTModel",
"FSMTModel",
"FSMTForConditionalGeneration"
]
23 changes: 23 additions & 0 deletions src/transformers/models/mt5/tokenization_mt5.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# coding=utf-8
# Copyright 2020, The T5 Authors and HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" mT5 tokenization file"""
from transformers import T5Tokenizer
from transformers.utils.import_utils import register


@register(backends=("sentencepiece",))
class MT5Tokenizer(T5Tokenizer):
pass

23 changes: 23 additions & 0 deletions src/transformers/models/mt5/tokenization_mt5_fast.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# coding=utf-8
# Copyright 2020, The T5 Authors and HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" mT5 fast tokenization file"""
from transformers import T5TokenizerFast
from transformers.utils.import_utils import register


@register(backends=("tokenizers",))
class MT5TokenizerFast(T5TokenizerFast):
pass

6 changes: 5 additions & 1 deletion src/transformers/models/nllb_moe/modeling_nllb_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,7 @@ def create_position_ids_from_inputs_embeds(self, inputs_embeds, past_key_values_
return position_ids.unsqueeze(0).expand(input_shape).contiguous() + past_key_values_length


@register(backends=("torch",))
class NllbMoeTop2Router(nn.Module):
"""
Router using tokens choose top-2 experts assignment.
Expand Down Expand Up @@ -388,6 +389,7 @@ def forward(self, hidden_states):
return hidden_states


@register(backends=("torch",))
class NllbMoeSparseMLP(nn.Module):
r"""
Implementation of the NLLB-MoE sparse MLP module.
Expand Down Expand Up @@ -1796,5 +1798,7 @@ def _reorder_cache(past_key_values, beam_idx):
__all__ = [
"NllbMoePreTrainedModel",
"NllbMoeModel",
"NllbMoeForConditionalGeneration"
"NllbMoeForConditionalGeneration",
"NllbMoeSparseMLP",
"NllbMoeTop2Router",
]
4 changes: 3 additions & 1 deletion src/transformers/models/pvt_v2/modeling_pvt_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -631,6 +631,7 @@ def forward(
""",
PVT_V2_START_DOCSTRING,
)
@register(backends=("torch",))
class PvtV2Backbone(PvtV2Model, BackboneMixin):
def __init__(self, config: PvtV2Config):
super().__init__(config)
Expand Down Expand Up @@ -706,5 +707,6 @@ def forward(
__all__ = [
"PvtV2PreTrainedModel",
"PvtV2Model",
"PvtV2ForImageClassification"
"PvtV2ForImageClassification",
"PvtV2Backbone"
]
1 change: 1 addition & 0 deletions src/transformers/models/rag/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from .modeling_rag import *
from .modeling_tf_rag import *
from .tokenization_rag import *
from .retrieval_rag import *
else:
import sys

Expand Down
8 changes: 7 additions & 1 deletion src/transformers/models/rag/retrieval_rag.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from ...utils import cached_file, is_datasets_available, is_faiss_available, logging, requires_backends, strtobool
from .configuration_rag import RagConfig
from .tokenization_rag import RagTokenizer

from ...utils.import_utils import register

if is_datasets_available():
from datasets import Dataset, load_dataset, load_from_disk
Expand Down Expand Up @@ -341,6 +341,7 @@ def init_index(self):
self._index_initialized = True


@register()
class RagRetriever:
"""
Retriever used to get documents from vector queries. It retrieves the documents embeddings as well as the documents
Expand Down Expand Up @@ -672,3 +673,8 @@ def __call__(
},
tensor_type=return_tensors,
)


__all__ = [
"RagRetriever"
]
1 change: 1 addition & 0 deletions src/transformers/models/realm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from .modeling_realm import *
from .tokenization_realm import *
from .tokenization_realm_fast import *
from .retrieval_realm import *
else:
import sys

Expand Down

0 comments on commit 26360e6

Please sign in to comment.