Paper: https://openai.com/blog/better-language-models/
Repository: https://github.com/huggingface/transformers/tree/master/src/transformers/models/gpt2
Model | Parameters | Size | URL |
---|---|---|---|
gpt2 | ~ 120 Million | ~ 500 MB | https://huggingface.co/gpt2 |
gpt2-medium | ~ 350 Million | ~ 1.5 GB | https://huggingface.co/gpt2-medium |
gpt2-large | ~ 800 Million | ~ 3 GB | https://huggingface.co/gpt2-large |
gpt2-xl | ~ 1.5 Billion | ~ 6 GB | https://huggingface.co/gpt2-xl |
For more usage examples check out this Colab.
This is very simple greedy text generation. There are more sophisticated methods out there.
import jax
import jax.numpy as jnp
import flaxmodels as fm
key = jax.random.PRNGKey(0)
# Initialize tokenizer
tokenizer = fm.gpt2.get_tokenizer()
# Encode start sequence
generated = tokenizer.encode('The Manhattan bridge')
context = jnp.array([generated])
past = None
# Initialize model
# Models to choose from ['gpt2', 'gpt2-medium', 'gpt2-large', 'gpt2-xl']
model = fm.gpt2.GPT2LMHeadModel(pretrained='gpt2')
params = model.init(key, input_ids=context, past_key_values=past)
for i in range(20):
# Predict next token in sequence
output = model.apply(params, input_ids=context, past_key_values=past, use_cache=True)
token = jnp.argmax(output['logits'][..., -1, :])
context = jnp.expand_dims(token, axis=0)
# Add token to sequence
generated += [token]
# Update past keys and values
past = output['past_key_values']
# Decode sequence of tokens
sequence = tokenizer.decode(generated)
print(sequence)
The documentation can be found here.
The tokenizer is taken from Huggingface.