Cloud TPU#

SkyPilot supports running jobs on Google’s Cloud TPU, a specialized hardware accelerator for ML workloads.

Free TPUs via TPU Research Cloud (TRC)#

ML researchers and students are encouraged to apply for free TPU access through TPU Research Cloud (TRC) program!

Getting TPUs in one command#

Use one command to quickly get TPU nodes for development:

sky launch --gpus tpu-v2-8
# Preemptible TPUs:
sky launch --gpus tpu-v2-8 --use-spot
# Change TPU type to tpu-v3-8:
sky launch --gpus tpu-v3-8
# Change the host VM type to n1-highmem-16:
sky launch --gpus tpu-v3-8 -t n1-highmem-16

After the command finishes, you will be dropped into a TPU host VM and can start developing code right away.

Below, we show examples of using SkyPilot to run MNIST training on (1) TPU VMs and (2) TPU Nodes.

TPU Architectures#

Two different TPU architectures are available on GCP:

Both are supported by SkyPilot. We recommend TPU VMs which is a newer architecture encouraged by GCP.

The two architectures differ as follows. For TPU VMs, you can directly SSH into the “TPU host” VM that is physically connected to the TPU device. For TPU Nodes, a user VM (an n1 instance) must be separately provisioned to communicate with an inaccessible TPU host over gRPC. More details can be found on GCP documentation.

TPU VMs#

To use TPU VMs, set the following in a task YAML’s resources field:

resources:
   accelerators: tpu-v2-8
   accelerator_args:
      runtime_version: tpu-vm-base  # optional

The accelerators field specifies the TPU type, and the accelerator_args dict includes the optional tpu_vm bool (defaults to true, which means TPU VM is used), and an optional TPU runtime_version field. To show what TPU types are supported, run sky show-gpus.

Here is a complete task YAML that runs MNIST training on a TPU VM using JAX.

name: mnist-tpu-vm

resources:
   accelerators: tpu-v2-8
   accelerator_args:
      tpu_vm: True
      runtime_version: tpu-vm-base

setup: |
   git clone https://github.com/google/flax.git

   conda activate flax
   if [ $? -eq 0 ]; then
      echo 'conda env exists'
   else
      conda create -n flax python=3.8 -y
      conda activate flax
      # Make sure to install TPU related packages in a conda env to avoid package conflicts.
      pip install "jax[tpu]>=0.2.16" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
      pip install --upgrade clu
      pip install -e flax
      pip install tensorflow tensorflow-datasets
   fi

run: |
   conda activate flax
   cd flax/examples/mnist
   python3 main.py --workdir=/tmp/mnist \
   --config=configs/default.py \
   --config.learning_rate=0.05 \
   --config.num_epochs=10

This YAML lives under the SkyPilot repo (examples/tpu/tpuvm_mnist.yaml), or you can paste it into a local file.

Launch it with:

$ sky launch examples/tpu/tpuvm_mnist.yaml -c mycluster

You should see the following outputs when the job finishes.

$ sky launch examples/tpu/tpuvm_mnist.yaml -c mycluster
...
(mnist-tpu-vm pid=10155) I0823 07:49:25.468526 139641357117440 train.py:146] epoch:  9, train_loss: 0.0120, train_accuracy: 99.64, test_loss: 0.0278, test_accuracy: 99.02
(mnist-tpu-vm pid=10155) I0823 07:49:26.966874 139641357117440 train.py:146] epoch: 10, train_loss: 0.0095, train_accuracy: 99.73, test_loss: 0.0264, test_accuracy: 99.19

TPU Nodes#

In a TPU Node, a normal CPU VM (an n1 instance) needs to be provisioned to communicate with the TPU host/device.

To use a TPU Node, set the following in a task YAML’s resources field:

resources:
   instance_type: n1-highmem-8
   accelerators: tpu-v2-8
   accelerator_args:
      runtime_version: 2.12.0  # optional, TPU runtime version.
      tpu_vm: False

The above YAML considers n1-highmem-8 as the host machine and tpu-v2-8 as the TPU node resource. You can modify the host instance type or the TPU type.

Here is a complete task YAML that runs MNIST training on a TPU Node using TensorFlow.

name: mnist-tpu-node

resources:
   accelerators: tpu-v2-8
   accelerator_args:
      runtime_version: 2.12.0  # optional, TPU runtime version.
      tpu_vm: False

# TPU node requires loading data from a GCS bucket.
# We use SkyPilot bucket mounting to mount a GCS bucket to /dataset.
file_mounts:
   /dataset:
      name: mnist-tpu-node
      store: gcs
      mode: MOUNT

setup: |
   git clone https://github.com/tensorflow/models.git

   conda activate mnist
   if [ $? -eq 0 ]; then
      echo 'conda env exists'
   else
      conda create -n mnist python=3.8 -y
      conda activate mnist
      pip install tensorflow==2.12.0 tensorflow-datasets tensorflow-model-optimization cloud-tpu-client
   fi

run: |
   conda activate mnist
   cd models/official/legacy/image_classification/

   export STORAGE_BUCKET=gs://mnist-tpu-node
   export MODEL_DIR=${STORAGE_BUCKET}/mnist
   export DATA_DIR=${STORAGE_BUCKET}/data

   export PYTHONPATH=/home/gcpuser/sky_workdir/models

   python3 mnist_main.py \
      --tpu=${TPU_NAME} \
      --model_dir=${MODEL_DIR} \
      --data_dir=${DATA_DIR} \
      --train_epochs=10 \
      --distribution_strategy=tpu \
      --download

Note

TPU node requires loading data from a GCS bucket. The file_mounts spec above simplifies this by using SkyPilot bucket mounting to create a new bucket/mount an existing bucket. If you encounter a bucket Permission denied error, make sure the bucket is created in the same region as the Host VM/TPU Nodes and IAM permission for Cloud TPU is correctly setup (follow instructions here).

Note

The special environment variable $TPU_NAME is automatically set by SkyPilot at run time, so it can be used in the run commands.

This YAML lives under the SkyPilot repo (examples/tpu/tpu_node_mnist.yaml). Launch it with:

$ sky launch examples/tpu/tpu_node_mnist.yaml  -c mycluster
...
(mnist-tpu-node pid=28961) Epoch 9/10
(mnist-tpu-node pid=28961) 58/58 [==============================] - 1s 19ms/step - loss: 0.1181 - sparse_categorical_accuracy: 0.9646 - val_loss: 0.0921 - val_sparse_categorical_accuracy: 0.9719
(mnist-tpu-node pid=28961) Epoch 10/10
(mnist-tpu-node pid=28961) 58/58 [==============================] - 1s 20ms/step - loss: 0.1139 - sparse_categorical_accuracy: 0.9655 - val_loss: 0.0831 - val_sparse_categorical_accuracy: 0.9742
...
(mnist-tpu-node pid=28961) {'accuracy_top_1': 0.9741753339767456, 'eval_loss': 0.0831054300069809, 'loss': 0.11388632655143738, 'training_accuracy_top_1': 0.9654667377471924}

Using TPU Pods#

A TPU Pod is a collection of TPU devices connected by dedicated high-speed network interfaces for high-performance training.

To use a TPU Pod, simply change the accelerators field in the task YAML (e.g., v2-8 -> v2-32).

resources:
   accelerators: tpu-v2-32  # Pods have > 8 cores (the last number)
   accelerator_args:
      runtime_version: tpu-vm-base

Note

Both TPU architectures, TPU VMs and TPU Nodes, can be used with TPU Pods. The example below is based on TPU VMs.

To show all available TPU Pod types, run sky show-gpus (more than 8 cores means Pods):

GOOGLE_TPU   AVAILABLE_QUANTITIES
tpu-v2-8     1
tpu-v2-32    1
tpu-v2-128   1
tpu-v2-256   1
tpu-v2-512   1
tpu-v3-8     1
tpu-v3-32    1
tpu-v3-64    1
tpu-v3-128   1
tpu-v3-256   1
tpu-v3-512   1
tpu-v3-1024  1
tpu-v3-2048  1

After creating a TPU Pod, multiple host VMs (e.g., v2-32 comes with 4 host VMs) are launched. Normally, the user needs to SSH into all hosts (depending on the architecture used, either the n1 User VMs or the TPU Host VMs) to prepare files and setup environments, and then launch the job on each host, which is a tedious and error-prone process.

SkyPilot automates away this complexity. From your laptop, a single sky launch command will perform:

  • workdir/file_mounts syncing; and

  • execute the setup/run commands on every host of the pod.

Here is a task YAML for a cifar10 training job on a v2-32 TPU Pod with JAX (code repo):

name: cifar-tpu-pod

resources:
   accelerators: tpu-v2-32
   accelerator_args:
      runtime_version: tpu-vm-base

setup: |
   git clone https://github.com/infwinston/tpu-example.git
   cd tpu-example
   pip install "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
   pip install -r requirements.txt

run: |
   python -u tpu-example/train.py

Launch it with:

$ sky launch examples/tpu/cifar_pod.yaml -c mycluster

You should see the following output.

(node-0 pid=57977, ip=10.164.0.24) JAX process: 1 / 4
(node-3 pid=57963, ip=10.164.0.26) JAX process: 3 / 4
(node-2 pid=57922, ip=10.164.0.25) JAX process: 2 / 4
(node-1 pid=63223) JAX process: 0 / 4
...
(node-0 pid=57977, ip=10.164.0.24) [  1000/100000]      time  0.034 ( 0.063)    data  0.008 ( 0.008)    loss  1.215 ( 1.489)    acc 68.750 (46.163)

Note

By default, outputs from all hosts are shown with the node-<i> prefix. Use jax.process_index() to control which host to print messages.

To submit more jobs to the same TPU Pod, use sky exec:

$ sky exec mycluster examples/tpu/cifar_pod.yaml