基于ASL-ML-Immersion项目的Keras函数式API实践指南

基于ASL-ML-Immersion项目的Keras函数式API实践指南

asl-ml-immersion This repos contains notebooks for the Advanced Solutions Lab: ML Immersion asl-ml-immersion 项目地址: https://gitcode.com/gh_mirrors/as/asl-ml-immersion

理解Keras函数式API的核心概念

在机器学习模型构建中,Keras提供了两种主要API:Sequential API和Functional API。函数式API相比顺序API提供了更大的灵活性,能够构建更复杂的模型拓扑结构。本教程将基于ASL-ML-Immersion项目中的出租车费预测案例,深入讲解如何使用Keras函数式API构建Wide & Deep模型。

Wide & Deep模型架构解析

Wide & Deep模型结合了两种学习方式:

  • Wide部分:负责记忆功能,通过线性模型和特征交叉学习数据相关性
  • Deep部分:负责泛化功能,通过嵌入向量和深度神经网络学习特征表示

这种组合架构特别适合推荐系统和预测任务,能够同时捕捉明确的特征关系和深层次的模式。

数据准备与预处理

首先我们需要准备出租车费预测数据集,包含以下关键字段:

  • 上车/下车经纬度
  • 乘客数量
  • 车费金额(预测目标)
CSV_COLUMNS = [
    "fare_amount",
    "pickup_datetime",
    "pickup_longitude",
    "pickup_latitude",
    "dropoff_longitude",
    "dropoff_latitude",
    "passenger_count",
    "key",
]

使用tf.data API高效加载和处理数据,构建输入管道:

def create_dataset(pattern, batch_size=1, mode="eval"):
    dataset = tf.data.experimental.make_csv_dataset(
        pattern, batch_size, CSV_COLUMNS, DEFAULTS
    )
    dataset = dataset.map(features_and_labels)
    if mode == "train":
        dataset = dataset.shuffle(buffer_size=1000).repeat()
    dataset = dataset.prefetch(1)
    return dataset

构建模型输入层

使用函数式API首先需要定义输入层,明确每个输入特征的形状和数据类型:

inputs = {
    colname: Input(name=colname, shape=(1,), dtype="float32")
    for colname in INPUT_COLS
}

特征工程关键步骤

  1. 分桶处理(Discretization): 将连续经纬度值转换为离散桶,便于后续处理

  2. 特征交叉(HashedCrossing): 创建位置特征组合,捕捉空间关系

  3. 嵌入层(Embedding): 将高维稀疏特征转换为低维稠密表示

# 分桶处理
plon = Discretization(lonbuckets, name="plon_bkt")(inputs["pickup_longitude"])

# 特征交叉
p_fc = HashedCrossing(num_bins=NBUCKETS*NBUCKETS, name="p_fc")((plon, plat))

# 嵌入层
pd_embed = Embedding(input_dim=NBUCKETS**4, output_dim=10, name="pd_embed")(pd_fc)

构建Wide & Deep网络

Deep部分构建

# 拼接原始特征和嵌入特征
deep = Concatenate(name="deep_input")([
    inputs["pickup_longitude"],
    inputs["pickup_latitude"],
    inputs["dropoff_longitude"],
    inputs["dropoff_latitude"],
    Flatten(name="flatten_embedding")(pd_embed),
])

# 添加全连接层
for i, num_nodes in enumerate(dnn_hidden_units, start=1):
    deep = Dense(num_nodes, activation="relu", name=f"hidden_{i}")(deep)

Wide部分构建

# 对交叉特征进行one-hot编码
p_onehot = CategoryEncoding(num_tokens=NBUCKETS*NBUCKETS, name="p_onehot")(p_fc)

# 拼接所有wide特征
wide = Concatenate(name="wide_input")([p_onehot, d_onehot, pd_onehot])

合并Wide & Deep部分

concat = Concatenate(name="concatenate")([deep, wide])
output = Dense(1, activation=None, name="output")(concat)

模型编译与训练

定义自定义RMSE评估指标并编译模型:

def rmse(y_true, y_pred):
    return tf.sqrt(tf.reduce_mean(tf.square(y_pred - y_true)))

model = Model(inputs=list(inputs.values()), outputs=output)
model.compile(optimizer="adam", loss="mse", metrics=[rmse])

配置训练参数并启动训练:

BATCH_SIZE = 100
NUM_TRAIN_EXAMPLES = 10000 * 50
NUM_EVALS = 50

history = model.fit(
    x=trainds,
    steps_per_epoch=steps_per_epoch,
    epochs=NUM_EVALS,
    validation_data=evalds,
    callbacks=[TensorBoard(OUTDIR)],
)

模型评估与可视化

训练完成后,我们可以绘制训练过程中的RMSE变化曲线,评估模型性能:

RMSE_COLS = ["rmse", "val_rmse"]
pd.DataFrame(history.history)[RMSE_COLS].plot()

关键知识点总结

  1. 函数式API优势:相比Sequential API,函数式API可以构建更复杂的模型结构,包括多输入/输出、共享层等

  2. 特征工程技巧

    • 分桶处理将连续值离散化
    • 特征交叉捕捉特征组合关系
    • 嵌入层处理高维稀疏特征
  3. Wide & Deep架构特点

    • Wide部分擅长记忆明确特征关系
    • Deep部分擅长发现潜在模式
    • 两者结合提升模型整体表现
  4. 训练优化

    • 使用tf.data构建高效输入管道
    • 合理设置批大小和评估频率
    • 监控训练和验证指标

通过本教程,您应该已经掌握了使用Keras函数式API构建复杂模型的核心方法,特别是Wide & Deep架构的实现技巧。这种技术可以广泛应用于各种预测和推荐场景。

asl-ml-immersion This repos contains notebooks for the Advanced Solutions Lab: ML Immersion asl-ml-immersion 项目地址: https://gitcode.com/gh_mirrors/as/asl-ml-immersion

创作声明:本文部分内容由AI辅助生成(AIGC),仅供参考

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

郁俪晟Gertrude

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值