{ "cells": [ { "cell_type": "code", "execution_count": null, "id": "c52f2969", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "markdown", "id": "516bc9a5", "metadata": {}, "source": [ "# Writing and Reading TFRecords\n", "\n", "\n", "Tensoflow-Transformers has off the shelf support to write and read tfrecord with so much ease.\n", "It also allows you to shard, shuffle and batch your data most of the times, with minimal code.\n", "\n", "Here we will see, how can we make use of these utilities to write and read tfrecords.\n", "\n", "For this examples, we will be using a [**Squad Dataset**](https://huggingface.co/datasets/squad \"Squad Dataset\"), to convert it to a text to text problem using\n", "GPT2 Tokenizer. TFRecords are useful for efficient training pipelines." ] }, { "cell_type": "code", "execution_count": null, "id": "658b1da5", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "id": "cbd5587a", "metadata": {}, "outputs": [], "source": [ "from tf_transformers.data import TFWriter, TFReader\n", "from transformers import GPT2TokenizerFast\n", "\n", "from datasets import load_dataset\n", "\n", "import tempfile\n", "import json\n", "import glob" ] }, { "cell_type": "code", "execution_count": null, "id": "38f0169a", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "markdown", "id": "b7c52655", "metadata": {}, "source": [ "## Load Data and Tokenizer\n", "\n", "We will load dataset and tokenizer. Then we will define the length for the examples.\n", "It is important to make sure we have limit the length within the allowed limit of each models." ] }, { "cell_type": "code", "execution_count": null, "id": "5814a467", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "id": "17af7500", "metadata": {}, "outputs": [], "source": [ "# Load Dataset\n", "dataset = load_dataset(\"squad\")\n", "tokenizer = GPT2TokenizerFast.from_pretrained('gpt2')\n", "\n", "# Define length for examples\n", "max_passage_length = 384\n", "max_question_length = 64\n", "max_answer_length = 40" ] }, { "cell_type": "code", "execution_count": null, "id": "b24f6776", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "markdown", "id": "bf02f3c0", "metadata": {}, "source": [ "## Write TFRecord\n", "\n", "To write a TFRecord, we need to provide a schema (**dict**). This schema supports **int**, **float**, **bytes**.\n", "\n", "**TFWriter**, support [**FixedLen**](https://www.tensorflow.org/api_docs/python/tf/io/FixedLenFeature) and\n", "[**VarLen**](https://www.tensorflow.org/api_docs/python/tf/io/VarLenFeature) feature types. \n", "\n", "The recommended and easiest is to use **Varlen**, this will be faster and easy to write and read.\n", "We can also pad it accordingly after reading." ] }, { "cell_type": "code", "execution_count": null, "id": "1e11fc56", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "id": "bd878197", "metadata": {}, "outputs": [], "source": [ "def parse_train(dataset, tokenizer, max_passage_length, max_question_length, max_answer_length, key):\n", " \"\"\"Function o to parse examples\n", "\n", " Args:\n", " dataset (:obj:`dataet`): HF dataset\n", " tokenizer (:obj:`tokenizer`): HF Tokenizer\n", " max_passage_length (:obj:`int`): Passage Length\n", " max_question_length (:obj:`int`): Question Length\n", " max_answer_length (:obj:`int`): Answer Length\n", " key (:obj:`str`): Key of dataset (`train`, `validation` etc)\n", " \"\"\" \n", " result = {}\n", " for f in dataset[key]:\n", " question_input_ids = tokenizer(item['context'], max_length=max_passage_length, truncation=True)['input_ids'] + [tokenizer.bos_token_id]\n", " passage_input_ids = tokenizer(item['question'], max_length=max_question_length, truncation=True)['input_ids'] + \\\n", " [tokenizer.bos_token_id] \n", " \n", " # Input Question + Context\n", " # We should make sure that we will mask labels here,as we dont want model to predict inputs\n", " input_ids = question_input_ids + passage_input_ids\n", " labels_mask = [0] * len(input_ids)\n", " \n", " # Answer part\n", " answer_ids = tokenizer(item['answers']['text'][0], max_length=max_answer_length, truncation=True)['input_ids'] + \\\n", " [tokenizer.bos_token_id]\n", " input_ids = input_ids + answer_ids\n", " labels_mask = labels_mask + [1] * len(answer_ids)\n", " \n", " # Shift positions to make proper training examples\n", " labels = input_ids[1:]\n", " labels_mask = labels_mask[1:]\n", " \n", " input_ids = input_ids[:-1]\n", "\n", " result = {}\n", " result['input_ids'] = input_ids\n", " \n", " result['labels'] = labels\n", " result['labels_mask'] = labels_mask\n", " \n", " yield result\n", " \n", "# Write using TF Writer\n", "\n", "schema = {\n", " \"input_ids\": (\"var_len\", \"int\"),\n", " \"labels\": (\"var_len\", \"int\"),\n", " \"labels_mask\": (\"var_len\", \"int\"),\n", "}\n", "\n", "tfrecord_train_dir = tempfile.mkdtemp()\n", "tfrecord_filename = 'squad'\n", "\n", "tfwriter = TFWriter(schema=schema, \n", " file_name=tfrecord_filename, \n", " model_dir=tfrecord_train_dir,\n", " tag='train',\n", " overwrite=True\n", " )\n", "\n", "# Train dataset\n", "train_parser_fn = parse_train(dataset, tokenizer, max_passage_length, max_question_length, max_answer_length, key='train')\n", "tfwriter.process(parse_fn=train_parser_fn)" ] }, { "cell_type": "code", "execution_count": null, "id": "4af4b706", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "markdown", "id": "7146c966", "metadata": {}, "source": [ "## Read TFRecords\n", "\n", "To read a TFRecord, we need to provide a schema (**dict**). This schema supports **int**, **float**, **bytes**.\n", "\n", "**TFWReader**, support [**FixedLen**](https://www.tensorflow.org/api_docs/python/tf/io/FixedLenFeature) and\n", "[**VarLen**](https://www.tensorflow.org/api_docs/python/tf/io/VarLenFeature) feature types. \n", "We can also **auto_batch**, **shuffle**, choose the optional keys (not all keys in tfrecords) might not be required while reading, etc in a single function." ] }, { "cell_type": "code", "execution_count": null, "id": "f180b982", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "id": "7f9d4110", "metadata": {}, "outputs": [], "source": [ "# Read TFRecord\n", "\n", "schema = json.load(open(\"{}/schema.json\".format(tfrecord_train_dir)))\n", "all_files = glob.glob(\"{}/*.tfrecord\".format(tfrecord_train_dir))\n", "tf_reader = TFReader(schema=schema, \n", " tfrecord_files=all_files)\n", "\n", "x_keys = ['input_ids']\n", "y_keys = ['labels', 'labels_mask']\n", "batch_size = 16\n", "train_dataset = tf_reader.read_record(auto_batch=True, \n", " keys=x_keys,\n", " batch_size=batch_size, \n", " x_keys = x_keys, \n", " y_keys = y_keys,\n", " shuffle=True, \n", " drop_remainder=True\n", " )" ] }, { "cell_type": "code", "execution_count": null, "id": "61165f1c", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "id": "797dab70", "metadata": {}, "outputs": [], "source": [ "for (batch_inputs, batch_labels) in train_dataset:\n", " print(batch_inputs, batch_labels)\n", " break" ] }, { "cell_type": "code", "execution_count": null, "id": "0d50e004", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "id": "267d5ed2", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "jupytext": { "formats": "ipynb,md:myst" }, "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" } }, "nbformat": 4, "nbformat_minor": 5 }