{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": { "id": "jbxNfeNfhrIt" }, "outputs": [], "source": [] }, { "cell_type": "markdown", "metadata": { "id": "3oU2Xas2ikcw" }, "source": [ "# Train (Masked Language Model) with tf-transformers in TPU\n", "\n", "This tutorial contains complete code to train MLM model on C4 EN 10K dataset.\n", "In addition to training a model, you will learn how to preprocess text into an appropriate format.\n", "\n", "In this notebook, you will:\n", "\n", "- Load the C4 (10k EN) dataset from HuggingFace\n", "- Load GPT2 style (configuration) Model using tf-transformers\n", "- Build train dataset (on the fly) feature preparation using\n", "tokenizer from tf-transformers.\n", "- Build a masked LM Model from GPT2 style configuration\n", "- Save your model\n", "- Use the base model for further tasks\n", "\n", "If you're new to working with the C4 dataset, please see [C4](https://www.tensorflow.org/datasets/catalog/c4) for more details." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "5V5sXDHbi4Fr", "outputId": "8a2dcfd7-1b69-4182-a24f-f533fb9c41b5" }, "outputs": [], "source": [ "!pip install tf-transformers\n", "\n", "!pip install sentencepiece\n", "\n", "!pip install tensorflow-text\n", "\n", "!pip install transformers\n", "\n", "!pip install wandb\n", "\n", "!pip install datasets" ] }, { "cell_type": "code", "execution_count": 28, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "Ub8L8158jHhI", "outputId": "cc274b0e-9867-4609-b8a8-bc8b6e68d4e9" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Tensorflow version 2.7.0\n", "Tensorflow text version 2.7.3\n", "Devices [PhysicalDevice(name='/physical_device:CPU:0', device_type='CPU')]\n" ] } ], "source": [ "import os\n", "os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' # Supper TF warnings\n", "\n", "import tensorflow as tf\n", "import tensorflow_text as tf_text\n", "import datasets\n", "import wandb\n", "\n", "print(\"Tensorflow version\", tf.__version__)\n", "print(\"Tensorflow text version\", tf_text.__version__)\n", "print(\"Devices\", tf.config.list_physical_devices())\n", "\n", "from tf_transformers.models import GPT2Model, MaskedLMModel, AlbertTokenizerTFText\n", "from tf_transformers.core import Trainer\n", "from tf_transformers.optimization import create_optimizer\n", "from tf_transformers.text.lm_tasks import mlm_fn\n", "from tf_transformers.losses.loss_wrapper import get_lm_loss" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Trainer has to be initialized before everything only in TPU (sometimes)." ] }, { "cell_type": "code", "execution_count": 30, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "yNgQDdsgj4Zw", "outputId": "d8d95b9e-047f-46e9-bd63-2f5a02495efc" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Deallocate tpu buffers before initializing tpu system.\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "INFO:tensorflow:Deallocate tpu buffers before initializing tpu system.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:TPU system grpc://10.91.104.90:8470 has already been initialized. Reinitializing the TPU can cause previously created variables on TPU to be lost.\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "WARNING:tensorflow:TPU system grpc://10.91.104.90:8470 has already been initialized. Reinitializing the TPU can cause previously created variables on TPU to be lost.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Initializing the TPU system: grpc://10.91.104.90:8470\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "INFO:tensorflow:Initializing the TPU system: grpc://10.91.104.90:8470\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Finished initializing TPU system.\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "INFO:tensorflow:Finished initializing TPU system.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Found TPU system:\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "INFO:tensorflow:Found TPU system:\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:*** Num TPU Cores: 8\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "INFO:tensorflow:*** Num TPU Cores: 8\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:*** Num TPU Workers: 1\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "INFO:tensorflow:*** Num TPU Workers: 1\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:*** Num TPU Cores Per Worker: 8\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "INFO:tensorflow:*** Num TPU Cores Per Worker: 8\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:localhost/replica:0/task:0/device:CPU:0, CPU, 0, 0)\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:localhost/replica:0/task:0/device:CPU:0, CPU, 0, 0)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:CPU:0, CPU, 0, 0)\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:CPU:0, CPU, 0, 0)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:0, TPU, 0, 0)\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:0, TPU, 0, 0)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:1, TPU, 0, 0)\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:1, TPU, 0, 0)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:2, TPU, 0, 0)\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:2, TPU, 0, 0)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:3, TPU, 0, 0)\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:3, TPU, 0, 0)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:4, TPU, 0, 0)\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:4, TPU, 0, 0)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:5, TPU, 0, 0)\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:5, TPU, 0, 0)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:6, TPU, 0, 0)\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:6, TPU, 0, 0)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:7, TPU, 0, 0)\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:7, TPU, 0, 0)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU_SYSTEM:0, TPU_SYSTEM, 0, 0)\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU_SYSTEM:0, TPU_SYSTEM, 0, 0)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:XLA_CPU:0, XLA_CPU, 0, 0)\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:XLA_CPU:0, XLA_CPU, 0, 0)\n" ] } ], "source": [ "trainer = Trainer(distribution_strategy='tpu', num_gpus=0, tpu_address='colab')" ] }, { "cell_type": "markdown", "metadata": { "id": "7MRKsHoyj_wU" }, "source": [ "### Load Model, Optimizer , Trainer\n", "\n", "Our Trainer expects ```model```, ```optimizer``` and ```loss``` to be a function." ] }, { "cell_type": "code", "execution_count": 31, "metadata": { "id": "JJ9KB3oCkAVP" }, "outputs": [], "source": [ "# Load Model\n", "def get_model(model_name, vocab_size, is_training, use_dropout, num_hidden_layers):\n", " \"\"\"Get Model\"\"\"\n", "\n", " def model_fn():\n", " config = GPT2Model.get_config(model_name)\n", " config['vocab_size'] = vocab_size\n", " model = GPT2Model.from_config(config, mask_mode='user_defined', num_hidden_layers=num_hidden_layers, return_layer=True)\n", " model = MaskedLMModel(\n", " model,\n", " use_extra_mlm_layer=False,\n", " hidden_size=config['embedding_size'],\n", " layer_norm_epsilon=config['layer_norm_epsilon'],\n", " ) \n", " return model.get_model()\n", " return model_fn\n", "\n", "# Load Optimizer\n", "def get_optimizer(learning_rate, examples, batch_size, epochs, use_constant_lr=False):\n", " \"\"\"Get optimizer\"\"\"\n", " steps_per_epoch = int(examples / batch_size)\n", " num_train_steps = steps_per_epoch * epochs\n", " warmup_steps = int(0.1 * num_train_steps)\n", "\n", " def optimizer_fn():\n", " optimizer, learning_rate_fn = create_optimizer(learning_rate, num_train_steps, warmup_steps, use_constant_lr=use_constant_lr)\n", " return optimizer\n", "\n", " return optimizer_fn\n", "\n", "# Load trainer\n", "def get_trainer(distribution_strategy, num_gpus=0, tpu_address=None):\n", " \"\"\"Get Trainer\"\"\"\n", " trainer = Trainer(distribution_strategy, num_gpus=num_gpus, tpu_address=tpu_address)\n", " return trainer" ] }, { "cell_type": "markdown", "metadata": { "id": "FrbiylDkklJ1" }, "source": [ "### Prepare Data for Training\n", "\n", "We will make use of ```Tensorflow Text``` based tokenizer to do ```on-the-fly``` preprocessing, without having any\n", "overhead of pre prepapre the data in the form of ```pickle```, ```numpy``` or ```tfrecords```." ] }, { "cell_type": "code", "execution_count": 32, "metadata": { "id": "DPfht1IOklYA" }, "outputs": [], "source": [ "# Load dataset\n", "def load_dataset(dataset, tokenizer_layer, max_seq_len, max_predictions_per_seq, batch_size):\n", " \"\"\"\n", " Args:\n", " dataset; HuggingFace dataset\n", " tokenizer_layer: tf-transformers tokenizer\n", " max_seq_len: int (maximum sequence length of text)\n", " batch_size: int (batch_size)\n", " max_predictions_per_seq: int (Maximum number of words to mask)\n", " \"\"\"\n", " tfds_dict = dataset.to_dict()\n", " tfdataset = tf.data.Dataset.from_tensor_slices(tfds_dict)\n", "\n", " # MLM function\n", " masked_lm_map_fn = mlm_fn(tokenizer_layer, max_seq_len, max_predictions_per_seq)\n", "\n", " # MLM\n", " tfdataset = tfdataset.map(masked_lm_map_fn, num_parallel_calls=tf.data.AUTOTUNE)\n", " # Batch\n", " tfdataset = tfdataset.batch(batch_size, drop_remainder=True).shuffle(50)\n", "\n", " # Auto SHARD\n", " options = tf.data.Options()\n", " options.experimental_distribute.auto_shard_policy = tf.data.experimental.AutoShardPolicy.AUTO\n", " tfdataset = tfdataset.with_options(options)\n", " \n", " return tfdataset" ] }, { "cell_type": "markdown", "metadata": { "id": "Q_dFlI_MrDwG" }, "source": [ "### Prepare Dataset\n", "\n", "1. Set necessay hyperparameters.\n", "2. Prepare ```train dataset```\n", "3. Load ```model```, ```optimizer```, ```loss``` and ```trainer```." ] }, { "cell_type": "code", "execution_count": 54, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 104, "referenced_widgets": [ "d5c1adfe09f7461c89e172b9743af081", "f729259d2bbb445888be1143cc4ffe4a", "c57a7dce76fe4235a9961df2b256ee62", "90b4ff5b71554b5fa21f304e18407da1", "55d6f70eef0c484ca3014e74e7520eb6", "9e7d5627b0ed4d2a953cc7a64f169817", "56d29164ccb643bf9879813c90281151", "7f567c87ab0d46b782796f6bdec74689", "9ba4bc085776410889c4bfb5cef623e5", "a7aebc9e539049ba9c86728a5d402cf2", "fdd587ac1984477485a714a9be9bb2a9" ] }, "id": "s_HE4DquolW2", "outputId": "3f22ea55-1572-4670-bd69-8702153f99ff" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "WARNING:datasets.builder:Reusing dataset c4_en10k (/root/.cache/huggingface/datasets/stas___c4_en10k/plain_text/1.0.0/edbf1ff8b8ee35a9751a7752b5e93a4873cc7905ffae010ad334a2c96f81e1cd)\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "d5c1adfe09f7461c89e172b9743af081", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/1 [00:00 {\n", " function loadScript(url) {\n", " return new Promise(function(resolve, reject) {\n", " let newScript = document.createElement(\"script\");\n", " newScript.onerror = reject;\n", " newScript.onload = resolve;\n", " document.body.appendChild(newScript);\n", " newScript.src = url;\n", " });\n", " }\n", " loadScript(\"https://cdn.jsdelivr.net/npm/postmate/build/postmate.min.js\").then(() => {\n", " const iframe = document.createElement('iframe')\n", " iframe.style.cssText = \"width:0;height:0;border:none\"\n", " document.body.appendChild(iframe)\n", " const handshake = new Postmate({\n", " container: iframe,\n", " url: 'https://wandb.ai/authorize'\n", " });\n", " const timeout = setTimeout(() => reject(\"Couldn't auto authenticate\"), 5000)\n", " handshake.then(function(child) {\n", " child.on('authorize', data => {\n", " clearTimeout(timeout)\n", " resolve(data)\n", " });\n", " });\n", " })\n", " });\n", " " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stderr", "output_type": "stream", "text": [ "\u001b[34m\u001b[1mwandb\u001b[0m: You can find your API key in your browser here: https://wandb.ai/authorize\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "wandb: Paste an API key from your profile and hit enter, or press ctrl+c to quit: ··········\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "\u001b[34m\u001b[1mwandb\u001b[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc\n" ] }, { "data": { "text/html": [ "\n", " Syncing run mlm_tpu to Weights & Biases (docs).
\n", "\n", " " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "" ], "text/plain": [ "" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "project = \"TUTORIALS\"\n", "display_name = \"mlm_tpu\"\n", "wandb.init(project=project, name=display_name)" ] }, { "cell_type": "code", "execution_count": 51, "metadata": { "id": "LKZugOHYUmC5" }, "outputs": [], "source": [] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Train :-)" ] }, { "cell_type": "code", "execution_count": 55, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "Gexhy6NplTv7", "outputId": "6010d4c6-68ec-4ac2-af8c-e5cebf14bd3e" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "INFO:absl:Make sure `steps_per_epoch` should be less than or equal to number of batches in dataset.\n", "INFO:absl:Policy: ----> float32\n", "INFO:absl:Strategy: ---> \n", "INFO:absl:Num TPU Devices: ---> 8\n", "INFO:absl:Create model from config\n", "INFO:absl:Using Constant learning rate\n", "INFO:absl:Using Adamw optimizer\n", "INFO:absl:No ❌❌ checkpoint found in gs://legacyai-bucket/sample_mlm_model\n", "Train: Epoch 1/4 --- Step 100/5000 --- total examples 0: 0%|\u001b[32m \u001b[0m| 0/50 [00:00" ] }, "execution_count": 56, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model_fn = get_model(model_name, vocab_size, is_training=False, use_dropout=False, num_hidden_layers=num_hidden_layers)\n", "\n", "model = model_fn()\n", "model.load_checkpoint(model_checkpoint_dir)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "H8KYSFqqQcHV" }, "outputs": [], "source": [] }, { "cell_type": "markdown", "metadata": { "id": "_oDAmWjlQcK0" }, "source": [ "### Test Model performance \n", "\n", "1. We can assess model performance by checking how it predicts masked word on sample sentences.\n", "2. As we see the following result, its clear that model starts learning." ] }, { "cell_type": "code", "execution_count": 60, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "bRoOpRV93Bun", "outputId": "52fa7b71-96a8-489a-815f-3e65a20dd3a9" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Input ----> Read the rest of this [MASK] to understand things in more detail.\n", "Predicted words ----> ['page', 'continent', 'means', 'window', 'website', 'post', 'tool', 'is', 'book', 'world']\n", "\n", "Input ----> I want to buy the [MASK] because it is so cheap.\n", "Predicted words ----> ['door', 'quote', 'electronics', 'house', 'review', 'website', 'graphics', 'property', 'doors', 'item']\n", "\n", "Input ----> The [MASK] was amazing.\n", "Predicted words ----> ['boys', 'turkey', 'epilogue', 'idea', 'project', 'answer', 'food', 'website', 'show', 'weather']\n", "\n", "Input ----> Sachin Tendulkar is one of the [MASK] palyers in the world.\n", "Predicted words ----> ['busiest', 'leading', 'english', 'latest', 'northern', 'coordinates', 'largest', 'international', 'state', 'registered']\n", "\n", "Input ----> [MASK] is the capital of France.\n", "Predicted words ----> ['this', 'there', 'india', 'it', 'what', 'here', 'below', 'he', 'france', 'that']\n", "\n", "Input ----> Machine Learning requires [MASK]\n", "Predicted words ----> ['.', 'the', 'that', 'for', 'a,', 'an', 'you', 'and', 'your']\n", "\n", "Input ----> He is working as a [MASK]\n", "Predicted words ----> ['field', 'real', 'great', 'chance', 'regular', 'business', 'team', 'facebook', 'freelance', 'sport']\n", "\n", "Input ----> She is working as a [MASK]\n", "Predicted words ----> ['field', 'real', 'path', 'facebook', 'chance', 'trip', 'strategic', 'great', 'regular', 'lot']\n", "\n" ] } ], "source": [ "from transformers import AlbertTokenizer\n", "tokenizer_hf = AlbertTokenizer.from_pretrained(\"albert-base-v2\")\n", "\n", "validation_sentences = [\n", " 'Read the rest of this [MASK] to understand things in more detail.',\n", " 'I want to buy the [MASK] because it is so cheap.',\n", " 'The [MASK] was amazing.',\n", " 'Sachin Tendulkar is one of the [MASK] palyers in the world.',\n", " '[MASK] is the capital of France.',\n", " 'Machine Learning requires [MASK]',\n", " 'He is working as a [MASK]',\n", " 'She is working as a [MASK]',\n", "]\n", "inputs = tokenizer_hf(validation_sentences, padding=True, return_tensors=\"tf\")\n", "\n", "inputs_tf = {}\n", "inputs_tf[\"input_ids\"] = inputs[\"input_ids\"]\n", "inputs_tf[\"input_mask\"] = inputs[\"attention_mask\"]\n", "seq_length = tf.shape(inputs_tf['input_ids'])[1]\n", "inputs_tf['masked_lm_positions'] = tf.zeros_like(inputs_tf[\"input_ids\"]) + tf.range(seq_length)\n", "\n", "\n", "top_k = 10 # topk similar words\n", "outputs_tf = model(inputs_tf)\n", "# Get masked positions from each sentence\n", "masked_positions = tf.argmax(tf.equal(inputs_tf[\"input_ids\"], tokenizer_hf.mask_token_id), axis=1)\n", "for i, logits in enumerate(outputs_tf['token_logits']):\n", " mask_token_logits = logits[masked_positions[i]]\n", " # 0 for probs and 1 for indexes from tf.nn.top_k\n", " top_words = tokenizer_hf.decode(tf.nn.top_k(mask_token_logits, k=top_k)[1].numpy())\n", " print(\"Input ----> {}\".format(validation_sentences[i]))\n", " print(\"Predicted words ----> {}\".format(top_words.split()))\n", " print()" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "80Y0ipAP4K5k" }, "outputs": [], "source": [] } ], "metadata": { "accelerator": "TPU", "colab": { "collapsed_sections": [], "name": "mlm_tpu.ipynb", "provenance": [] }, "jupytext": { "formats": "ipynb,md:myst" }, "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.7" } }, "nbformat": 4, "nbformat_minor": 2 }