PyTorch

PyTorch project is a Python package that provides GPU accelerated tensor computation and high level functionalities for building deep learning networks. For licensing details, see the PyTorch license doc on GitHub.

In the sections below, we provide guidance on installing PyTorch on Azure Databricks and give an example of running PyTorch programs. See Integrating Deep Learning Libraries with Apache Spark for an example of integrating a deep learning library with Spark.

Note

This guide is not a comprehensive guide on PyTorch. Refer to the PyTorch website.

Install PyTorch

PyTorch can be installed as a Databricks library from PyPI. Azure Databricks recommends using version 0.3.0.post4 to install PyTorch to make it available on all cluster nodes. It can be installed by providing the wheel URL from PyTorch previous versions as the PyPI name.

  • Python 2: install PyTorch from the URL http://download.pytorch.org/whl/cu80/torch-0.3.0.post4-cp27-cp27mu-linux_x86_64.whl and torchvision
  • Python 3: install PyTorch from the URL http://download.pytorch.org/whl/cu80/torch-0.3.0.post4-cp35-cp35m-linux_x86_64.whl and torchvision

Use PyTorch on a single node

To test and migrate single-machine PyTorch workflows, you can start with a driver-only cluster on Databricks by setting the number of workers to zero. Though Apache Spark is not functional under this setting, it is a cost-effective way to run single-machine PyTorch workflows.