Skip to main content
LanceDB makes an excellent data backend for training machine learning models. While a Table by itself can be treated as input to a data loader this is typcially limited. A Permutation can be created to control which rows are accessed and in what order. For an even more complete solution LanceDB also provides a StreamingDataset which adapts the lower level Permutation API into a simple iterable dataset supporting prefetching, elastic determinism, resumability, and multi-threaded transformations.

Basic Data loading

Most model training frameworks iterate through data in batches and feed this data into the model. This process is often referred to as data loading. The simplest way to load data into a model is to iterate a LanceDB table in a loop and feed the data into the model.
import lancedb

db = lancedb.connect("file://some/db/path")
table = db.open_table("some_table")

for batch in table:
    print(batch.to_pydict())
In practice, this is too simplistic for effective training. We may not want to load all the data, or we may want to load the data in a different order, or we may need to apply some sort of processing to the data before training. To achieve this, we can use the StreamingDataset.
from lancedb.streaming import StreamingDataset
import lancedb

db = lancedb.connect("file://some/db/path")
table = db.open_table("some_table")

ds = StreamingDataset(table, shuffle_seed=42)
for sample in ds:
    # sample is a plain Python dict, e.g. {"feature": 0.82, "label": "cat"}
    train_step(sample)

Advanced Data Loading

The StreamingDataset wraps a LanceDB Table and, by default, simply adds prefetching and transformation from Arrow format to Python. However, it can be configured to handle more advanced scenarios. To help understand we will consider a model trained with stochastic gradient descent (SGD) and distributed data parallelism (DDP). In this example we need to load multiple GPUs, across multiple servers, with batches of data. After each batch is processed the GPUs exchange weights and the next batch is loaded. This introduces a number of concepts and we will use terms from PyTorch in our examples:
  • World size - The world size is the number of GPUs that we are loading. For example, if we have 2 servers and each server has 4 GPUs then the world size is 8.
  • Rank - When loading data we will create a process for each GPU. Each process will have its own rank. This is an integer in the range [0, world_size). This is important for data loading because each rank should get its own portion of the data (e.g. rank 3 and rank 4 will see different rows).
  • Global batch size - The global batch size is the number of rows, across all GPUs, that we process in each step of the SGD algorithm. For example, if we have 8 GPUs and the global batch size is 1024 then we need to load 128 rows onto each GPU for each step. The global batch size must be divisible by the world size.
  • Batch size - The batch size is the number of rows, for a single GPU, that we process in each step of the SGD algorithm. Once again, if we have 8 GPUs and the global batch size is 1024 then the batch size is 128.
Other concepts (read batch size, num_workers) will be discussed later but are specific to a particular section.

Prefetching

PyTorch datasets were originally built around in-memory structures like a Pandas DataFrame. When they are iterated they yield a single sample at a time. This makes sense for simple in-memory structure but if try and access a (potentially remote) database one row at a time the per-call overhead will typically be far too expensive. To work around this the StreamingDataset fetches data and transforms data in batches. The read_batch_size parameter controls how many rows we read per call to the underlying Table. In addition to batching up requests to the database the prefetching mechanism will read ahead in the background. While the first batch is being transformed and processed by the GPU a StreamingDataset will also be reading the next batch of data. The prefetch_batches parameter controls how many batches of data we will read ahead. This should typically be at least 2. A larger value can provide more buffering against jittery workloads but will require more RAM.
ds = StreamingDataset(
    table,
    shuffle_seed=42,
    read_batch_size=256,   # rows fetched per LanceDB call per split
    prefetch_batches=8,    # batches to keep in flight per split
)

Transformation

Many model training workloads require a transformation step between loading the data and training the model. For example, we may need to decode images, tokenize text, or normalize data. A transformation function can be provided using the transform parameter. Transformation can be expensive and we often want to utilize multiple CPUs to apply these transformations. By default transformations will be applied with a ThreadPoolExecutor with a number of workers equal to the number of CPUs. Transformations are applied on batches of data, not individual samples, to allow transformations to amortize per-batch overhead. A transformation function will receive an Arrow record batch and should return an iterable of samples (one sample per row). The StreamingDataset does not care what format these samples take but they should match what your data loader expects. For example, the default PyTorch dataloader’s collation function can except a variety of different sample types, with a python dictionary being one of the most common. The default transformation function converts the Arrow record batch into an iterable of python dictionaries without doing any processing of the data itself.
import pyarrow as pa

def normalize(batch: pa.RecordBatch) -> list[dict]:
    # This pure-Python loop holds the GIL and is shown for illustration only.
    # In practice, prefer a library like torchvision or numpy that releases the
    # GIL so the ThreadPoolExecutor can run transforms in parallel.
    rows = batch.to_pylist()
    for row in rows:
        row["image"] = [v / 255.0 for v in row["image"]]
    return rows

ds = StreamingDataset(table, shuffle_seed=42, transform=normalize)

Worker Info

The thread-based transformation model that StreamingDataset uses by default is only effective when the transform function releases the GIL. This is true for most Python scientific libraries (e.g. numpy, pandas, arrow, torchvision) but there are some libraries which may not do this. Because of this limitation PyTorch supports launching multiple worker processes per rank (the num_workers variable in the data loader). The StreamingDataset can handle this scenario and will call get_worker_info to determine the worker id and the total number of workers and will adjust accordingly. However, we find this multiprocessing to be inefficient (adds pickling and transfer overhead as well as significantly increasing the amount of RAM required) and suggest starting with num_workers=1 and only using a higher value when you’ve confirmed an unavoidable GIL bottleneck.

Observability & performance

Optimizing data loader performance is tricky because it can be difficult to locate the bottleneck. What is often blamed on I/O ends up being a CPU bottleneck in the transform stage (or vice versa). To assist developers the StreamingDataset offers a number of observability controls. The raw_queue_depth can be polled on a regular basis to determine the number of rows that have been loaded (I/O finished) but not trasnformed. The prefetch_queue_depth can be polled to determine the number of rows that have been transformed and are waiting to be consumed by the GPU. As long as these queue sizes are non-empty the GPU should be operating at capacity. If the prefetch_queue_depth is consistently zero but the raw_queue_depth is not then you have a CPU transformation bottleneck. You should investigate GIL bottlenecks or look for ways to optimize your transformation. This can often be done by batching the compute work. If both the prefetch_queue_depth and raw_queue_depth are consistently zero then you are bottlenecked by I/O. A larger read batch size or clumped shuffling could help to reduce the I/O bottleneck.
import threading, time

ds = StreamingDataset(table, shuffle_seed=42)

def log_pipeline_health():
    while True:
        print(
            f"unscanned={ds.unscanned_rows} "
            f"raw={ds.raw_queue_depth} "
            f"cooked={ds.prefetch_queue_depth} "
            f"consumed={ds.consumed_rows}"
        )
        time.sleep(1.0)

monitor = threading.Thread(target=log_pipeline_health, daemon=True)
monitor.start()

for sample in ds:
    train_step(sample)

print(f"fetch time: {ds.fetch_time:.2f}s  transform time: {ds.transform_time:.2f}s")

Filtering data

By default the streaming dataset will load all rows and all columns. LanceDB is a columnar database that also supports efficient random access. Reducing the number of columns you load will have a direct impact on I/O performance. Reducing the number of rows you load will also have an impact on I/O performance, especially if you have a very selective filter, are loading large values, or the data is local or in the LanceDB enterprise cache. You can use the columns parameter to specify which columns to load. You can use the filter parameter to specify which rows to load. Filtered rows are not loaded from storage. LanceDB will first calculate the row ids that match the filter, then divide the matching row ids into splits for loading.
ds = StreamingDataset(
    table,
    shuffle_seed=42,
    columns=["image", "label"],       # skip all other columns
    filter="category = 'train'",      # only training rows
)

Shuffling rows

By default, the streaming dataset will access the data in the order the data is stored in the table. This can cause our model to learn artifacts specific to the order of the data. This is one of many ways we can “overfit” our model to our data. To avoid this, we typically want to shuffle the data before training. This is done by setting the shuffle parameter to True. If this is not set then the data will be divided into splits sequentially. By default, the shuffle seed will be a combination of the provided shuffle_seed and the provided epoch which will ensure that each epoch has a different ordering. If you wish for all epochs to have the same ordering then you can set the epoch parameter to 0 (it is only used to determine the shuffle seed). If you do not want to provide a shuffle_seed then you can set it to None and a random seed will be used instead.
# Training loop: each epoch gets a different shuffled ordering
for epoch in range(num_epochs):
    ds = StreamingDataset(
        table,
        shuffle_seed=42,
        epoch=epoch,       # changes the permutation each epoch
    )
    for sample in ds:
        train_step(sample)

# Evaluation: deterministic sequential order, no shuffle
eval_ds = StreamingDataset(table, shuffle=False)
for sample in eval_ds:
    eval_step(sample)
Shuffling can have significant impacts on I/O performance, especially if you are loading data from cloud storage. In many cases the GPU pipeline is slow enough that this penalty will not be noticeable. However, you can use the shuffle_clump_size parameter to shuffle the data in clumps (small contiguous batches that get shuffled together). This will give some penalty to the randomness of the shuffle, but will significantly improve I/O performance.
# Clumped shuffle: groups of 16 contiguous rows are shuffled together,
# preserving read locality while still randomising the global ordering.
ds = StreamingDataset(
    table,
    shuffle_seed=42,
    shuffle_clump_size=16,
)

Data splits and elasticity

The data is divided across a number of processes based on the world size and the number of workers per rank. These groups are called “splits” and by default the dataset will create world_size * num_workers splits. This is simple but can lead to problems if you need to rerun the training with a different number of GPUs (i.e. different world size). For example, if we have 8 GPUs and 1 worker per rank then we have 8 splits. If we later train with 4 GPUs then we will only have 4 splits and the data will be divided differently. This could lead to a different model being trained which can make deterministic training harder to reproduce. To work around this, you can manually specify the number of splits used. This allows you to select a larger number of splits than the default and can provide you with a property called “elastic determinism”. If we consider our example above, if we are going to train on 4 GPUs we can set num_splits=8 and we will divide the data as if we had 8 GPUs. Each rank will be assigned 2 splits and will pull from those two splits in a round-robin fashion. This means each global batch that gets generated will be the same as the global batches that were generated when we trained with 8 GPUs. In order for this determinism to work, the number of splits must be a multiple of the world size. Our simple example above works because 8 is a multiple of 4. However, what would happen if we trained with 8 GPUs and then wanted to train with 6 GPUs. In that case, num_splits=8 would not be a multiple of 6 and the data would not be divided evenly across the ranks. To make this work we can choose a num_splits value that is a multiple of both 8 and 6. For example, we could choose num_splits=24. Good choices for num_splits are highly composite numbers like 48 (allows for 1, 2, 3, 4, 6, 8, 12, 16, 24, and 48 GPUs) and 60 (allows for 1, 2, 3, 4, 5, 6, 10, 12, 15, 20, 30, 40, and 60 GPUs).
import torch.distributed as dist

dist.init_process_group("nccl")
rank = dist.get_rank()
world_size = dist.get_world_size()

# num_splits=48 is divisible by 1, 2, 3, 4, 6, 8, 12, 16, 24, 48
# so this dataset works unchanged as you scale up or down GPUs.
ds = StreamingDataset(
    table,
    num_splits=48,
    shuffle_seed=42,
    epoch=current_epoch,
    rank=rank,
    world_size=world_size,
)

for sample in ds:
    train_step(sample)

Checkpointing and resumability

Model training is an expensive process and failures can often occur partway through. Checkpointing allows you to save the model state and resume training from where you left off. Most modern deep learning frameworks support checkpointing models. However, we must also be able to checkpoint the data loader so that we can resume training from where we left off. To support this the streaming dataset provides the state_dict and load_state_dict methods so that you can save and load the state of the data loader. These methods should be called by your training framework when you want to save or load a checkpoint. The state_dict method returns a simple python dictionary that can easily be persisted.
import torch

ds = StreamingDataset(table, num_splits=48, shuffle_seed=42, rank=rank, world_size=world_size)

for step, sample in enumerate(ds):
    train_step(sample)

    if step % checkpoint_interval == 0:
        torch.save(
            {"model": model.state_dict(), "dataloader": ds.state_dict()},
            f"checkpoint_{step}.pt",
        )

# --- resuming after a crash ---
checkpoint = torch.load("checkpoint_100.pt")
model.load_state_dict(checkpoint["model"])

ds = StreamingDataset(table, num_splits=48, shuffle_seed=42, rank=rank, world_size=world_size)
ds.load_state_dict(checkpoint["dataloader"])

for sample in ds:   # continues from step 100, no repeated or skipped rows
    train_step(sample)

Permutations

In some more complicated scenarios, you may want to flexibility of the StreamingDataset to shuffle, split, and select data while not buying into the full behavior of the iterable dataset. In these cases you can use the Permutation class. This is a lower level class which the StreamingDataset is built on top of. A Permutation can be used to define a custom ordering of the data. You can then index into the Permutation using the __getitems__ method to access rows by their ordering in the Permutation. More details on the Permutation class can be found in the API reference.