Table by itself can be treated
as input to a data loader this is typcially limited. A Permutation can be created to control which rows are accessed
and in what order. For an even more complete solution LanceDB also provides a StreamingDataset which adapts the lower
level Permutation API into a simple iterable dataset supporting prefetching, elastic determinism, resumability,
and multi-threaded transformations.
Basic Data loading
Most model training frameworks iterate through data in batches and feed this data into the model. This process is often referred to as data loading. The simplest way to load data into a model is to iterate a LanceDB table in a loop and feed the data into the model.StreamingDataset.
Advanced Data Loading
TheStreamingDataset wraps a LanceDB Table and, by default, simply adds prefetching and transformation from
Arrow format to Python. However, it can be configured to handle more advanced scenarios. To help understand we
will consider a model trained with stochastic gradient descent (SGD) and distributed data parallelism (DDP). In
this example we need to load multiple GPUs, across multiple servers, with batches of data. After each batch is
processed the GPUs exchange weights and the next batch is loaded. This introduces a number of concepts and we
will use terms from PyTorch in our examples:
- World size - The world size is the number of GPUs that we are loading. For example, if we have 2 servers and each server has 4 GPUs then the world size is 8.
- Rank - When loading data we will create a process for each GPU. Each process will have its own rank. This is
an integer in the range
[0, world_size). This is important for data loading because each rank should get its own portion of the data (e.g. rank 3 and rank 4 will see different rows). - Global batch size - The global batch size is the number of rows, across all GPUs, that we process in each step of the SGD algorithm. For example, if we have 8 GPUs and the global batch size is 1024 then we need to load 128 rows onto each GPU for each step. The global batch size must be divisible by the world size.
- Batch size - The batch size is the number of rows, for a single GPU, that we process in each step of the SGD algorithm. Once again, if we have 8 GPUs and the global batch size is 1024 then the batch size is 128.
Prefetching
PyTorch datasets were originally built around in-memory structures like a Pandas DataFrame. When they are iterated they yield a single sample at a time. This makes sense for simple in-memory structure but if try and access a (potentially remote) database one row at a time the per-call overhead will typically be far too expensive. To work around this theStreamingDataset fetches data and transforms data in batches. The read_batch_size parameter
controls how many rows we read per call to the underlying Table.
In addition to batching up requests to the database the prefetching mechanism will read ahead in the background.
While the first batch is being transformed and processed by the GPU a StreamingDataset will also be reading the
next batch of data. The prefetch_batches parameter controls how many batches of data we will read ahead. This
should typically be at least 2. A larger value can provide more buffering against jittery workloads but will
require more RAM.
Transformation
Many model training workloads require a transformation step between loading the data and training the model. For example, we may need to decode images, tokenize text, or normalize data. A transformation function can be provided using thetransform parameter. Transformation can be expensive and we often want to utilize multiple CPUs to apply
these transformations. By default transformations will be applied with a ThreadPoolExecutor with a number of workers
equal to the number of CPUs.
Transformations are applied on batches of data, not individual samples, to allow transformations to amortize per-batch
overhead. A transformation function will receive an Arrow record batch and should return an iterable of samples (one
sample per row). The StreamingDataset does not care what format these samples take but they should match what your
data loader expects. For example, the default PyTorch dataloader’s collation function can except a variety of different
sample types, with a python dictionary being one of the most common. The default transformation function converts
the Arrow record batch into an iterable of python dictionaries without doing any processing of the data itself.
Worker Info
The thread-based transformation model thatStreamingDataset uses by default is only effective when the transform
function releases the GIL. This is true for most Python scientific libraries (e.g. numpy, pandas, arrow, torchvision)
but there are some libraries which may not do this. Because of this limitation PyTorch supports launching multiple
worker processes per rank (the num_workers variable in the data loader). The StreamingDataset can handle this
scenario and will call get_worker_info to determine the worker id and the total number of workers and will adjust
accordingly. However, we find this multiprocessing to be inefficient (adds pickling and transfer overhead as well
as significantly increasing the amount of RAM required) and suggest starting with num_workers=1 and only using
a higher value when you’ve confirmed an unavoidable GIL bottleneck.
Observability & performance
Optimizing data loader performance is tricky because it can be difficult to locate the bottleneck. What is often blamed on I/O ends up being a CPU bottleneck in the transform stage (or vice versa). To assist developers theStreamingDataset offers a number of observability controls. The raw_queue_depth can be polled on a regular
basis to determine the number of rows that have been loaded (I/O finished) but not trasnformed. The
prefetch_queue_depth can be polled to determine the number of rows that have been transformed and are waiting
to be consumed by the GPU. As long as these queue sizes are non-empty the GPU should be operating at capacity.
If the prefetch_queue_depth is consistently zero but the raw_queue_depth is not then you have a CPU
transformation bottleneck. You should investigate GIL bottlenecks or look for ways to optimize your transformation.
This can often be done by batching the compute work. If both the prefetch_queue_depth and raw_queue_depth are
consistently zero then you are bottlenecked by I/O. A larger read batch size or clumped shuffling could help to
reduce the I/O bottleneck.
Filtering data
By default the streaming dataset will load all rows and all columns. LanceDB is a columnar database that also supports efficient random access. Reducing the number of columns you load will have a direct impact on I/O performance. Reducing the number of rows you load will also have an impact on I/O performance, especially if you have a very selective filter, are loading large values, or the data is local or in the LanceDB enterprise cache. You can use thecolumns parameter to specify which columns to load. You can use the filter parameter to
specify which rows to load. Filtered rows are not loaded from storage. LanceDB will first calculate the row
ids that match the filter, then divide the matching row ids into splits for loading.
Shuffling rows
By default, the streaming dataset will access the data in the order the data is stored in the table. This can cause our model to learn artifacts specific to the order of the data. This is one of many ways we can “overfit” our model to our data. To avoid this, we typically want to shuffle the data before training. This is done by setting theshuffle parameter to True. If this is not set then the data will be divided into splits sequentially.
By default, the shuffle seed will be a combination of the provided shuffle_seed and the provided epoch which
will ensure that each epoch has a different ordering. If you wish for all epochs to have the same ordering then
you can set the epoch parameter to 0 (it is only used to determine the shuffle seed). If you do not want
to provide a shuffle_seed then you can set it to None and a random seed will be used instead.
Shuffling can have significant impacts on I/O performance, especially if you are loading data from cloud storage.
In many cases the GPU pipeline is slow enough that this penalty will not be noticeable. However,
you can use the
shuffle_clump_size parameter to shuffle the data in clumps (small contiguous batches that get
shuffled together). This will give some penalty to the randomness of the shuffle, but will significantly improve
I/O performance.Data splits and elasticity
The data is divided across a number of processes based on the world size and the number of workers per rank. These groups are called “splits” and by default the dataset will createworld_size * num_workers splits.
This is simple but can lead to problems if you need to rerun the training with a different number of GPUs (i.e.
different world size). For example, if we have 8 GPUs and 1 worker per rank then we have 8 splits. If we
later train with 4 GPUs then we will only have 4 splits and the data will be divided differently. This could
lead to a different model being trained which can make deterministic training harder to reproduce.
To work around this, you can manually specify the number of splits used. This allows you to select a larger
number of splits than the default and can provide you with a property called “elastic determinism”. If we
consider our example above, if we are going to train on 4 GPUs we can set num_splits=8 and we will divide
the data as if we had 8 GPUs. Each rank will be assigned 2 splits and will pull from those two splits in
a round-robin fashion. This means each global batch that gets generated will be the same as the global batches
that were generated when we trained with 8 GPUs.
In order for this determinism to work, the number of splits must be a multiple of the world size. Our simple
example above works because 8 is a multiple of 4. However, what would happen if we trained with 8 GPUs and then
wanted to train with 6 GPUs. In that case, num_splits=8 would not be a multiple of 6 and the data would not
be divided evenly across the ranks. To make this work we can choose a num_splits value that is a multiple of
both 8 and 6. For example, we could choose num_splits=24. Good choices for num_splits are highly composite
numbers like 48 (allows for 1, 2, 3, 4, 6, 8, 12, 16, 24, and 48 GPUs) and 60 (allows for 1, 2, 3, 4, 5, 6, 10,
12, 15, 20, 30, 40, and 60 GPUs).
Checkpointing and resumability
Model training is an expensive process and failures can often occur partway through. Checkpointing allows you to save the model state and resume training from where you left off. Most modern deep learning frameworks support checkpointing models. However, we must also be able to checkpoint the data loader so that we can resume training from where we left off. To support this the streaming dataset provides thestate_dict and load_state_dict methods so that you can save and
load the state of the data loader. These methods should be called by your training framework when you want to save or
load a checkpoint. The state_dict method returns a simple python dictionary that can easily be persisted.
Permutations
In some more complicated scenarios, you may want to flexibility of theStreamingDataset to shuffle, split, and select
data while not buying into the full behavior of the iterable dataset. In these cases you can use the Permutation
class. This is a lower level class which the StreamingDataset is built on top of. A Permutation can be used to
define a custom ordering of the data. You can then index into the Permutation using the __getitems__ method to
access rows by their ordering in the Permutation. More details on the Permutation class can be found in the API
reference.