{ "cells": [ { "cell_type": "markdown", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "qjUYDME9kgWS", "outputId": "89cd7257-9656-4743-f458-09fd70a437c6" }, "source": [ "# Bert TFLite" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "!pip install tf-transformers\n", "\n", "!pip install sentencepiece\n", "\n", "!pip install tensorflow-text\n", "\n", "!pip install transformers" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "CV0Bh-eFlEot" }, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": 3, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "cy6shXrXlL_D", "outputId": "da3d5ff0-96f1-4912-f115-9310fc4cc68d" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Tensorflow version 2.7.0\n" ] } ], "source": [ "import os\n", "os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' # Supper TF warnings\n", "\n", "import tensorflow as tf\n", "print(\"Tensorflow version\", tf.__version__)\n", "\n", "from tf_transformers.models import BertModel" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "0HEJnnnFlPxR" }, "outputs": [], "source": [] }, { "cell_type": "markdown", "metadata": { "id": "P89IVu5JlREX" }, "source": [ "## Convert a Model to TFlite\n", "\n", "The most important thing to notice here is that, if we want to convert a model to ```tflite```, we have to ensure that ```inputs``` to the model are **deterministic**, which means inputs should not be dynamic. We have to fix **batch_size**, **sequence_length** and other related input constraints depends on the model of interest.\n", "\n", "### Load Bert Model\n", "\n", "1. Fix the inputs\n", "2. We can always check the ```model``` **inputs** and **output** by using ```model.input``` and ```model.output```.\n", "3. We use ```batch_size=1``` and ```sequence_length=64```.)" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 232, "referenced_widgets": [ "1cdd9d78ce9043f0ab5226ff3817e40f", "1183ef28125c4355b8384ac0a095a77d", "8d5acca23dce4761a10beb7d12ea6237", "e218568e72914679b1cd3788fde722e8", "bf024667e0df4f14bed40f568ee1706d", "e9991c72ab9b4b7d894eb35cfe9931d5", "bd80b419051e4f0bb76d04cead2bf39a", "de0d7bda08ae4d4cb52557070798a8a3", "3406c7bb0b014329ae2028a595f4126b", "3713b7a03a974bb3bbb93a5447fa7756", "3a3a7517314743019c52e50865b09b62", "3ccd889a14524cffb9ddc48b829ea80c", "482e429c0fd8480d852af1e9072a8327", "b594a51ef6824172a54abd4f4bb47ee3", "ea455218338e44d4bdf2829ac7549f94", "b4ef3b1efeee4570b4cc235d0e9cb448", "c4814db70d9d4396b02d110e65545564", "d2ebc13056c64068b748fe3e99863701", "ccab4d0e02424cf4a0cf858b43074e81", "a3578756cac04bc1aff56fa88691b85d", "27d903ca66a54e36b93de4cb59aec2b4", "6e2de8476f7c4f42a27c6493b08d8ebe", "fe632e8555c04edc817e5cd374b44a49", "fac901137e764d378a7af42088daf415", "eaca6fa5b71a45e59a3903651475a4c7", "305ccd307e4f45678edf855ecba17402", "ad3364e02d41480c93bfbef9cd87dbb0", "83c128175f514e7c8e9a57ebfff4b65c", "7bbb1a2df6c34f61bc100e6493ed6e12", "125c421fc9274690b12d18d45ca23b88", "b9981523db1d4d2abeb690aed4b4ee22", "0dba2d828ee44b96a581331ec9461037", "9faeb9428c054b969d7eeaf4e49a9c5a", "d73e8fc970104819add8b9a017bbd777", "905a81bcf37c447e888bf8227a8bf6d6", "46a376bdb9bb4e0384f27da965d8704e", "4c13093c04a34c6cba5c7b61d5008371", "f1fc861911ea49029eeddffcacb03ef5", "333ba5cbe40444d1b249dfc7f4a436dc", "035ce7f12c494cefb4b1ec4c1703e58f", "d2a888aa210a4be287fd9508f14cc44b", "eb759afe634a49b5a045179a3d3f245f", "47023da6416747ccb6a7d47ecadc4bae", "10a5b71899f0429e956f5c9b549def3f", "3cc63c284163439e84e93c0f481582b9", "a8bec3ec0fc049418c0994bcc2d548be", "805f7204ee7b4854af4e2eade2cf1c51", "1f24e7fc3a6d4a5eb88787dd43fbb98e", "d8eeeeeba0814229b3f74eeb7077c520", "eddc707e7c6e450c93270113bc4a7f4e", "0d3fb328555a4e6ba6fc0e25bd701688", "1d37be0c3bb541f28b93022045fc7f3e", "e8c1caadb8ee48e989c31b5aee741a32", "d0d596a9f70e4ec69462aca935756629", "69ccd23f17514af28c93491cafcf1c03" ] }, "id": "WV-Ygv4Mlnlo", "outputId": "453da3db-d8b0-48e6-8f11-8cbff79333ff" }, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "1cdd9d78ce9043f0ab5226ff3817e40f", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Downloading: 0%| | 0.00/1.21k [00:00, 'input_mask': , 'input_type_ids': }\n", "Model outputs {'cls_output': , 'token_embeddings': , 'token_logits': , 'last_token_logits': }\n" ] } ], "source": [ "print(\"Model inputs\", model.input)\n", "print(\"Model outputs\", model.output)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "JN6QpCnznFNR" }, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "k97vFtrSnMGd" }, "outputs": [], "source": [] }, { "cell_type": "markdown", "metadata": { "id": "SGBPiXy8nMjQ" }, "source": [ "## Save Model as Serialized Version\n", "\n", "We have to save the model using ```model.save```. We use the ```SavedModel``` for converting it to ```tflite```." ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "HeDAyaXznZiX", "outputId": "33d022bc-85fc-48cb-c5ee-8ab0d6d66157" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:Compiled the loaded model, but the compiled metrics have yet to be built. `model.compile_metrics` will be empty until you train or evaluate the model.\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "WARNING:tensorflow:Compiled the loaded model, but the compiled metrics have yet to be built. `model.compile_metrics` will be empty until you train or evaluate the model.\n", "WARNING:absl:Found untraced functions such as word_embeddings_layer_call_fn, word_embeddings_layer_call_and_return_conditional_losses, type_embeddings_layer_call_fn, type_embeddings_layer_call_and_return_conditional_losses, positional_embeddings_layer_call_fn while saving (showing 5 of 870). These functions will not be directly callable after loading.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Assets written to: bert-base-cased/saved_model/assets\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "INFO:tensorflow:Assets written to: bert-base-cased/saved_model/assets\n", "WARNING:absl: has the same name 'MultiHeadAttention' as a built-in Keras object. Consider renaming to avoid naming conflicts when loading with `tf.keras.models.load_model`. If renaming is not possible, pass the object in the `custom_objects` parameter of the load function.\n", "WARNING:absl: has the same name 'MultiHeadAttention' as a built-in Keras object. Consider renaming to avoid naming conflicts when loading with `tf.keras.models.load_model`. If renaming is not possible, pass the object in the `custom_objects` parameter of the load function.\n", "WARNING:absl: has the same name 'MultiHeadAttention' as a built-in Keras object. Consider renaming to avoid naming conflicts when loading with `tf.keras.models.load_model`. If renaming is not possible, pass the object in the `custom_objects` parameter of the load function.\n", "WARNING:absl: has the same name 'MultiHeadAttention' as a built-in Keras object. Consider renaming to avoid naming conflicts when loading with `tf.keras.models.load_model`. If renaming is not possible, pass the object in the `custom_objects` parameter of the load function.\n", "WARNING:absl: has the same name 'MultiHeadAttention' as a built-in Keras object. Consider renaming to avoid naming conflicts when loading with `tf.keras.models.load_model`. If renaming is not possible, pass the object in the `custom_objects` parameter of the load function.\n", "WARNING:absl: has the same name 'MultiHeadAttention' as a built-in Keras object. Consider renaming to avoid naming conflicts when loading with `tf.keras.models.load_model`. If renaming is not possible, pass the object in the `custom_objects` parameter of the load function.\n", "WARNING:absl: has the same name 'MultiHeadAttention' as a built-in Keras object. Consider renaming to avoid naming conflicts when loading with `tf.keras.models.load_model`. If renaming is not possible, pass the object in the `custom_objects` parameter of the load function.\n", "WARNING:absl: has the same name 'MultiHeadAttention' as a built-in Keras object. Consider renaming to avoid naming conflicts when loading with `tf.keras.models.load_model`. If renaming is not possible, pass the object in the `custom_objects` parameter of the load function.\n", "WARNING:absl: has the same name 'MultiHeadAttention' as a built-in Keras object. Consider renaming to avoid naming conflicts when loading with `tf.keras.models.load_model`. If renaming is not possible, pass the object in the `custom_objects` parameter of the load function.\n", "WARNING:absl: has the same name 'MultiHeadAttention' as a built-in Keras object. Consider renaming to avoid naming conflicts when loading with `tf.keras.models.load_model`. If renaming is not possible, pass the object in the `custom_objects` parameter of the load function.\n", "WARNING:absl: has the same name 'MultiHeadAttention' as a built-in Keras object. Consider renaming to avoid naming conflicts when loading with `tf.keras.models.load_model`. If renaming is not possible, pass the object in the `custom_objects` parameter of the load function.\n", "WARNING:absl: has the same name 'MultiHeadAttention' as a built-in Keras object. Consider renaming to avoid naming conflicts when loading with `tf.keras.models.load_model`. If renaming is not possible, pass the object in the `custom_objects` parameter of the load function.\n" ] } ], "source": [ "model.save(\"{}/saved_model\".format(model_name), save_format='tf')" ] }, { "cell_type": "code", "execution_count": 15, "metadata": { "id": "DGIccdmJnj_5" }, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "-HOEyoodnvoU" }, "outputs": [], "source": [] }, { "cell_type": "markdown", "metadata": { "id": "XTLzBJAGnv2m" }, "source": [ "## Convert SavedModel to TFlite" ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "SqC2HJ10nywO", "outputId": "6cf948fe-df89-4375-9fc8-268570583365" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "WARNING:absl:Buffer deduplication procedure will be skipped when flatbuffer library is not properly loaded\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "TFlite conversion succesful\n" ] } ], "source": [ "converter = tf.lite.TFLiteConverter.from_saved_model(\"{}/saved_model\".format(model_name)) # path to the SavedModel directory\n", "converter.experimental_new_converter = True\n", "\n", "tflite_model = converter.convert()\n", "\n", "open(\"{}/saved_model.tflite\".format(model_name), \"wb\").write(tflite_model)\n", "print(\"TFlite conversion succesful\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "eBqTsjUToG7I" }, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "FdIFext9ta6E" }, "outputs": [], "source": [] }, { "cell_type": "markdown", "metadata": { "id": "bDnheGa_tmZq" }, "source": [ "## Load TFlite Model " ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "id": "vHzESzPAtpHd" }, "outputs": [], "source": [ "# Load the TFLite model and allocate tensors.\n", "interpreter = tf.lite.Interpreter(model_path=\"{}/saved_model.tflite\".format(model_name))\n", "interpreter.allocate_tensors()\n", "\n", "# Get input and output tensors.\n", "input_details = interpreter.get_input_details()\n", "output_details = interpreter.get_output_details()" ] }, { "cell_type": "code", "execution_count": 28, "metadata": { "id": "GTYdEgPatzVk" }, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "C5fUX6ZqxefF" }, "outputs": [], "source": [] }, { "cell_type": "markdown", "metadata": { "id": "u1I0ZJ-XxfDg" }, "source": [ "## Assert TFlite Model and Keras Model outputs\n", "\n", "After conversion we have to assert the model outputs using\n", "```tflite``` and ```Keras``` model, to ensure proper conversion.\n", "\n", "1. Create examples using ```tf.random.uniform```. \n", "2. Check outputs using both models.\n", "3. Note: We need slightly higher ```rtol``` here to assert." ] }, { "cell_type": "code", "execution_count": 15, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "QnYr9D5Ot6t4", "outputId": "ddb89066-9b37-47db-d13b-246f061b0582" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Outputs asserted and succesful: ✅\n" ] } ], "source": [ "# Dummy Examples \n", "input_ids = tf.random.uniform(minval=0, maxval=100, shape=(batch_size, sequence_length), dtype=tf.int32)\n", "input_mask = tf.ones_like(input_ids)\n", "input_type_ids = tf.zeros_like(input_ids)\n", "\n", "\n", "# input type ids\n", "interpreter.set_tensor(\n", " input_details[0]['index'],\n", " input_type_ids,\n", ")\n", "# input_mask\n", "interpreter.set_tensor(input_details[1]['index'], input_mask)\n", "\n", "# input ids\n", "interpreter.set_tensor(\n", " input_details[2]['index'],\n", " input_ids,\n", ")\n", "\n", "# Invoke inputs\n", "interpreter.invoke()\n", "# Take last output\n", "tflite_output = interpreter.get_tensor(output_details[-1]['index'])\n", "\n", "# Keras Model outputs .\n", "model_inputs = {'input_ids': input_ids, 'input_mask': input_mask, 'input_type_ids': input_type_ids}\n", "model_outputs = model(model_inputs)\n", "\n", "# We need a slightly higher rtol here to assert :-)\n", "tf.debugging.assert_near(tflite_output, model_outputs['token_embeddings'], rtol=3.0)\n", "print(\"Outputs asserted and succesful: ✅\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "mh3bNREFQyk0" }, "outputs": [], "source": [] } ], "metadata": { "colab": { "name": "bert_tflite.ipynb", "provenance": [] }, "jupytext": { "formats": "ipynb,md:myst" }, "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.7" } }, "nbformat": 4, "nbformat_minor": 1 }