Read and Write Images as TFRecords

Tensorflow-Transformers makes it easy to read and write tfrecords in and from any data type. Here we will see how we can make use of it to write and read images as tfrecords

import json
import glob
import numpy as np 
import tensorflow as tf
import matplotlib.pyplot as plt
import math
from tf_transformers.data import TFWriter, TFReader
from datasets import load_dataset

Load CelebA dataset from HF

# Load CelebA dataset from HF
dataset = load_dataset("nielsr/CelebA-faces")
WARNING:datasets.builder:Using custom data configuration nielsr--CelebA-faces-00908b91f44a46a2
Downloading and preparing dataset image_folder/default (download: 1.29 GiB, generated: 1.06 GiB, post-processed: Unknown size, total: 2.35 GiB) to /root/.cache/huggingface/datasets/nielsr___parquet/nielsr--CelebA-faces-00908b91f44a46a2/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec...
Dataset parquet downloaded and prepared to /root/.cache/huggingface/datasets/nielsr___parquet/nielsr--CelebA-faces-00908b91f44a46a2/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec. Subsequent calls will reuse this data.

Write TFRecords

  • We save image as string (bytes) along with original size and dimension, which helps us when we read it back

def parse_train():
    for item in dataset['train']:
        image = np.asarray(item['image'])
        height, width, channels = image.shape
        image_string = image.tobytes()
        yield {'image': image_string, 
               'height': height,
               'width': width,
               'channels': channels
              }
        
# Write using TF Writer

schema = {
    "image": ("var_len", "bytes"),
    "height": ("var_len", "int"),
    "width": ("var_len", "int"),
    "channels": ("var_len", "int")
}


tfrecord_train_dir = 'TFRECORDS/celeba'
tfrecord_filename = 'celeba'

tfwriter = TFWriter(schema=schema, 
                    file_name=tfrecord_filename, 
                    model_dir=tfrecord_train_dir,
                    tag='train',
                    overwrite=True,
                    verbose_counter=10000
                    )

# Train dataset
train_parser_fn = parse_train()
tfwriter.process(parse_fn=train_parser_fn)
INFO:absl:Wrote 10000 tfrecods
INFO:absl:Wrote 20000 tfrecods
INFO:absl:Wrote 30000 tfrecods
INFO:absl:Wrote 40000 tfrecods
INFO:absl:Wrote 50000 tfrecods
INFO:absl:Wrote 60000 tfrecods
INFO:absl:Wrote 70000 tfrecods
INFO:absl:Wrote 80000 tfrecods
INFO:absl:Wrote 90000 tfrecods
INFO:absl:Wrote 100000 tfrecods
INFO:absl:Wrote 110000 tfrecods
INFO:absl:Wrote 120000 tfrecods
INFO:absl:Wrote 130000 tfrecods
INFO:absl:Wrote 140000 tfrecods
INFO:absl:Wrote 150000 tfrecods

Read TFRecords

# Read TFRecord

schema = json.load(open("{}/schema.json".format(tfrecord_train_dir)))
all_files = glob.glob("{}/*.tfrecord".format(tfrecord_train_dir))
tf_reader = TFReader(schema=schema, 
                    tfrecord_files=all_files)

x_keys = ['image', 'height', 'width', 'channels']

def decode_img(item):
    byte_string = item['image'][0]
    im_height = item['height'][0]
    im_width = item['width'][0]
    im_channels = item['channels'][0]
    
    image_array = tf.io.decode_raw(byte_string, tf.uint8)
    image = tf.reshape(image_array, (im_height, im_width, im_channels))
    # image = tf.cast(tf.image.resize(image, (64, 64)), tf.uint8)
    return {'image': image}

batch_size = 16
train_dataset = tf_reader.read_record( 
                                   keys=x_keys,
                                   shuffle=True
                                  )

train_dataset = train_dataset.map(decode_img, num_parallel_calls=tf.data.AUTOTUNE)
train_dataset = train_dataset.padded_batch(batch_size, drop_remainder=True)

for item in train_dataset:
    break

Plot images after reading

def display_images(images, cols=5):
    """Display given images and their labels in a grid."""
    rows = int(math.ceil(len(images) / cols))
    fig = plt.figure()
    fig.set_size_inches(cols * 3, rows * 3)
    for i, (image) in enumerate(images):
        plt.subplot(rows, cols, i + 1)
        plt.axis('off')
        plt.imshow(image)
        plt.title(i)

NUM_IMAGES = 16
# Extract each images individually for plot
batch_images = [im.numpy() for im in item['image']]
display_images(batch_images)
../_images/303030987c830afffdbae7d51973b498c2f8ac32924be147b802b5203fd1d762.png