{ "cells": [ { "cell_type": "markdown", "id": "e0f1ee00", "metadata": {}, "source": [ "# GPT2 for QA using Squad V1 ( Causal LM )\n", "\n", "This tutorial contains complete code to fine-tune GPT2 to finetune for Question Answering using Squad V1 data.\n", "In addition to training a model, you will learn how to preprocess text into an appropriate format.\n", "\n", "In this notebook, you will:\n", "\n", "- Load the Squad v1 dataset from HuggingFace\n", "- Load GPT2 Model using tf-transformers\n", "- Build model using ```causal``` (default) and ```prefix``` masking.\n", "- Build train and validation dataset feature preparation using\n", "tokenizer from transformers.\n", "- Train your own model, fine-tuning GPT2 \n", "- Save your model and use it to for QA\n", "- Use the end-to-end (inference) in production setup\n", "\n", "If you're new to working with the Quora dataset, please see [SQUAD](https://huggingface.co/datasets/squad) for more details." ] }, { "cell_type": "code", "execution_count": null, "id": "51bb7f1c", "metadata": {}, "outputs": [], "source": [ "!pip install tf-transformers\n", "\n", "!pip install transformers\n", "\n", "!pip install wandb\n", "\n", "!pip install datasets" ] }, { "cell_type": "code", "execution_count": null, "id": "8ac430fd", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": 26, "id": "7069b428", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Tensorflow version 2.7.0\n", "Devices [PhysicalDevice(name='/physical_device:CPU:0', device_type='CPU'), PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU'), PhysicalDevice(name='/physical_device:GPU:1', device_type='GPU')]\n" ] } ], "source": [ "import tensorflow as tf\n", "import random\n", "import collections\n", "import wandb\n", "import tempfile\n", "import tqdm\n", "import json\n", "\n", "import os\n", "import numpy as np\n", "\n", "print(\"Tensorflow version\", tf.__version__)\n", "print(\"Devices\", tf.config.list_physical_devices())\n", "\n", "from tf_transformers.models import GPT2Model\n", "from tf_transformers.core import Trainer\n", "from tf_transformers.optimization import create_optimizer\n", "from tf_transformers.data import TFWriter, TFReader\n", "from tf_transformers.losses.loss_wrapper import get_lm_loss\n", "from tf_transformers.text import TextDecoder\n", "\n", "\n", "from datasets import load_dataset\n", "\n", "\n", "from transformers import GPT2Tokenizer" ] }, { "cell_type": "code", "execution_count": null, "id": "5317c9b6", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "markdown", "id": "e1493dbf", "metadata": {}, "source": [ "### Load Data, Tokenizer" ] }, { "cell_type": "code", "execution_count": null, "id": "b40f1d34", "metadata": {}, "outputs": [], "source": [ "model_name = 'gpt2'\n", "\n", "# Load Dataset\n", "dataset = load_dataset(\"squad\")\n", "tokenizer = GPT2Tokenizer.from_pretrained(model_name)\n", "\n", "# Define length for examples\n", "max_sequence_length = 384\n", "max_question_length = 64\n", "max_answer_length = 40\n", "batch_size = 32" ] }, { "cell_type": "markdown", "id": "188d4b46", "metadata": {}, "source": [ "### Prepare Training TFRecords and Validation TFRecords using Squad ( causal and prefix )\n", "\n", "* 1. We combine ```(question + context + answer)```\n", "* 2. For ```mask_mode=causal```, we don't need any mask. For ```mask_mode=prefix```, we need ```input_mask```.\n", "* 3. For ```prefix```, we will mask only ```question + context```, as ```answer``` is supposed to be generated, we shouldn't mask it, means its causal.\n", "* 4. Note how ```labels_mask``` is prepared and how it is different from ```input_mask```." ] }, { "cell_type": "code", "execution_count": 11, "id": "a0146899", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "INFO:absl:Total individual observations/examples written is 87599 in 173.0701344013214 seconds\n", "INFO:absl:All writer objects closed\n" ] } ], "source": [ "def parse_train(dataset, tokenizer, max_question_length, max_passage_length, max_answer_length, key):\n", " \"\"\"Function to parse examples which are is_duplicate=1\n", "\n", " Args:\n", " dataset (:obj:`dataet`): HF dataset\n", " tokenizer (:obj:`tokenizer`): HF Tokenizer\n", " max_question_length (:obj:`int`): Question Length\n", " max_passage_length (:obj:`int`): Passage Length\n", " max_answer_length (:obj:`int`): Answer Length\n", " key (:obj:`str`): Key of dataset (`train`, `validation` etc)\n", " \"\"\" \n", " result = {}\n", " for f in dataset[key]:\n", " \n", " question_ids = tokenizer('Question: ' + f['question'], max_length=max_question_length, truncation=True)['input_ids']\n", " context_ids = tokenizer('Context: ' + f['context'], max_length=max_passage_length, truncation=True)['input_ids']\n", " answer_ids = tokenizer('answer: ' + f['answers']['text'][0], max_length=max_answer_length, truncation=True)['input_ids']\n", " # add EOS\n", " context_ids = context_ids + [tokenizer.bos_token_id]\n", " answer_ids = answer_ids + [tokenizer.bos_token_id] # EOS token\n", " \n", " # input_ids\n", " input_ids = (question_ids + context_ids + answer_ids)\n", " \n", " # input_mask\n", " input_mask = ([1] * len(question_ids)) + ([1] * len(context_ids)) + ([0] * len(answer_ids))\n", " # labels mask is opposite to input_mask, as we need to find loss only on answerids\n", " labels_mask = ([0] * len(question_ids)) + ([0] * len(context_ids)) + ([1] * len(answer_ids))\n", " result = {}\n", " # Except last word\n", " result['input_ids'] = input_ids[:-1]\n", " result['input_mask'] = input_mask[:-1]\n", " \n", " # Shift one word next\n", " result['labels'] = input_ids[1:]\n", " result['labels_mask'] = labels_mask[1:]\n", " \n", " yield result\n", " \n", "# Write using TF Writer\n", "schema = {\n", " \"input_ids\": (\"var_len\", \"int\"),\n", " \"input_mask\": (\"var_len\", \"int\"),\n", " \"labels\": (\"var_len\", \"int\"),\n", " \"labels_mask\": (\"var_len\", \"int\")\n", " \n", "}\n", "\n", "tfrecord_train_dir = tempfile.mkdtemp()\n", "tfrecord_filename = 'squad'\n", "\n", "tfwriter = TFWriter(schema=schema, \n", " file_name=tfrecord_filename, \n", " model_dir=tfrecord_train_dir,\n", " tag='train',\n", " overwrite=True\n", " )\n", "\n", "# Train dataset\n", "train_parser_fn = parse_train(dataset, tokenizer, max_question_length, max_sequence_length, max_answer_length, key='train')\n", "tfwriter.process(parse_fn=train_parser_fn)" ] }, { "cell_type": "code", "execution_count": 4, "id": "b6814ebc", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "markdown", "id": "9055937d", "metadata": {}, "source": [ "### Prepare Validation TFRecords" ] }, { "cell_type": "code", "execution_count": 12, "id": "221b7444", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "INFO:absl:Total individual observations/examples written is 10570 in 19.614187002182007 seconds\n", "INFO:absl:All writer objects closed\n" ] } ], "source": [ "def parse_dev(dataset, tokenizer, max_question_length, max_passage_length, max_answer_length, key):\n", " \"\"\"Function to parse examples\n", " Args:\n", " dataset (:obj:`dataet`): HF dataset\n", " tokenizer (:obj:`tokenizer`): HF Tokenizer\n", " max_question_length (:obj:`int`): Question Length\n", " max_passage_length (:obj:`int`): Passage Length\n", " max_answer_length (:obj:`int`): Answer Length\n", " key (:obj:`str`): Key of dataset (`train`, `validation` etc)\n", " \"\"\" \n", " result = {}\n", " for f in dataset[key]:\n", " \n", " question_ids = tokenizer('Question: ' + f['question'], max_length=max_question_length, truncation=True)['input_ids']\n", " context_ids = tokenizer('Context: ' + f['context'], max_length=max_passage_length, truncation=True)['input_ids']\n", " answer_ids = tokenizer('answer: ' + f['answers']['text'][0], max_length=max_answer_length, truncation=True)['input_ids']\n", " # add EOS\n", " context_ids = context_ids + [tokenizer.bos_token_id]\n", " \n", " # input_ids\n", " input_ids = (question_ids + context_ids)\n", " \n", " # input_mask\n", " input_mask = ([1] * len(question_ids)) + ([1] * len(context_ids))\n", " result['input_ids'] = input_ids\n", " result['input_mask'] = input_mask\n", " result['original_answer'] = f['answers']['text'][0]\n", " \n", " yield result\n", " \n", "tfrecord_validation_dir = tempfile.mkdtemp()\n", "tfrecord_validation_filename = 'squad'\n", "\n", "validation_schema = {\n", " \"input_ids\": (\"var_len\", \"int\"),\n", " \"input_mask\": (\"var_len\", \"int\"),\n", " \"original_answer\": (\"var_len\", \"bytes\")\n", " \n", "}\n", "tfwriter = TFWriter(schema=validation_schema, \n", " file_name=tfrecord_validation_filename, \n", " model_dir=tfrecord_validation_dir,\n", " tag='eval',\n", " overwrite=True\n", " )\n", "\n", "# Validation dataset\n", "validation_parser_fn = parse_dev(dataset, tokenizer, max_question_length, max_sequence_length, max_answer_length, key='validation')\n", "tfwriter.process(parse_fn=validation_parser_fn)" ] }, { "cell_type": "markdown", "id": "8fa10521", "metadata": {}, "source": [ "### Wandb Configuration" ] }, { "cell_type": "code", "execution_count": null, "id": "d831dc81", "metadata": {}, "outputs": [], "source": [ "project = \"TUTORIALS\"\n", "display_name = 'causal_mask'\n", "wandb.init(project=project, name=display_name)" ] }, { "cell_type": "markdown", "id": "4bfaf6b2", "metadata": {}, "source": [ "### Load Model, Optimizer , Trainer\n", "\n", "Our Trainer expects ```model```, ```optimizer``` and ```loss``` to be a function." ] }, { "cell_type": "code", "execution_count": 9, "id": "c230d395", "metadata": {}, "outputs": [], "source": [ "# Load Model\n", "def get_model(model_name, is_training, use_dropout, mask_mode='causal'):\n", " \"\"\"Get Model\"\"\"\n", "\n", " def model_fn():\n", " model = GPT2Model.from_pretrained(model_name, mask_mode=mask_mode) #causal by default\n", " return model\n", " return model_fn\n", "\n", "# Load Optimizer\n", "def get_optimizer(learning_rate, examples, batch_size, epochs, use_constant_lr=False):\n", " \"\"\"Get optimizer\"\"\"\n", " steps_per_epoch = int(examples / batch_size)\n", " num_train_steps = steps_per_epoch * epochs\n", " warmup_steps = int(0.1 * num_train_steps)\n", "\n", " def optimizer_fn():\n", " optimizer, learning_rate_fn = create_optimizer(learning_rate, num_train_steps, warmup_steps, use_constant_lr=use_constant_lr)\n", " return optimizer\n", "\n", " return optimizer_fn\n", "\n", "# Load trainer\n", "def get_trainer(distribution_strategy, num_gpus=0, tpu_address=None):\n", " \"\"\"Get Trainer\"\"\"\n", " trainer = Trainer(distribution_strategy, num_gpus=num_gpus, tpu_address=tpu_address)\n", " return trainer\n", "\n", "# Load loss fn\n", "def get_loss():\n", " loss_fn = get_lm_loss(label_column='labels', \n", " label_weights_column='labels_mask')\n", " return loss_fn" ] }, { "cell_type": "code", "execution_count": null, "id": "6cb70797", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "markdown", "id": "3d9b343f", "metadata": {}, "source": [ "### Set Hyperparameters and Configs\n", "\n", "1. Set necessay hyperparameters.\n", "2. Prepare ```train dataset```, ```validation dataset```.\n", "3. Load ```model```, ```optimizer```, ```loss``` and ```trainer```." ] }, { "cell_type": "code", "execution_count": 14, "id": "9099bbe8", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1')\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1')\n" ] } ], "source": [ "# Model configs\n", "learning_rate = 2e-5\n", "epochs = 5\n", "model_checkpoint_dir = 'MODELS/gpt2_squad_causal'\n", "\n", "# Train dataset\n", "schema = json.load(open(\"{}/schema.json\".format(tfrecord_train_dir)))\n", "total_train_examples = json.load(open(\"{}/stats.json\".format(tfrecord_train_dir)))['total_records']\n", "\n", "\n", "all_files = tf.io.gfile.glob(\"{}/*.tfrecord\".format(tfrecord_train_dir))\n", "tf_reader = TFReader(schema=schema, \n", " tfrecord_files=all_files)\n", "\n", "x_keys = ['input_ids']\n", "y_keys = ['labels', 'labels_mask']\n", "train_dataset = tf_reader.read_record(auto_batch=True, \n", " batch_size=batch_size, \n", " x_keys = x_keys, \n", " y_keys = y_keys,\n", " shuffle=True\n", " )\n", "\n", "# Total train examples\n", "steps_per_epoch = total_train_examples // batch_size\n", "\n", "# model\n", "model_fn = get_model(model_name, is_training=True, use_dropout=True, mask_mode='causal')\n", "# optimizer\n", "optimizer_fn = get_optimizer(learning_rate, total_train_examples, batch_size, epochs)\n", "# loss\n", "loss_fn = get_loss()\n", "# trainer (multi gpu strategy)\n", "trainer = get_trainer(distribution_strategy='mirrored', num_gpus=2)" ] }, { "cell_type": "code", "execution_count": null, "id": "5b79b7d4", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "markdown", "id": "04c657db", "metadata": {}, "source": [ "### Train GPT2 Causal :-)\n", "\n", "* 1. Loss is coming down in epoch 1 itself.\n", "* 2. Evaluation results clearly indicated how well model has learned." ] }, { "cell_type": "code", "execution_count": 17, "id": "c9fb93d8", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "INFO:absl:Make sure `steps_per_epoch` should be less than or equal to number of batches in dataset.\n", "INFO:absl:Policy: ----> float32\n", "INFO:absl:Strategy: ---> \n", "INFO:absl:Num GPU Devices: ---> 2\n", "INFO:absl:Successful ✅✅: Model checkpoints matched and loaded from /home/jovyan/.cache/huggingface/hub/tftransformers__gpt2.main.8843a828e80c53bb121d7e395d07e3821ba88ea5/ckpt-1\n", "INFO:absl:Successful ✅: Loaded model from tftransformers/gpt2\n", "INFO:absl:Using linear optimization warmup\n", "INFO:absl:Using Adamw optimizer\n", "INFO:absl:No ❌❌ checkpoint found in MODELS/gpt2_squad_causal\n", "Train: Epoch 1/6 --- Step 1/2737 --- total examples 0 , trainable variables 148: 0%|\u001b[32m \u001b[0m| 0/2737 [00:00