In AIStore, PyTorch integration is a growing set of datasets (both iterable and map-style), samplers, and dataloaders. This readme illustrates taxonomy of the associated abstractions and provides API reference documentation.

For usage examples, please see:

PyTorch Structure

Base class for AIS Map Style Datasets

Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.

Class: AISBaseMapDataset

class AISBaseMapDataset(ABC, Dataset)

A base class for creating map-style AIS Datasets. Should not be instantiated directly. Subclasses should implement :meth:__getitem__ which fetches a samples given a key from the dataset and can optionally override other methods from torch Dataset such as :meth:__len__ and :meth:__getitems__.

Arguments:

  • ais_source_list Union[AISSource, List[AISSource]] - Single or list of AISSource objects to load data prefix_map (Dict(AISSource, List[str]), optional): Map of AISSource objects to list of prefixes that only allows objects with the specified prefixes to be used from each source

Base class for AIS Iterable Style Datasets

Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.

Class: AISBaseIterDataset

class AISBaseIterDataset(ABC, IterableDataset)

A base class for creating AIS Iterable Datasets. Should not be instantiated directly. Subclasses should implement :meth:__iter__ which returns the samples from the dataset and can optionally override other methods from torch IterableDataset such as :meth:__len__.

Arguments:

  • ais_source_list Union[AISSource, List[AISSource]] - Single or list of AISSource objects to load data prefix_map (Dict(AISSource, List[str]), optional): Map of AISSource objects to list of prefixes that only allows objects with the specified prefixes to be used from each source

__iter__

@abstractmethod
def __iter__() -> Iterator

Return iterator with samples in this dataset.

Returns:

  • Iterator - Iterator of samples

__len__

def __len__()

Returns the length of the dataset. Note that calling this will iterate through the dataset, taking O(N) time.

NOTE: If you want the length of the dataset after iterating through it, use for i, data in enumerate(dataset) instead.

PyTorch Dataset for AIS.

Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.

Class: AISMapDataset

class AISMapDataset(AISBaseMapDataset)

A map-style dataset for objects in AIS. If etl_name is provided, that ETL must already exist on the AIStore cluster.

Arguments:

  • ais_source_list Union[AISSource, List[AISSource]] - Single or list of AISSource objects to load data prefix_map (Dict(AISSource, List[str]), optional): Map of AISSource objects to list of prefixes that only allows objects with the specified prefixes to be used from each source
  • etl_name str, optional - Optional ETL on the AIS cluster to apply to each object

    NOTE: Each object is represented as a tuple of object_name (str) and object_content (bytes)

Iterable Dataset for AIS

Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.

Class: AISIterDataset

class AISIterDataset(AISBaseIterDataset)

An iterable-style dataset that iterates over objects in AIS and yields samples represented as a tuple of object_name (str) and object_content (bytes). If etl_name is provided, that ETL must already exist on the AIStore cluster.

Arguments:

  • ais_source_list Union[AISSource, List[AISSource]] - Single or list of AISSource objects to load data prefix_map (Dict(AISSource, Union[str, List[str]]), optional): Map of AISSource objects to list of prefixes that only allows objects with the specified prefixes to be used from each source
  • etl_name str, optional - Optional ETL on the AIS cluster to apply to each object
  • show_progress bool, optional - Enables console dataset reading progress indicator

Yields:

Tuple[str, bytes]: Each item is a tuple where the first element is the name of the object and the second element is the byte representation of the object data.

AIS Shard Reader for PyTorch

PyTorch Dataset and DataLoader for AIS.

Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.

Class: AISShardReader

class AISShardReader(AISBaseIterDataset)

An iterable-style dataset that iterates over objects stored as Webdataset shards and yields samples represented as a tuple of basename (str) and contents (dictionary).

Arguments:

  • bucket_list Union[Bucket, List[Bucket]] - Single or list of Bucket objects to load data prefix_map (Dict(AISSource, Union[str, List[str]]), optional): Map of Bucket objects to list of prefixes that only allows objects with the specified prefixes to be used from each source
  • etl_name str, optional - Optional ETL on the AIS cluster to apply to each object
  • show_progress bool, optional - Enables console shard reading progress indicator

Yields:

Tuple[str, Dict(str, bytes)]: Each item is a tuple where the first element is the basename of the shard and the second element is a dictionary mapping strings of file extensions to bytes.

__len__

def __len__()

Returns the length of the dataset. Note that calling this will iterate through the dataset, taking O(N) time.

NOTE: If you want the length of the dataset after iterating through it, use for i, data in enumerate(dataset) instead.

Class: ZeroDict

class ZeroDict(dict)

When collate_fn is called while using ShardReader with a dataloader, the content dictionaries for each sample are merged into a single dictionary with file extensions as keys and lists of contents as values. This means, however, that each sample must have a value for that file extension in the batch at iteration time or else collation will fail. To avoid forcing the user to pass in a custom collation function, we workaround the default implementation of collation.

As such, we define a dictionary that has a default value of b"" (zero bytes) for every key that we have seen so far. We cannot use None as collation does not accept None. Initially, when we open a shard tar, we collect every file type (pre-processing pass) from its members and cache those. Then, we read the shard files. Lastly, before yielding the sample, we wrap its content dictionary with this custom dictionary to insert any keys that it does not contain, hence ensuring consistent keys across samples.

NOTE: For our use case, defaultdict does not work due to needing a lambda which cannot be pickled in multithreaded contexts.

Worker Supported Request Client for PyTorch

This client allows PyTorch workers to have separate request sessions per thread which is needed in order to use workers in a DataLoader as the default implementation of RequestClient and requests is not thread-safe.

Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.

Class: WorkerRequestClient

class WorkerRequestClient(RequestClient)

Extension that supports PyTorch and multiple workers of internal client for buckets, objects, jobs, etc. to use for making requests to an AIS cluster.

Arguments:

  • client RequestClient - Existing RequestClient to replace

session

@property
def session()

Returns: Active request session acquired for a specific PyTorch dataloader worker

Multishard Stream Dataset for AIS.

Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.

Class: AISMultiShardStream

class AISMultiShardStream(IterableDataset)

An iterable-style dataset that iterates over multiple shard streams and yields combined samples.

Arguments:

  • data_sources List[DataShard] - List of DataShard objects

Returns:

  • Iterable - Iterable over the combined samples, where each sample is a tuple of one object bytes from each shard stream

AIS IO Datapipe Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.

Class: AISFileListerIterDataPipe

@functional_datapipe("ais_list_files")
class AISFileListerIterDataPipe(IterDataPipe[str])

Iterable Datapipe that lists files from the AIStore backends with the given URL prefixes. (functional name: list_files_by_ais). Acceptable prefixes include but not limited to - ais://bucket-name, ais://bucket-name/

Notes:

  • This function also supports files from multiple backends (aws://.., gcp://.., etc.)
  • Input must be a list and direct URLs are not supported.
  • length is -1 by default, all calls to len() are invalid as not all items are iterated at the start.
  • This internally uses AIStore Python SDK.

Arguments:

  • source_datapipe(IterDataPipe[str]) - a DataPipe that contains URLs/URL prefixes to objects on AIS
  • length(int) - length of the datapipe
  • url(str) - AIStore endpoint

Example:

from torchdata.datapipes.iter import IterableWrapper, AISFileLister ais_prefixes = IterableWrapper([‘gcp://bucket-name/folder/’, ‘aws:bucket-name/folder/’, ‘ais://bucket-name/folder/’, ...]) dp_ais_urls = AISFileLister(url=’localhost:8080’, source_datapipe=ais_prefixes) for url in dp_ais_urls: ... pass

Functional API

dp_ais_urls = ais_prefixes.list_files_by_ais(url=’localhost:8080’) for url in dp_ais_urls: ... pass

Class: AISFileLoaderIterDataPipe

@functional_datapipe("ais_load_files")
class AISFileLoaderIterDataPipe(IterDataPipe[Tuple[str, StreamWrapper]])

Iterable DataPipe that loads files from AIStore with the given URLs (functional name: load_files_by_ais). Iterates all files in BytesIO format and returns a tuple (url, BytesIO).

Notes:

  • This function also supports files from multiple backends (aws://.., gcp://.., etc)
  • Input must be a list and direct URLs are not supported.
  • This internally uses AIStore Python SDK.
  • An etl_name can be provided to run an existing ETL on the AIS cluster. See https://github.com/NVIDIA/aistore/blob/main/docs/etl.md for more info on AIStore ETL.

Arguments:

  • source_datapipe(IterDataPipe[str]) - a DataPipe that contains URLs/URL prefixes to objects
  • length(int) - length of the datapipe
  • url(str) - AIStore endpoint
  • etl_name str, optional - Optional etl on the AIS cluster to apply to each object

Example:

from torchdata.datapipes.iter import IterableWrapper, AISFileLister,AISFileLoader ais_prefixes = IterableWrapper([‘gcp://bucket-name/folder/’, ‘aws:bucket-name/folder/’, ‘ais://bucket-name/folder/’, ...]) dp_ais_urls = AISFileLister(url=’localhost:8080’, source_datapipe=ais_prefixes) dp_cloud_files = AISFileLoader(url=’localhost:8080’, source_datapipe=dp_ais_urls) for url, file in dp_cloud_files: ... pass

Functional API

dp_cloud_files = dp_ais_urls.load_files_by_ais(url=’localhost:8080’) for url, file in dp_cloud_files: ... pass

Class: AISSourceLister

@functional_datapipe("ais_list_sources")
class AISSourceLister(IterDataPipe[str])

__init__

def __init__(ais_sources: List[AISSource], prefix="", etl_name=None)

Iterable DataPipe over the full URLs for each of the provided AIS source object types

Arguments:

  • ais_sources List[AISSource] - List of types implementing the AISSource interface: Bucket, ObjectGroup, Object, etc.
  • prefix str, optional - Filter results to only include objects with names starting with this prefix
  • etl_name str, optional - Pre-existing ETL on AIS to apply to all selected objects on the cluster side