目录
MLflow 概述
MLflow 是一个开源的机器学习生命周期管理平台,由 Databricks 创建并维护。它主要解决机器学习开发中的以下问题:
-
实验跟踪:记录和查询实验代码、参数、指标和输出
-
项目管理:打包可重用、可复现的代码
-
模型管理:从多种 ML 库中部署模型
-
模型注册:集中管理模型的完整生命周期
核心组件
-
Tracking:记录和查询实验
-
Projects:打包可重用代码
-
Models:管理模型
-
Model Registry:集中模型管理
常用操作
安装 MLflow, 启动 MLflow 用户界面
# 安装环境
pip install mlflow
#指定mlruns目录启动mflow,否则会抓不到目标值
mlflow ui --backend-store-uri ./mlruns
# 使用 SQLite 作为后端存储
mlflow ui --backend-store-uri sqlite:///mlflow.db --default-artifact-root ./mlruns
基本实验跟踪流程
import mlflow
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score
# 设置实验名称(如果不存在会自动创建)
mlflow.set_experiment("Iris Classification")
# 开始一个 MLflow 运行
with mlflow.start_run():
# 加载数据
iris = load_iris()
X_train, X_test, y_train, y_test = train_test_split(iris.data, iris.target, test_size=0.2)
# 记录参数
mlflow.log_param("n_estimators", 100)
mlflow.log_param("max_depth", 5)
# 训练模型
model = RandomForestClassifier(n_estimators=100, max_depth=5)
model.fit(X_train, y_train)
# 计算并记录指标
y_pred = model.predict(X_test)
accuracy = accuracy_score(y_test, y_pred)
mlflow.log_metric("accuracy", accuracy)
# 记录模型
mlflow.sklearn.log_model(model, "random_forest_model")
# 记录一个 artifact(如 CSV 文件或图片)
import pandas as pd
pd.DataFrame(X_test).to_csv("test_data.csv", index=False)
mlflow.log_artifact("test_data.csv")
运行结果:
2025/07/03 23:19:36 INFO mlflow.tracking.fluent: Experiment with name 'Iris Classification' does not exist. Creating a new experiment.
2025/07/03 23:19:37 WARNING mlflow.models.model: `artifact_path` is deprecated. Please use `name` instead.
2025/07/03 23:19:40 WARNING mlflow.models.model: Model logged without a signature and input example. Please set `input_example` parameter when logging the model to auto infer the model signature.
mflow web 页面展示
详情页面信息
test_data.csv 表格信息
模型注册 -- 注册管理页面
记录 artifacts
# 在 with mlflow.start_run() 块内
import matplotlib.pyplot as plt
plt.figure()
plt.scatter(y_test, predictions)
plt.xlabel("Actual")
plt.ylabel("Predicted")
plt.title("Actual vs Predicted")
plt.savefig("scatter_plot.png")
# 记录图像文件
mlflow.log_artifact("scatter_plot.png")
运行结果图片展示:
加载和使用已记录的模型
# 根据运行ID加载模型
run_id = "your_run_id_here"
model_uri = f"runs:/{run_id}/random_forest_model"
model = mlflow.sklearn.load_model(model_uri)
# 使用模型进行预测
new_predictions = model.predict(X_test)
基本案例:完整工作流
import mlflow
import mlflow.sklearn
from sklearn import datasets
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score, precision_score, recall_score
# 设置实验
mlflow.set_experiment("Iris Classification")
# 加载数据
iris = datasets.load_iris()
X = iris.data
y = iris.target
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)
# 定义模型参数
params = {
"solver": "lbfgs",
"max_iter": 1000,
"multi_class": "auto",
"random_state": 42
}
# 开始运行
with mlflow.start_run():
# 记录参数
for key in params:
mlflow.log_param(key, params[key])
# 训练模型
lr = LogisticRegression(**params)
lr.fit(X_train, y_train)
# 评估模型
y_pred = lr.predict(X_test)
accuracy = accuracy_score(y_test, y_pred)
precision = precision_score(y_test, y_pred, average="weighted")
recall = recall_score(y_test, y_pred, average="weighted")
# 记录指标
mlflow.log_metric("accuracy", accuracy)
mlflow.log_metric("precision", precision)
mlflow.log_metric("recall", recall)
# 记录模型
mlflow.sklearn.log_model(lr, "model")
# 记录标签名称
mlflow.log_dict({"target_names": iris.target_names.tolist()}, "target_names.json")
运行结果详情页面展示:
总结
步骤 | 操作 |
1. 记录数据 | 使用mlflow.log_param(), mlflow.log_metric(), mlflow.log_artifact() |
2. 启动 UI | 运行 mlflow ui,访问 http://localhost:5000 |
3. 查看数据 | 在 Web 界面查看参数、指标、模型、文件 |
4. 管理模型 | 注册模型到 Model Registry,或下载模型进行推理 |