Train (Masked Language Model) with tf-transformers in TPU

This tutorial contains complete code to train MLM model on C4 EN 10K dataset. In addition to training a model, you will learn how to preprocess text into an appropriate format.

In this notebook, you will:

  • Load the C4 (10k EN) dataset from HuggingFace

  • Load GPT2 style (configuration) Model using tf-transformers

  • Build train dataset (on the fly) feature preparation using tokenizer from tf-transformers.

  • Build a masked LM Model from GPT2 style configuration

  • Save your model

  • Use the base model for further tasks

If you’re new to working with the C4 dataset, please see C4 for more details.

!pip install tf-transformers

!pip install sentencepiece

!pip install tensorflow-text

!pip install transformers

!pip install wandb

!pip install datasets
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' # Supper TF warnings

import tensorflow as tf
import tensorflow_text as tf_text
import datasets
import wandb

print("Tensorflow version", tf.__version__)
print("Tensorflow text version", tf_text.__version__)
print("Devices", tf.config.list_physical_devices())

from tf_transformers.models import GPT2Model, MaskedLMModel, AlbertTokenizerTFText
from tf_transformers.core import Trainer
from tf_transformers.optimization import create_optimizer
from tf_transformers.text.lm_tasks import mlm_fn
from tf_transformers.losses.loss_wrapper import get_lm_loss
Tensorflow version 2.7.0
Tensorflow text version 2.7.3
Devices [PhysicalDevice(name='/physical_device:CPU:0', device_type='CPU')]

Trainer has to be initialized before everything only in TPU (sometimes).

trainer = Trainer(distribution_strategy='tpu', num_gpus=0, tpu_address='colab')
INFO:tensorflow:Deallocate tpu buffers before initializing tpu system.
INFO:tensorflow:Deallocate tpu buffers before initializing tpu system.
WARNING:tensorflow:TPU system grpc://10.91.104.90:8470 has already been initialized. Reinitializing the TPU can cause previously created variables on TPU to be lost.
WARNING:tensorflow:TPU system grpc://10.91.104.90:8470 has already been initialized. Reinitializing the TPU can cause previously created variables on TPU to be lost.
INFO:tensorflow:Initializing the TPU system: grpc://10.91.104.90:8470
INFO:tensorflow:Initializing the TPU system: grpc://10.91.104.90:8470
INFO:tensorflow:Finished initializing TPU system.
INFO:tensorflow:Finished initializing TPU system.
INFO:tensorflow:Found TPU system:
INFO:tensorflow:Found TPU system:
INFO:tensorflow:*** Num TPU Cores: 8
INFO:tensorflow:*** Num TPU Cores: 8
INFO:tensorflow:*** Num TPU Workers: 1
INFO:tensorflow:*** Num TPU Workers: 1
INFO:tensorflow:*** Num TPU Cores Per Worker: 8
INFO:tensorflow:*** Num TPU Cores Per Worker: 8
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:localhost/replica:0/task:0/device:CPU:0, CPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:localhost/replica:0/task:0/device:CPU:0, CPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:CPU:0, CPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:CPU:0, CPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:0, TPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:0, TPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:1, TPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:1, TPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:2, TPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:2, TPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:3, TPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:3, TPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:4, TPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:4, TPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:5, TPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:5, TPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:6, TPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:6, TPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:7, TPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:7, TPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU_SYSTEM:0, TPU_SYSTEM, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU_SYSTEM:0, TPU_SYSTEM, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:XLA_CPU:0, XLA_CPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:XLA_CPU:0, XLA_CPU, 0, 0)

Load Model, Optimizer , Trainer

Our Trainer expects model, optimizer and loss to be a function.

# Load Model
def get_model(model_name, vocab_size, is_training, use_dropout, num_hidden_layers):
  """Get Model"""

  def model_fn():
    config = GPT2Model.get_config(model_name)
    config['vocab_size'] = vocab_size
    model = GPT2Model.from_config(config, mask_mode='user_defined', num_hidden_layers=num_hidden_layers, return_layer=True)
    model = MaskedLMModel(
              model,
              use_extra_mlm_layer=False,
              hidden_size=config['embedding_size'],
              layer_norm_epsilon=config['layer_norm_epsilon'],
          )    
    return model.get_model()
  return model_fn

# Load Optimizer
def get_optimizer(learning_rate, examples, batch_size, epochs, use_constant_lr=False):
    """Get optimizer"""
    steps_per_epoch = int(examples / batch_size)
    num_train_steps = steps_per_epoch * epochs
    warmup_steps = int(0.1 * num_train_steps)

    def optimizer_fn():
        optimizer, learning_rate_fn = create_optimizer(learning_rate, num_train_steps, warmup_steps, use_constant_lr=use_constant_lr)
        return optimizer

    return optimizer_fn

# Load trainer
def get_trainer(distribution_strategy, num_gpus=0, tpu_address=None):
    """Get Trainer"""
    trainer = Trainer(distribution_strategy, num_gpus=num_gpus, tpu_address=tpu_address)
    return trainer

Prepare Data for Training

We will make use of Tensorflow Text based tokenizer to do on-the-fly preprocessing, without having any overhead of pre prepapre the data in the form of pickle, numpy or tfrecords.

# Load dataset
def load_dataset(dataset, tokenizer_layer, max_seq_len, max_predictions_per_seq, batch_size):
    """
    Args:
      dataset; HuggingFace dataset
      tokenizer_layer: tf-transformers tokenizer
      max_seq_len: int (maximum sequence length of text)
      batch_size: int (batch_size)
      max_predictions_per_seq: int (Maximum number of words to mask)
    """
    tfds_dict = dataset.to_dict()
    tfdataset = tf.data.Dataset.from_tensor_slices(tfds_dict)

    # MLM function
    masked_lm_map_fn = mlm_fn(tokenizer_layer, max_seq_len, max_predictions_per_seq)

    # MLM
    tfdataset = tfdataset.map(masked_lm_map_fn, num_parallel_calls=tf.data.AUTOTUNE)
    # Batch
    tfdataset = tfdataset.batch(batch_size, drop_remainder=True).shuffle(50)

    # Auto SHARD
    options = tf.data.Options()
    options.experimental_distribute.auto_shard_policy = tf.data.experimental.AutoShardPolicy.AUTO
    tfdataset = tfdataset.with_options(options)
    
    return tfdataset

Prepare Dataset

  1. Set necessay hyperparameters.

  2. Prepare train dataset

  3. Load model, optimizer, loss and trainer.

# Data configs
dataset_name = 'stas/c4-en-10k'
model_name  = 'gpt2'
max_seq_len = 128
max_predictions_per_seq = 20
batch_size  = 128

# Model configs
learning_rate = 0.0005
epochs = 3
model_checkpoint_dir = 'gs://legacyai-bucket/sample_mlm_model' # If using TPU, provide GCP bucket for 
                                                        # storing model checkpoints

# Load HF dataset
dataset = datasets.load_dataset(dataset_name)
# Load tokenizer from tf-transformers
tokenizer_layer = AlbertTokenizerTFText.from_pretrained("albert-base-v2")
# Train Dataset
train_dataset = load_dataset(dataset['train'], tokenizer_layer, max_seq_len, max_predictions_per_seq, batch_size)

# Total train examples
total_train_examples = dataset['train'].num_rows
steps_per_epoch = 5000
num_hidden_layers = 8

# model
vocab_size = tokenizer_layer.vocab_size.numpy()
model_fn =  get_model(model_name, vocab_size, is_training=True, use_dropout=True, num_hidden_layers=num_hidden_layers)
# optimizer
optimizer_fn = get_optimizer(learning_rate, total_train_examples, batch_size, epochs, use_constant_lr=True)
# trainer
# trainer = get_trainer(distribution_strategy='tpu', num_gpus=0, tpu_address='colab')
# loss
loss_fn = get_lm_loss(loss_type=None)
WARNING:datasets.builder:Reusing dataset c4_en10k (/root/.cache/huggingface/datasets/stas___c4_en10k/plain_text/1.0.0/edbf1ff8b8ee35a9751a7752b5e93a4873cc7905ffae010ad334a2c96f81e1cd)
INFO:absl:Loading albert-base-v2 tokenizer to /tmp/tftransformers_tokenizer_cache/albert-base-v2/spiece.model

Wandb configuration

project = "TUTORIALS"
display_name = "mlm_tpu"
wandb.init(project=project, name=display_name)
wandb: You can find your API key in your browser here: https://wandb.ai/authorize
wandb: Paste an API key from your profile and hit enter, or press ctrl+c to quit: ··········
wandb: Appending key for api.wandb.ai to your netrc file: /root/.netrc
Syncing run mlm_tpu to Weights & Biases (docs).

Train :-)

history = trainer.run(
    model_fn=model_fn,
    optimizer_fn=optimizer_fn,
    train_dataset=train_dataset,
    train_loss_fn=loss_fn,
    epochs=epochs,
    steps_per_epoch=steps_per_epoch,
    model_checkpoint_dir=model_checkpoint_dir,
    batch_size=batch_size,
    wandb=wandb
)
INFO:absl:Make sure `steps_per_epoch` should be less than or equal to number of batches in dataset.
INFO:absl:Policy: ----> float32
INFO:absl:Strategy: ---> <tensorflow.python.distribute.tpu_strategy.TPUStrategyV2 object at 0x7fab94155e10>
INFO:absl:Num TPU Devices: ---> 8
INFO:absl:Create model from config
INFO:absl:Using Constant learning rate
INFO:absl:Using Adamw optimizer
INFO:absl:No ❌❌ checkpoint found in gs://legacyai-bucket/sample_mlm_model
Train: Epoch 1/4 --- Step 100/5000 --- total examples 0:   0%|          | 0/50 [00:00<?, ?batch /s]/usr/local/lib/python3.7/dist-packages/keras/engine/functional.py:559: UserWarning: Input dict contained keys ['input_type_ids'] which did not match any model input. They will be ignored by the model.
  inputs = self._flatten_to_reference_inputs(inputs)
Train: Epoch 1/4 --- Step 5000/5000 --- total examples 627200: 100%|██████████| 50/50 [07:02<00:00,  8.44s/batch , learning_rate=0.0005, loss=3.11]
INFO:absl:Model saved at epoch 1 at gs://legacyai-bucket/sample_mlm_model/ckpt-1

Train: Epoch 2/4 --- Step 5000/5000 --- total examples 1267200: 100%|██████████| 50/50 [06:28<00:00,  7.77s/batch , learning_rate=0.0005, loss=1.33]
INFO:absl:Model saved at epoch 2 at gs://legacyai-bucket/sample_mlm_model/ckpt-2

Train: Epoch 3/4 --- Step 5000/5000 --- total examples 1907200: 100%|██████████| 50/50 [06:29<00:00,  7.78s/batch , learning_rate=0.0005, loss=0.785]
INFO:absl:Model saved at epoch 3 at gs://legacyai-bucket/sample_mlm_model/ckpt-3

Load the Model from checkpoint

model_fn =  get_model(model_name, vocab_size, is_training=False, use_dropout=False, num_hidden_layers=num_hidden_layers)

model = model_fn()
model.load_checkpoint(model_checkpoint_dir)
INFO:absl:Create model from config
INFO:absl:Successful ✅✅: Model checkpoints matched and loaded from gs://legacyai-bucket/sample_mlm_model/ckpt-3
<tensorflow.python.training.tracking.util.Checkpoint at 0x7fab8fd104d0>

Test Model performance

  1. We can assess model performance by checking how it predicts masked word on sample sentences.

  2. As we see the following result, its clear that model starts learning.

from transformers import AlbertTokenizer
tokenizer_hf = AlbertTokenizer.from_pretrained("albert-base-v2")

validation_sentences = [
    'Read the rest of this [MASK] to understand things in more detail.',
    'I want to buy the [MASK] because it is so cheap.',
    'The [MASK] was amazing.',
    'Sachin Tendulkar is one of the [MASK] palyers in the world.',
    '[MASK] is the capital of France.',
    'Machine Learning requires [MASK]',
    'He is working as a [MASK]',
    'She is working as a [MASK]',
]
inputs = tokenizer_hf(validation_sentences, padding=True, return_tensors="tf")

inputs_tf = {}
inputs_tf["input_ids"] = inputs["input_ids"]
inputs_tf["input_mask"] = inputs["attention_mask"]
seq_length = tf.shape(inputs_tf['input_ids'])[1]
inputs_tf['masked_lm_positions'] = tf.zeros_like(inputs_tf["input_ids"]) + tf.range(seq_length)


top_k = 10 # topk similar words
outputs_tf = model(inputs_tf)
# Get masked positions from each sentence
masked_positions = tf.argmax(tf.equal(inputs_tf["input_ids"], tokenizer_hf.mask_token_id), axis=1)
for i, logits in enumerate(outputs_tf['token_logits']):
    mask_token_logits = logits[masked_positions[i]]
    # 0 for probs and 1 for indexes from tf.nn.top_k
    top_words = tokenizer_hf.decode(tf.nn.top_k(mask_token_logits, k=top_k)[1].numpy())
    print("Input ----> {}".format(validation_sentences[i]))
    print("Predicted words ----> {}".format(top_words.split()))
    print()
Input ----> Read the rest of this [MASK] to understand things in more detail.
Predicted words ----> ['page', 'continent', 'means', 'window', 'website', 'post', 'tool', 'is', 'book', 'world']

Input ----> I want to buy the [MASK] because it is so cheap.
Predicted words ----> ['door', 'quote', 'electronics', 'house', 'review', 'website', 'graphics', 'property', 'doors', 'item']

Input ----> The [MASK] was amazing.
Predicted words ----> ['boys', 'turkey', 'epilogue', 'idea', 'project', 'answer', 'food', 'website', 'show', 'weather']

Input ----> Sachin Tendulkar is one of the [MASK] palyers in the world.
Predicted words ----> ['busiest', 'leading', 'english', 'latest', 'northern', 'coordinates', 'largest', 'international', 'state', 'registered']

Input ----> [MASK] is the capital of France.
Predicted words ----> ['this', 'there', 'india', 'it', 'what', 'here', 'below', 'he', 'france', 'that']

Input ----> Machine Learning requires [MASK]
Predicted words ----> ['.', 'the', 'that', 'for', 'a,', 'an', 'you', 'and', 'your']

Input ----> He is working as a [MASK]
Predicted words ----> ['field', 'real', 'great', 'chance', 'regular', 'business', 'team', 'facebook', 'freelance', 'sport']

Input ----> She is working as a [MASK]
Predicted words ----> ['field', 'real', 'path', 'facebook', 'chance', 'trip', 'strategic', 'great', 'regular', 'lot']