Skip to content

Latest commit

 

History

History

gpt2

Folders and files

NameName
Last commit message
Last commit date

parent directory

..
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Better Language Models and Their Implications (GPT2)

Paper: https://openai.com/blog/better-language-models/
Repository: https://github.com/huggingface/transformers/tree/master/src/transformers/models/gpt2

Table of Contents

1. Models

Model Parameters Size URL
gpt2 ~ 120 Million ~ 500 MB https://huggingface.co/gpt2
gpt2-medium ~ 350 Million ~ 1.5 GB https://huggingface.co/gpt2-medium
gpt2-large ~ 800 Million ~ 3 GB https://huggingface.co/gpt2-large
gpt2-xl ~ 1.5 Billion ~ 6 GB https://huggingface.co/gpt2-xl

2. Basic Usage

For more usage examples check out this Colab.

This is very simple greedy text generation. There are more sophisticated methods out there.

import jax
import jax.numpy as jnp
import flaxmodels as fm

key = jax.random.PRNGKey(0)

# Initialize tokenizer
tokenizer = fm.gpt2.get_tokenizer()

# Encode start sequence
generated = tokenizer.encode('The Manhattan bridge')

context = jnp.array([generated])
past = None

# Initialize model
# Models to choose from ['gpt2', 'gpt2-medium', 'gpt2-large', 'gpt2-xl']
model = fm.gpt2.GPT2LMHeadModel(pretrained='gpt2')
params = model.init(key, input_ids=context, past_key_values=past)

for i in range(20):
    # Predict next token in sequence
    output = model.apply(params, input_ids=context, past_key_values=past, use_cache=True)
    token = jnp.argmax(output['logits'][..., -1, :])
    context = jnp.expand_dims(token, axis=0)
    # Add token to sequence
    generated += [token]
    # Update past keys and values
    past = output['past_key_values']

# Decode sequence of tokens
sequence = tokenizer.decode(generated)
print(sequence)

3. Documentation

The documentation can be found here.

4. Acknowledgments

The tokenizer is taken from Huggingface.

5. License

Apache-2.0 License