深入理解Keras Functional API与Wide & Deep模型构建

深入理解Keras Functional API与Wide & Deep模型构建

asl-ml-immersion This repos contains notebooks for the Advanced Solutions Lab: ML Immersion asl-ml-immersion 项目地址: https://gitcode.com/gh_mirrors/as/asl-ml-immersion

前言

在机器学习模型构建过程中,Keras提供了两种主要的API:Sequential API和Functional API。本文我们将重点探讨Functional API的强大功能,并通过构建一个Wide & Deep模型来展示其灵活性。

Keras Functional API简介

Keras Functional API相比Sequential API提供了更高的灵活性,特别适合构建以下类型的模型:

  • 具有复杂拓扑结构的模型
  • 多输入或多输出的模型
  • 需要共享层的模型
  • 非顺序数据流模型(如残差连接)

Functional API的核心思想是通过显式定义输入和输出,以及它们之间的连接方式来构建模型。

Wide & Deep模型架构

Wide & Deep模型结合了两种学习方式:

  1. Wide部分(记忆组件):通过线性模型和特征交叉来学习特征与标签之间的相关性
  2. Deep部分(泛化组件):通过嵌入向量和深度神经网络来学习特征的深层表示

这种组合方式能够同时利用记忆和泛化的优势,在推荐系统等场景中表现优异。

数据准备

我们使用出租车费用预测数据集,包含以下特征:

  • 上车时间
  • 上车经纬度
  • 下车经纬度
  • 乘客数量

首先需要加载并预处理数据:

# 定义CSV列和默认值
CSV_COLUMNS = [
    "fare_amount", "pickup_datetime", 
    "pickup_longitude", "pickup_latitude",
    "dropoff_longitude", "dropoff_latitude",
    "passenger_count", "key"
]
LABEL_COLUMN = "fare_amount"
DEFAULTS = [[0.0], ["na"], [0.0], [0.0], [0.0], [0.0], [0.0], ["na"]]

# 创建数据集函数
def create_dataset(pattern, batch_size=1, mode="eval"):
    dataset = tf.data.experimental.make_csv_dataset(
        pattern, batch_size, CSV_COLUMNS, DEFAULTS
    )
    dataset = dataset.map(features_and_labels)
    if mode == "train":
        dataset = dataset.shuffle(buffer_size=1000).repeat()
    dataset = dataset.prefetch(1)
    return dataset

构建Wide & Deep模型

1. 定义输入层

首先为每个特征创建输入层:

INPUT_COLS = [
    "pickup_longitude", "pickup_latitude",
    "dropoff_longitude", "dropoff_latitude",
    "passenger_count"
]

inputs = {
    colname: Input(name=colname, shape=(1,), dtype="float32")
    for colname in INPUT_COLS
}

2. 特征工程

使用预处理层进行特征转换:

# 分桶处理
latbuckets = np.linspace(start=40.5, stop=41.0, num=NBUCKETS).tolist()
lonbuckets = np.linspace(start=-74.2, stop=-73.7, num=NBUCKETS).tolist()

plon = Discretization(lonbuckets, name="plon_bkt")(inputs["pickup_longitude"])
plat = Discretization(latbuckets, name="plat_bkt")(inputs["pickup_latitude"])
dlon = Discretization(lonbuckets, name="dlon_bkt")(inputs["dropoff_longitude"])
dlat = Discretization(latbuckets, name="dlat_bkt")(inputs["dropoff_latitude"])

# 特征交叉
p_fc = HashedCrossing(num_bins=NBUCKETS*NBUCKETS, name="p_fc")((plon, plat))
d_fc = HashedCrossing(num_bins=NBUCKETS*NBUCKETS, name="d_fc")((dlon, dlat))
pd_fc = HashedCrossing(num_bins=NBUCKETS*NBUCKETS, name="pd_fc")((plon, plat, dlon, dlat))

# 嵌入层
pd_embed = Embedding(NBUCKETS*NBUCKETS, 10, name="pd_embed")(pd_fc)

3. 构建Deep部分

# 拼接深度网络输入
deep = Concatenate(name='deep_input')(
    [inputs['pickup_longitude'], inputs['pickup_latitude'],
     inputs['dropoff_longitude'], inputs['dropoff_latitude'],
     Flatten(name='flatten_embedding')(pd_embed)]
)

# 添加隐藏层
for i, num_nodes in enumerate(dnn_hidden_units, start=1):
    deep = Dense(num_nodes, activation="relu", name=f'hidden_{i}')(deep)

4. 构建Wide部分

# One-hot编码
p_onehot = CategoryEncoding(num_tokens=NBUCKETS*NBUCKETS, name="p_onehot")(p_fc)
d_onehot = CategoryEncoding(num_tokens=NBUCKETS*NBUCKETS, name="d_onehot")(d_fc)
pd_onehot = CategoryEncoding(num_tokens=NBUCKETS*NBUCKETS, name="pd_onehot")(pd_fc)

# 拼接宽度网络输入
wide = Concatenate(name="wide_input")([p_onehot, d_onehot, pd_onehot])

5. 合并Wide和Deep部分

concat = Concatenate(name="wide_deep_concatenate")([wide, deep])
output = Dense(1, name="fare_amount")(concat)

模型编译与训练

定义自定义RMSE指标并编译模型:

def rmse(y_true, y_pred):
    return tf.sqrt(tf.reduce_mean(tf.square(y_pred - y_true)))

model = Model(inputs=list(inputs.values()), outputs=output)
model.compile(optimizer="adam", loss="mse", metrics=[rmse])

设置训练参数并开始训练:

BATCH_SIZE = 100
NUM_TRAIN_EXAMPLES = 10000 * 50
NUM_EVALS = 50
NUM_EVAL_EXAMPLES = 10000

trainds = create_dataset(pattern="../data/taxi-train*", batch_size=BATCH_SIZE, mode="train")
evalds = create_dataset(pattern="../data/taxi-valid*", batch_size=BATCH_SIZE, mode="eval").take(NUM_EVAL_EXAMPLES // BATCH_SIZE)

history = model.fit(
    x=trainds,
    steps_per_epoch=steps_per_epoch,
    epochs=NUM_EVALS,
    validation_data=evalds,
    callbacks=[TensorBoard(OUTDIR)],
)

结果分析

训练完成后,我们可以绘制训练和验证RMSE的变化曲线:

RMSE_COLS = ["rmse", "val_rmse"]
pd.DataFrame(history.history)[RMSE_COLS].plot()

总结

通过本文,我们学习了:

  1. Keras Functional API的核心概念和优势
  2. Wide & Deep模型的架构原理
  3. 如何使用预处理层进行特征工程
  4. 如何构建和训练一个完整的Wide & Deep模型

Functional API为我们提供了构建复杂模型的强大工具,而Wide & Deep模型则展示了如何结合不同学习方式的优势来解决实际问题。这种技术组合在推荐系统、搜索排名等场景中都有广泛应用。

asl-ml-immersion This repos contains notebooks for the Advanced Solutions Lab: ML Immersion asl-ml-immersion 项目地址: https://gitcode.com/gh_mirrors/as/asl-ml-immersion

创作声明:本文部分内容由AI辅助生成(AIGC),仅供参考

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

邵冠敬Robin

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值