参数 stratify=y
在训练集和测试集划分中非常关键,尤其是在分类任务中。
✅ stratify=y
的意思:
在划分训练集和测试集时,保持各类别在训练集和测试集中的“比例一致”。
换句话说,它会让划分后的每个子集中,各类别的样本比例和原始数据中是一样的,避免某个类别在训练或测试集中比例失衡。
🧪 举个例子更直观:
假设你原始数据中共有 1000 条样本,其中:
- 类别 0 有 600 个
- 类别 1 有 300 个
- 类别 2 有 100 个
如果你这样写:
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)
则划分后,测试集中可能会出现“类别 2 只有几条”甚至没有的极端情况,类别分布会随机、不均衡。
✅ 而如果你这样写:
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, stratify=y)
那么划分后:
- 训练集和测试集中仍然保持 60%:30%:10% 的类别比例,
- 这样训练时不会偏向多数类,测试评估也更公平。
🧠 总结:
是否加 stratify | 说明 |
---|---|
❌ 不加 | 数据随机打乱,可能某些类别偏斜甚至丢失 |
✅ 加 stratify=y | 保持标签分布一致,适合分类问题,强烈建议使用 |