{ "cells": [ { "cell_type": "markdown", "id": "a5c88f3a", "metadata": {}, "source": [ "# Text Generation using GPT2\n", "\n", "* This tutorial is intended to provide, a familiarity in how to use ```GPT2``` for text-generation tasks.\n", "* No training is involved in this." ] }, { "cell_type": "code", "execution_count": null, "id": "16cd8c45", "metadata": {}, "outputs": [], "source": [ "!pip install tf-transformers\n", "\n", "!pip install transformers\n" ] }, { "cell_type": "code", "execution_count": null, "id": "15450927", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": 2, "id": "9a37c55e", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Tensorflow version 2.7.0\n", "Devices [PhysicalDevice(name='/physical_device:CPU:0', device_type='CPU'), PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]\n" ] } ], "source": [ "import os\n", "os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' # Supper TF warnings\n", "\n", "import tensorflow as tf\n", "print(\"Tensorflow version\", tf.__version__)\n", "print(\"Devices\", tf.config.list_physical_devices())\n", "\n", "from tf_transformers.models import GPT2Model\n", "from tf_transformers.text import TextDecoder\n", "from transformers import GPT2Tokenizer" ] }, { "cell_type": "code", "execution_count": null, "id": "8d057c45", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "markdown", "id": "a0cc782b", "metadata": {}, "source": [ "### Load GPT2 Model \n", "\n", "* 1. Note `use_auto_regressive=True`, argument. This is required for any models to enable text-generation." ] }, { "cell_type": "code", "execution_count": 3, "id": "6b0e5790", "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "908d508d110b4f76a90003c1d8d8cb57", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Downloading: 0%| | 0.00/0.99M [00:00 and ).\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "INFO:absl:Successful ✅✅: Model checkpoints matched and loaded from /Users/sarathrnair/.cache/huggingface/hub/tftransformers__gpt2.main.8843a828e80c53bb121d7e395d07e3821ba88ea5/ckpt-1\n", "INFO:absl:Successful ✅: Loaded model from tftransformers/gpt2\n" ] } ], "source": [ "model_name = 'gpt2'\n", "\n", "tokenizer = GPT2Tokenizer.from_pretrained(model_name)\n", "model = GPT2Model.from_pretrained(model_name, use_auto_regressive=True)" ] }, { "cell_type": "code", "execution_count": null, "id": "3e695918", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "markdown", "id": "54d5cedb", "metadata": {}, "source": [ "### Serialize and load\n", "\n", "* The most recommended way of using a Tensorflow model is to load it after serializing.\n", "* The speedup, especially for text generation is up to 50x times." ] }, { "cell_type": "code", "execution_count": 4, "id": "9412db40", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "WARNING:absl:Found untraced functions such as word_embeddings_layer_call_fn, word_embeddings_layer_call_and_return_conditional_losses, positional_embeddings_layer_call_fn, positional_embeddings_layer_call_and_return_conditional_losses, dropout_layer_call_fn while saving (showing 5 of 740). These functions will not be directly callable after loading.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Assets written to: MODELS/gpt2/assets\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "INFO:tensorflow:Assets written to: MODELS/gpt2/assets\n" ] } ], "source": [ "# Save as serialized\n", "model_dir = 'MODELS/gpt2'\n", "model.save_transformers_serialized(model_dir)\n", "\n", "# Load\n", "loaded = tf.saved_model.load(model_dir)\n", "model = loaded.signatures['serving_default']" ] }, { "cell_type": "code", "execution_count": null, "id": "63782df5", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "markdown", "id": "5cfeb2b3", "metadata": {}, "source": [ "### Text-Generation\n", "\n", "* . We can pass ```tf.keras.Model``` also to ```TextDecoder```, but this is recommended\n", "* . GPT2 like (Encoder) only models require ```-1``` as padding token." ] }, { "cell_type": "code", "execution_count": 5, "id": "8a2120e7", "metadata": {}, "outputs": [], "source": [ "decoder = TextDecoder(model=loaded)" ] }, { "cell_type": "markdown", "id": "91adf307", "metadata": {}, "source": [ "### Greedy Decoding" ] }, { "cell_type": "code", "execution_count": 34, "id": "4808b798", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[\", but I don't want to be a burden to my family. I want to be a part of the community. I want to be a part of the\", \" to me. I've been able to get through a lot of things, but I've also been able to get through a lot of things that I've never\"]\n" ] } ], "source": [ "texts = ['I would like to walk with my cat', \n", " 'Music has been very soothing']\n", "\n", "input_ids = tf.ragged.constant(tokenizer(texts)['input_ids']).to_tensor(-1) # Padding GPT2 style models needs -1\n", "\n", "inputs = {'input_ids': input_ids}\n", "predictions = decoder.decode(inputs, \n", " mode='greedy', \n", " max_iterations=32)\n", "print(tokenizer.batch_decode(tf.squeeze(predictions['predicted_ids'], axis=1)))" ] }, { "cell_type": "markdown", "id": "7c3d5d17", "metadata": {}, "source": [ "### Beam Decoding" ] }, { "cell_type": "code", "execution_count": 40, "id": "3ee3035e", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "['. I would like to walk with my cat. I would like to walk with my cat. I would like to walk with my cat. I would like to', ' to me, and I\\'m glad to be able to share it with you.\"\\n\\n\"I\\'m glad to be able to share it with you.\"\\n']\n" ] } ], "source": [ "inputs = {'input_ids': input_ids}\n", "predictions = decoder.decode(inputs, \n", " mode='beam',\n", " num_beams=3,\n", " max_iterations=32)\n", "print(tokenizer.batch_decode(predictions['predicted_ids'][:, 0, :]))" ] }, { "cell_type": "markdown", "id": "07aae733", "metadata": {}, "source": [ "### Top K Nucleus Sampling" ] }, { "cell_type": "code", "execution_count": 41, "id": "c22c4dfe", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[\", but I don't want to be a burden to my family. I want to be a part of the community. I want to be a part of the\", \" to me. I've been able to get through a lot of things, but I've also been able to get through a lot of things that I've never\"]\n" ] } ], "source": [ "inputs = {'input_ids': input_ids}\n", "predictions = decoder.decode(inputs, \n", " mode='top_k_top_p',\n", " top_k=50,\n", " top_p=0.7,\n", " num_return_sequences=3,\n", " max_iterations=32)\n", "print(tokenizer.batch_decode(predictions['predicted_ids'][:, 0, :]))" ] }, { "cell_type": "code", "execution_count": null, "id": "da62fc7e", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "jupytext": { "formats": "ipynb,md:myst" }, "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.7" } }, "nbformat": 4, "nbformat_minor": 5 }