深入理解Keras Functional API与Wide & Deep模型构建
前言
在机器学习模型构建过程中,Keras提供了两种主要的API:Sequential API和Functional API。本文我们将重点探讨Functional API的强大功能,并通过构建一个Wide & Deep模型来展示其灵活性。
Keras Functional API简介
Keras Functional API相比Sequential API提供了更高的灵活性,特别适合构建以下类型的模型:
- 具有复杂拓扑结构的模型
- 多输入或多输出的模型
- 需要共享层的模型
- 非顺序数据流模型(如残差连接)
Functional API的核心思想是通过显式定义输入和输出,以及它们之间的连接方式来构建模型。
Wide & Deep模型架构
Wide & Deep模型结合了两种学习方式:
- Wide部分(记忆组件):通过线性模型和特征交叉来学习特征与标签之间的相关性
- Deep部分(泛化组件):通过嵌入向量和深度神经网络来学习特征的深层表示
这种组合方式能够同时利用记忆和泛化的优势,在推荐系统等场景中表现优异。
数据准备
我们使用出租车费用预测数据集,包含以下特征:
- 上车时间
- 上车经纬度
- 下车经纬度
- 乘客数量
首先需要加载并预处理数据:
# 定义CSV列和默认值
CSV_COLUMNS = [
"fare_amount", "pickup_datetime",
"pickup_longitude", "pickup_latitude",
"dropoff_longitude", "dropoff_latitude",
"passenger_count", "key"
]
LABEL_COLUMN = "fare_amount"
DEFAULTS = [[0.0], ["na"], [0.0], [0.0], [0.0], [0.0], [0.0], ["na"]]
# 创建数据集函数
def create_dataset(pattern, batch_size=1, mode="eval"):
dataset = tf.data.experimental.make_csv_dataset(
pattern, batch_size, CSV_COLUMNS, DEFAULTS
)
dataset = dataset.map(features_and_labels)
if mode == "train":
dataset = dataset.shuffle(buffer_size=1000).repeat()
dataset = dataset.prefetch(1)
return dataset
构建Wide & Deep模型
1. 定义输入层
首先为每个特征创建输入层:
INPUT_COLS = [
"pickup_longitude", "pickup_latitude",
"dropoff_longitude", "dropoff_latitude",
"passenger_count"
]
inputs = {
colname: Input(name=colname, shape=(1,), dtype="float32")
for colname in INPUT_COLS
}
2. 特征工程
使用预处理层进行特征转换:
# 分桶处理
latbuckets = np.linspace(start=40.5, stop=41.0, num=NBUCKETS).tolist()
lonbuckets = np.linspace(start=-74.2, stop=-73.7, num=NBUCKETS).tolist()
plon = Discretization(lonbuckets, name="plon_bkt")(inputs["pickup_longitude"])
plat = Discretization(latbuckets, name="plat_bkt")(inputs["pickup_latitude"])
dlon = Discretization(lonbuckets, name="dlon_bkt")(inputs["dropoff_longitude"])
dlat = Discretization(latbuckets, name="dlat_bkt")(inputs["dropoff_latitude"])
# 特征交叉
p_fc = HashedCrossing(num_bins=NBUCKETS*NBUCKETS, name="p_fc")((plon, plat))
d_fc = HashedCrossing(num_bins=NBUCKETS*NBUCKETS, name="d_fc")((dlon, dlat))
pd_fc = HashedCrossing(num_bins=NBUCKETS*NBUCKETS, name="pd_fc")((plon, plat, dlon, dlat))
# 嵌入层
pd_embed = Embedding(NBUCKETS*NBUCKETS, 10, name="pd_embed")(pd_fc)
3. 构建Deep部分
# 拼接深度网络输入
deep = Concatenate(name='deep_input')(
[inputs['pickup_longitude'], inputs['pickup_latitude'],
inputs['dropoff_longitude'], inputs['dropoff_latitude'],
Flatten(name='flatten_embedding')(pd_embed)]
)
# 添加隐藏层
for i, num_nodes in enumerate(dnn_hidden_units, start=1):
deep = Dense(num_nodes, activation="relu", name=f'hidden_{i}')(deep)
4. 构建Wide部分
# One-hot编码
p_onehot = CategoryEncoding(num_tokens=NBUCKETS*NBUCKETS, name="p_onehot")(p_fc)
d_onehot = CategoryEncoding(num_tokens=NBUCKETS*NBUCKETS, name="d_onehot")(d_fc)
pd_onehot = CategoryEncoding(num_tokens=NBUCKETS*NBUCKETS, name="pd_onehot")(pd_fc)
# 拼接宽度网络输入
wide = Concatenate(name="wide_input")([p_onehot, d_onehot, pd_onehot])
5. 合并Wide和Deep部分
concat = Concatenate(name="wide_deep_concatenate")([wide, deep])
output = Dense(1, name="fare_amount")(concat)
模型编译与训练
定义自定义RMSE指标并编译模型:
def rmse(y_true, y_pred):
return tf.sqrt(tf.reduce_mean(tf.square(y_pred - y_true)))
model = Model(inputs=list(inputs.values()), outputs=output)
model.compile(optimizer="adam", loss="mse", metrics=[rmse])
设置训练参数并开始训练:
BATCH_SIZE = 100
NUM_TRAIN_EXAMPLES = 10000 * 50
NUM_EVALS = 50
NUM_EVAL_EXAMPLES = 10000
trainds = create_dataset(pattern="../data/taxi-train*", batch_size=BATCH_SIZE, mode="train")
evalds = create_dataset(pattern="../data/taxi-valid*", batch_size=BATCH_SIZE, mode="eval").take(NUM_EVAL_EXAMPLES // BATCH_SIZE)
history = model.fit(
x=trainds,
steps_per_epoch=steps_per_epoch,
epochs=NUM_EVALS,
validation_data=evalds,
callbacks=[TensorBoard(OUTDIR)],
)
结果分析
训练完成后,我们可以绘制训练和验证RMSE的变化曲线:
RMSE_COLS = ["rmse", "val_rmse"]
pd.DataFrame(history.history)[RMSE_COLS].plot()
总结
通过本文,我们学习了:
- Keras Functional API的核心概念和优势
- Wide & Deep模型的架构原理
- 如何使用预处理层进行特征工程
- 如何构建和训练一个完整的Wide & Deep模型
Functional API为我们提供了构建复杂模型的强大工具,而Wide & Deep模型则展示了如何结合不同学习方式的优势来解决实际问题。这种技术组合在推荐系统、搜索排名等场景中都有广泛应用。
创作声明:本文部分内容由AI辅助生成(AIGC),仅供参考