Skip to content

Commit

Permalink
Move interactive API examples to openfl-contrib
Browse files Browse the repository at this point in the history
    As we deprecate the interactive API by a more generalized form with the workflow API, we
    decided to keep the examples in this repo.
  • Loading branch information
vrancurel committed Nov 6, 2024
1 parent c762e88 commit 42e53c3
Show file tree
Hide file tree
Showing 274 changed files with 25,045 additions and 0 deletions.
106 changes: 106 additions & 0 deletions openfl_contrib_tutorials/interactive_api/Flax_CNN_CIFAR/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
# Federated FLAX CIFAR-10 CNN Tutorial

### 1. About dataset

The CIFAR-10 dataset consists of 60000 32x32 colour images in 10 classes, with 6000 images per class. There are 50000 training images and 10000 test images.

Define the below param in envoy.yaml config to shard the dataset across participants/envoy.
- rank_worldsize

### 2. About model

A simple multi-layer CNN is used with XLA compiled and Auto-grad based parameter updates.
Definition provided in the notebook.

### 3. Notebook Overview

1. Class `CustomTrainState` - Subclasses `flax.training.TrainState`
- Variable `opt_vars` to keep track of generic optimizer variables.
- Method `update_state` to update the OpenFL `ModelInterface` registered state with the new_state returned within the `TaskInterface` registered training loop.

2. Method `create_train_state`: Creates a new `TrainState` by encapsulating model layer definitions, random model parameters, and optax optimizer state.

3. Method `apply_model` (`@jax.jit` decorated function): It takes a TrainState, images, and labels as parameters. It computes and returns the gradients, loss, and accuracy. These gradients are applied to a given state in the `update_model` method (`@jax.jit` decorated function) and a new TrainState instance is returned.

### 4. How to run this tutorial (without TLS and locally as a simulation):

0. Pre-requisites:

- Nvidia Driver >= 495.29.05
- CUDA >= 11.1.105
- cuDNN >= 8

Activate virtual environment (Python - 3.8.10) and install packages from requirements.txt

Set the variable `DEFAULT_DEVICE to 'CPU' or 'GPU'` in `start_envoy.sh` and notebook to enforce/control the execution platform.

```sh
cd Flax_CNN_CIFAR
pip install -r requirements.txt
```

1. Run director:

```sh
cd director
./start_director.sh
```

2. Run envoy:

```sh
cd envoy
./start_envoy.sh "envoy_identifier" envoy_config.yaml
```

Optional: start second envoy:

- Copy `envoy` folder to another place and follow the same process as above:

```sh
./start_envoy.sh "envoy_identifier_2" envoy_config_2.yaml
```

3. Run `FLAX_CIFAR10_CNN.ipynb` jupyter notebook:

```sh
cd workspace
jupyter lab FLAX_CIFAR10_CNN.ipynb
```

4. Visualization:

```
tensorboard --logdir logs/
```


### 5. Known issues

1. #### CUDA_ERROR_OUT_OF_MEMORY Exception - JAX XLA pre-allocates 90% of the GPU at start

- set XLA_PYTHON_CLIENT_PREALLOCATE to start with a small memory footprint.
```
%env XLA_PYTHON_CLIENT_PREALLOCATE=false
```
OR

- Below flag to restrict max GPU allocation to 50%
```
%env XLA_PYTHON_CLIENT_MEM_FRACTION=.5
```


2. #### Tensorflow pre-allocates 90% of the GPU (Potential OOM Errors).

- set TF_FORCE_GPU_ALLOW_GROWTH to start with a small memory footprint.
```
%env TF_FORCE_GPU_ALLOW_GROWTH=true
```

3. #### DNN library Not found error

- Make sure the jaxlib(cuda version), Nvidia Driver, CUDA and cuDNN versions are specific, relevant and compatible as per the documentation.
- Reference:
- CUDA and cuDNN Compatibility Matrix: https://docs.nvidia.com/deeplearning/cudnn/support-matrix/index.html
- Official JAX Compatible CUDA Releases: https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
settings:
listen_host: localhost
listen_port: 50055
sample_shape: ['32', '32', '3'] # [[shape], channel]
target_shape: ['1']
envoy_health_check_period: 5 # in seconds
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
#!/bin/bash
set -e

fx director start --disable-tls -c director_config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
# Copyright (C) 2022 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

"""CIFAR10 Shard Descriptor (using `TFDS` API)"""
import jax.numpy as jnp
import logging
import tensorflow as tf
import tensorflow_datasets as tfds

from typing import List, Tuple
from openfl.interface.interactive_api.shard_descriptor import ShardDescriptor

logger = logging.getLogger(__name__)


class CIFAR10ShardDescriptor(ShardDescriptor):
"""
CIFAR10 Shard Descriptor
This example is based on `tfds` data loader.
Note that the ingestion of any model/task requires an iterable dataloader.
Hence, it is possible to utilize these pipelines without explicit need of a
new interface.
"""

def __init__(
self,
rank_worldsize: str = '1, 1',
**kwargs
) -> None:
"""Download/Prepare CIFAR10 dataset"""
self.rank, self.worldsize = tuple(int(num) for num in rank_worldsize.split(','))

# Load dataset
train_ds, valid_ds = self._download_and_prepare_dataset(self.rank, self.worldsize)

# Set attributes
self._sample_shape = train_ds['image'].shape[1:]
self._target_shape = tf.expand_dims(train_ds['label'], -1).shape[1:]

self.splits = {
'train': train_ds,
'valid': valid_ds
}

def _download_and_prepare_dataset(self, rank: int, worldsize: int) -> Tuple[dict]:
"""
Download, Cache CIFAR10 and prepare `tfds` builder.
Provide `rank` and `worldsize` to virtually split dataset across shards
uniquely for each client for simulation purposes.
Returns:
Tuple (train_dict, test_dict) of dictionary with JAX DeviceArray (image and label)
dict['image'] -> DeviceArray float32
dict['label'] -> DeviceArray int32
{'image' : DeviceArray(...), 'label' : DeviceArray(...)}
"""

dataset_builder = tfds.builder('cifar10')
dataset_builder.download_and_prepare()

datasets = dataset_builder.as_dataset()

train_shard_size = int(len(datasets['train']) / worldsize)
test_shard_size = int(len(datasets['test']) / worldsize)

self.train_segment = f'train[{train_shard_size * (rank - 1)}:{train_shard_size * rank}]'
self.test_segment = f'test[{test_shard_size * (rank - 1)}:{test_shard_size * rank}]'
train_dataset = dataset_builder.as_dataset(split=self.train_segment, batch_size=-1)
test_dataset = dataset_builder.as_dataset(split=self.test_segment, batch_size=-1)
train_ds = tfds.as_numpy(train_dataset)
test_ds = tfds.as_numpy(test_dataset)

train_ds['image'] = jnp.float32(train_ds['image']) / 255.
test_ds['image'] = jnp.float32(test_ds['image']) / 255.
train_ds['label'] = jnp.int32(train_ds['label'])
test_ds['label'] = jnp.int32(test_ds['label'])

return train_ds, test_ds

def get_shard_dataset_types(self) -> List[str]:
"""Get available split names"""
return list(self.splits)

def get_split(self, name: str) -> tf.data.Dataset:
"""Return a shard dataset by type."""
if name not in self.splits:
raise Exception(f'Split name `{name}` not found.'
f' Expected one of {list(self.splits.keys())}')
return self.splits[name]

@property
def sample_shape(self) -> List[str]:
"""Return the sample shape info."""
return list(map(str, self._sample_shape))

@property
def target_shape(self) -> List[str]:
"""Return the target shape info."""
return list(map(str, self._target_shape))

@property
def dataset_description(self) -> str:
"""Return the dataset description."""
n_train = len(self.splits['train']['label'])
n_test = len(self.splits['valid']['label'])
return (f'CIFAR10 dataset, Shard Segments {self.train_segment}/{self.test_segment}, '
f'rank/world {self.rank}/{self.worldsize}.'
f'\n num_samples [Train/Valid]: [{n_train}/{n_test}]')
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
params:
cuda_devices: []

optional_plugin_components: {}

shard_descriptor:
template: cifar10_shard_descriptor.CIFAR10ShardDescriptor
params:
rank_worldsize: 1, 2
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
#!/bin/bash
set -e
ENVOY_NAME=$1
ENVOY_CONF=$2

DEFAULT_DEVICE='CPU'

if [[ $DEFAULT_DEVICE == 'CPU' ]]
then
export JAX_PLATFORMS="cpu" # Force XLA to use CPU
export CUDA_VISIBLE_DEVICES='-1' # Force TF to use CPU
else
export XLA_PYTHON_CLIENT_PREALLOCATE=false
export TF_FORCE_GPU_ALLOW_GROWTH=true
fi

fx envoy start -n "$ENVOY_NAME" --disable-tls --envoy-config-path "$ENVOY_CONF" -dh localhost -dp 50055
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
--find-links https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
jax
jaxlib
tensorflow==2.13
tensorflow-datasets==4.6.0
Loading

0 comments on commit 42e53c3

Please sign in to comment.