dataset和DataLoader

简介

Dataset和DataLoader在pytorch中主要用于数据的组织。这两个类通常一起搭配处理深度学习中的数据流。

  • Dataset 用于产出“单个样本”:定义怎么按索引取到一个样本,以及总共有多少个样本。
  • DataLoader 负责“成批取样”:决定批大小、是否打乱、多进程加载、并用 collate_fn 把一个批里的样本“拼起来”(对齐、padding、mask、teacher forcing 等)。

一句话记忆:Dataset 只管“单条样本”;DataLoader 负责“多条怎么一起、怎么并行、怎么对齐”。变长就写 collate_fn,性能就调 workers/pin_memory/分桶。

Dataset

Dataset类作用:定义数据集的统一接口,支持自定义数据加载逻辑。

关键方法:

  • init:初始化数据路径、预处理函数等。
  • len:返回数据集样本总数。
  • getitem:根据索引返回单个样本(数据+标签)。

通常情况下用户都会有自己的数据集,所以定义的数据集类继承dataset。

#准备一个数据集
pairs: List[Tuple[str, str]] = [
    ("我 有 一个 苹果", "i have an apple"),
    ("我 有 一本 书", "i have a book"),
    ("你 喜欢 书", "you like books"),
    ("我 吃 苹果", "i eat apples"),
]


def build_vocab(texts: List[str]):
    tokens = set()
    for s in texts:
        tokens.update([w.lower() for w in s.split()])
    itos = ["<pad>", "<bos>", "<eos>"] + sorted(tokens)
    stoi = {t: i for i, t in enumerate(itos)}
    return stoi, itos


src_texts = [s for s, _ in pairs]
tgt_texts = [t for _, t in pairs]
SRC_STOI, SRC_ITOS = build_vocab(src_texts)
TGT_STOI, TGT_ITOS = build_vocab(tgt_texts)

PAD_IDX, BOS_IDX, EOS_IDX = 0, 1, 2


def encode_src(s: str) -> List[int]:
    return [SRC_STOI[w.lower()] for w in s.split()]


def encode_tgt(s: str) -> List[int]:
    return [BOS_IDX] + [TGT_STOI[w.lower()] for w in s.split()] + [EOS_IDX]



# Dataset:定义“单样本怎么取”
@dataclass
class Example:
    src: List[int]
    tgt: List[int]


class ToyDataset(Dataset):
    def __init__(self, pairs: List[Tuple[str, str]]):
        for s, t in pairs:
            print("encode_src(s)",encode_src(s))
            print("encode_tgt(t)",encode_tgt(t))
        self.data = [Example(encode_src(s), encode_tgt(t)) for s, t in pairs]

    def __len__(self) -> int:
        return len(self.data)

    def __getitem__(self, idx: int) -> Example:
        return self.data[idx]
  • 样本结构:用 Example(src: List[int], tgt: List[int]) 表示一条样本的源序列与目标序列(都是 token id 列表)。
  • 词表与编码:源序列仅分词并映射到 id。目标序列前加 bos、后加 eos,便于自回归训练。
  • 协议:实现 lengetitem 两个方法即可被 DataLoader 使用。

DataLoader

  class torch.utils.data.DataLoader(Data[T_co]):
      def __init__(
          self,
          dataset,
          batch_size: int = 1,
          shuffle: bool | None = None,
          sampler = None,
          batch_sampler = None,
          num_workers: int = 0,
          collate_fn = None,
          pin_memory: bool = False,
          drop_last: bool = False,
          timeout: float = 0,
          worker_init_fn = None,
          multiprocessing_context = None,
          generator = None,
          prefetch_factor: int = 2,
          persistent_workers: bool = False,
          pin_memory_device: str = ""
      ): ...
  • dataset: Dataset 或 IterableDataset 实例。
  • batch_size: 每批样本数。
  • shuffle: 是否在每个 epoch 打乱索引(Map-style 且未显式传 sampler 时有效)。
  • sampler: 自定义样本采样器(与 shuffle 互斥;指定它就不要再用 shuffle)。
  • batch_sampler: 一次直接产出“一个 batch 的索引列表”(与 batch_size、shuffle、sampler 互斥)。
  • num_workers: 进程数(0 为主进程;>0 开多进程并行加载)。
  • collate_fn(samples_list) -> batch: 批内拼接函数;变长序列需要自定义(默认会尝试堆叠等长 tensor)。
  • pin_memory: 将 batch 固定到页锁内存,配合 CUDA 加速 H2D 拷贝。
  • drop_last: 数据量不是 batch_size 整数倍时,是否丢弃最后不满的一批。
  • timeout: 从 worker 等待数据的秒数(>0 时生效)。
  • worker_init_fn(worker_id): 每个 worker 的初始化回调(设随机种子、打开文件等)。
  • multiprocessing_context: 指定多进程上下文(spawn/forkserver 等)。
  • generator: 控制随机性(打乱、采样)用的随机数生成器。
  • prefetch_factor: 每个 worker 预取多少个 batch(num_workers > 0 时有效)。
  • persistent_workers: True 时 DataLoader 第一次迭代后保持 worker 不销毁,提高多轮迭代性能。
  • pin_memory_device: 当 pin_memory=True 时,指定固定内存的设备标签(一般留空即可)。

DataLoader返回是一个可迭代的对象,每次迭代产出一个批次的样本。一个批次的内容就是把当批样本列表交给 collate_fn 的返回值(若未自定义,则用 PyTorch 的默认 default_collate)。而类型取决于两点Dataset.getitem 返回什么(tensor/数值/dict/tuple…)和collate_fn 如何把一批“样本列表”拼成“批次”。

这里重点阐述一下collate_fn是一个用户需要注册的回调函数,目的是要把一个批的样本拼接起来。同时对于输入样本如果张量的形状不一致如变长序列,进行padding、对齐、mask等动作。

def collate_fn(batch: List[Example]):
    src_max = max(len(b.src) for b in batch)
    tgt_max = max(len(b.tgt) for b in batch)

    src_batch: List[List[int]] = []
    tgt_in_batch: List[List[int]] = []
    tgt_out_batch: List[List[int]] = []

    for ex in batch:
        src = ex.src + [PAD_IDX] * (src_max - len(ex.src))
        # teacher forcing:输入去掉最后一个、输出去掉第一个
        tgt_in = ex.tgt[:-1] + [PAD_IDX] * (tgt_max - 1 - len(ex.tgt[:-1]))
        tgt_out = ex.tgt[1:] + [PAD_IDX] * (tgt_max - 1 - len(ex.tgt[1:]))

        src_batch.append(src)
        tgt_in_batch.append(tgt_in)
        tgt_out_batch.append(tgt_out)

    src = torch.tensor(src_batch, dtype=torch.long)         # (B,S)
    tgt_in = torch.tensor(tgt_in_batch, dtype=torch.long)   # (B,T)
    tgt_out = torch.tensor(tgt_out_batch, dtype=torch.long) # (B,T)
    src_pad_mask = src.eq(PAD_IDX)                          # (B,S) True=PAD
    tgt_pad_mask = tgt_in.eq(PAD_IDX)                       # (B,T) True=PAD
    return src, tgt_in, tgt_out, src_pad_mask, tgt_pad_mask
  • 输入:batch 是若干个 Example,每个包含 src: List[int] 与 tgt: List[int](目标序列已含 bos/eos)。
  • 核心:对齐变长序列(右侧 padding),构造 teacher forcing 的 (tgt_in, tgt_out),并生成 padding 掩码。
  • 输出:
    • src: (B, S)
    • tgt_in: (B, T)
    • tgt_out: (B, T)
    • src_pad_mask: (B, S);True=PAD
    • tgt_pad_mask: (B, T);True=PAD

首先使用src_max/tgt_max计算批内最长长度,这样能够将所有样本右侧补到同一长度,方便堆叠为矩阵。

接着定义批内累积的容器src_batch,tgt_in_batch,tgt_out_batch。

  • src_batch: 编码器输入样本的批次。
  • tgt_in_batch:解码器输入样本的批次。
  • tgt_out_batch:解码器输出样本的批次。

其次使用for循环对每个样本进行补齐,使其跟src_max、tgt_max长度一致,[PAD_IDX] * (src_max - len(ex.src))的意思是将[PAD_IDX]的单元素列表重复src_max - len(ex.src)用于拼接追加到ex.src后,使其对齐。tgt_in和tgt_out同理。

在对tgt_in和tgt_out做样本补齐时,因为输入ex.tgt是包含了bos和eos目标序列,对于tgt_in输入需要去掉最后一个token bos,tgt_out输出需要去掉第一个token eos。

然后就是将补齐的序列依次添加到src_batch,tgt_in_batch,tgt_out_batch。这样就对输入的数据进行了分类,把编码器的输入整合了在一起,解码器的输入和输出整合了一起。

最后就是将批内对齐后的源序列列表转换为张量,同时计算src和tag_in的mask,也就是说对数据哪些位置添加了pad。

下面是collate_fn相关的打印数据,便于理解。

batch [Example(src=[9, 10, 3, 11], tgt=[1, 11, 10, 4, 5, 2]), Example(src=[6, 8, 5], tgt=[1, 13, 12, 8, 2])]
src [9, 10, 3, 11]
tgt_in [1, 11, 10, 4, 5]
tgt_out [11, 10, 4, 5, 2]
src [6, 8, 5, 0]
tgt_in [1, 13, 12, 8, 0]
tgt_out [13, 12, 8, 2, 0]
src_batch [[9, 10, 3, 11], [6, 8, 5, 0]]
tgt_in_batch [[1, 11, 10, 4, 5], [1, 13, 12, 8, 0]]
tgt_out_batch [[11, 10, 4, 5, 2], [13, 12, 8, 2, 0]]
src tensor([[ 9, 10,  3, 11],
        [ 6,  8,  5,  0]])
tgt_in tensor([[ 1, 11, 10,  4,  5],
        [ 1, 13, 12,  8,  0]])
tgt_out tensor([[11, 10,  4,  5,  2],
        [13, 12,  8,  2,  0]])
src_pad_mask tensor([[False, False, False, False],
        [False, False, False,  True]])
tgt_pad_mask tensor([[False, False, False, False, False],
        [False, False, False, False,  True]])
src tensor([[ 9, 10,  3, 11],
        [ 6,  8,  5,  0]])
tgt_in tensor([[ 1, 11, 10,  4,  5],
        [ 1, 13, 12,  8,  0]])
tgt_out tensor([[11, 10,  4,  5,  2],
        [13, 12,  8,  2,  0]])
src_mask tensor([[False, False, False, False],
        [False, False, False,  True]])
tgt_mask tensor([[False, False, False, False, False],
        [False, False, False, False,  True]])

最后完整的示例代码

#!/usr/bin/env python3
"""
最小可运行示例:用 Dataset + DataLoader(含 collate_fn)演示变长序列如何拼批并生成 padding 掩码。

运行:
  python3 dataloader_demo.py
"""

from dataclasses import dataclass
from typing import List, Tuple

import torch
from torch.utils.data import Dataset, DataLoader


# --------------------------
# 1) 准备一点语料(空格分词)
# --------------------------
pairs: List[Tuple[str, str]] = [
    ("我 有 一个 苹果", "i have an apple"),
    ("我 有 一本 书", "i have a book"),
    ("你 喜欢 书", "you like books"),
    ("我 吃 苹果", "i eat apples"),
]


def build_vocab(texts: List[str]):
    tokens = set()
    for s in texts:
        tokens.update([w.lower() for w in s.split()])
    itos = ["<pad>", "<bos>", "<eos>"] + sorted(tokens)
    stoi = {t: i for i, t in enumerate(itos)}
    return stoi, itos


src_texts = [s for s, _ in pairs]
tgt_texts = [t for _, t in pairs]
SRC_STOI, SRC_ITOS = build_vocab(src_texts)
TGT_STOI, TGT_ITOS = build_vocab(tgt_texts)

PAD_IDX, BOS_IDX, EOS_IDX = 0, 1, 2


def encode_src(s: str) -> List[int]:
    return [SRC_STOI[w.lower()] for w in s.split()]


def encode_tgt(s: str) -> List[int]:
    return [BOS_IDX] + [TGT_STOI[w.lower()] for w in s.split()] + [EOS_IDX]


# --------------------------
# 2) Dataset:定义“单样本怎么取”
# --------------------------
@dataclass
class Example:
    src: List[int]
    tgt: List[int]


class ToyDataset(Dataset):
    def __init__(self, pairs: List[Tuple[str, str]]):
        for s, t in pairs:
            print("encode_src(s)",encode_src(s))
            print("encode_tgt(t)",encode_tgt(t))
        self.data = [Example(encode_src(s), encode_tgt(t)) for s, t in pairs]

    def __len__(self) -> int:
        return len(self.data)

    def __getitem__(self, idx: int) -> Example:
        return self.data[idx]


# --------------------------
# 3) collate_fn:把“样本列表”拼成一批(对齐 padding + 生成 mask + teacher forcing)
# --------------------------
def collate_fn(batch: List[Example]):
    src_max = max(len(b.src) for b in batch)
    #计算批次内最长长度,这样能将样本右侧补齐到同一长度,方便堆叠矩阵
    tgt_max = max(len(b.tgt) for b in batch)

    src_batch: List[List[int]] = []
    tgt_in_batch: List[List[int]] = []
    tgt_out_batch: List[List[int]] = []

    print("batch",batch)
    for ex in batch:
        src = ex.src + [PAD_IDX] * (src_max - len(ex.src))
        # teacher forcing:输入去掉最后一个、输出去掉第一个
        tgt_in = ex.tgt[:-1] + [PAD_IDX] * (tgt_max - 1 - len(ex.tgt[:-1]))
        tgt_out = ex.tgt[1:] + [PAD_IDX] * (tgt_max - 1 - len(ex.tgt[1:]))
        print("src",src)
        print("tgt_in",tgt_in)
        print("tgt_out",tgt_out)

        src_batch.append(src)
        tgt_in_batch.append(tgt_in)
        tgt_out_batch.append(tgt_out)

    print("src_batch",src_batch)
    print("tgt_in_batch",tgt_in_batch)
    print("tgt_out_batch",tgt_out_batch)
    src = torch.tensor(src_batch, dtype=torch.long)         # (B,S)
    tgt_in = torch.tensor(tgt_in_batch, dtype=torch.long)   # (B,T)
    tgt_out = torch.tensor(tgt_out_batch, dtype=torch.long) # (B,T)
    src_pad_mask = src.eq(PAD_IDX)                          # (B,S) True=PAD
    tgt_pad_mask = tgt_in.eq(PAD_IDX)                       # (B,T) True=PAD
    print("src",src)
    print("tgt_in",tgt_in)
    print("tgt_out",tgt_out)
    print("src_pad_mask",src_pad_mask)
    print("tgt_pad_mask",tgt_pad_mask)
    return src, tgt_in, tgt_out, src_pad_mask, tgt_pad_mask


# --------------------------
# 4) DataLoader:定义“如何按批取样本”并演示输出
# --------------------------
def main():
    dataset = ToyDataset(pairs)
    for i in range(len(dataset)):
        print("dataset",dataset.__getitem__(i))
    loader = DataLoader(
        dataset,
        batch_size=2,
        shuffle=True,
        num_workers=0,        # 跨平台演示,用 0;Linux 可调大
        collate_fn=collate_fn,
        pin_memory=False,
    )
    # EPOCH=40
    # for epoch in range(EPOCH):
    #   for src, tgt_in, tgt_out, src_mask, tgt_mask in loader:
        # 前向、loss、反传、优化
    total_steps = 1000
    data_iter = iter(loader)
    for step in range(total_steps):
        try:
            src, tgt_in, tgt_out, src_mask, tgt_mask = next(data_iter)
        except StopIteration:
            # 当前迭代器用尽,重建一个新的(相当于进入新一轮)
            data_iter = iter(loader)
            src, tgt_in, tgt_out, src_mask, tgt_mask = next(data_iter)
        print("src",src)
        print("tgt_in",tgt_in)
        print("tgt_out",tgt_out)
        print("src_mask",src_mask)
        print("tgt_mask",tgt_mask)

if __name__ == "__main__":
    main()
  • iter(loader): 把可迭代的 DataLoader 变成“批次迭代器”。
  • next(iterator): 从该迭代器中取“下一个批次”。第一次调用就是“第一个 batch”。
    it = iter(loader)
    batch1 = next(it)
    batch2 = next(it)

在 shuffle=True 时,每次 iter(loader) 相当于开始“新的一轮遍历”,顺序会重新洗牌;drop_last、num_workers、pin_memory 等参数会影响批次数量、并行加载与传输性能。

当然除了用next迭代,还是用for循环的方式,如下:

  for epoch in range(EPOCH):
      for src, tgt_in, tgt_out, src_mask, tgt_mask in loader:
          print("src", src)
          print("tgt_in", tgt_in)
          print("tgt_out", tgt_out)
          print("src_mask", src_mask)
          print("tgt_mask", tgt_mask)