基于ASL-ML-Immersion项目的Keras函数式API实践指南
理解Keras函数式API的核心概念
在机器学习模型构建中,Keras提供了两种主要API:Sequential API和Functional API。函数式API相比顺序API提供了更大的灵活性,能够构建更复杂的模型拓扑结构。本教程将基于ASL-ML-Immersion项目中的出租车费预测案例,深入讲解如何使用Keras函数式API构建Wide & Deep模型。
Wide & Deep模型架构解析
Wide & Deep模型结合了两种学习方式:
- Wide部分:负责记忆功能,通过线性模型和特征交叉学习数据相关性
- Deep部分:负责泛化功能,通过嵌入向量和深度神经网络学习特征表示
这种组合架构特别适合推荐系统和预测任务,能够同时捕捉明确的特征关系和深层次的模式。
数据准备与预处理
首先我们需要准备出租车费预测数据集,包含以下关键字段:
- 上车/下车经纬度
- 乘客数量
- 车费金额(预测目标)
CSV_COLUMNS = [
"fare_amount",
"pickup_datetime",
"pickup_longitude",
"pickup_latitude",
"dropoff_longitude",
"dropoff_latitude",
"passenger_count",
"key",
]
使用tf.data API高效加载和处理数据,构建输入管道:
def create_dataset(pattern, batch_size=1, mode="eval"):
dataset = tf.data.experimental.make_csv_dataset(
pattern, batch_size, CSV_COLUMNS, DEFAULTS
)
dataset = dataset.map(features_and_labels)
if mode == "train":
dataset = dataset.shuffle(buffer_size=1000).repeat()
dataset = dataset.prefetch(1)
return dataset
构建模型输入层
使用函数式API首先需要定义输入层,明确每个输入特征的形状和数据类型:
inputs = {
colname: Input(name=colname, shape=(1,), dtype="float32")
for colname in INPUT_COLS
}
特征工程关键步骤
-
分桶处理(Discretization): 将连续经纬度值转换为离散桶,便于后续处理
-
特征交叉(HashedCrossing): 创建位置特征组合,捕捉空间关系
-
嵌入层(Embedding): 将高维稀疏特征转换为低维稠密表示
# 分桶处理
plon = Discretization(lonbuckets, name="plon_bkt")(inputs["pickup_longitude"])
# 特征交叉
p_fc = HashedCrossing(num_bins=NBUCKETS*NBUCKETS, name="p_fc")((plon, plat))
# 嵌入层
pd_embed = Embedding(input_dim=NBUCKETS**4, output_dim=10, name="pd_embed")(pd_fc)
构建Wide & Deep网络
Deep部分构建
# 拼接原始特征和嵌入特征
deep = Concatenate(name="deep_input")([
inputs["pickup_longitude"],
inputs["pickup_latitude"],
inputs["dropoff_longitude"],
inputs["dropoff_latitude"],
Flatten(name="flatten_embedding")(pd_embed),
])
# 添加全连接层
for i, num_nodes in enumerate(dnn_hidden_units, start=1):
deep = Dense(num_nodes, activation="relu", name=f"hidden_{i}")(deep)
Wide部分构建
# 对交叉特征进行one-hot编码
p_onehot = CategoryEncoding(num_tokens=NBUCKETS*NBUCKETS, name="p_onehot")(p_fc)
# 拼接所有wide特征
wide = Concatenate(name="wide_input")([p_onehot, d_onehot, pd_onehot])
合并Wide & Deep部分
concat = Concatenate(name="concatenate")([deep, wide])
output = Dense(1, activation=None, name="output")(concat)
模型编译与训练
定义自定义RMSE评估指标并编译模型:
def rmse(y_true, y_pred):
return tf.sqrt(tf.reduce_mean(tf.square(y_pred - y_true)))
model = Model(inputs=list(inputs.values()), outputs=output)
model.compile(optimizer="adam", loss="mse", metrics=[rmse])
配置训练参数并启动训练:
BATCH_SIZE = 100
NUM_TRAIN_EXAMPLES = 10000 * 50
NUM_EVALS = 50
history = model.fit(
x=trainds,
steps_per_epoch=steps_per_epoch,
epochs=NUM_EVALS,
validation_data=evalds,
callbacks=[TensorBoard(OUTDIR)],
)
模型评估与可视化
训练完成后,我们可以绘制训练过程中的RMSE变化曲线,评估模型性能:
RMSE_COLS = ["rmse", "val_rmse"]
pd.DataFrame(history.history)[RMSE_COLS].plot()
关键知识点总结
-
函数式API优势:相比Sequential API,函数式API可以构建更复杂的模型结构,包括多输入/输出、共享层等
-
特征工程技巧:
- 分桶处理将连续值离散化
- 特征交叉捕捉特征组合关系
- 嵌入层处理高维稀疏特征
-
Wide & Deep架构特点:
- Wide部分擅长记忆明确特征关系
- Deep部分擅长发现潜在模式
- 两者结合提升模型整体表现
-
训练优化:
- 使用tf.data构建高效输入管道
- 合理设置批大小和评估频率
- 监控训练和验证指标
通过本教程,您应该已经掌握了使用Keras函数式API构建复杂模型的核心方法,特别是Wide & Deep架构的实现技巧。这种技术可以广泛应用于各种预测和推荐场景。
创作声明:本文部分内容由AI辅助生成(AIGC),仅供参考