Writing and Reading TFRecords¶
Tensoflow-Transformers has off the shelf support to write and read tfrecord with so much ease. It also allows you to shard, shuffle and batch your data most of the times, with minimal code.
Here we will see, how can we make use of these utilities to write and read tfrecords.
For this examples, we will be using a Squad Dataset, to convert it to a text to text problem using GPT2 Tokenizer. TFRecords are useful for efficient training pipelines.
from tf_transformers.data import TFWriter, TFReader
from transformers import GPT2TokenizerFast
from datasets import load_dataset
import tempfile
import json
import glob
Load Data and Tokenizer¶
We will load dataset and tokenizer. Then we will define the length for the examples. It is important to make sure we have limit the length within the allowed limit of each models.
# Load Dataset
dataset = load_dataset("squad")
tokenizer = GPT2TokenizerFast.from_pretrained('gpt2')
# Define length for examples
max_passage_length = 384
max_question_length = 64
max_answer_length = 40
Write TFRecord¶
To write a TFRecord, we need to provide a schema (dict). This schema supports int, float, bytes.
TFWriter, support FixedLen and VarLen feature types.
The recommended and easiest is to use Varlen, this will be faster and easy to write and read. We can also pad it accordingly after reading.
def parse_train(dataset, tokenizer, max_passage_length, max_question_length, max_answer_length, key):
"""Function o to parse examples
Args:
dataset (:obj:`dataet`): HF dataset
tokenizer (:obj:`tokenizer`): HF Tokenizer
max_passage_length (:obj:`int`): Passage Length
max_question_length (:obj:`int`): Question Length
max_answer_length (:obj:`int`): Answer Length
key (:obj:`str`): Key of dataset (`train`, `validation` etc)
"""
result = {}
for f in dataset[key]:
question_input_ids = tokenizer(item['context'], max_length=max_passage_length, truncation=True)['input_ids'] + [tokenizer.bos_token_id]
passage_input_ids = tokenizer(item['question'], max_length=max_question_length, truncation=True)['input_ids'] + \
[tokenizer.bos_token_id]
# Input Question + Context
# We should make sure that we will mask labels here,as we dont want model to predict inputs
input_ids = question_input_ids + passage_input_ids
labels_mask = [0] * len(input_ids)
# Answer part
answer_ids = tokenizer(item['answers']['text'][0], max_length=max_answer_length, truncation=True)['input_ids'] + \
[tokenizer.bos_token_id]
input_ids = input_ids + answer_ids
labels_mask = labels_mask + [1] * len(answer_ids)
# Shift positions to make proper training examples
labels = input_ids[1:]
labels_mask = labels_mask[1:]
input_ids = input_ids[:-1]
result = {}
result['input_ids'] = input_ids
result['labels'] = labels
result['labels_mask'] = labels_mask
yield result
# Write using TF Writer
schema = {
"input_ids": ("var_len", "int"),
"labels": ("var_len", "int"),
"labels_mask": ("var_len", "int"),
}
tfrecord_train_dir = tempfile.mkdtemp()
tfrecord_filename = 'squad'
tfwriter = TFWriter(schema=schema,
file_name=tfrecord_filename,
model_dir=tfrecord_train_dir,
tag='train',
overwrite=True
)
# Train dataset
train_parser_fn = parse_train(dataset, tokenizer, max_passage_length, max_question_length, max_answer_length, key='train')
tfwriter.process(parse_fn=train_parser_fn)
Read TFRecords¶
To read a TFRecord, we need to provide a schema (dict). This schema supports int, float, bytes.
TFWReader, support FixedLen and VarLen feature types. We can also auto_batch, shuffle, choose the optional keys (not all keys in tfrecords) might not be required while reading, etc in a single function.
# Read TFRecord
schema = json.load(open("{}/schema.json".format(tfrecord_train_dir)))
all_files = glob.glob("{}/*.tfrecord".format(tfrecord_train_dir))
tf_reader = TFReader(schema=schema,
tfrecord_files=all_files)
x_keys = ['input_ids']
y_keys = ['labels', 'labels_mask']
batch_size = 16
train_dataset = tf_reader.read_record(auto_batch=True,
keys=x_keys,
batch_size=batch_size,
x_keys = x_keys,
y_keys = y_keys,
shuffle=True,
drop_remainder=True
)
for (batch_inputs, batch_labels) in train_dataset:
print(batch_inputs, batch_labels)
break