Mlflow 学习教程

目录

MLflow 概述

核心组件

常用操作

安装 MLflow, 启动 MLflow 用户界面

基本实验跟踪流程

mflow web 页面展示

加载和使用已记录的模型

基本案例:完整工作流

总结


MLflow 概述

MLflow 是一个开源的机器学习生命周期管理平台,由 Databricks 创建并维护。它主要解决机器学习开发中的以下问题:

  1. 实验跟踪:记录和查询实验代码、参数、指标和输出

  2. 项目管理:打包可重用、可复现的代码

  3. 模型管理:从多种 ML 库中部署模型

  4. 模型注册:集中管理模型的完整生命周期

核心组件

  1. Tracking:记录和查询实验

  2. Projects:打包可重用代码

  3. Models:管理模型

  4. Model Registry:集中模型管理

常用操作

安装 MLflow, 启动 MLflow 用户界面
# 安装环境
pip install mlflow 

#指定mlruns目录启动mflow,否则会抓不到目标值
mlflow ui --backend-store-uri ./mlruns  


# 使用 SQLite 作为后端存储
mlflow ui --backend-store-uri sqlite:///mlflow.db --default-artifact-root ./mlruns
基本实验跟踪流程
import mlflow
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score

# 设置实验名称(如果不存在会自动创建)
mlflow.set_experiment("Iris Classification")

# 开始一个 MLflow 运行
with mlflow.start_run():
    # 加载数据
    iris = load_iris()
    X_train, X_test, y_train, y_test = train_test_split(iris.data, iris.target, test_size=0.2)

    # 记录参数
    mlflow.log_param("n_estimators", 100)
    mlflow.log_param("max_depth", 5)

    # 训练模型
    model = RandomForestClassifier(n_estimators=100, max_depth=5)
    model.fit(X_train, y_train)

    # 计算并记录指标
    y_pred = model.predict(X_test)
    accuracy = accuracy_score(y_test, y_pred)
    mlflow.log_metric("accuracy", accuracy)

    # 记录模型
    mlflow.sklearn.log_model(model, "random_forest_model")

    # 记录一个 artifact(如 CSV 文件或图片)
    import pandas as pd

    pd.DataFrame(X_test).to_csv("test_data.csv", index=False)
    mlflow.log_artifact("test_data.csv")
运行结果:
2025/07/03 23:19:36 INFO mlflow.tracking.fluent: Experiment with name 'Iris Classification' does not exist. Creating a new experiment.
2025/07/03 23:19:37 WARNING mlflow.models.model: `artifact_path` is deprecated. Please use `name` instead.
2025/07/03 23:19:40 WARNING mlflow.models.model: Model logged without a signature and input example. Please set `input_example` parameter when logging the model to auto infer the model signature.
mflow web 页面展示

 详情页面信息

test_data.csv 表格信息

模型注册 -- 注册管理页面

记录 artifacts

# 在 with mlflow.start_run() 块内
import matplotlib.pyplot as plt

plt.figure()
plt.scatter(y_test, predictions)
plt.xlabel("Actual")
plt.ylabel("Predicted")
plt.title("Actual vs Predicted")
plt.savefig("scatter_plot.png")

# 记录图像文件
mlflow.log_artifact("scatter_plot.png")

运行结果图片展示:

加载和使用已记录的模型
# 根据运行ID加载模型
run_id = "your_run_id_here"
model_uri = f"runs:/{run_id}/random_forest_model"
model = mlflow.sklearn.load_model(model_uri)

# 使用模型进行预测
new_predictions = model.predict(X_test)

基本案例:完整工作流

import mlflow
import mlflow.sklearn
from sklearn import datasets
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score, precision_score, recall_score

# 设置实验
mlflow.set_experiment("Iris Classification")


# 加载数据
iris = datasets.load_iris()
X = iris.data
y = iris.target
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)

# 定义模型参数
params = {
    "solver": "lbfgs",
    "max_iter": 1000,
    "multi_class": "auto",
    "random_state": 42
}

# 开始运行
with mlflow.start_run():
    # 记录参数
    for key in params:
        mlflow.log_param(key, params[key])
    
    # 训练模型
    lr = LogisticRegression(**params)
    lr.fit(X_train, y_train)
    
    # 评估模型
    y_pred = lr.predict(X_test)
    accuracy = accuracy_score(y_test, y_pred)
    precision = precision_score(y_test, y_pred, average="weighted")
    recall = recall_score(y_test, y_pred, average="weighted")
    
    # 记录指标
    mlflow.log_metric("accuracy", accuracy)
    mlflow.log_metric("precision", precision)
    mlflow.log_metric("recall", recall)
    
    # 记录模型
    mlflow.sklearn.log_model(lr, "model")
    
    # 记录标签名称
    mlflow.log_dict({"target_names": iris.target_names.tolist()}, "target_names.json")

运行结果详情页面展示:

总结

步骤

操作

1. 记录数据

使用mlflow.log_param(), mlflow.log_metric(), mlflow.log_artifact()

2. 启动 UI

运行 mlflow ui,访问 http://localhost:5000

3. 查看数据

在 Web 界面查看参数、指标、模型、文件

4. 管理模型

注册模型到 Model Registry,或下载模型进行推理

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值