GPT2 for QA using Squad V1 ( Causal LM )¶
This tutorial contains complete code to fine-tune GPT2 to finetune for Question Answering using Squad V1 data. In addition to training a model, you will learn how to preprocess text into an appropriate format.
In this notebook, you will:
Load the Squad v1 dataset from HuggingFace
Load GPT2 Model using tf-transformers
Build model using
causal
(default) andprefix
masking.Build train and validation dataset feature preparation using tokenizer from transformers.
Train your own model, fine-tuning GPT2
Save your model and use it to for QA
Use the end-to-end (inference) in production setup
If you’re new to working with the Quora dataset, please see SQUAD for more details.
!pip install tf-transformers
!pip install transformers
!pip install wandb
!pip install datasets
import tensorflow as tf
import random
import collections
import wandb
import tempfile
import tqdm
import json
import os
import numpy as np
print("Tensorflow version", tf.__version__)
print("Devices", tf.config.list_physical_devices())
from tf_transformers.models import GPT2Model
from tf_transformers.core import Trainer
from tf_transformers.optimization import create_optimizer
from tf_transformers.data import TFWriter, TFReader
from tf_transformers.losses.loss_wrapper import get_lm_loss
from tf_transformers.text import TextDecoder
from datasets import load_dataset
from transformers import GPT2Tokenizer
Tensorflow version 2.7.0
Devices [PhysicalDevice(name='/physical_device:CPU:0', device_type='CPU'), PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU'), PhysicalDevice(name='/physical_device:GPU:1', device_type='GPU')]
Load Data, Tokenizer¶
model_name = 'gpt2'
# Load Dataset
dataset = load_dataset("squad")
tokenizer = GPT2Tokenizer.from_pretrained(model_name)
# Define length for examples
max_sequence_length = 384
max_question_length = 64
max_answer_length = 40
batch_size = 32
Prepare Training TFRecords and Validation TFRecords using Squad ( causal and prefix )¶
We combine
(question + context + answer)
For
mask_mode=causal
, we don’t need any mask. Formask_mode=prefix
, we needinput_mask
.
For
prefix
, we will mask onlyquestion + context
, asanswer
is supposed to be generated, we shouldn’t mask it, means its causal.
Note how
labels_mask
is prepared and how it is different frominput_mask
.
def parse_train(dataset, tokenizer, max_question_length, max_passage_length, max_answer_length, key):
"""Function to parse examples which are is_duplicate=1
Args:
dataset (:obj:`dataet`): HF dataset
tokenizer (:obj:`tokenizer`): HF Tokenizer
max_question_length (:obj:`int`): Question Length
max_passage_length (:obj:`int`): Passage Length
max_answer_length (:obj:`int`): Answer Length
key (:obj:`str`): Key of dataset (`train`, `validation` etc)
"""
result = {}
for f in dataset[key]:
question_ids = tokenizer('Question: ' + f['question'], max_length=max_question_length, truncation=True)['input_ids']
context_ids = tokenizer('Context: ' + f['context'], max_length=max_passage_length, truncation=True)['input_ids']
answer_ids = tokenizer('answer: ' + f['answers']['text'][0], max_length=max_answer_length, truncation=True)['input_ids']
# add EOS
context_ids = context_ids + [tokenizer.bos_token_id]
answer_ids = answer_ids + [tokenizer.bos_token_id] # EOS token
# input_ids
input_ids = (question_ids + context_ids + answer_ids)
# input_mask
input_mask = ([1] * len(question_ids)) + ([1] * len(context_ids)) + ([0] * len(answer_ids))
# labels mask is opposite to input_mask, as we need to find loss only on answerids
labels_mask = ([0] * len(question_ids)) + ([0] * len(context_ids)) + ([1] * len(answer_ids))
result = {}
# Except last word
result['input_ids'] = input_ids[:-1]
result['input_mask'] = input_mask[:-1]
# Shift one word next
result['labels'] = input_ids[1:]
result['labels_mask'] = labels_mask[1:]
yield result
# Write using TF Writer
schema = {
"input_ids": ("var_len", "int"),
"input_mask": ("var_len", "int"),
"labels": ("var_len", "int"),
"labels_mask": ("var_len", "int")
}
tfrecord_train_dir = tempfile.mkdtemp()
tfrecord_filename = 'squad'
tfwriter = TFWriter(schema=schema,
file_name=tfrecord_filename,
model_dir=tfrecord_train_dir,
tag='train',
overwrite=True
)
# Train dataset
train_parser_fn = parse_train(dataset, tokenizer, max_question_length, max_sequence_length, max_answer_length, key='train')
tfwriter.process(parse_fn=train_parser_fn)
INFO:absl:Total individual observations/examples written is 87599 in 173.0701344013214 seconds
INFO:absl:All writer objects closed
Prepare Validation TFRecords¶
def parse_dev(dataset, tokenizer, max_question_length, max_passage_length, max_answer_length, key):
"""Function to parse examples
Args:
dataset (:obj:`dataet`): HF dataset
tokenizer (:obj:`tokenizer`): HF Tokenizer
max_question_length (:obj:`int`): Question Length
max_passage_length (:obj:`int`): Passage Length
max_answer_length (:obj:`int`): Answer Length
key (:obj:`str`): Key of dataset (`train`, `validation` etc)
"""
result = {}
for f in dataset[key]:
question_ids = tokenizer('Question: ' + f['question'], max_length=max_question_length, truncation=True)['input_ids']
context_ids = tokenizer('Context: ' + f['context'], max_length=max_passage_length, truncation=True)['input_ids']
answer_ids = tokenizer('answer: ' + f['answers']['text'][0], max_length=max_answer_length, truncation=True)['input_ids']
# add EOS
context_ids = context_ids + [tokenizer.bos_token_id]
# input_ids
input_ids = (question_ids + context_ids)
# input_mask
input_mask = ([1] * len(question_ids)) + ([1] * len(context_ids))
result['input_ids'] = input_ids
result['input_mask'] = input_mask
result['original_answer'] = f['answers']['text'][0]
yield result
tfrecord_validation_dir = tempfile.mkdtemp()
tfrecord_validation_filename = 'squad'
validation_schema = {
"input_ids": ("var_len", "int"),
"input_mask": ("var_len", "int"),
"original_answer": ("var_len", "bytes")
}
tfwriter = TFWriter(schema=validation_schema,
file_name=tfrecord_validation_filename,
model_dir=tfrecord_validation_dir,
tag='eval',
overwrite=True
)
# Validation dataset
validation_parser_fn = parse_dev(dataset, tokenizer, max_question_length, max_sequence_length, max_answer_length, key='validation')
tfwriter.process(parse_fn=validation_parser_fn)
INFO:absl:Total individual observations/examples written is 10570 in 19.614187002182007 seconds
INFO:absl:All writer objects closed
Wandb Configuration¶
project = "TUTORIALS"
display_name = 'causal_mask'
wandb.init(project=project, name=display_name)
Load Model, Optimizer , Trainer¶
Our Trainer expects model
, optimizer
and loss
to be a function.
# Load Model
def get_model(model_name, is_training, use_dropout, mask_mode='causal'):
"""Get Model"""
def model_fn():
model = GPT2Model.from_pretrained(model_name, mask_mode=mask_mode) #causal by default
return model
return model_fn
# Load Optimizer
def get_optimizer(learning_rate, examples, batch_size, epochs, use_constant_lr=False):
"""Get optimizer"""
steps_per_epoch = int(examples / batch_size)
num_train_steps = steps_per_epoch * epochs
warmup_steps = int(0.1 * num_train_steps)
def optimizer_fn():
optimizer, learning_rate_fn = create_optimizer(learning_rate, num_train_steps, warmup_steps, use_constant_lr=use_constant_lr)
return optimizer
return optimizer_fn
# Load trainer
def get_trainer(distribution_strategy, num_gpus=0, tpu_address=None):
"""Get Trainer"""
trainer = Trainer(distribution_strategy, num_gpus=num_gpus, tpu_address=tpu_address)
return trainer
# Load loss fn
def get_loss():
loss_fn = get_lm_loss(label_column='labels',
label_weights_column='labels_mask')
return loss_fn
Set Hyperparameters and Configs¶
Set necessay hyperparameters.
Prepare
train dataset
,validation dataset
.Load
model
,optimizer
,loss
andtrainer
.
# Model configs
learning_rate = 2e-5
epochs = 5
model_checkpoint_dir = 'MODELS/gpt2_squad_causal'
# Train dataset
schema = json.load(open("{}/schema.json".format(tfrecord_train_dir)))
total_train_examples = json.load(open("{}/stats.json".format(tfrecord_train_dir)))['total_records']
all_files = tf.io.gfile.glob("{}/*.tfrecord".format(tfrecord_train_dir))
tf_reader = TFReader(schema=schema,
tfrecord_files=all_files)
x_keys = ['input_ids']
y_keys = ['labels', 'labels_mask']
train_dataset = tf_reader.read_record(auto_batch=True,
batch_size=batch_size,
x_keys = x_keys,
y_keys = y_keys,
shuffle=True
)
# Total train examples
steps_per_epoch = total_train_examples // batch_size
# model
model_fn = get_model(model_name, is_training=True, use_dropout=True, mask_mode='causal')
# optimizer
optimizer_fn = get_optimizer(learning_rate, total_train_examples, batch_size, epochs)
# loss
loss_fn = get_loss()
# trainer (multi gpu strategy)
trainer = get_trainer(distribution_strategy='mirrored', num_gpus=2)
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1')
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1')
Train GPT2 Causal :-)¶
Loss is coming down in epoch 1 itself.
Evaluation results clearly indicated how well model has learned.
history = trainer.run(
model_fn=model_fn,
optimizer_fn=optimizer_fn,
train_dataset=train_dataset,
train_loss_fn=loss_fn,
epochs=epochs,
steps_per_epoch=steps_per_epoch,
model_checkpoint_dir=model_checkpoint_dir,
batch_size=batch_size,
steps_per_call=1,
wandb=wandb
)
INFO:absl:Make sure `steps_per_epoch` should be less than or equal to number of batches in dataset.
INFO:absl:Policy: ----> float32
INFO:absl:Strategy: ---> <tensorflow.python.distribute.mirrored_strategy.MirroredStrategy object at 0x7f86c42ad4d0>
INFO:absl:Num GPU Devices: ---> 2
INFO:absl:Successful ✅✅: Model checkpoints matched and loaded from /home/jovyan/.cache/huggingface/hub/tftransformers__gpt2.main.8843a828e80c53bb121d7e395d07e3821ba88ea5/ckpt-1
INFO:absl:Successful ✅: Loaded model from tftransformers/gpt2
INFO:absl:Using linear optimization warmup
INFO:absl:Using Adamw optimizer
INFO:absl:No ❌❌ checkpoint found in MODELS/gpt2_squad_causal
Train: Epoch 1/6 --- Step 1/2737 --- total examples 0 , trainable variables 148: 0%| | 0/2737 [00:00<?, ?batch /s]
INFO:tensorflow:batch_all_reduce: 147 all-reduces with algorithm = nccl, num_packs = 1
INFO:tensorflow:batch_all_reduce: 147 all-reduces with algorithm = nccl, num_packs = 1
WARNING:tensorflow:Efficient allreduce is not supported for 1 IndexedSlices
WARNING:tensorflow:Efficient allreduce is not supported for 1 IndexedSlices
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:GPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1').
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:GPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1').
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:batch_all_reduce: 147 all-reduces with algorithm = nccl, num_packs = 1
INFO:tensorflow:batch_all_reduce: 147 all-reduces with algorithm = nccl, num_packs = 1
WARNING:tensorflow:Efficient allreduce is not supported for 1 IndexedSlices
WARNING:tensorflow:Efficient allreduce is not supported for 1 IndexedSlices
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:GPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1').
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:GPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1').
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
Train: Epoch 1/6 --- Step 2737/2737 --- total examples 87552 , trainable variables 148: 100%|██████████| 2737/2737 [29:30<00:00, 1.55batch /s, _runtime=1816, _timestamp=1.65e+9, learning_rate=1.6e-5, loss=0.293]
INFO:absl:Model saved at epoch 1 at MODELS/gpt2_squad_causal/ckpt-1
Train: Epoch 2/6 --- Step 2737/2737 --- total examples 175136 , trainable variables 148: 100%|██████████| 2737/2737 [28:52<00:00, 1.58batch /s, _runtime=3553, _timestamp=1.65e+9, learning_rate=1.2e-5, loss=0.226]
INFO:absl:Model saved at epoch 2 at MODELS/gpt2_squad_causal/ckpt-2
Train: Epoch 3/6 --- Step 2737/2737 --- total examples 262720 , trainable variables 148: 100%|██████████| 2737/2737 [28:53<00:00, 1.58batch /s, _runtime=5290, _timestamp=1.65e+9, learning_rate=8e-6, loss=0.111]
INFO:absl:Model saved at epoch 3 at MODELS/gpt2_squad_causal/ckpt-3
Train: Epoch 4/6 --- Step 2737/2737 --- total examples 350304 , trainable variables 148: 100%|██████████| 2737/2737 [28:53<00:00, 1.58batch /s, _runtime=7028, _timestamp=1.65e+9, learning_rate=4e-6, loss=0.0581]
INFO:absl:Model saved at epoch 4 at MODELS/gpt2_squad_causal/ckpt-4
Train: Epoch 5/6 --- Step 2737/2737 --- total examples 437888 , trainable variables 148: 100%|██████████| 2737/2737 [28:53<00:00, 1.58batch /s, _runtime=8765, _timestamp=1.65e+9, learning_rate=1.19e-12, loss=0.209]
INFO:absl:Model saved at epoch 5 at MODELS/gpt2_squad_causal/ckpt-5
Evaluation Script (Squad V1) - Exact match, F1 score¶
from collections import Counter
import string
import re
import argparse
import json
import sys
def normalize_answer(s):
"""Lower text and remove punctuation, articles and extra whitespace."""
def remove_articles(text):
return re.sub(r'\b(a|an|the)\b', ' ', text)
def white_space_fix(text):
return ' '.join(text.split())
def remove_punc(text):
exclude = set(string.punctuation)
return ''.join(ch for ch in text if ch not in exclude)
def lower(text):
return text.lower()
return white_space_fix(remove_articles(remove_punc(lower(s))))
def f1_score(prediction, ground_truth):
prediction_tokens = normalize_answer(prediction).split()
ground_truth_tokens = normalize_answer(ground_truth).split()
common = Counter(prediction_tokens) & Counter(ground_truth_tokens)
num_same = sum(common.values())
if num_same == 0:
return 0
precision = 1.0 * num_same / len(prediction_tokens)
recall = 1.0 * num_same / len(ground_truth_tokens)
f1 = (2 * precision * recall) / (precision + recall)
return f1
def exact_match_score(prediction, ground_truth):
return (normalize_answer(prediction) == normalize_answer(ground_truth))
def metric_max_over_ground_truths(metric_fn, prediction, ground_truths):
scores_for_ground_truths = []
for ground_truth in ground_truths:
score = metric_fn(prediction, ground_truth)
scores_for_ground_truths.append(score)
return max(scores_for_ground_truths)
def evaluate(dataset, predictions):
"""Squad evaluate
dataset: Huggingface Dataset
predictions: List of predictions
"""
f1 = exact_match = total = 0
for item in dataset:
ground_truths = item['answers']['text'] # list of answers
prediction = predictions[total]
exact_match += metric_max_over_ground_truths(
exact_match_score, prediction, ground_truths)
f1 += metric_max_over_ground_truths(
f1_score, prediction, ground_truths)
total += 1
exact_match = 100.0 * exact_match / total
f1 = 100.0 * f1 / total
return {'exact_match': exact_match, 'f1': f1}
Evaluate ( exact match and F1 score ) on all checkpoints - GPT2 Causal¶
def split_by_id(predicted_ids, eos_id):
"""Split by EOS_ID to make decoding proper"""
all_ids = []
for per_example_id in predicted_ids:
index = -1
if eos_id in per_example_id:
index = per_example_id.index(eos_id)
sliced_ids = per_example_id[:index]
all_ids.append(sliced_ids)
return all_ids
def get_serialized_model_from_checkpoint(model_checkpoint_dir, checkpoint_number):
"""Load serialized model checkpoint"""
model = GPT2Model.from_pretrained(model_name, use_auto_regressive=True)
model.load_checkpoint(checkpoint_path='{}/ckpt-{}'.format(model_checkpoint_dir, checkpoint_number))
model.save_transformers_serialized('{}/saved_model'.format(model_checkpoint_dir), overwrite=True)
loaded = tf.saved_model.load('{}/saved_model'.format(model_checkpoint_dir))
return loaded
# Validation dataset
validation_files = tf.io.gfile.glob("{}/*.tfrecord".format(tfrecord_validation_dir))
tf_reader = TFReader(schema=validation_schema,
tfrecord_files=validation_files)
x_keys = ['input_ids']
y_keys = ['original_answer'] # not necessarily required
validation_dataset = tf_reader.read_record(auto_batch=True,
batch_size=batch_size,
x_keys = x_keys,
y_keys = y_keys,
shuffle=False,
padded_values={'input_ids': tf.constant(-1)}
)
validation_results = []
for checkpoint_number in range(1, epochs+1):
# get serialized model
loaded = get_serialized_model_from_checkpoint(model_checkpoint_dir, checkpoint_number)
# Load decoder
decoder = TextDecoder(model=loaded)
# greedy decoding
predicted_answers = []
for (batch_inputs, batch_labels) in tqdm.tqdm(validation_dataset):
predictions = decoder.decode(batch_inputs,
mode='greedy',
max_iterations=max_answer_length,
eos_id=tokenizer.bos_token_id)
predicted_ids = tf.squeeze(predictions['predicted_ids'], axis=1).numpy().tolist()
# Squeeze to 2D
predicted_ids = split_by_id(predicted_ids, tokenizer.bos_token_id)
# Decode
predicted_answers_batch = tokenizer.batch_decode(predicted_ids)
predicted_answers.extend(predicted_answers_batch)
# generation will start with 'answer:'. remove that
predicted_answers = [answer.replace('answer: ', '') for answer in predicted_answers]
# Exact match and f1 score
val_result = evaluate(dataset['validation'], predicted_answers)
validation_results.append(val_result)
for checkpoint_number, result in enumerate(validation_results):
print("Checkpoint {} , {}".format(checkpoint_number+1, result))
Checkpoint 1 , {'exact_match': 66.66982024597918, 'f1': 75.18278422439874}
Checkpoint 2 , {'exact_match': 69.59318826868495, 'f1': 78.53187632449746}
Checkpoint 3 , {'exact_match': 70.13245033112582, 'f1': 78.74015684895157}
Checkpoint 4 , {'exact_match': 69.93377483443709, 'f1': 78.93703082679076}
Checkpoint 5 , {'exact_match': 69.2336802270577, 'f1': 78.31863252452612}