AIStore & ETL: Using WebDataset to train on a sharded dataset (post #3)
Deprecated -- WDTransform is no longer included as part of the AIS client, so this post only remains for educational purposes. ETL is in development and additional transformation tools will be included in future posts.
Background
In our previous post, we have built a basic PyTorch data loader and used it to load transformed images from AIStore (AIS). We have used the latter to run torchvision transformations of the Oxford-IIIT Pet Dataset images.
Now, we’ll be looking further into training that involves sharded datasets. We will utilize WebDataset, an iterable PyTorch dataset that provides a number of important features. For demonstration purposes, we’ll be using ImageNet - a sharded version of the ImageNet, to be precise, where original images are assembled into .tar
archives (aka shards).
For background on WebDataset and AIStore, and the benefits of sharding very large datasets, please see Efficient PyTorch I/O library for Large Datasets, Many Files, Many GPUs.
On the Python side, we’ll have AIS Python client. The client is a thin layer on top of AIStore API/SDK providing easy operations on remote datasets. We’ll be using it to offload image transformations to AIS cluster.
The remainder of this text is structured as follows:
- introduce sharded ImageNet;
- load a single shard and apply assorted
torchvision
transformations; - run the same exact transformation in the cluster (in other words, offload this specific ETL to AIS);
- operate on multiple (brace-expansion defined) shards
First step, though is to install the required dependencies (e.g., from your Jupyter notebook), as follows:
! pip install webdataset aistore torch torchvision matplotlib
The Dataset
Pre-shared ImageNet will be stored in a Google Cloud bucket that we’ll also call sharded-imagenet
. Shards can be inspected with ais
command-line tool - on average, in our case, any given shard will contain about 1000 original (.jpg
) ImageNet images and their corresponding (.cls
) classes:
! ais get gcp://sharded-imagenet/imagenet-train-000000.tar - | tar -tvf - | head -n 10
-r--r--r-- bigdata/bigdata 3 2020-06-25 17:41 0911032.cls
-r--r--r-- bigdata/bigdata 92227 2020-06-25 17:41 0911032.jpg
-r--r--r-- bigdata/bigdata 3 2020-06-25 17:41 1203092.cls
-r--r--r-- bigdata/bigdata 15163 2020-06-25 17:41 1203092.jpg
-r--r--r-- bigdata/bigdata 3 2020-06-25 17:41 0403282.cls
-r--r--r-- bigdata/bigdata 139179 2020-06-25 17:41 0403282.jpg
-r--r--r-- bigdata/bigdata 3 2020-06-25 17:41 0267084.cls
-r--r--r-- bigdata/bigdata 200458 2020-06-25 17:41 0267084.jpg
-r--r--r-- bigdata/bigdata 3 2020-06-25 17:41 1026057.cls
-r--r--r-- bigdata/bigdata 159009 2020-06-25 17:41 1026057.jpg
Thus, in terms of its internal structure, this dataset is identical to what we’ve had in the previous article, with one distinct difference: shards (formatted as .tar files).
Further, we assume (and require) that AIStore can “see” this GCP bucket. Covering the corresponding AIStore configuration would be outside the scope, but the main point is that AIS self-populates on demand. When getting user data from any remote location, AIS always stores it (ie., the data), acting simultaneously as a fast-cache tier and a high-performance reliable-and-scalable storage.
Client-side transformation with WebDataset, and with AIStore acting as a traditional (dumb) storage
Next in the plan is to have WebDataset running transformations on the client side. Eventually, we’ll move this entire ETL code onto AIS. But first, let’s go over the conventional way of doing things:
from torchvision import transforms
import webdataset as wds
from aistore.sdk import Client
client = Client("http://aistore-sample-proxy:51080") # AIS IP address or hostname
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
# PyTorch transform to apply (compare with the [previous post](https://aiatscale.org/blog/2021/10/22/ais-etl-2))
train_transform = transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
normalize,
])
# use AIS client to get the http URL for the shard
shard_url = client.object_url("gcp://sharded-imagenet", "imagenet-train-000000.tar")
dataset = (
wds.WebDataset(shard_url, handler=wds.handlers.warn_and_continue)
.decode("pil")
.to_tuple("jpg;png;jpeg cls", handler=wds.handlers.warn_and_continue)
.map_tuple(train_transform, lambda x: x)
)
loader = wds.WebLoader(
dataset,
batch_size=64,
shuffle=False,
num_workers=1,
)
Comments to the code above
- AIS Python client helps with WebDataset data loader initialization. In this case, WebDataset loads a single
.tar
shard withjpg
images, and transforms each image in the batch. .decode("pil")
indicates torchvision data augmentation (for details, see WebDataset docs on decoding)..map_tuple
does the actual heavy lifting applying torchvision transforms.
Let’s visually compare original images with their loaded-and-transformed counterparts:
from utils import display_shard_images, display_loader_images
display_shard_images(client, "gcp://sharded-imagenet", "imagenet-train-000000.tar", objects=4)
display_loader_images(loader, objects=4)
Source code for
display_shard_images
anddisplay_loader_images
can be found here.
WDTransform with ETL in the cluster (deprecated)
We are now ready to create a data loader that’d rely on AIStore for image transformations. For this, we introduce WDTransform
class, full source for which is available as part of the AIS Python client:
from aistore.client.transform import WDTransform
train_etl = WDTransform(client, wd_transform, transform_name="imagenet-train", verbose=True)
In our example, WDTransform
takes the following transform_func
as an input:
def wd_transform(sample):
# `train_transform` was declared in previous section.
sample["npy"] = train_transform(sample.pop("jpg")).permute(1, 2, 0).numpy().astype("float32")
return sample
The function above returns transformed NumPy array after applying precisely the same torchvision
transformations that we described in the previous section. The input is a dictionary containing a single training tuple - an image (key=”jpg”) and a corresponding class (key=”cls”).
Putting it all together, the code that loads a single transformed shard with AIS nodes carrying out actual transformations:
# NOTE:
# AIS Python client handles initialization of ETL on AIStore cluster - if a given named ETL
# does not exist in the cluster the client will simply initialize it on the fly:
etl_object_url = client.object_url("gcp://sharded-imagenet", "imagenet-train-000000.tar", transform_id=train_etl.uuid)
to_tensor = transforms.Compose([transforms.ToTensor()])
etl_dataset = (
wds.WebDataset(etl_object_url, handler=wds.handlers.warn_and_continue)
.decode("rgb")
.to_tuple("npy cls", handler=wds.handlers.warn_and_continue)
.map_tuple(to_tensor, lambda x: x)
)
etl_loader = wds.WebLoader(
etl_dataset,
batch_size=64,
shuffle=False,
num_workers=1,
)
Discussion
- We transform all images from a randomly selected shard (e.g.,
imagenet-train-000000.tar
above). - AIS Python client handles ETL initialization in the cluster.
- The data loader (
etl_dataset
) is very similar, almost identical, to the WebDataset loader from the previous section. One major difference, of course, is that it is AIS that runs transformations. - The client-side part of the training pipeline handles already transformed images (represented as NumPy arrays).
- Which is exactly why we’ve set the decoder to “rgb” and added
to_tensor
(NumPy array to PyTorch tensor) conversion.
When we run this snippet in the notebook, we will first see:
Initializing ETL...
ETL imagenet-train successfully initialized
indicating that our (user-defined) ETL is now ready to execute.
To further confirm that it does work, we do again display_loader_images
:
display_loader_images(etl_loader, objects=4)
So yes indeed, the results of loading imagenet-train-000000.tar
look virtually identical to the post-transform images we have seen in the previous section.
Iterating through multiple shards
The one and, practically, the only difference between single-shard and multi-shard operation is that for the latter we specify a list or a range of shards (to operate upon). AIStore supports multiple list/range formats; in this section, we’ll show just one that utilizes the familiar Bash brace expansion notation.
For instance, brace expansion “imagenet-val-{000000..000005}.tar” translates as a range of up to 6 (six) shards named accordingly.
For this purpose, we’ll keep using WDTransform
, but this time with a validation dataset and the transformation function that looks as follows:
val_transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
normalize,
])
def wd_val_transform(sample):
sample["npy"] = val_transform(sample.pop("jpg")).permute(1, 2, 0).numpy().astype("float32")
return sample
val_etl = WDTransform(client, wd_val_transform, transform_name="imagenet-val", verbose=True)
The code to iterate over an arbitrary range (e.g., {000000..000005}
) with torchvision
performed by AIS nodes - in parallel and concurrently for all shards in a batch:
# Loading multiple shards using template.
val_objects = "imagenet-val-{000000..000005}.tar"
val_urls = client.expand_object_urls("gcp://sharded-imagenet", transform_id=val_etl.uuid, template=val_objects)
val_dataset = (
wds.WebDataset(val_urls, handler=wds.handlers.warn_and_continue)
.decode("rgb")
.to_tuple("npy cls", handler=wds.handlers.warn_and_continue)
.map_tuple(to_tensor, lambda x: x)
)
val_loader = wds.WebLoader(
val_dataset,
batch_size=64,
shuffle=False,
num_workers=1,
)
Compare this code with the single-shard example from the previous section.
Now, as before, to make sure that our validation data loader does work, we display a random image, or images:
display_loader_images(val_loader, objects=4)
Remarks
So far we have shown how to use (WebDataset + AIStore) to offload compute and I/O intensive transformations to a dedicated cluster.
Overall, the topic - large-scale inline and offline ETL - is vast. And we’ve only barely scratched the surface. The hope is, though, that this text provides at least a few useful insights.
The complete end-to-end PyTorch training example that we have used here is available.
Other references include:
- AIStore & ETL:
- GitHub:
- Documentation, blogs, videos: