Model Inference Performance Tuning Guide

This section provides some tips for debugging and performance tuning for model inference on Azure Databricks. For an overview, refer to the inference workflow.

Typically there are two main parts in model inference: data input pipeline and model inference. The data input pipeline is heavy on data I/O input and model inference is heavy on computation. Determining the bottleneck of the workflow is simple. Below are some approaches:

  • Reduce the model to a trivial model and measure the examples per second. If the difference of the end to end time between the full model and the trivial model is minimal, then the data input pipeline is likely a bottleneck, otherwise model inference is the bottleneck.
  • If running model inference with GPU, check the GPU utilization metrics. If GPU utilization is not continuously high, then the data input pipeline may be the bottleneck.

Optimize data input pipeline

Using GPUs can efficiently optimize the running speed for model inference. As GPUs and other accelerators become faster, it is important that the data input pipeline keep up with demand. The data input pipeline reads the data into Spark Dataframes, transforms it, and loads it as the input for model inference. If data input is the bottleneck, here are some tips to increase I/O throughput:

  • Set the max records per batch. Larger number of max records can reduce the I/O overhead to call the UDF function as long as the records can fit in memory. To set the batch size, set the following config:

    spark.conf.set("spark.sql.execution.arrow.maxRecordsPerBatch", "5000")
  • Load the data in batches and prefetch it when preprocessing the input data in the Pandas UDF.

    For TensorFlow or Keras, Azure Databricks recommends using the API. You can parse the map in parallel by setting num_parallel_calls in a map function and call prefetch and batch for prefetching and batching., num_parallel_calls=num_process).prefetch(prefetch_size).batch(batch_size)

    For PyTorch, Azure Databricks recommends using the DataLoader class. You can set batch_size for batching and num_workers for parallel data loading., batch_size=batch_size, num_workers=num_process)