PyTorch DataLoader的工作机制

PyTorch-DataLoader的工作机制

阅读前提:

  • 了解 pytorch dataloader 的基本用法

最近调试代码时,遇到了PyTorch多进程加载数据的问题,在windows中遇到多进程启动失败的问题,问题连接:https://pytorch.org/docs/stable/notes/windows.html#multiprocessing-error-without-if-clause-protection 然后在github的issue 找到了解决办法:在if __name__ == '__main__':中去启动dataloader的迭代,这样可以避免启动子进程时windows系统又去启动dataloader的迭代。虽然问题解决了,但是我比较好奇,DataLoader多进程加载数据的机制。

为了学习DataLoader的机制,最好是看它的源码,而且看最早期的源码,因为比较简单,但核心思想是一致的。因此我看了PyTorch v0.1.4 版本的源码,比现在最新版本的源码简单多了。看完之后总算对多进程加载数据有了更清楚的理解,也理解了其中进程之间队列的死锁问题及其处理方法。

代码: github v0.1.4 版本的 dataloader.py

https://github.com/pytorch/pytorch/blob/v0.1.4/torch/utils/data/dataloader.py https://github.com/pytorch/pytorch/blob/v0.1.4/torch/utils/data/dataset.py https://github.com/pytorch/pytorch/blob/v0.1.4/torch/utils/data/sampler.py

dataset、sampler 与 dataloader的关系?

dataset 提供索引的数据 sampler 提供如何返回datasetindex,比如:顺序与随机 dataloader 提供一个batch的数据,并通过多进程的方式返回数据。它可以通过sampler 构造一个batch index ,然后利用多进程、队列的方法调用dataset的方法来返回一个batch的数据。

dataset源码:

class Dataset(object):

    def __getitem__(self, index):
        raise NotImplementedError

    def __len__(self):
        raise NotImplementedError


class TensorDataset(Dataset):

    def __init__(self, data_tensor, target_tensor):
        assert data_tensor.size(0) == target_tensor.size(0)
        self.data_tensor = data_tensor
        self.target_tensor = target_tensor
        if self.data_tensor.dim() == 1:
            self.data_tensor = self.data_tensor.view(-1, 1)
        if self.target_tensor.dim() == 1:
            self.target_tensor = self.target_tensor.view(-1, 1)

    def __getitem__(self, index):
        return self.data_tensor[index], self.target_tensor[index]

    def __len__(self):
        return self.data_tensor.size(0)

sampler 源码:

class Sampler(object):

    def __init__(self, data_source):
        pass

    def __iter__(self):
        raise NotImplementedError

    def __len__(self):
        raise NotImplementedError


class SequentialSampler(Sampler):

    def __init__(self, data_source):
        self.num_samples = len(data_source)

    def __iter__(self):
        return iter(range(self.num_samples))

    def __len__(self):
        return self.num_samples


class RandomSampler(Sampler):

    def __init__(self, data_source):
        self.num_samples = len(data_source)

    def __iter__(self):
        return iter(torch.randperm(self.num_samples).long())

    def __len__(self):
        return self.num_samples

dataloader 源码:

class DataLoader(object):
    """
    Data loader. Combines a dataset and a sampler, and provides
    single- or multi-process iterators over the dataset.
    """

    def __init__(self, dataset, batch_size=1, shuffle=False,
                 sampler=None, num_workers=0, collate_fn=default_collate):
        self.dataset     = dataset
        self.batch_size  = batch_size
        self.num_workers = num_workers
        self.collate_fn  = collate_fn

        if sampler is not None:
            self.sampler = sampler
        elif shuffle:
            self.sampler = RandomSampler(dataset)
        elif not shuffle:
            self.sampler = SequentialSampler(dataset)

    def __iter__(self):
        return DataLoaderIter(self)

    def __len__(self):
        return len(self.sampler)

dataloader.py 主进程与工作进程的关系?

DataLoader 类:

class DataLoader(object):
    """
    Data loader. Combines a dataset and a sampler, and provides
    single- or multi-process iterators over the dataset.
    """

    def __init__(self, dataset, batch_size=1, shuffle=False,
                 sampler=None, num_workers=0, collate_fn=default_collate):
        self.dataset     = dataset
        self.batch_size  = batch_size
        self.num_workers = num_workers
        self.collate_fn  = collate_fn

        if sampler is not None:
            self.sampler = sampler
        elif shuffle:
            self.sampler = RandomSampler(dataset)
        elif not shuffle:
            self.sampler = SequentialSampler(dataset)

    def __iter__(self):
        return DataLoaderIter(self)

    def __len__(self):
        return len(self.sampler)

DataLoader 类中定义了__iter____iter__ 是一个特殊方法(special method),用于定义一个可迭代对象(iterable)。可迭代对象是指实现了 __iter__ 方法的对象,该方法返回一个迭代器(iterator)。 这个迭代器是:DataLoaderIter,它__init__方法中定义了两个队列,用于进程之间的通信,如下所示:

class DataLoaderIter(object):
    "Iterates once over the DataLoader's dataset, as specified by the sampler"

    # loader 是 DataLoader  的实例
    def __init__(self, loader):
        self.dataset = loader.dataset
        self.batch_size = loader.batch_size
        self.collate_fn = loader.collate_fn
        self.sampler = loader.sampler
        self.num_workers = loader.num_workers

        self.samples_remaining = len(self.sampler)
        self.sample_iter = iter(self.sampler)

        if self.num_workers:
            self.index_queue = multiprocessing.Queue()
            self.data_queue = multiprocessing.Queue()
            self.batches_outstanding = 0
            self.joined = False

            self.workers = [
                multiprocessing.Process(
                    target=_workerLoop,
                    args=(self.dataset, self.index_queue, self.data_queue, self.collate_fn))
                for i in range(self.num_workers)]

            for w in self.workers:
                w.daemon = True # ensure that the worker exits on process exit
                w.start()
                # prime the prefetch loop with exactly 1 batch per process
                # this ensures no deadlocks on the queues using the blocking queue API
                self._putBatch()

其中self.index_queue 表示一个batch的索引,数据的填入由主进程完成,工作进程会从self.index_queue队列中获取数据,然后将一个批量的数据放入另外一个队列:self.data_queue 然后又next函数返回给迭代器的调用者。

数据时如何从迭代器DataLoaderIter返回的?

迭代器DataLoaderIter 中定义了__next__方法:

def next(self):
    if self.num_workers:
        # multi-process loading
        if self.batches_outstanding:
            assert(not self.joined)
            # maintain at most len(workers)+1 outstanding batches
            # to avoid deadlocks in the queues, using the blocking queue API
            # TODO: add and use non-blocking queue API
            self._putBatch()
            assert(self.batches_outstanding <= len(self.workers) + 1)
            self.batches_outstanding -= 1
            data = self.data_queue.get()

            if isinstance(data, ExceptionWrapper):
                raise data.exc_type(data.exc_msg)
            else:
                return data
        else:
            self._joinWorkers()
            raise StopIteration()
    else:
        # single-process loading
        if self.samples_remaining:
            return _processBatch(self.dataset, self._nextBatch(), self.collate_fn)
        else:
            raise StopIteration()

__next__ = next

dataloader开启多进程时,self.num_workers 表示进程的数量,self.batches_outstanding记录 self.index_queue batch index 的数量,然后先执行self._putBatch()self.index_queue添加一个batch’s index ,然后执行data = self.data_queue.get()来获取一个batch的数据,然后返回data, 这就完成一个迭代过程。

dataloader.py中主进程和工作进程是何时切换的?会发生死锁吗?

PyTorch 中的多进程数据加载机制使用了阻塞式队列(blocking queue)来进行进程间通信。阻塞式队列在队列为空或队列已满时会自动阻塞,直到有新的数据可用或队列中有空余空间为止。但是当工作进程A想要从空队列中取数据时,它首先会获得队列加锁,其他进程暂时无法访问队列。但是工作进程A由于取不到数据,也会被阻塞。这时候操作系统切换到主进程后,依然可以向空队列发送数据,因为工作进程A已经阻塞了。

具体解释是: 在阻塞式队列中,如果一个进程A想要从空队列(阻塞式队列)中获取数据,它会被自动阻塞,并等待队列中有新的数据可用。如果队列中一直没有数据,那么进程A会被阻塞并挂起。在这种情况下,工作进程A通常会释放对队列的锁,允许其他进程访问该队列。这是因为阻塞式队列的设计目的之一就是允许多个生产者和消费者并发地访问队列。

其它工作进程在空队列上等待,如果主进程取得cpu控制权,此时,由于工作队列已经释放了锁,主进程可以成功向队列中发送数据,并且有数据可用,它可以向队列中发送数据。后来当工作进程会被唤醒时,就可以获取到数据,并继续执行后续操作。

需要注意的是,当多个进程同时访问同一个队列时,可能会出现一些竞争条件和锁竞争的情况。这时候就需要使用线程安全的队列(如 queue.Queue)或者进程安全的队列(如 multiprocessing.Queue)来保证多个进程之间的数据交换安全可靠。此外,为了避免死锁和竞争条件等问题,还需要合理地设置队列的容量和缓冲策略。

启动多进程的过程:

for w in self.workers:
    w.daemon = True # ensure that the worker exits on process exit
    w.start()
    # prime the prefetch loop with exactly 1 batch per process
    # this ensures no deadlocks on the queues using the blocking queue API
    self._putBatch()

w.daemon = True : 将工作进程A放入后台 w.start() :启动工作进程A self._putBatch() :向 self.index_queue 发送一个批量的索引,免得在主进程执行next 函数时发送死锁。

当启动工作进程A后,就算主进程没来得及向执行self._putBatch()也没关系,因为工作进程A由于获取不到数据最后会被挂起,并且工作进程A通常会释放对队列的锁。那么切换到主进程后可以继续执行self._putBatch(),最后在主进程调用next()函数从队列(self.data_queue)中拿数据时能拿到数据,因为工作进程由于之前的self._putBatch()函数会往队列self.index_queue发送了数据,然后工作队列会根据self.index_queue的数据往队列(self.data_queue)发送数据,所有下一次主进程向队列(self.data_queue)中拿数据时,可以保证拿到数据,不会造成死锁的局面。

这种死锁的局面就是:主进程向空队列(self.data_queue)拿数据,但是由于是空队列,于是阻塞挂起了 而工作进程也在向空队列(self.index_queue)拿数据,但是由于是空队列,于是阻塞也挂起了。于是都挂起了,但是self.index_queue需要主进程往其中填数据,才能保证工作往队列(self.data_queue)填数据。这样这两个进程互相死锁了。

为了避免这种死锁的情况,在主进程向空队列(self.data_queue)拿数据前要先执行self._putBatch()