Load Data from TFRecord Files with TensorFlow

The TFRecord file format is a simple record-oriented binary format for ML training data. The tf.data.TFRecordDataset class enables you to stream over the contents of one or more TFRecord files as part of an input pipeline.

Note

This guide is not a comprehensive guide on importing data with TensorFlow. See the TensorFlow API Guide.

To load data from TFRecord files, the workflow is as follows:

Step 1: Create a tf.data.TFRecordDataset with TFRecord filename(s).

Step 2: Define a decoder to read and parse data from each record. A decoder contains the following steps:

  1. Define a parser with tf.parse_single_example with a dictionary that maps feature keys to FixedLenFeature or VarLenFeature values. The parser returns a dictionary that maps feature keys to Tensor or SparseTensor values.
  2. Convert the data from string back to proper types. Use tf.decode_raw to convert a Tensor of type string to out_type. For data that has not been converted to string, use tf.cast to cast it.
  3. Reshape the data. Since the conversion to string loses the data shape, use tf.reshape to reshape the data to the original shape.

Step 3: Preprocessing. Do some preprocessing of the data. You can define your own normalize method or call a member function of tf.data.TFRecordDataset like shuffle, repeat, batch, or shard.

Step 4: Create an iterator. The most common way to consume values from a Dataset is to make an iterator. See the TensorFlow Importing Data for an in-depth explanation.

The example notebook below demonstrates how to load TFRecord files for ML training. Before running the notebook, you must:

  1. Prepare storage mounts for distributed data loading.
  2. Configure your FUSE_MOUNT_LOCATION in the notebook.