Skip to content
This repository has been archived by the owner on May 16, 2022. It is now read-only.

Commit

Permalink
Merge pull request #1 from Fraser-Greenlee/test_change
Browse files Browse the repository at this point in the history
Test Change
  • Loading branch information
Fraser-Greenlee committed Dec 18, 2020
2 parents d686412 + bdc3109 commit 356b01c
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 5 deletions.
1 change: 1 addition & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ activate:
# source t5_vae_env/bin/activate

test:
black --check -l 120 -t py37 .
python -m pytest -s -v ./tests/

test-one-case:
Expand Down
12 changes: 9 additions & 3 deletions transformer_vae/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,13 +195,13 @@ def __init__(self, config: Transformer_VAE_Config):
self.config.use_reg_loss,
)

def get_input_embeddings(self):
raise NotImplementedError()

def resize_token_embeddings(self, *args, **kwargs):
super().resize_token_embeddings(*args, **kwargs)
self.transformer.resize_token_embeddings(*args, **kwargs)

def get_input_embeddings(self):
return self.transformer.shared

def set_input_embeddings(self, new_embeddings):
return self.transformer.set_input_embeddings(new_embeddings)

Expand Down Expand Up @@ -277,6 +277,9 @@ class T5_VAE_Model(Transformer_VAE_Base_Model):
"""
config_class = T5_VAE_Config

def get_input_embeddings(self):
return self.transformer.shared

def _shift_input_right(self, input_ids):
start_token_id = self.transformer.config.eos_token_id
pad_token_id = self.config.transformer_decoder.pad_token_id
Expand Down Expand Up @@ -380,6 +383,9 @@ def forward(


class Funnel_VAE_Model_Base(Transformer_VAE_Base_Model):
def get_input_embeddings(self):
return self.transformer.funnel.embeddings.word_embeddings

def _get_encoder_outputs(
self,
input_ids=None,
Expand Down
2 changes: 0 additions & 2 deletions transformer_vae/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,8 @@
from torch.utils.data.dataset import Dataset
from torch.utils.data.sampler import RandomSampler
from torch.utils.data.dataloader import DataLoader
import inspect
import time

import datasets
from transformers import trainer as trainer_script
from transformers.integrations import (
WandbCallback,
Expand Down

0 comments on commit 356b01c

Please sign in to comment.