悠悠楠杉
解决PyTorch多标签分类中批量大小不一致的实战技巧
解决PyTorch多标签分类中批量大小不一致的实战技巧
问题背景:多标签分类的"变长"困境
在电商商品分类、医疗影像诊断等场景中,我们常遇到样本标签数量不一致的情况——有的商品有3个标签,有的可能有20个。传统PyTorch的DataLoader
默认要求batch内样本保持维度一致,这就像试图把不同长度的绳子硬塞进同一个盒子。
python
典型错误示例
labels = [
[1, 0, 1], # 样本1有3个标签
[0, 1, 1, 0, 1] # 样本2有5个标签
] # 直接stack会报错
核心解决方案:填充与掩码技术
方法1:动态填充(Padding)
使用pad_sequence
统一长度,配合掩码标记无效位置:
python
from torch.nn.utils.rnn import pad_sequence
def collate_fn(batch):
inputs = [item[0] for item in batch]
labels = [torch.tensor(item[1]) for item in batch]
# 填充到当前batch最大长度
padded_labels = pad_sequence(labels, batch_first=True, padding_value=-1)
# 生成掩码 (1有效, 0无效)
masks = (padded_labels != -1).float()
return torch.stack(inputs), padded_labels, masks
方法2:稀疏矩阵表示
对于极稀疏的多标签场景(如百万级标签中每个样本仅激活少数),可用torch.sparse_coo_tensor
:
python
def to_sparse(labels):
indices = [[i for i, val in enumerate(label) if val == 1]
for label in labels]
values = [1] * sum(len(idx) for idx in indices)
return torch.sparse_coo_tensor(indices, values, size=len(labels), max_label)
进阶技巧:损失函数的适应性改造
掩码加权BCE损失
python
class MaskedBCELoss(nn.Module):
def forward(self, pred, target, mask):
loss = F.binary_cross_entropy(pred, target, reduction='none')
return (loss * mask).sum() / mask.sum() # 只计算有效部分
标签采样策略
对超多标签(如>10万)的情况,可采用负采样:
python
def negative_sampling(labels, ratio=0.1):
pos_indices = torch.where(labels == 1)
neg_pool = torch.where(labels == 0)
sampled_neg = neg_pool[torch.randperm(len(neg_pool))[:int(len(pos_indices)*ratio)]]
return torch.cat([pos_indices, sampled_neg])
工程实践中的性能优化
- 内存映射技术:对超大规模标签集使用
torch.load(..., map_location='cpu')
- 异步数据加载:配合
DataLoader
的num_workers
和pin_memory
- 混合精度训练:
scaler.scale(loss).backward()
减少显存占用
验证指标的特殊处理
多标签场景下需使用特殊指标:python
from sklearn.metrics import jaccard_score
def multilabelmetrics(ytrue, ypred, threshold=0.5):
ypredbin = (ypred > threshold).astype(int)
return {
'microf1': f1score(ytrue, ypredbin, average='micro'),
'macroiou': jaccardscore(ytrue, ypredbin, average='macro')
}
完整Pipeline示例
python
dataset = CustomDataset(...)
dataloader = DataLoader(
dataset,
batchsize=32,
collatefn=collatefn,
numworkers=4
)
model = MultiLabelModel(numclasses=1000) optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4) lossfn = MaskedBCELoss()
for epoch in range(10):
for inputs, labels, masks in dataloader:
optimizer.zerograd()
outputs = model(inputs)
loss = lossfn(outputs, labels, masks)
loss.backward()
optimizer.step()