Text Generation using GPT2

  • This tutorial is intended to provide, a familiarity in how to use GPT2 for text-generation tasks.

  • No training is involved in this.

!pip install tf-transformers

!pip install transformers
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' # Supper TF warnings

import tensorflow as tf
print("Tensorflow version", tf.__version__)
print("Devices", tf.config.list_physical_devices())

from tf_transformers.models import GPT2Model
from tf_transformers.text import TextDecoder
from transformers import GPT2Tokenizer
Tensorflow version 2.7.0
Devices [PhysicalDevice(name='/physical_device:CPU:0', device_type='CPU'), PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]

Load GPT2 Model

    1. Note use_auto_regressive=True, argument. This is required for any models to enable text-generation.

model_name = 'gpt2'

tokenizer = GPT2Tokenizer.from_pretrained(model_name)
model = GPT2Model.from_pretrained(model_name, use_auto_regressive=True)
Metal device set to: Apple M1
WARNING:tensorflow:Inconsistent references when loading the checkpoint into this object graph. Either the Trackable object references in the Python program have changed in an incompatible way, or the checkpoint was generated in an incompatible program.

Two checkpoint references resolved to different objects (<tf_transformers.models.gpt2.gpt2.GPT2Encoder object at 0x2908a54c0> and <keras.engine.input_layer.InputLayer object at 0x2979a3a30>).
INFO:absl:Successful ✅✅: Model checkpoints matched and loaded from /Users/sarathrnair/.cache/huggingface/hub/tftransformers__gpt2.main.8843a828e80c53bb121d7e395d07e3821ba88ea5/ckpt-1
INFO:absl:Successful ✅: Loaded model from tftransformers/gpt2

Serialize and load

  • The most recommended way of using a Tensorflow model is to load it after serializing.

  • The speedup, especially for text generation is up to 50x times.

# Save as serialized
model_dir = 'MODELS/gpt2'
model.save_transformers_serialized(model_dir)

# Load
loaded = tf.saved_model.load(model_dir)
model = loaded.signatures['serving_default']
WARNING:absl:Found untraced functions such as word_embeddings_layer_call_fn, word_embeddings_layer_call_and_return_conditional_losses, positional_embeddings_layer_call_fn, positional_embeddings_layer_call_and_return_conditional_losses, dropout_layer_call_fn while saving (showing 5 of 740). These functions will not be directly callable after loading.
INFO:tensorflow:Assets written to: MODELS/gpt2/assets
INFO:tensorflow:Assets written to: MODELS/gpt2/assets

Text-Generation

  • . We can pass tf.keras.Model also to TextDecoder, but this is recommended

  • . GPT2 like (Encoder) only models require -1 as padding token.

decoder = TextDecoder(model=loaded)

Greedy Decoding

texts = ['I would like to walk with my cat', 
         'Music has been very soothing']

input_ids = tf.ragged.constant(tokenizer(texts)['input_ids']).to_tensor(-1) # Padding GPT2 style models needs -1

inputs = {'input_ids': input_ids}
predictions = decoder.decode(inputs, 
                             mode='greedy', 
                             max_iterations=32)
print(tokenizer.batch_decode(tf.squeeze(predictions['predicted_ids'], axis=1)))
[", but I don't want to be a burden to my family. I want to be a part of the community. I want to be a part of the", " to me. I've been able to get through a lot of things, but I've also been able to get through a lot of things that I've never"]

Beam Decoding

inputs = {'input_ids': input_ids}
predictions = decoder.decode(inputs, 
                             mode='beam',
                             num_beams=3,
                             max_iterations=32)
print(tokenizer.batch_decode(predictions['predicted_ids'][:, 0, :]))
['. I would like to walk with my cat. I would like to walk with my cat. I would like to walk with my cat. I would like to', ' to me, and I\'m glad to be able to share it with you."\n\n"I\'m glad to be able to share it with you."\n']

Top K Nucleus Sampling

inputs = {'input_ids': input_ids}
predictions = decoder.decode(inputs, 
                             mode='top_k_top_p',
                             top_k=50,
                             top_p=0.7,
                             num_return_sequences=3,
                             max_iterations=32)
print(tokenizer.batch_decode(predictions['predicted_ids'][:, 0, :]))
[", but I don't want to be a burden to my family. I want to be a part of the community. I want to be a part of the", " to me. I've been able to get through a lot of things, but I've also been able to get through a lot of things that I've never"]