HorovodEstimator: Distributed Deep Learning with Horovod and Spark

HorovodEstimator is an MLlib-style estimator API that leverages Uber’s Horovod framework. It facilitates distributed, multi-GPU training of deep neural networks on Spark DataFrames, simplifying the integration of ETL in Spark with model training in TensorFlow. Specifically, HorovodEstimator simplifies launching distributed training with Horovod by:

  • Distributing training code & data to each machine on your cluster
  • Enabling passwordless SSH between the driver and workers, and launching training via MPI
  • Writing custom data-ingest & model-export logic
  • Simultaneously running model training & evaluation


The HorovodEstimator API is experimental and subject to change - we’d love to hear your feedback as we improve the API and consider new features.


HorovodEstimator requires the Databricks Runtime ML (Beta).

Cluster types

You can run HorovodEstimator on clusters of two or more CPU or GPU-enabled machines; we recommend running on GPU instances if possible.

HorovodEstimator expects all GPUs on the current cluster to be available; thus we do not recommend using the API on shared clusters.

If using GPUs, we recommend not opening any other TensorFlow sessions on the same cluster as the one you’re using with HorovodEstimator. If you open a TensorFlow session, the Python REPL running your notebook will use a GPU, preventing HorovodEstimator from running. In this case you may need to detach/reattach your notebook, and rerun your HorovodEstimator code without running any TensorFlow code beforehand.

Using HorovodEstimator

HorovodEstimator is a Spark MLlib Estimator and can be used with the Spark MLlib Pipelines API, although estimator persistence is not yet supported.

Fitting a HorovodEstimator returns an MLlib Transformer (a TFTransformer) that can be used for distributed inference on a DataFrame. It also stores model checkpoints (can be used to resume training), event files (contain metrics logged during training), and a tf.SavedModel (can be used to apply the model for inference outside Spark) into the specified model directory.

HorovodEstimator makes no fault-tolerance guarantees. If an error occurs during training, HorovodEstimator does not attempt to recover, although you can rerun fit() to resume training from the latest checkpoint.

In the example notebook below, we demonstrate how to use HorovodEstimator to train a deep neural network on the MNIST dataset.

HorovodEstimator Notebook