Text Generation using T5

  • This tutorial is intended to provide, a familiarity in how to use T5 for text-generation tasks.

  • No training is involved in this.

!pip install tf-transformers

!pip install transformers
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' # Supper TF warnings

import tensorflow as tf
print("Tensorflow version", tf.__version__)
print("Devices", tf.config.list_physical_devices())

from tf_transformers.models import T5Model, T5TokenizerTFText
from tf_transformers.core import TextGenerationChainer
from tf_transformers.text import TextDecoder, TextDecoderSerializable
Tensorflow version 2.7.0
Devices [PhysicalDevice(name='/physical_device:CPU:0', device_type='CPU'), PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]

Load T5 Model

    1. Note use_auto_regressive=True, argument. This is required for any models to enable text-generation.

model_name = 't5-small'

tokenizer = T5TokenizerTFText.from_pretrained(model_name, dynamic_padding=True, truncate=True, max_length=256)
model = T5Model.from_pretrained(model_name, use_auto_regressive=True)
INFO:absl:Saving t5-small tokenizer to /var/folders/vq/4fxns8l55gq8_msgygbyb51h0000gn/T/tftransformers_tokenizer_cache/t5-small
INFO:absl:Loading t5-small tokenizer to /var/folders/vq/4fxns8l55gq8_msgygbyb51h0000gn/T/tftransformers_tokenizer_cache/t5-small/spiece.model
Metal device set to: Apple M1
WARNING:tensorflow:Inconsistent references when loading the checkpoint into this object graph. Either the Trackable object references in the Python program have changed in an incompatible way, or the checkpoint was generated in an incompatible program.

Two checkpoint references resolved to different objects (<tf_transformers.models.encoder_decoder.encoder_decoder.EncoderDecoder object at 0x17b752e20> and <keras.engine.input_layer.InputLayer object at 0x17b752340>).
WARNING:tensorflow:Inconsistent references when loading the checkpoint into this object graph. Either the Trackable object references in the Python program have changed in an incompatible way, or the checkpoint was generated in an incompatible program.

Two checkpoint references resolved to different objects (<tf_transformers.models.encoder_decoder.encoder_decoder.EncoderDecoder object at 0x17b752e20> and <keras.engine.input_layer.InputLayer object at 0x17b752340>).
INFO:absl:Successful ✅✅: Model checkpoints matched and loaded from /Users/sarathrnair/.cache/huggingface/hub/tftransformers__t5-small.main.699b12fe9601feda4892ca82c07e800f3c1da440/ckpt-1
INFO:absl:Successful ✅: Loaded model from tftransformers/t5-small

Serialize and load

  • The most recommended way of using a Tensorflow model is to load it after serializing.

  • The speedup, especially for text generation is up to 50x times.

# Save as serialized
model_dir = 'MODELS/t5'
model.save_transformers_serialized(model_dir)

# Load
loaded = tf.saved_model.load(model_dir)
Metal device set to: Apple M1

Text-Generation

  • . We can pass tf.keras.Model also to TextDecoder, but SavedModel this is recommended

decoder = TextDecoder(model=loaded)

Greedy Decoding

texts = ['translate English to German: The house is wonderful and we wish to be here :)', 
         'translate English to French: She is beautiful']

inputs = tokenizer({'text': tf.constant(texts)})

predictions = decoder.decode(inputs, 
                             mode='greedy', 
                             max_iterations=64, 
                             eos_id=tokenizer.eos_token_id)
print(tokenizer._tokenizer.detokenize(tf.squeeze(predictions['predicted_ids'], axis=1)))
tf.Tensor(
[b'Das Haus ist wunderbar und wir m\xc3\xb6chten hier sein :)'
 b'Elle est belle::::'], shape=(2,), dtype=string)

Beam Decoding

predictions = decoder.decode(inputs, 
                             mode='beam',
                             num_beams=3,
                             max_iterations=64,
                             eos_id=tokenizer.eos_token_id)
print(tokenizer._tokenizer.detokenize(predictions['predicted_ids'][:, 0, :]))
tf.Tensor(
[b'Das Haus ist wunderbar und wir m\xc3\xb6chten hier sein :)'
 b'Elle est belle::::'], shape=(2,), dtype=string)

Top K Nucleus Sampling

predictions = decoder.decode(inputs, 
                             mode='top_k_top_p',
                             top_k=50,
                             top_p=0.7,
                             num_return_sequences=3,
                             max_iterations=64,
                             eos_id=tokenizer.eos_token_id)
print(tokenizer._tokenizer.detokenize(predictions['predicted_ids'][:, 0, :]))
tf.Tensor(
[b'Das Haus ist wunderbar und wir m\xc3\xb6chten hier sein :)'
 b'Elle est belle::::'], shape=(2,), dtype=string)

Advanced Serialization (include preprocessing + Decoding Together)

  • What if we can bundle all this into a single model and serialize it ?

model_dir = 'MODELS/t5_serialized/'
# Load Auto Regressive Version
model = T5Model.from_pretrained(model_name=model_name, use_auto_regressive=True)
# Assume we are doing beam decoding
text_generation_kwargs = {'mode': 'beam', 
                         'num_beams': 3,
                          'max_iterations': 32,
                          'eos_id': tokenizer.eos_token_id
                         }
# TextDecoderSerializable - makes decoding serializable
decoder = TextDecoderSerializable(model=model, **text_generation_kwargs)
# TextGenerationChainer - joins tokenizer + TextDecoderSerializable
model_fully_serialized = TextGenerationChainer(tokenizer.get_model(), decoder)
model_fully_serialized = model_fully_serialized.get_model()
# Save as saved_model
model_fully_serialized.save_serialized(model_dir, overwrite=True)
WARNING:tensorflow:Inconsistent references when loading the checkpoint into this object graph. Either the Trackable object references in the Python program have changed in an incompatible way, or the checkpoint was generated in an incompatible program.

Two checkpoint references resolved to different objects (<tf_transformers.models.encoder_decoder.encoder_decoder.EncoderDecoder object at 0x38790b400> and <keras.engine.input_layer.InputLayer object at 0x3827cc880>).
WARNING:tensorflow:Inconsistent references when loading the checkpoint into this object graph. Either the Trackable object references in the Python program have changed in an incompatible way, or the checkpoint was generated in an incompatible program.

Two checkpoint references resolved to different objects (<tf_transformers.models.encoder_decoder.encoder_decoder.EncoderDecoder object at 0x38790b400> and <keras.engine.input_layer.InputLayer object at 0x3827cc880>).
INFO:absl:Successful ✅✅: Model checkpoints matched and loaded from /Users/sarathrnair/.cache/huggingface/hub/tftransformers__t5-small.main.699b12fe9601feda4892ca82c07e800f3c1da440/ckpt-1
INFO:absl:Successful ✅: Loaded model from tftransformers/t5-small
Using default `decoder_start_token_id` 0 from the model
WARNING:absl:Found untraced functions such as tf_transformers/t5_encoder_layer_call_fn, tf_transformers/t5_encoder_layer_call_and_return_conditional_losses, tf_transformers/t5_decoder_layer_call_fn, tf_transformers/t5_decoder_layer_call_and_return_conditional_losses, word_embeddings_layer_call_fn while saving (showing 5 of 1140). These functions will not be directly callable after loading.
INFO:tensorflow:Assets written to: MODELS/t5_serialized/assets
INFO:tensorflow:Assets written to: MODELS/t5_serialized/assets

Load Advanced Model and Generate text

  • How nice is it? All done by model, no overhead of anything (tokenization, decoding, generating)

    1. TextDecoderSerializable - very advances serializable decoder written in pure tensorflow ops

    1. TextGenerationChainer - very simple tf.keras.layers.Layer wrapper.

loaded = tf.saved_model.load(model_dir)
model = loaded.signatures['serving_default']
texts = ['translate English to German: The house is wonderful and we wish to be here :)', 
         'translate English to French: She is beautiful']

predictions = model(**{'text': tf.constant(texts)})
print(predictions['decoded_text'])
tf.Tensor(
[b'Das Haus ist wunderbar und wir m\xc3\xb6chten hier sein :)'
 b'Elle est belle::::'], shape=(2,), dtype=string)