{ "cells": [ { "cell_type": "markdown", "id": "a10b8406", "metadata": {}, "source": [ "# Sentence Transformer in tf-transformers\n", "\n", "* This is a simple tutorial to demonstrate how ```SentenceTransformer``` models has been integrated\n", "to ```tf-transformers``` and how to use it\n", "* The following tutorial is applicable to all supported ```SentenceTransformer``` models." ] }, { "cell_type": "code", "execution_count": null, "id": "57850dbf", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "markdown", "id": "e69042f8", "metadata": {}, "source": [ "### Load Sentence-t5 model" ] }, { "cell_type": "code", "execution_count": 11, "id": "789ef733", "metadata": {}, "outputs": [], "source": [ "import tensorflow as tf\n", "from tf_transformers.models import SentenceTransformer" ] }, { "cell_type": "code", "execution_count": 3, "id": "f1521514", "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "4ac56f7c27b74d998861500c3f5a25d4", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Downloading: 0%| | 0.00/1.25k [00:00 physical PluggableDevice (device: 0, name: METAL, pci bus id: )\n", "INFO:absl:Successful ✅✅: Model checkpoints matched and loaded from /Users/sarathrnair/.cache/huggingface/hub/tftransformers__sentence-t5-base-sentence-transformers.main.d64dbdc4c8c15637da4215b81f38af99d48a586c/ckpt-1\n", "INFO:absl:Successful ✅: Loaded model from tftransformers/sentence-t5-base-sentence-transformers\n" ] } ], "source": [ "model_name = 'sentence-transformers/sentence-t5-base' # Load any sentencetransformer model here\n", "model = SentenceTransformer.from_pretrained(model_name)" ] }, { "cell_type": "markdown", "id": "8e288ba8", "metadata": {}, "source": [ "### Whats my model input?\n", "\n", "* All models in ```tf-transformers``` are designed with full connections. All you need is ```model.input``` if its a ```LegacyModel/tf.keras.Model``` or ```model.model_inputs``` if its a ```LegacyLayer/tf.keras.layers.Layer```" ] }, { "cell_type": "code", "execution_count": 5, "id": "b84a7d5f", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'input_ids': ,\n", " 'input_mask': }" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model.input" ] }, { "cell_type": "markdown", "id": "bab5feb6", "metadata": {}, "source": [ "### Whats my model output?\n", "\n", "* All models in ```tf-transformers``` are designed with full connections. All you need is ```model.output``` if its a ```LegacyModel/tf.keras.Model``` or ```model.model_outputs``` if its a ```LegacyLayer/tf.keras.layers.Layer```" ] }, { "cell_type": "code", "execution_count": 6, "id": "f94721b1", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'sentence_vector': }" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model.output" ] }, { "cell_type": "markdown", "id": "34c829dd", "metadata": {}, "source": [ "### Sentence vectors" ] }, { "cell_type": "code", "execution_count": 10, "id": "37b6a87c", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Sentence vector (2, 768)\n" ] } ], "source": [ "from transformers import AutoTokenizer\n", "\n", "tokenizer = AutoTokenizer.from_pretrained(model_name)\n", "\n", "text = ['This is a sentence to get vector', 'This one too']\n", "inputs = tokenizer(text, return_tensors='tf', padding=True)\n", "\n", "inputs_tf = {'input_ids': inputs['input_ids'], 'input_mask': inputs['attention_mask']}\n", "outputs_tf = model(inputs_tf)\n", "print(\"Sentence vector\", outputs_tf['sentence_vector'].shape)" ] }, { "cell_type": "code", "execution_count": null, "id": "fd60cec4", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "markdown", "id": "ca373133", "metadata": {}, "source": [ "### Serialize as usual and load it\n", "\n", "* Serialize, load and assert outputs with non serialized ```(```tf.keras.Model```)```" ] }, { "cell_type": "code", "execution_count": 12, "id": "7c251fe5", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "2022-04-02 17:09:27.002497: W tensorflow/python/util/util.cc:368] Sets are not currently considered sequences, but this may change in the future, so consider avoiding using them.\n", "WARNING:absl:Found untraced functions such as tf_transformers/t5_encoder_layer_call_fn, tf_transformers/t5_encoder_layer_call_and_return_conditional_losses, grt5_dense_layer_layer_call_fn, grt5_dense_layer_layer_call_and_return_conditional_losses, dropout_2_layer_call_fn while saving (showing 5 of 880). These functions will not be directly callable after loading.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Assets written to: MODELS/sentence_t5/assets\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "INFO:tensorflow:Assets written to: MODELS/sentence_t5/assets\n", "2022-04-02 17:09:32.964144: W tensorflow/core/platform/profile_utils/cpu_utils.cc:128] Failed to get CPU frequency: 0 Hz\n", "2022-04-02 17:09:32.965446: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:112] Plugin optimizer for device_type GPU is enabled.\n" ] } ], "source": [ "model_dir = 'MODELS/sentence_t5'\n", "model.save_transformers_serialized(model_dir)\n", "\n", "loaded = tf.saved_model.load(model_dir)\n", "model = loaded.signatures['serving_default']\n", "\n", "outputs_tf_serialized = model(**inputs_tf)\n", "\n", "tf.debugging.assert_near(outputs_tf['sentence_vector'], outputs_tf_serialized['sentence_vector'])" ] }, { "cell_type": "code", "execution_count": null, "id": "109f2d15", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "id": "7a62004c", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "jupytext": { "formats": "ipynb,md:myst" }, "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.7" } }, "nbformat": 4, "nbformat_minor": 5 }