Roberta TFLite

!pip install tf-transformers

!pip install sentencepiece

!pip install tensorflow-text

!pip install transformers
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' # Supper TF warnings

import tensorflow as tf
print("Tensorflow version", tf.__version__)

from tf_transformers.models import RobertaModel
Tensorflow version 2.7.0

Convert a Model to TFlite

The most important thing to notice here is that, if we want to convert a model to tflite, we have to ensure that inputs to the model are deterministic, which means inputs should not be dynamic. We have to fix batch_size, sequence_length and other related input constraints depends on the model of interest.

Load Roberta Model

  1. Fix the inputs

  2. We can always check the model inputs and output by using model.input and model.output.

  3. We use batch_size=1 and sequence_length=64.)

model_name = 'roberta-base'
batch_size = 1
sequence_length = 64
model = RobertaModel.from_pretrained(model_name, batch_size=batch_size, sequence_length=sequence_length)
INFO:absl:Successful ✅✅: Model checkpoints matched and loaded from /root/.cache/huggingface/hub/tftransformers--roberta-base-no-mlm.main.9e4aa91ba5936c6ac98586f85c152831e421d0ec/ckpt-1
INFO:absl:Successful ✅: Loaded model from tftransformers/roberta-base-no-mlm

Verify Models inputs and outputs

print("Model inputs", model.input)
print("Model outputs", model.output)
Model inputs {'input_ids': <KerasTensor: shape=(1, 64) dtype=int32 (created by layer 'input_ids')>, 'input_mask': <KerasTensor: shape=(1, 64) dtype=int32 (created by layer 'input_mask')>, 'input_type_ids': <KerasTensor: shape=(1, 64) dtype=int32 (created by layer 'input_type_ids')>}
Model outputs {'cls_output': <KerasTensor: shape=(1, 768) dtype=float32 (created by layer 'tf_transformers/roberta')>, 'token_embeddings': <KerasTensor: shape=(1, 64, 768) dtype=float32 (created by layer 'tf_transformers/roberta')>, 'token_logits': <KerasTensor: shape=(1, 64, 50265) dtype=float32 (created by layer 'tf_transformers/roberta')>, 'last_token_logits': <KerasTensor: shape=(1, 50265) dtype=float32 (created by layer 'tf_transformers/roberta')>}

Save Model as Serialized Version

We have to save the model using model.save. We use the SavedModel for converting it to tflite.

model.save("{}/saved_model".format(model_name), save_format='tf')
WARNING:tensorflow:Compiled the loaded model, but the compiled metrics have yet to be built. `model.compile_metrics` will be empty until you train or evaluate the model.
WARNING:tensorflow:Compiled the loaded model, but the compiled metrics have yet to be built. `model.compile_metrics` will be empty until you train or evaluate the model.
WARNING:absl:Found untraced functions such as word_embeddings_layer_call_fn, word_embeddings_layer_call_and_return_conditional_losses, type_embeddings_layer_call_fn, type_embeddings_layer_call_and_return_conditional_losses, positional_embeddings_layer_call_fn while saving (showing 5 of 870). These functions will not be directly callable after loading.
INFO:tensorflow:Assets written to: roberta-base/saved_model/assets
INFO:tensorflow:Assets written to: roberta-base/saved_model/assets
WARNING:absl:<tf_transformers.layers.attention.bert_attention.MultiHeadAttention object at 0x7f7edda72790> has the same name 'MultiHeadAttention' as a built-in Keras object. Consider renaming <class 'tf_transformers.layers.attention.bert_attention.MultiHeadAttention'> to avoid naming conflicts when loading with `tf.keras.models.load_model`. If renaming is not possible, pass the object in the `custom_objects` parameter of the load function.
WARNING:absl:<tf_transformers.layers.attention.bert_attention.MultiHeadAttention object at 0x7f7edda3f490> has the same name 'MultiHeadAttention' as a built-in Keras object. Consider renaming <class 'tf_transformers.layers.attention.bert_attention.MultiHeadAttention'> to avoid naming conflicts when loading with `tf.keras.models.load_model`. If renaming is not possible, pass the object in the `custom_objects` parameter of the load function.
WARNING:absl:<tf_transformers.layers.attention.bert_attention.MultiHeadAttention object at 0x7f7edd7c8d10> has the same name 'MultiHeadAttention' as a built-in Keras object. Consider renaming <class 'tf_transformers.layers.attention.bert_attention.MultiHeadAttention'> to avoid naming conflicts when loading with `tf.keras.models.load_model`. If renaming is not possible, pass the object in the `custom_objects` parameter of the load function.
WARNING:absl:<tf_transformers.layers.attention.bert_attention.MultiHeadAttention object at 0x7f7edd969e10> has the same name 'MultiHeadAttention' as a built-in Keras object. Consider renaming <class 'tf_transformers.layers.attention.bert_attention.MultiHeadAttention'> to avoid naming conflicts when loading with `tf.keras.models.load_model`. If renaming is not possible, pass the object in the `custom_objects` parameter of the load function.
WARNING:absl:<tf_transformers.layers.attention.bert_attention.MultiHeadAttention object at 0x7f7edd9c5a50> has the same name 'MultiHeadAttention' as a built-in Keras object. Consider renaming <class 'tf_transformers.layers.attention.bert_attention.MultiHeadAttention'> to avoid naming conflicts when loading with `tf.keras.models.load_model`. If renaming is not possible, pass the object in the `custom_objects` parameter of the load function.
WARNING:absl:<tf_transformers.layers.attention.bert_attention.MultiHeadAttention object at 0x7f7edd86da10> has the same name 'MultiHeadAttention' as a built-in Keras object. Consider renaming <class 'tf_transformers.layers.attention.bert_attention.MultiHeadAttention'> to avoid naming conflicts when loading with `tf.keras.models.load_model`. If renaming is not possible, pass the object in the `custom_objects` parameter of the load function.
WARNING:absl:<tf_transformers.layers.attention.bert_attention.MultiHeadAttention object at 0x7f7edda6cd90> has the same name 'MultiHeadAttention' as a built-in Keras object. Consider renaming <class 'tf_transformers.layers.attention.bert_attention.MultiHeadAttention'> to avoid naming conflicts when loading with `tf.keras.models.load_model`. If renaming is not possible, pass the object in the `custom_objects` parameter of the load function.
WARNING:absl:<tf_transformers.layers.attention.bert_attention.MultiHeadAttention object at 0x7f7edda5db90> has the same name 'MultiHeadAttention' as a built-in Keras object. Consider renaming <class 'tf_transformers.layers.attention.bert_attention.MultiHeadAttention'> to avoid naming conflicts when loading with `tf.keras.models.load_model`. If renaming is not possible, pass the object in the `custom_objects` parameter of the load function.
WARNING:absl:<tf_transformers.layers.attention.bert_attention.MultiHeadAttention object at 0x7f7eddab7910> has the same name 'MultiHeadAttention' as a built-in Keras object. Consider renaming <class 'tf_transformers.layers.attention.bert_attention.MultiHeadAttention'> to avoid naming conflicts when loading with `tf.keras.models.load_model`. If renaming is not possible, pass the object in the `custom_objects` parameter of the load function.
WARNING:absl:<tf_transformers.layers.attention.bert_attention.MultiHeadAttention object at 0x7f7eddbd2950> has the same name 'MultiHeadAttention' as a built-in Keras object. Consider renaming <class 'tf_transformers.layers.attention.bert_attention.MultiHeadAttention'> to avoid naming conflicts when loading with `tf.keras.models.load_model`. If renaming is not possible, pass the object in the `custom_objects` parameter of the load function.
WARNING:absl:<tf_transformers.layers.attention.bert_attention.MultiHeadAttention object at 0x7f7eddb13c50> has the same name 'MultiHeadAttention' as a built-in Keras object. Consider renaming <class 'tf_transformers.layers.attention.bert_attention.MultiHeadAttention'> to avoid naming conflicts when loading with `tf.keras.models.load_model`. If renaming is not possible, pass the object in the `custom_objects` parameter of the load function.
WARNING:absl:<tf_transformers.layers.attention.bert_attention.MultiHeadAttention object at 0x7f7edd8f9850> has the same name 'MultiHeadAttention' as a built-in Keras object. Consider renaming <class 'tf_transformers.layers.attention.bert_attention.MultiHeadAttention'> to avoid naming conflicts when loading with `tf.keras.models.load_model`. If renaming is not possible, pass the object in the `custom_objects` parameter of the load function.

Convert SavedModel to TFlite

converter = tf.lite.TFLiteConverter.from_saved_model("{}/saved_model".format(model_name)) # path to the SavedModel directory
converter.experimental_new_converter = True

tflite_model = converter.convert()

open("{}/saved_model.tflite".format(model_name), "wb").write(tflite_model)
print("TFlite conversion succesful")
WARNING:absl:Buffer deduplication procedure will be skipped when flatbuffer library is not properly loaded
TFlite conversion succesful

Load TFlite Model

# Load the TFLite model and allocate tensors.
interpreter = tf.lite.Interpreter(model_path="{}/saved_model.tflite".format(model_name))
interpreter.allocate_tensors()

# Get input and output tensors.
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()

Assert TFlite Model and Keras Model outputs

After conversion we have to assert the model outputs using tflite and Keras model, to ensure proper conversion.

  1. Create examples using tf.random.uniform.

  2. Check outputs using both models.

  3. Note: We need slightly higher rtol here to assert.

# Dummy Examples 
input_ids = tf.random.uniform(minval=0, maxval=100, shape=(batch_size, sequence_length), dtype=tf.int32)
input_mask = tf.ones_like(input_ids)
input_type_ids = tf.zeros_like(input_ids)


# input type ids
interpreter.set_tensor(
    input_details[0]['index'],
    input_type_ids,
)
# input_mask
interpreter.set_tensor(input_details[1]['index'], input_mask)

# input ids
interpreter.set_tensor(
    input_details[2]['index'],
    input_ids,
)

# Invoke inputs
interpreter.invoke()
# Take last output
tflite_output = interpreter.get_tensor(output_details[-1]['index'])

# Keras Model outputs .
model_inputs = {'input_ids': input_ids, 'input_mask': input_mask, 'input_type_ids': input_type_ids}
model_outputs = model(model_inputs)

# We need a slightly higher rtol here to assert :-)
tf.debugging.assert_near(tflite_output, model_outputs['token_embeddings'], rtol=3.0)
print("Outputs asserted and succesful:  ✅")
Outputs asserted and succesful:  ✅