dataset和DataLoader
- Ai应用
- 5天前
- 181热度
- 0评论
简介
Dataset和DataLoader在pytorch中主要用于数据的组织。这两个类通常一起搭配处理深度学习中的数据流。
- Dataset 用于产出“单个样本”:定义怎么按索引取到一个样本,以及总共有多少个样本。
- DataLoader 负责“成批取样”:决定批大小、是否打乱、多进程加载、并用 collate_fn 把一个批里的样本“拼起来”(对齐、padding、mask、teacher forcing 等)。
一句话记忆:Dataset 只管“单条样本”;DataLoader 负责“多条怎么一起、怎么并行、怎么对齐”。变长就写 collate_fn,性能就调 workers/pin_memory/分桶。
Dataset
Dataset类作用:定义数据集的统一接口,支持自定义数据加载逻辑。
关键方法:
- init:初始化数据路径、预处理函数等。
- len:返回数据集样本总数。
- getitem:根据索引返回单个样本(数据+标签)。
通常情况下用户都会有自己的数据集,所以定义的数据集类继承dataset。
#准备一个数据集
pairs: List[Tuple[str, str]] = [
("我 有 一个 苹果", "i have an apple"),
("我 有 一本 书", "i have a book"),
("你 喜欢 书", "you like books"),
("我 吃 苹果", "i eat apples"),
]
def build_vocab(texts: List[str]):
tokens = set()
for s in texts:
tokens.update([w.lower() for w in s.split()])
itos = ["<pad>", "<bos>", "<eos>"] + sorted(tokens)
stoi = {t: i for i, t in enumerate(itos)}
return stoi, itos
src_texts = [s for s, _ in pairs]
tgt_texts = [t for _, t in pairs]
SRC_STOI, SRC_ITOS = build_vocab(src_texts)
TGT_STOI, TGT_ITOS = build_vocab(tgt_texts)
PAD_IDX, BOS_IDX, EOS_IDX = 0, 1, 2
def encode_src(s: str) -> List[int]:
return [SRC_STOI[w.lower()] for w in s.split()]
def encode_tgt(s: str) -> List[int]:
return [BOS_IDX] + [TGT_STOI[w.lower()] for w in s.split()] + [EOS_IDX]
# Dataset:定义“单样本怎么取”
@dataclass
class Example:
src: List[int]
tgt: List[int]
class ToyDataset(Dataset):
def __init__(self, pairs: List[Tuple[str, str]]):
for s, t in pairs:
print("encode_src(s)",encode_src(s))
print("encode_tgt(t)",encode_tgt(t))
self.data = [Example(encode_src(s), encode_tgt(t)) for s, t in pairs]
def __len__(self) -> int:
return len(self.data)
def __getitem__(self, idx: int) -> Example:
return self.data[idx]
- 样本结构:用 Example(src: List[int], tgt: List[int]) 表示一条样本的源序列与目标序列(都是 token id 列表)。
- 词表与编码:源序列仅分词并映射到 id。目标序列前加 bos、后加 eos,便于自回归训练。
- 协议:实现 len 和 getitem 两个方法即可被 DataLoader 使用。
DataLoader
class torch.utils.data.DataLoader(Data[T_co]):
def __init__(
self,
dataset,
batch_size: int = 1,
shuffle: bool | None = None,
sampler = None,
batch_sampler = None,
num_workers: int = 0,
collate_fn = None,
pin_memory: bool = False,
drop_last: bool = False,
timeout: float = 0,
worker_init_fn = None,
multiprocessing_context = None,
generator = None,
prefetch_factor: int = 2,
persistent_workers: bool = False,
pin_memory_device: str = ""
): ...
- dataset: Dataset 或 IterableDataset 实例。
- batch_size: 每批样本数。
- shuffle: 是否在每个 epoch 打乱索引(Map-style 且未显式传 sampler 时有效)。
- sampler: 自定义样本采样器(与 shuffle 互斥;指定它就不要再用 shuffle)。
- batch_sampler: 一次直接产出“一个 batch 的索引列表”(与 batch_size、shuffle、sampler 互斥)。
- num_workers: 进程数(0 为主进程;>0 开多进程并行加载)。
- collate_fn(samples_list) -> batch: 批内拼接函数;变长序列需要自定义(默认会尝试堆叠等长 tensor)。
- pin_memory: 将 batch 固定到页锁内存,配合 CUDA 加速 H2D 拷贝。
- drop_last: 数据量不是 batch_size 整数倍时,是否丢弃最后不满的一批。
- timeout: 从 worker 等待数据的秒数(>0 时生效)。
- worker_init_fn(worker_id): 每个 worker 的初始化回调(设随机种子、打开文件等)。
- multiprocessing_context: 指定多进程上下文(spawn/forkserver 等)。
- generator: 控制随机性(打乱、采样)用的随机数生成器。
- prefetch_factor: 每个 worker 预取多少个 batch(num_workers > 0 时有效)。
- persistent_workers: True 时 DataLoader 第一次迭代后保持 worker 不销毁,提高多轮迭代性能。
- pin_memory_device: 当 pin_memory=True 时,指定固定内存的设备标签(一般留空即可)。
DataLoader返回是一个可迭代的对象,每次迭代产出一个批次的样本。一个批次的内容就是把当批样本列表交给 collate_fn 的返回值(若未自定义,则用 PyTorch 的默认 default_collate)。而类型取决于两点Dataset.getitem 返回什么(tensor/数值/dict/tuple…)和collate_fn 如何把一批“样本列表”拼成“批次”。
这里重点阐述一下collate_fn是一个用户需要注册的回调函数,目的是要把一个批的样本拼接起来。同时对于输入样本如果张量的形状不一致如变长序列,进行padding、对齐、mask等动作。
def collate_fn(batch: List[Example]):
src_max = max(len(b.src) for b in batch)
tgt_max = max(len(b.tgt) for b in batch)
src_batch: List[List[int]] = []
tgt_in_batch: List[List[int]] = []
tgt_out_batch: List[List[int]] = []
for ex in batch:
src = ex.src + [PAD_IDX] * (src_max - len(ex.src))
# teacher forcing:输入去掉最后一个、输出去掉第一个
tgt_in = ex.tgt[:-1] + [PAD_IDX] * (tgt_max - 1 - len(ex.tgt[:-1]))
tgt_out = ex.tgt[1:] + [PAD_IDX] * (tgt_max - 1 - len(ex.tgt[1:]))
src_batch.append(src)
tgt_in_batch.append(tgt_in)
tgt_out_batch.append(tgt_out)
src = torch.tensor(src_batch, dtype=torch.long) # (B,S)
tgt_in = torch.tensor(tgt_in_batch, dtype=torch.long) # (B,T)
tgt_out = torch.tensor(tgt_out_batch, dtype=torch.long) # (B,T)
src_pad_mask = src.eq(PAD_IDX) # (B,S) True=PAD
tgt_pad_mask = tgt_in.eq(PAD_IDX) # (B,T) True=PAD
return src, tgt_in, tgt_out, src_pad_mask, tgt_pad_mask
- 输入:batch 是若干个 Example,每个包含 src: List[int] 与 tgt: List[int](目标序列已含 bos/eos)。
- 核心:对齐变长序列(右侧 padding),构造 teacher forcing 的 (tgt_in, tgt_out),并生成 padding 掩码。
- 输出:
- src: (B, S)
- tgt_in: (B, T)
- tgt_out: (B, T)
- src_pad_mask: (B, S);True=PAD
- tgt_pad_mask: (B, T);True=PAD
首先使用src_max/tgt_max计算批内最长长度,这样能够将所有样本右侧补到同一长度,方便堆叠为矩阵。
接着定义批内累积的容器src_batch,tgt_in_batch,tgt_out_batch。
- src_batch: 编码器输入样本的批次。
- tgt_in_batch:解码器输入样本的批次。
- tgt_out_batch:解码器输出样本的批次。
其次使用for循环对每个样本进行补齐,使其跟src_max、tgt_max长度一致,[PAD_IDX] * (src_max - len(ex.src))的意思是将[PAD_IDX]的单元素列表重复src_max - len(ex.src)用于拼接追加到ex.src后,使其对齐。tgt_in和tgt_out同理。
在对tgt_in和tgt_out做样本补齐时,因为输入ex.tgt是包含了bos和eos目标序列,对于tgt_in输入需要去掉最后一个token bos,tgt_out输出需要去掉第一个token eos。
然后就是将补齐的序列依次添加到src_batch,tgt_in_batch,tgt_out_batch。这样就对输入的数据进行了分类,把编码器的输入整合了在一起,解码器的输入和输出整合了一起。
最后就是将批内对齐后的源序列列表转换为张量,同时计算src和tag_in的mask,也就是说对数据哪些位置添加了pad。
下面是collate_fn相关的打印数据,便于理解。
batch [Example(src=[9, 10, 3, 11], tgt=[1, 11, 10, 4, 5, 2]), Example(src=[6, 8, 5], tgt=[1, 13, 12, 8, 2])]
src [9, 10, 3, 11]
tgt_in [1, 11, 10, 4, 5]
tgt_out [11, 10, 4, 5, 2]
src [6, 8, 5, 0]
tgt_in [1, 13, 12, 8, 0]
tgt_out [13, 12, 8, 2, 0]
src_batch [[9, 10, 3, 11], [6, 8, 5, 0]]
tgt_in_batch [[1, 11, 10, 4, 5], [1, 13, 12, 8, 0]]
tgt_out_batch [[11, 10, 4, 5, 2], [13, 12, 8, 2, 0]]
src tensor([[ 9, 10, 3, 11],
[ 6, 8, 5, 0]])
tgt_in tensor([[ 1, 11, 10, 4, 5],
[ 1, 13, 12, 8, 0]])
tgt_out tensor([[11, 10, 4, 5, 2],
[13, 12, 8, 2, 0]])
src_pad_mask tensor([[False, False, False, False],
[False, False, False, True]])
tgt_pad_mask tensor([[False, False, False, False, False],
[False, False, False, False, True]])
src tensor([[ 9, 10, 3, 11],
[ 6, 8, 5, 0]])
tgt_in tensor([[ 1, 11, 10, 4, 5],
[ 1, 13, 12, 8, 0]])
tgt_out tensor([[11, 10, 4, 5, 2],
[13, 12, 8, 2, 0]])
src_mask tensor([[False, False, False, False],
[False, False, False, True]])
tgt_mask tensor([[False, False, False, False, False],
[False, False, False, False, True]])
最后完整的示例代码
#!/usr/bin/env python3
"""
最小可运行示例:用 Dataset + DataLoader(含 collate_fn)演示变长序列如何拼批并生成 padding 掩码。
运行:
python3 dataloader_demo.py
"""
from dataclasses import dataclass
from typing import List, Tuple
import torch
from torch.utils.data import Dataset, DataLoader
# --------------------------
# 1) 准备一点语料(空格分词)
# --------------------------
pairs: List[Tuple[str, str]] = [
("我 有 一个 苹果", "i have an apple"),
("我 有 一本 书", "i have a book"),
("你 喜欢 书", "you like books"),
("我 吃 苹果", "i eat apples"),
]
def build_vocab(texts: List[str]):
tokens = set()
for s in texts:
tokens.update([w.lower() for w in s.split()])
itos = ["<pad>", "<bos>", "<eos>"] + sorted(tokens)
stoi = {t: i for i, t in enumerate(itos)}
return stoi, itos
src_texts = [s for s, _ in pairs]
tgt_texts = [t for _, t in pairs]
SRC_STOI, SRC_ITOS = build_vocab(src_texts)
TGT_STOI, TGT_ITOS = build_vocab(tgt_texts)
PAD_IDX, BOS_IDX, EOS_IDX = 0, 1, 2
def encode_src(s: str) -> List[int]:
return [SRC_STOI[w.lower()] for w in s.split()]
def encode_tgt(s: str) -> List[int]:
return [BOS_IDX] + [TGT_STOI[w.lower()] for w in s.split()] + [EOS_IDX]
# --------------------------
# 2) Dataset:定义“单样本怎么取”
# --------------------------
@dataclass
class Example:
src: List[int]
tgt: List[int]
class ToyDataset(Dataset):
def __init__(self, pairs: List[Tuple[str, str]]):
for s, t in pairs:
print("encode_src(s)",encode_src(s))
print("encode_tgt(t)",encode_tgt(t))
self.data = [Example(encode_src(s), encode_tgt(t)) for s, t in pairs]
def __len__(self) -> int:
return len(self.data)
def __getitem__(self, idx: int) -> Example:
return self.data[idx]
# --------------------------
# 3) collate_fn:把“样本列表”拼成一批(对齐 padding + 生成 mask + teacher forcing)
# --------------------------
def collate_fn(batch: List[Example]):
src_max = max(len(b.src) for b in batch)
#计算批次内最长长度,这样能将样本右侧补齐到同一长度,方便堆叠矩阵
tgt_max = max(len(b.tgt) for b in batch)
src_batch: List[List[int]] = []
tgt_in_batch: List[List[int]] = []
tgt_out_batch: List[List[int]] = []
print("batch",batch)
for ex in batch:
src = ex.src + [PAD_IDX] * (src_max - len(ex.src))
# teacher forcing:输入去掉最后一个、输出去掉第一个
tgt_in = ex.tgt[:-1] + [PAD_IDX] * (tgt_max - 1 - len(ex.tgt[:-1]))
tgt_out = ex.tgt[1:] + [PAD_IDX] * (tgt_max - 1 - len(ex.tgt[1:]))
print("src",src)
print("tgt_in",tgt_in)
print("tgt_out",tgt_out)
src_batch.append(src)
tgt_in_batch.append(tgt_in)
tgt_out_batch.append(tgt_out)
print("src_batch",src_batch)
print("tgt_in_batch",tgt_in_batch)
print("tgt_out_batch",tgt_out_batch)
src = torch.tensor(src_batch, dtype=torch.long) # (B,S)
tgt_in = torch.tensor(tgt_in_batch, dtype=torch.long) # (B,T)
tgt_out = torch.tensor(tgt_out_batch, dtype=torch.long) # (B,T)
src_pad_mask = src.eq(PAD_IDX) # (B,S) True=PAD
tgt_pad_mask = tgt_in.eq(PAD_IDX) # (B,T) True=PAD
print("src",src)
print("tgt_in",tgt_in)
print("tgt_out",tgt_out)
print("src_pad_mask",src_pad_mask)
print("tgt_pad_mask",tgt_pad_mask)
return src, tgt_in, tgt_out, src_pad_mask, tgt_pad_mask
# --------------------------
# 4) DataLoader:定义“如何按批取样本”并演示输出
# --------------------------
def main():
dataset = ToyDataset(pairs)
for i in range(len(dataset)):
print("dataset",dataset.__getitem__(i))
loader = DataLoader(
dataset,
batch_size=2,
shuffle=True,
num_workers=0, # 跨平台演示,用 0;Linux 可调大
collate_fn=collate_fn,
pin_memory=False,
)
# EPOCH=40
# for epoch in range(EPOCH):
# for src, tgt_in, tgt_out, src_mask, tgt_mask in loader:
# 前向、loss、反传、优化
total_steps = 1000
data_iter = iter(loader)
for step in range(total_steps):
try:
src, tgt_in, tgt_out, src_mask, tgt_mask = next(data_iter)
except StopIteration:
# 当前迭代器用尽,重建一个新的(相当于进入新一轮)
data_iter = iter(loader)
src, tgt_in, tgt_out, src_mask, tgt_mask = next(data_iter)
print("src",src)
print("tgt_in",tgt_in)
print("tgt_out",tgt_out)
print("src_mask",src_mask)
print("tgt_mask",tgt_mask)
if __name__ == "__main__":
main()
- iter(loader): 把可迭代的 DataLoader 变成“批次迭代器”。
- next(iterator): 从该迭代器中取“下一个批次”。第一次调用就是“第一个 batch”。
it = iter(loader)
batch1 = next(it)
batch2 = next(it)
在 shuffle=True 时,每次 iter(loader) 相当于开始“新的一轮遍历”,顺序会重新洗牌;drop_last、num_workers、pin_memory 等参数会影响批次数量、并行加载与传输性能。
当然除了用next迭代,还是用for循环的方式,如下:
for epoch in range(EPOCH):
for src, tgt_in, tgt_out, src_mask, tgt_mask in loader:
print("src", src)
print("tgt_in", tgt_in)
print("tgt_out", tgt_out)
print("src_mask", src_mask)
print("tgt_mask", tgt_mask)