Bert TFLite

!pip install tf-transformers

!pip install sentencepiece

!pip install tensorflow-text

!pip install transformers
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' # Supper TF warnings

import tensorflow as tf
print("Tensorflow version", tf.__version__)

from tf_transformers.models import BertModel
Tensorflow version 2.7.0

Convert a Model to TFlite

The most important thing to notice here is that, if we want to convert a model to tflite, we have to ensure that inputs to the model are deterministic, which means inputs should not be dynamic. We have to fix batch_size, sequence_length and other related input constraints depends on the model of interest.

Load Bert Model

  1. Fix the inputs

  2. We can always check the model inputs and output by using model.input and model.output.

  3. We use batch_size=1 and sequence_length=64.)

model_name = 'bert-base-cased'
batch_size = 1
sequence_length = 64
model = BertModel.from_pretrained(model_name, batch_size=batch_size, sequence_length=sequence_length)
INFO:absl:Successful ✅✅: Model checkpoints matched and loaded from /root/.cache/huggingface/hub/tftransformers--bert-base-cased-no-mlm.main.8558edc2f96edfb697bb90ec134aef6242e3166b/ckpt-1
INFO:absl:Successful ✅: Loaded model from tftransformers/bert-base-cased-no-mlm

Verify Models inputs and outputs

print("Model inputs", model.input)
print("Model outputs", model.output)
Model inputs {'input_ids': <KerasTensor: shape=(1, 64) dtype=int32 (created by layer 'input_ids')>, 'input_mask': <KerasTensor: shape=(1, 64) dtype=int32 (created by layer 'input_mask')>, 'input_type_ids': <KerasTensor: shape=(1, 64) dtype=int32 (created by layer 'input_type_ids')>}
Model outputs {'cls_output': <KerasTensor: shape=(1, 768) dtype=float32 (created by layer 'tf_transformers/bert')>, 'token_embeddings': <KerasTensor: shape=(1, 64, 768) dtype=float32 (created by layer 'tf_transformers/bert')>, 'token_logits': <KerasTensor: shape=(1, 64, 28996) dtype=float32 (created by layer 'tf_transformers/bert')>, 'last_token_logits': <KerasTensor: shape=(1, 28996) dtype=float32 (created by layer 'tf_transformers/bert')>}

Save Model as Serialized Version

We have to save the model using model.save. We use the SavedModel for converting it to tflite.

model.save("{}/saved_model".format(model_name), save_format='tf')
WARNING:tensorflow:Compiled the loaded model, but the compiled metrics have yet to be built. `model.compile_metrics` will be empty until you train or evaluate the model.
WARNING:tensorflow:Compiled the loaded model, but the compiled metrics have yet to be built. `model.compile_metrics` will be empty until you train or evaluate the model.
WARNING:absl:Found untraced functions such as word_embeddings_layer_call_fn, word_embeddings_layer_call_and_return_conditional_losses, type_embeddings_layer_call_fn, type_embeddings_layer_call_and_return_conditional_losses, positional_embeddings_layer_call_fn while saving (showing 5 of 870). These functions will not be directly callable after loading.
INFO:tensorflow:Assets written to: bert-base-cased/saved_model/assets
INFO:tensorflow:Assets written to: bert-base-cased/saved_model/assets
WARNING:absl:<tf_transformers.layers.attention.bert_attention.MultiHeadAttention object at 0x7fa8a9b31850> has the same name 'MultiHeadAttention' as a built-in Keras object. Consider renaming <class 'tf_transformers.layers.attention.bert_attention.MultiHeadAttention'> to avoid naming conflicts when loading with `tf.keras.models.load_model`. If renaming is not possible, pass the object in the `custom_objects` parameter of the load function.
WARNING:absl:<tf_transformers.layers.attention.bert_attention.MultiHeadAttention object at 0x7fa8a9bb8f50> has the same name 'MultiHeadAttention' as a built-in Keras object. Consider renaming <class 'tf_transformers.layers.attention.bert_attention.MultiHeadAttention'> to avoid naming conflicts when loading with `tf.keras.models.load_model`. If renaming is not possible, pass the object in the `custom_objects` parameter of the load function.
WARNING:absl:<tf_transformers.layers.attention.bert_attention.MultiHeadAttention object at 0x7fa8a98c99d0> has the same name 'MultiHeadAttention' as a built-in Keras object. Consider renaming <class 'tf_transformers.layers.attention.bert_attention.MultiHeadAttention'> to avoid naming conflicts when loading with `tf.keras.models.load_model`. If renaming is not possible, pass the object in the `custom_objects` parameter of the load function.
WARNING:absl:<tf_transformers.layers.attention.bert_attention.MultiHeadAttention object at 0x7fa8a9a14610> has the same name 'MultiHeadAttention' as a built-in Keras object. Consider renaming <class 'tf_transformers.layers.attention.bert_attention.MultiHeadAttention'> to avoid naming conflicts when loading with `tf.keras.models.load_model`. If renaming is not possible, pass the object in the `custom_objects` parameter of the load function.
WARNING:absl:<tf_transformers.layers.attention.bert_attention.MultiHeadAttention object at 0x7fa8a98ffcd0> has the same name 'MultiHeadAttention' as a built-in Keras object. Consider renaming <class 'tf_transformers.layers.attention.bert_attention.MultiHeadAttention'> to avoid naming conflicts when loading with `tf.keras.models.load_model`. If renaming is not possible, pass the object in the `custom_objects` parameter of the load function.
WARNING:absl:<tf_transformers.layers.attention.bert_attention.MultiHeadAttention object at 0x7fa8a9915e90> has the same name 'MultiHeadAttention' as a built-in Keras object. Consider renaming <class 'tf_transformers.layers.attention.bert_attention.MultiHeadAttention'> to avoid naming conflicts when loading with `tf.keras.models.load_model`. If renaming is not possible, pass the object in the `custom_objects` parameter of the load function.
WARNING:absl:<tf_transformers.layers.attention.bert_attention.MultiHeadAttention object at 0x7fa8a99b8390> has the same name 'MultiHeadAttention' as a built-in Keras object. Consider renaming <class 'tf_transformers.layers.attention.bert_attention.MultiHeadAttention'> to avoid naming conflicts when loading with `tf.keras.models.load_model`. If renaming is not possible, pass the object in the `custom_objects` parameter of the load function.
WARNING:absl:<tf_transformers.layers.attention.bert_attention.MultiHeadAttention object at 0x7fa8a975a950> has the same name 'MultiHeadAttention' as a built-in Keras object. Consider renaming <class 'tf_transformers.layers.attention.bert_attention.MultiHeadAttention'> to avoid naming conflicts when loading with `tf.keras.models.load_model`. If renaming is not possible, pass the object in the `custom_objects` parameter of the load function.
WARNING:absl:<tf_transformers.layers.attention.bert_attention.MultiHeadAttention object at 0x7fa8a98fd610> has the same name 'MultiHeadAttention' as a built-in Keras object. Consider renaming <class 'tf_transformers.layers.attention.bert_attention.MultiHeadAttention'> to avoid naming conflicts when loading with `tf.keras.models.load_model`. If renaming is not possible, pass the object in the `custom_objects` parameter of the load function.
WARNING:absl:<tf_transformers.layers.attention.bert_attention.MultiHeadAttention object at 0x7fa8a9956c50> has the same name 'MultiHeadAttention' as a built-in Keras object. Consider renaming <class 'tf_transformers.layers.attention.bert_attention.MultiHeadAttention'> to avoid naming conflicts when loading with `tf.keras.models.load_model`. If renaming is not possible, pass the object in the `custom_objects` parameter of the load function.
WARNING:absl:<tf_transformers.layers.attention.bert_attention.MultiHeadAttention object at 0x7fa8a9adc090> has the same name 'MultiHeadAttention' as a built-in Keras object. Consider renaming <class 'tf_transformers.layers.attention.bert_attention.MultiHeadAttention'> to avoid naming conflicts when loading with `tf.keras.models.load_model`. If renaming is not possible, pass the object in the `custom_objects` parameter of the load function.
WARNING:absl:<tf_transformers.layers.attention.bert_attention.MultiHeadAttention object at 0x7fa8a9c16050> has the same name 'MultiHeadAttention' as a built-in Keras object. Consider renaming <class 'tf_transformers.layers.attention.bert_attention.MultiHeadAttention'> to avoid naming conflicts when loading with `tf.keras.models.load_model`. If renaming is not possible, pass the object in the `custom_objects` parameter of the load function.

Convert SavedModel to TFlite

converter = tf.lite.TFLiteConverter.from_saved_model("{}/saved_model".format(model_name)) # path to the SavedModel directory
converter.experimental_new_converter = True

tflite_model = converter.convert()

open("{}/saved_model.tflite".format(model_name), "wb").write(tflite_model)
print("TFlite conversion succesful")
WARNING:absl:Buffer deduplication procedure will be skipped when flatbuffer library is not properly loaded
TFlite conversion succesful

Load TFlite Model

# Load the TFLite model and allocate tensors.
interpreter = tf.lite.Interpreter(model_path="{}/saved_model.tflite".format(model_name))
interpreter.allocate_tensors()

# Get input and output tensors.
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()

Assert TFlite Model and Keras Model outputs

After conversion we have to assert the model outputs using tflite and Keras model, to ensure proper conversion.

  1. Create examples using tf.random.uniform.

  2. Check outputs using both models.

  3. Note: We need slightly higher rtol here to assert.

# Dummy Examples 
input_ids = tf.random.uniform(minval=0, maxval=100, shape=(batch_size, sequence_length), dtype=tf.int32)
input_mask = tf.ones_like(input_ids)
input_type_ids = tf.zeros_like(input_ids)


# input type ids
interpreter.set_tensor(
    input_details[0]['index'],
    input_type_ids,
)
# input_mask
interpreter.set_tensor(input_details[1]['index'], input_mask)

# input ids
interpreter.set_tensor(
    input_details[2]['index'],
    input_ids,
)

# Invoke inputs
interpreter.invoke()
# Take last output
tflite_output = interpreter.get_tensor(output_details[-1]['index'])

# Keras Model outputs .
model_inputs = {'input_ids': input_ids, 'input_mask': input_mask, 'input_type_ids': input_type_ids}
model_outputs = model(model_inputs)

# We need a slightly higher rtol here to assert :-)
tf.debugging.assert_near(tflite_output, model_outputs['token_embeddings'], rtol=3.0)
print("Outputs asserted and succesful:  ✅")
Outputs asserted and succesful:  ✅