Roberta TFLite¶
!pip install tf-transformers
!pip install sentencepiece
!pip install tensorflow-text
!pip install transformers
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' # Supper TF warnings
import tensorflow as tf
print("Tensorflow version", tf.__version__)
from tf_transformers.models import RobertaModel
Tensorflow version 2.7.0
Convert a Model to TFlite¶
The most important thing to notice here is that, if we want to convert a model to tflite
, we have to ensure that inputs
to the model are deterministic, which means inputs should not be dynamic. We have to fix batch_size, sequence_length and other related input constraints depends on the model of interest.
Load Roberta Model¶
Fix the inputs
We can always check the
model
inputs and output by usingmodel.input
andmodel.output
.We use
batch_size=1
andsequence_length=64
.)
model_name = 'roberta-base'
batch_size = 1
sequence_length = 64
model = RobertaModel.from_pretrained(model_name, batch_size=batch_size, sequence_length=sequence_length)
INFO:absl:Successful ✅✅: Model checkpoints matched and loaded from /root/.cache/huggingface/hub/tftransformers--roberta-base-no-mlm.main.9e4aa91ba5936c6ac98586f85c152831e421d0ec/ckpt-1
INFO:absl:Successful ✅: Loaded model from tftransformers/roberta-base-no-mlm
Verify Models inputs and outputs¶
print("Model inputs", model.input)
print("Model outputs", model.output)
Model inputs {'input_ids': <KerasTensor: shape=(1, 64) dtype=int32 (created by layer 'input_ids')>, 'input_mask': <KerasTensor: shape=(1, 64) dtype=int32 (created by layer 'input_mask')>, 'input_type_ids': <KerasTensor: shape=(1, 64) dtype=int32 (created by layer 'input_type_ids')>}
Model outputs {'cls_output': <KerasTensor: shape=(1, 768) dtype=float32 (created by layer 'tf_transformers/roberta')>, 'token_embeddings': <KerasTensor: shape=(1, 64, 768) dtype=float32 (created by layer 'tf_transformers/roberta')>, 'token_logits': <KerasTensor: shape=(1, 64, 50265) dtype=float32 (created by layer 'tf_transformers/roberta')>, 'last_token_logits': <KerasTensor: shape=(1, 50265) dtype=float32 (created by layer 'tf_transformers/roberta')>}
Save Model as Serialized Version¶
We have to save the model using model.save
. We use the SavedModel
for converting it to tflite
.
model.save("{}/saved_model".format(model_name), save_format='tf')
WARNING:tensorflow:Compiled the loaded model, but the compiled metrics have yet to be built. `model.compile_metrics` will be empty until you train or evaluate the model.
WARNING:tensorflow:Compiled the loaded model, but the compiled metrics have yet to be built. `model.compile_metrics` will be empty until you train or evaluate the model.
WARNING:absl:Found untraced functions such as word_embeddings_layer_call_fn, word_embeddings_layer_call_and_return_conditional_losses, type_embeddings_layer_call_fn, type_embeddings_layer_call_and_return_conditional_losses, positional_embeddings_layer_call_fn while saving (showing 5 of 870). These functions will not be directly callable after loading.
INFO:tensorflow:Assets written to: roberta-base/saved_model/assets
INFO:tensorflow:Assets written to: roberta-base/saved_model/assets
WARNING:absl:<tf_transformers.layers.attention.bert_attention.MultiHeadAttention object at 0x7f7edda72790> has the same name 'MultiHeadAttention' as a built-in Keras object. Consider renaming <class 'tf_transformers.layers.attention.bert_attention.MultiHeadAttention'> to avoid naming conflicts when loading with `tf.keras.models.load_model`. If renaming is not possible, pass the object in the `custom_objects` parameter of the load function.
WARNING:absl:<tf_transformers.layers.attention.bert_attention.MultiHeadAttention object at 0x7f7edda3f490> has the same name 'MultiHeadAttention' as a built-in Keras object. Consider renaming <class 'tf_transformers.layers.attention.bert_attention.MultiHeadAttention'> to avoid naming conflicts when loading with `tf.keras.models.load_model`. If renaming is not possible, pass the object in the `custom_objects` parameter of the load function.
WARNING:absl:<tf_transformers.layers.attention.bert_attention.MultiHeadAttention object at 0x7f7edd7c8d10> has the same name 'MultiHeadAttention' as a built-in Keras object. Consider renaming <class 'tf_transformers.layers.attention.bert_attention.MultiHeadAttention'> to avoid naming conflicts when loading with `tf.keras.models.load_model`. If renaming is not possible, pass the object in the `custom_objects` parameter of the load function.
WARNING:absl:<tf_transformers.layers.attention.bert_attention.MultiHeadAttention object at 0x7f7edd969e10> has the same name 'MultiHeadAttention' as a built-in Keras object. Consider renaming <class 'tf_transformers.layers.attention.bert_attention.MultiHeadAttention'> to avoid naming conflicts when loading with `tf.keras.models.load_model`. If renaming is not possible, pass the object in the `custom_objects` parameter of the load function.
WARNING:absl:<tf_transformers.layers.attention.bert_attention.MultiHeadAttention object at 0x7f7edd9c5a50> has the same name 'MultiHeadAttention' as a built-in Keras object. Consider renaming <class 'tf_transformers.layers.attention.bert_attention.MultiHeadAttention'> to avoid naming conflicts when loading with `tf.keras.models.load_model`. If renaming is not possible, pass the object in the `custom_objects` parameter of the load function.
WARNING:absl:<tf_transformers.layers.attention.bert_attention.MultiHeadAttention object at 0x7f7edd86da10> has the same name 'MultiHeadAttention' as a built-in Keras object. Consider renaming <class 'tf_transformers.layers.attention.bert_attention.MultiHeadAttention'> to avoid naming conflicts when loading with `tf.keras.models.load_model`. If renaming is not possible, pass the object in the `custom_objects` parameter of the load function.
WARNING:absl:<tf_transformers.layers.attention.bert_attention.MultiHeadAttention object at 0x7f7edda6cd90> has the same name 'MultiHeadAttention' as a built-in Keras object. Consider renaming <class 'tf_transformers.layers.attention.bert_attention.MultiHeadAttention'> to avoid naming conflicts when loading with `tf.keras.models.load_model`. If renaming is not possible, pass the object in the `custom_objects` parameter of the load function.
WARNING:absl:<tf_transformers.layers.attention.bert_attention.MultiHeadAttention object at 0x7f7edda5db90> has the same name 'MultiHeadAttention' as a built-in Keras object. Consider renaming <class 'tf_transformers.layers.attention.bert_attention.MultiHeadAttention'> to avoid naming conflicts when loading with `tf.keras.models.load_model`. If renaming is not possible, pass the object in the `custom_objects` parameter of the load function.
WARNING:absl:<tf_transformers.layers.attention.bert_attention.MultiHeadAttention object at 0x7f7eddab7910> has the same name 'MultiHeadAttention' as a built-in Keras object. Consider renaming <class 'tf_transformers.layers.attention.bert_attention.MultiHeadAttention'> to avoid naming conflicts when loading with `tf.keras.models.load_model`. If renaming is not possible, pass the object in the `custom_objects` parameter of the load function.
WARNING:absl:<tf_transformers.layers.attention.bert_attention.MultiHeadAttention object at 0x7f7eddbd2950> has the same name 'MultiHeadAttention' as a built-in Keras object. Consider renaming <class 'tf_transformers.layers.attention.bert_attention.MultiHeadAttention'> to avoid naming conflicts when loading with `tf.keras.models.load_model`. If renaming is not possible, pass the object in the `custom_objects` parameter of the load function.
WARNING:absl:<tf_transformers.layers.attention.bert_attention.MultiHeadAttention object at 0x7f7eddb13c50> has the same name 'MultiHeadAttention' as a built-in Keras object. Consider renaming <class 'tf_transformers.layers.attention.bert_attention.MultiHeadAttention'> to avoid naming conflicts when loading with `tf.keras.models.load_model`. If renaming is not possible, pass the object in the `custom_objects` parameter of the load function.
WARNING:absl:<tf_transformers.layers.attention.bert_attention.MultiHeadAttention object at 0x7f7edd8f9850> has the same name 'MultiHeadAttention' as a built-in Keras object. Consider renaming <class 'tf_transformers.layers.attention.bert_attention.MultiHeadAttention'> to avoid naming conflicts when loading with `tf.keras.models.load_model`. If renaming is not possible, pass the object in the `custom_objects` parameter of the load function.
Convert SavedModel to TFlite¶
converter = tf.lite.TFLiteConverter.from_saved_model("{}/saved_model".format(model_name)) # path to the SavedModel directory
converter.experimental_new_converter = True
tflite_model = converter.convert()
open("{}/saved_model.tflite".format(model_name), "wb").write(tflite_model)
print("TFlite conversion succesful")
WARNING:absl:Buffer deduplication procedure will be skipped when flatbuffer library is not properly loaded
TFlite conversion succesful
Load TFlite Model¶
# Load the TFLite model and allocate tensors.
interpreter = tf.lite.Interpreter(model_path="{}/saved_model.tflite".format(model_name))
interpreter.allocate_tensors()
# Get input and output tensors.
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
Assert TFlite Model and Keras Model outputs¶
After conversion we have to assert the model outputs using
tflite
and Keras
model, to ensure proper conversion.
Create examples using
tf.random.uniform
.Check outputs using both models.
Note: We need slightly higher
rtol
here to assert.
# Dummy Examples
input_ids = tf.random.uniform(minval=0, maxval=100, shape=(batch_size, sequence_length), dtype=tf.int32)
input_mask = tf.ones_like(input_ids)
input_type_ids = tf.zeros_like(input_ids)
# input type ids
interpreter.set_tensor(
input_details[0]['index'],
input_type_ids,
)
# input_mask
interpreter.set_tensor(input_details[1]['index'], input_mask)
# input ids
interpreter.set_tensor(
input_details[2]['index'],
input_ids,
)
# Invoke inputs
interpreter.invoke()
# Take last output
tflite_output = interpreter.get_tensor(output_details[-1]['index'])
# Keras Model outputs .
model_inputs = {'input_ids': input_ids, 'input_mask': input_mask, 'input_type_ids': input_type_ids}
model_outputs = model(model_inputs)
# We need a slightly higher rtol here to assert :-)
tf.debugging.assert_near(tflite_output, model_outputs['token_embeddings'], rtol=3.0)
print("Outputs asserted and succesful: ✅")
Outputs asserted and succesful: ✅