From 54fcf217faaacd23499c4b3c3ffe7520b054ea1d Mon Sep 17 00:00:00 2001 From: Fraser Greenlee Date: Thu, 17 Dec 2020 21:55:17 +0000 Subject: [PATCH 1/5] add newline --- transformer_vae/utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/transformer_vae/utils.py b/transformer_vae/utils.py index b9b86d2..60fcb66 100644 --- a/transformer_vae/utils.py +++ b/transformer_vae/utils.py @@ -1,3 +1,4 @@ + def assertEqual(actual, expected, msg, first="Got", second="Expected"): if actual != expected: raise ValueError(msg + f' {first}: "{actual}" {second}: "{expected}"') From b97e8241ca6df5ea97eb6423f8e538679950aedb Mon Sep 17 00:00:00 2001 From: Fraser Greenlee Date: Thu, 17 Dec 2020 22:19:02 +0000 Subject: [PATCH 2/5] small fixes --- Makefile | 1 + transformer_vae/trainer.py | 2 -- transformer_vae/utils.py | 1 - 3 files changed, 1 insertion(+), 3 deletions(-) diff --git a/Makefile b/Makefile index 1f167df..2d1d1dd 100644 --- a/Makefile +++ b/Makefile @@ -8,6 +8,7 @@ activate: # source t5_vae_env/bin/activate test: + black --check -l 120 -t py37 . python -m pytest -s -v ./tests/ test-one-case: diff --git a/transformer_vae/trainer.py b/transformer_vae/trainer.py index ad5e0d1..56c0e3c 100644 --- a/transformer_vae/trainer.py +++ b/transformer_vae/trainer.py @@ -5,10 +5,8 @@ from torch.utils.data.dataset import Dataset from torch.utils.data.sampler import RandomSampler from torch.utils.data.dataloader import DataLoader -import inspect import time -import datasets from transformers import trainer as trainer_script from transformers.integrations import ( WandbCallback, diff --git a/transformer_vae/utils.py b/transformer_vae/utils.py index 60fcb66..b9b86d2 100644 --- a/transformer_vae/utils.py +++ b/transformer_vae/utils.py @@ -1,4 +1,3 @@ - def assertEqual(actual, expected, msg, first="Got", second="Expected"): if actual != expected: raise ValueError(msg + f' {first}: "{actual}" {second}: "{expected}"') From 30a1682045734b1a9ff4b796cc6ea79ae032a3df Mon Sep 17 00:00:00 2001 From: Fraser Greenlee Date: Thu, 17 Dec 2020 22:23:28 +0000 Subject: [PATCH 3/5] fix get embeddings --- transformer_vae/model.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/transformer_vae/model.py b/transformer_vae/model.py index 2c9e089..98692aa 100644 --- a/transformer_vae/model.py +++ b/transformer_vae/model.py @@ -199,9 +199,6 @@ def resize_token_embeddings(self, *args, **kwargs): super().resize_token_embeddings(*args, **kwargs) self.transformer.resize_token_embeddings(*args, **kwargs) - def get_input_embeddings(self): - return self.transformer.shared - def set_input_embeddings(self, new_embeddings): return self.transformer.set_input_embeddings(new_embeddings) @@ -277,6 +274,9 @@ class T5_VAE_Model(Transformer_VAE_Base_Model): """ config_class = T5_VAE_Config + def get_input_embeddings(self): + return self.transformer.shared + def _shift_input_right(self, input_ids): start_token_id = self.transformer.config.eos_token_id pad_token_id = self.config.transformer_decoder.pad_token_id @@ -441,6 +441,9 @@ class Funnel_VAE_Model(Funnel_VAE_Model_Base): """ config_class = Funnel_VAE_Config + def get_input_embeddings(self): + return self.transformer.embeddings.word_embeddings + def forward( self, input_ids=None, From a45a6eac2e22241d2c183e7386f0f23575ddf657 Mon Sep 17 00:00:00 2001 From: Fraser Greenlee Date: Thu, 17 Dec 2020 22:29:30 +0000 Subject: [PATCH 4/5] shuffle --- transformer_vae/model.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/transformer_vae/model.py b/transformer_vae/model.py index 98692aa..8ddbd46 100644 --- a/transformer_vae/model.py +++ b/transformer_vae/model.py @@ -380,6 +380,9 @@ def forward( class Funnel_VAE_Model_Base(Transformer_VAE_Base_Model): + def get_input_embeddings(self): + return self.transformer.funnel.embeddings.word_embeddings + def _get_encoder_outputs( self, input_ids=None, @@ -441,9 +444,6 @@ class Funnel_VAE_Model(Funnel_VAE_Model_Base): """ config_class = Funnel_VAE_Config - def get_input_embeddings(self): - return self.transformer.embeddings.word_embeddings - def forward( self, input_ids=None, From bdc3109ab861d732badaeba96b7248f2bec46697 Mon Sep 17 00:00:00 2001 From: Fraser Greenlee Date: Fri, 18 Dec 2020 08:30:43 +0000 Subject: [PATCH 5/5] x --- transformer_vae/model.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/transformer_vae/model.py b/transformer_vae/model.py index 8ddbd46..3e8873a 100644 --- a/transformer_vae/model.py +++ b/transformer_vae/model.py @@ -195,6 +195,9 @@ def __init__(self, config: Transformer_VAE_Config): self.config.use_reg_loss, ) + def get_input_embeddings(self): + raise NotImplementedError() + def resize_token_embeddings(self, *args, **kwargs): super().resize_token_embeddings(*args, **kwargs) self.transformer.resize_token_embeddings(*args, **kwargs)