030-数据库操作与ORM
🔴 难度: 高级 | ⏱️ 预计时间: 6小时 | 📋 前置: 029-网络编程基础
学习目标
完成本章节后,你将能够:
- 掌握Python数据库编程的基础概念
- 使用原生SQL进行数据库操作
- 理解和使用ORM框架(SQLAlchemy)
- 实现数据库连接池和事务管理
- 设计和优化数据库查询
- 处理数据库迁移和版本控制
- 实现数据库安全和性能优化
内容大纲
数据库编程基础
数据库类型和选择
# 数据库编程基础演示
print("=== 数据库编程基础 ===")
import sqlite3
import mysql.connector
import psycopg2
from typing import Dict, List, Optional, Any, Union, Tuple
from dataclasses import dataclass, field
from datetime import datetime, date
from enum import Enum
import json
import logging
from contextlib import contextmanager
import threading
import time
from abc import ABC, abstractmethod
class DatabaseType(Enum):
"""数据库类型枚举"""
SQLITE = "sqlite"
MYSQL = "mysql"
POSTGRESQL = "postgresql"
ORACLE = "oracle"
MONGODB = "mongodb"
@dataclass
class DatabaseConfig:
"""数据库配置类"""
db_type: DatabaseType
host: str = "localhost"
port: int = 3306
database: str = ""
username: str = ""
password: str = ""
charset: str = "utf8mb4"
pool_size: int = 10
max_overflow: int = 20
pool_timeout: int = 30
pool_recycle: int = 3600
def get_connection_string(self) -> str:
"""获取连接字符串"""
if self.db_type == DatabaseType.SQLITE:
return f"sqlite:///{self.database}"
elif self.db_type == DatabaseType.MYSQL:
return f"mysql+pymysql://{self.username}:{self.password}@{self.host}:{self.port}/{self.database}?charset={self.charset}"
elif self.db_type == DatabaseType.POSTGRESQL:
return f"postgresql://{self.username}:{self.password}@{self.host}:{self.port}/{self.database}"
else:
raise ValueError(f"不支持的数据库类型: {self.db_type}")
class DatabaseConnection(ABC):
"""数据库连接抽象基类"""
def __init__(self, config: DatabaseConfig):
self.config = config
self.connection = None
self.is_connected = False
@abstractmethod
def connect(self) -> bool:
"""连接数据库"""
pass
@abstractmethod
def disconnect(self):
"""断开数据库连接"""
pass
@abstractmethod
def execute(self, sql: str, params: Optional[tuple] = None) -> Any:
"""执行SQL语句"""
pass
@abstractmethod
def fetch_one(self, sql: str, params: Optional[tuple] = None) -> Optional[Dict[str, Any]]:
"""获取单条记录"""
pass
@abstractmethod
def fetch_all(self, sql: str, params: Optional[tuple] = None) -> List[Dict[str, Any]]:
"""获取所有记录"""
pass
@abstractmethod
def begin_transaction(self):
"""开始事务"""
pass
@abstractmethod
def commit(self):
"""提交事务"""
pass
@abstractmethod
def rollback(self):
"""回滚事务"""
pass
class SQLiteConnection(DatabaseConnection):
"""SQLite数据库连接类"""
def connect(self) -> bool:
"""连接SQLite数据库"""
try:
self.connection = sqlite3.connect(
self.config.database,
check_same_thread=False,
timeout=self.config.pool_timeout
)
# 设置行工厂,返回字典格式的结果
self.connection.row_factory = sqlite3.Row
self.is_connected = True
print(f"已连接到SQLite数据库: {self.config.database}")
return True
except Exception as e:
print(f"连接SQLite数据库失败: {e}")
return False
def disconnect(self):
"""断开SQLite连接"""
if self.connection:
self.connection.close()
self.is_connected = False
print("SQLite连接已断开")
def execute(self, sql: str, params: Optional[tuple] = None) -> Any:
"""执行SQL语句"""
if not self.is_connected:
raise Exception("数据库未连接")
cursor = self.connection.cursor()
try:
if params:
result = cursor.execute(sql, params)
else:
result = cursor.execute(sql)
self.connection.commit()
return result
except Exception as e:
self.connection.rollback()
raise e
finally:
cursor.close()
def fetch_one(self, sql: str, params: Optional[tuple] = None) -> Optional[Dict[str, Any]]:
"""获取单条记录"""
if not self.is_connected:
raise Exception("数据库未连接")
cursor = self.connection.cursor()
try:
if params:
cursor.execute(sql, params)
else:
cursor.execute(sql)
row = cursor.fetchone()
return dict(row) if row else None
finally:
cursor.close()
def fetch_all(self, sql: str, params: Optional[tuple] = None) -> List[Dict[str, Any]]:
"""获取所有记录"""
if not self.is_connected:
raise Exception("数据库未连接")
cursor = self.connection.cursor()
try:
if params:
cursor.execute(sql, params)
else:
cursor.execute(sql)
rows = cursor.fetchall()
return [dict(row) for row in rows]
finally:
cursor.close()
def begin_transaction(self):
"""开始事务"""
if not self.is_connected:
raise Exception("数据库未连接")
# SQLite默认自动提交,这里手动控制
self.connection.execute("BEGIN")
def commit(self):
"""提交事务"""
if self.connection:
self.connection.commit()
def rollback(self):
"""回滚事务"""
if self.connection:
self.connection.rollback()
# 数据库连接工厂
class DatabaseFactory:
"""数据库连接工厂"""
@staticmethod
def create_connection(config: DatabaseConfig) -> DatabaseConnection:
"""创建数据库连接"""
if config.db_type == DatabaseType.SQLITE:
return SQLiteConnection(config)
elif config.db_type == DatabaseType.MYSQL:
# 这里可以实现MySQL连接类
raise NotImplementedError("MySQL连接尚未实现")
elif config.db_type == DatabaseType.POSTGRESQL:
# 这里可以实现PostgreSQL连接类
raise NotImplementedError("PostgreSQL连接尚未实现")
else:
raise ValueError(f"不支持的数据库类型: {config.db_type}")
# 数据库基础操作演示
def run_database_basics_demo():
"""运行数据库基础操作演示"""
print("\n--- 数据库基础操作演示 ---")
# 创建SQLite数据库配置
config = DatabaseConfig(
db_type=DatabaseType.SQLITE,
database="demo.db"
)
# 创建数据库连接
db = DatabaseFactory.create_connection(config)
try:
# 连接数据库
if not db.connect():
return
# 创建用户表
create_table_sql = """
CREATE TABLE IF NOT EXISTS users (
id INTEGER PRIMARY KEY AUTOINCREMENT,
username VARCHAR(50) UNIQUE NOT NULL,
email VARCHAR(100) NOT NULL,
age INTEGER,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
)
"""
db.execute(create_table_sql)
print("用户表创建成功")
# 插入测试数据
users_data = [
("alice", "alice@example.com", 25),
("bob", "bob@example.com", 30),
("charlie", "charlie@example.com", 35)
]
for username, email, age in users_data:
try:
insert_sql = "INSERT INTO users (username, email, age) VALUES (?, ?, ?)"
db.execute(insert_sql, (username, email, age))
print(f"插入用户: {username}")
except Exception as e:
print(f"插入用户 {username} 失败: {e}")
# 查询所有用户
print("\n所有用户:")
users = db.fetch_all("SELECT * FROM users ORDER BY id")
for user in users:
print(f" ID: {user['id']}, 用户名: {user['username']}, 邮箱: {user['email']}, 年龄: {user['age']}")
# 查询单个用户
print("\n查询用户 'alice':")
user = db.fetch_one("SELECT * FROM users WHERE username = ?", ("alice",))
if user:
print(f" 找到用户: {user['username']} ({user['email']})")
# 更新用户信息
update_sql = "UPDATE users SET age = ?, updated_at = CURRENT_TIMESTAMP WHERE username = ?"
db.execute(update_sql, (26, "alice"))
print("\n更新用户 alice 的年龄")
# 验证更新
updated_user = db.fetch_one("SELECT * FROM users WHERE username = ?", ("alice",))
if updated_user:
print(f" 更新后年龄: {updated_user['age']}")
# 统计查询
count_result = db.fetch_one("SELECT COUNT(*) as total FROM users")
print(f"\n总用户数: {count_result['total']}")
avg_age_result = db.fetch_one("SELECT AVG(age) as avg_age FROM users")
print(f"平均年龄: {avg_age_result['avg_age']:.1f}")
except Exception as e:
print(f"数据库操作失败: {e}")
finally:
db.disconnect()
return db
# 运行数据库基础操作演示
print("运行数据库基础操作演示...")
db_demo = run_database_basics_demo()
原生SQL操作
SQL查询构建器
# 原生SQL操作演示
print("\n=== 原生SQL操作 ===")
from typing import Dict, List, Optional, Any, Union
from dataclasses import dataclass
from enum import Enum
import re
class SQLOperator(Enum):
"""SQL操作符枚举"""
EQ = "="
NE = "!="
GT = ">"
GTE = ">="
LT = "<"
LTE = "<="
LIKE = "LIKE"
IN = "IN"
NOT_IN = "NOT IN"
IS_NULL = "IS NULL"
IS_NOT_NULL = "IS NOT NULL"
BETWEEN = "BETWEEN"
class JoinType(Enum):
"""连接类型枚举"""
INNER = "INNER JOIN"
LEFT = "LEFT JOIN"
RIGHT = "RIGHT JOIN"
FULL = "FULL OUTER JOIN"
@dataclass
class WhereCondition:
"""WHERE条件类"""
field: str
operator: SQLOperator
value: Any
logic_operator: str = "AND" # AND 或 OR
def to_sql(self) -> Tuple[str, List[Any]]:
"""转换为SQL字符串和参数"""
if self.operator == SQLOperator.IS_NULL:
return f"{self.field} IS NULL", []
elif self.operator == SQLOperator.IS_NOT_NULL:
return f"{self.field} IS NOT NULL", []
elif self.operator == SQLOperator.IN:
placeholders = ",".join(["?" for _ in self.value])
return f"{self.field} IN ({placeholders})", list(self.value)
elif self.operator == SQLOperator.NOT_IN:
placeholders = ",".join(["?" for _ in self.value])
return f"{self.field} NOT IN ({placeholders})", list(self.value)
elif self.operator == SQLOperator.BETWEEN:
return f"{self.field} BETWEEN ? AND ?", [self.value[0], self.value[1]]
else:
return f"{self.field} {self.operator.value} ?", [self.value]
@dataclass
class JoinClause:
"""JOIN子句类"""
join_type: JoinType
table: str
on_condition: str
def to_sql(self) -> str:
"""转换为SQL字符串"""
return f"{self.join_type.value} {self.table} ON {self.on_condition}"
class SQLQueryBuilder:
"""SQL查询构建器"""
def __init__(self, table: str):
self.table = table
self.select_fields = ["*"]
self.where_conditions = []
self.join_clauses = []
self.group_by_fields = []
self.having_conditions = []
self.order_by_fields = []
self.limit_count = None
self.offset_count = None
def select(self, *fields: str) -> 'SQLQueryBuilder':
"""设置SELECT字段"""
self.select_fields = list(fields) if fields else ["*"]
return self
def where(self, field: str, operator: SQLOperator, value: Any, logic_operator: str = "AND") -> 'SQLQueryBuilder':
"""添加WHERE条件"""
condition = WhereCondition(field, operator, value, logic_operator)
self.where_conditions.append(condition)
return self
def where_in(self, field: str, values: List[Any]) -> 'SQLQueryBuilder':
"""添加WHERE IN条件"""
return self.where(field, SQLOperator.IN, values)
def where_between(self, field: str, start: Any, end: Any) -> 'SQLQueryBuilder':
"""添加WHERE BETWEEN条件"""
return self.where(field, SQLOperator.BETWEEN, [start, end])
def where_like(self, field: str, pattern: str) -> 'SQLQueryBuilder':
"""添加WHERE LIKE条件"""
return self.where(field, SQLOperator.LIKE, pattern)
def where_null(self, field: str) -> 'SQLQueryBuilder':
"""添加WHERE IS NULL条件"""
return self.where(field, SQLOperator.IS_NULL, None)
def where_not_null(self, field: str) -> 'SQLQueryBuilder':
"""添加WHERE IS NOT NULL条件"""
return self.where(field, SQLOperator.IS_NOT_NULL, None)
def join(self, table: str, on_condition: str, join_type: JoinType = JoinType.INNER) -> 'SQLQueryBuilder':
"""添加JOIN子句"""
join_clause = JoinClause(join_type, table, on_condition)
self.join_clauses.append(join_clause)
return self
def left_join(self, table: str, on_condition: str) -> 'SQLQueryBuilder':
"""添加LEFT JOIN"""
return self.join(table, on_condition, JoinType.LEFT)
def right_join(self, table: str, on_condition: str) -> 'SQLQueryBuilder':
"""添加RIGHT JOIN"""
return self.join(table, on_condition, JoinType.RIGHT)
def group_by(self, *fields: str) -> 'SQLQueryBuilder':
"""添加GROUP BY字段"""
self.group_by_fields.extend(fields)
return self
def having(self, field: str, operator: SQLOperator, value: Any) -> 'SQLQueryBuilder':
"""添加HAVING条件"""
condition = WhereCondition(field, operator, value)
self.having_conditions.append(condition)
return self
def order_by(self, field: str, direction: str = "ASC") -> 'SQLQueryBuilder':
"""添加ORDER BY字段"""
self.order_by_fields.append(f"{field} {direction}")
return self
def limit(self, count: int) -> 'SQLQueryBuilder':
"""设置LIMIT"""
self.limit_count = count
return self
def offset(self, count: int) -> 'SQLQueryBuilder':
"""设置OFFSET"""
self.offset_count = count
return self
def build(self) -> Tuple[str, List[Any]]:
"""构建SQL查询和参数"""
sql_parts = []
params = []
# SELECT子句
fields_str = ", ".join(self.select_fields)
sql_parts.append(f"SELECT {fields_str}")
# FROM子句
sql_parts.append(f"FROM {self.table}")
# JOIN子句
for join_clause in self.join_clauses:
sql_parts.append(join_clause.to_sql())
# WHERE子句
if self.where_conditions:
where_parts = []
for i, condition in enumerate(self.where_conditions):
condition_sql, condition_params = condition.to_sql()
if i > 0:
where_parts.append(condition.logic_operator)
where_parts.append(condition_sql)
params.extend(condition_params)
sql_parts.append(f"WHERE {' '.join(where_parts)}")
# GROUP BY子句
if self.group_by_fields:
group_by_str = ", ".join(self.group_by_fields)
sql_parts.append(f"GROUP BY {group_by_str}")
# HAVING子句
if self.having_conditions:
having_parts = []
for i, condition in enumerate(self.having_conditions):
condition_sql, condition_params = condition.to_sql()
if i > 0:
having_parts.append("AND")
having_parts.append(condition_sql)
params.extend(condition_params)
sql_parts.append(f"HAVING {' '.join(having_parts)}")
# ORDER BY子句
if self.order_by_fields:
order_by_str = ", ".join(self.order_by_fields)
sql_parts.append(f"ORDER BY {order_by_str}")
# LIMIT子句
if self.limit_count is not None:
sql_parts.append(f"LIMIT {self.limit_count}")
# OFFSET子句
if self.offset_count is not None:
sql_parts.append(f"OFFSET {self.offset_count}")
return " ".join(sql_parts), params
class SQLInsertBuilder:
"""SQL插入构建器"""
def __init__(self, table: str):
self.table = table
self.data = {}
self.on_conflict_action = None
def values(self, **kwargs) -> 'SQLInsertBuilder':
"""设置插入值"""
self.data.update(kwargs)
return self
def on_conflict_ignore(self) -> 'SQLInsertBuilder':
"""冲突时忽略"""
self.on_conflict_action = "IGNORE"
return self
def on_conflict_replace(self) -> 'SQLInsertBuilder':
"""冲突时替换"""
self.on_conflict_action = "REPLACE"
return self
def build(self) -> Tuple[str, List[Any]]:
"""构建INSERT语句"""
if not self.data:
raise ValueError("没有设置插入数据")
fields = list(self.data.keys())
values = list(self.data.values())
placeholders = ", ".join(["?" for _ in fields])
sql = f"INSERT"
if self.on_conflict_action:
sql += f" OR {self.on_conflict_action}"
sql += f" INTO {self.table} ({', '.join(fields)}) VALUES ({placeholders})"
return sql, values
class SQLUpdateBuilder:
"""SQL更新构建器"""
def __init__(self, table: str):
self.table = table
self.set_data = {}
self.where_conditions = []
def set(self, **kwargs) -> 'SQLUpdateBuilder':
"""设置更新值"""
self.set_data.update(kwargs)
return self
def where(self, field: str, operator: SQLOperator, value: Any) -> 'SQLUpdateBuilder':
"""添加WHERE条件"""
condition = WhereCondition(field, operator, value)
self.where_conditions.append(condition)
return self
def build(self) -> Tuple[str, List[Any]]:
"""构建UPDATE语句"""
if not self.set_data:
raise ValueError("没有设置更新数据")
params = []
# SET子句
set_parts = []
for field, value in self.set_data.items():
set_parts.append(f"{field} = ?")
params.append(value)
sql = f"UPDATE {self.table} SET {', '.join(set_parts)}"
# WHERE子句
if self.where_conditions:
where_parts = []
for i, condition in enumerate(self.where_conditions):
condition_sql, condition_params = condition.to_sql()
if i > 0:
where_parts.append("AND")
where_parts.append(condition_sql)
params.extend(condition_params)
sql += f" WHERE {' '.join(where_parts)}"
return sql, params
class SQLDeleteBuilder:
"""SQL删除构建器"""
def __init__(self, table: str):
self.table = table
self.where_conditions = []
def where(self, field: str, operator: SQLOperator, value: Any) -> 'SQLDeleteBuilder':
"""添加WHERE条件"""
condition = WhereCondition(field, operator, value)
self.where_conditions.append(condition)
return self
def build(self) -> Tuple[str, List[Any]]:
"""构建DELETE语句"""
params = []
sql = f"DELETE FROM {self.table}"
# WHERE子句
if self.where_conditions:
where_parts = []
for i, condition in enumerate(self.where_conditions):
condition_sql, condition_params = condition.to_sql()
if i > 0:
where_parts.append("AND")
where_parts.append(condition_sql)
params.extend(condition_params)
sql += f" WHERE {' '.join(where_parts)}"
return sql, params
# SQL构建器演示
def run_sql_builder_demo():
"""运行SQL构建器演示"""
print("\n--- SQL构建器演示 ---")
# 创建数据库连接
config = DatabaseConfig(
db_type=DatabaseType.SQLITE,
database="demo.db"
)
db = DatabaseFactory.create_connection(config)
try:
if not db.connect():
return
# 创建订单表
create_orders_table = """
CREATE TABLE IF NOT EXISTS orders (
id INTEGER PRIMARY KEY AUTOINCREMENT,
user_id INTEGER NOT NULL,
product_name VARCHAR(100) NOT NULL,
quantity INTEGER NOT NULL,
price DECIMAL(10,2) NOT NULL,
status VARCHAR(20) DEFAULT 'pending',
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
FOREIGN KEY (user_id) REFERENCES users(id)
)
"""
db.execute(create_orders_table)
# 使用INSERT构建器插入订单数据
print("\n1. 使用INSERT构建器插入数据:")
orders_data = [
{"user_id": 1, "product_name": "笔记本电脑", "quantity": 1, "price": 5999.99, "status": "completed"},
{"user_id": 2, "product_name": "手机", "quantity": 2, "price": 3999.99, "status": "pending"},
{"user_id": 1, "product_name": "键盘", "quantity": 1, "price": 299.99, "status": "shipped"},
{"user_id": 3, "product_name": "鼠标", "quantity": 3, "price": 99.99, "status": "completed"}
]
for order_data in orders_data:
insert_builder = SQLInsertBuilder("orders")
sql, params = insert_builder.values(**order_data).build()
print(f" SQL: {sql}")
print(f" 参数: {params}")
db.execute(sql, tuple(params))
print("订单数据插入完成")
# 使用SELECT构建器查询数据
print("\n2. 使用SELECT构建器查询数据:")
# 查询所有已完成的订单
query_builder = SQLQueryBuilder("orders")
sql, params = (query_builder
.select("id", "product_name", "quantity", "price")
.where("status", SQLOperator.EQ, "completed")
.order_by("created_at", "DESC")
.build())
print(f" 查询已完成订单 SQL: {sql}")
print(f" 参数: {params}")
completed_orders = db.fetch_all(sql, tuple(params))
print(" 已完成的订单:")
for order in completed_orders:
print(f" ID: {order['id']}, 产品: {order['product_name']}, 数量: {order['quantity']}, 价格: {order['price']}")
# 查询价格在指定范围内的订单
print("\n 查询价格在100-1000之间的订单:")
query_builder = SQLQueryBuilder("orders")
sql, params = (query_builder
.select("product_name", "price", "status")
.where_between("price", 100, 1000)
.order_by("price", "ASC")
.build())
print(f" SQL: {sql}")
price_range_orders = db.fetch_all(sql, tuple(params))
for order in price_range_orders:
print(f" 产品: {order['product_name']}, 价格: {order['price']}, 状态: {order['status']}")
# 使用JOIN查询用户和订单信息
print("\n 使用JOIN查询用户和订单信息:")
query_builder = SQLQueryBuilder("users")
sql, params = (query_builder
.select("users.username", "users.email", "orders.product_name", "orders.price", "orders.status")
.left_join("orders", "users.id = orders.user_id")
.where("orders.status", SQLOperator.IN, ["completed", "shipped"])
.order_by("users.username")
.build())
print(f" SQL: {sql}")
user_orders = db.fetch_all(sql, tuple(params))
for record in user_orders:
print(f" 用户: {record['username']}, 产品: {record['product_name']}, 价格: {record['price']}, 状态: {record['status']}")
# 使用GROUP BY进行聚合查询
print("\n 按用户统计订单信息:")
query_builder = SQLQueryBuilder("orders")
sql, params = (query_builder
.select("user_id", "COUNT(*) as order_count", "SUM(price * quantity) as total_amount")
.group_by("user_id")
.having("COUNT(*)", SQLOperator.GT, 1)
.order_by("total_amount", "DESC")
.build())
print(f" SQL: {sql}")
user_stats = db.fetch_all(sql, tuple(params))
for stat in user_stats:
print(f" 用户ID: {stat['user_id']}, 订单数: {stat['order_count']}, 总金额: {stat['total_amount']}")
# 使用UPDATE构建器更新数据
print("\n3. 使用UPDATE构建器更新数据:")
update_builder = SQLUpdateBuilder("orders")
sql, params = (update_builder
.set(status="shipped")
.where("status", SQLOperator.EQ, "pending")
.build())
print(f" 更新待处理订单状态 SQL: {sql}")
print(f" 参数: {params}")
db.execute(sql, tuple(params))
print(" 订单状态更新完成")
# 验证更新结果
pending_count = db.fetch_one("SELECT COUNT(*) as count FROM orders WHERE status = 'pending'")
print(f" 剩余待处理订单数: {pending_count['count']}")
# 使用DELETE构建器删除数据(演示,不实际执行)
print("\n4. 使用DELETE构建器(仅演示SQL生成):")
delete_builder = SQLDeleteBuilder("orders")
sql, params = (delete_builder
.where("status", SQLOperator.EQ, "cancelled")
.build())
print(f" 删除已取消订单 SQL: {sql}")
print(f" 参数: {params}")
except Exception as e:
print(f"SQL构建器演示失败: {e}")
finally:
db.disconnect()
return db
# 运行SQL构建器演示
print("运行SQL构建器演示...")
sql_builder_demo = run_sql_builder_demo()
SQLAlchemy ORM
ORM模型定义
# SQLAlchemy ORM演示
print("\n=== SQLAlchemy ORM ===")
from sqlalchemy import create_engine, Column, Integer, String, DateTime, ForeignKey, Text, Boolean, Numeric, Table
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker, relationship, Session, scoped_session
from sqlalchemy.sql import func
from sqlalchemy.pool import StaticPool
from datetime import datetime
from typing import List, Optional
import json
# 创建基类
Base = declarative_base()
# 用户角色关联表(多对多关系)
user_roles = Table(
'user_roles',
Base.metadata,
Column('user_id', Integer, ForeignKey('users.id'), primary_key=True),
Column('role_id', Integer, ForeignKey('roles.id'), primary_key=True)
)
class User(Base):
"""用户模型"""
__tablename__ = 'users'
id = Column(Integer, primary_key=True, autoincrement=True)
username = Column(String(50), unique=True, nullable=False, index=True)
email = Column(String(100), nullable=False, index=True)
password_hash = Column(String(255), nullable=False)
full_name = Column(String(100))
age = Column(Integer)
is_active = Column(Boolean, default=True)
created_at = Column(DateTime, default=func.now())
updated_at = Column(DateTime, default=func.now(), onupdate=func.now())
# 关系定义
orders = relationship("Order", back_populates="user", cascade="all, delete-orphan")
profile = relationship("UserProfile", back_populates="user", uselist=False, cascade="all, delete-orphan")
roles = relationship("Role", secondary=user_roles, back_populates="users")
def __repr__(self):
return f"<User(id={self.id}, username='{self.username}', email='{self.email}')>"
def to_dict(self) -> Dict[str, Any]:
"""转换为字典"""
return {
'id': self.id,
'username': self.username,
'email': self.email,
'full_name': self.full_name,
'age': self.age,
'is_active': self.is_active,
'created_at': self.created_at.isoformat() if self.created_at else None,
'updated_at': self.updated_at.isoformat() if self.updated_at else None
}
@classmethod
def create(cls, session: Session, **kwargs) -> 'User':
"""创建用户"""
user = cls(**kwargs)
session.add(user)
session.commit()
session.refresh(user)
return user
def update(self, session: Session, **kwargs):
"""更新用户信息"""
for key, value in kwargs.items():
if hasattr(self, key):
setattr(self, key, value)
self.updated_at = func.now()
session.commit()
def delete(self, session: Session):
"""删除用户"""
session.delete(self)
session.commit()
class UserProfile(Base):
"""用户资料模型"""
__tablename__ = 'user_profiles'
id = Column(Integer, primary_key=True, autoincrement=True)
user_id = Column(Integer, ForeignKey('users.id'), nullable=False, unique=True)
bio = Column(Text)
avatar_url = Column(String(255))
phone = Column(String(20))
address = Column(Text)
preferences = Column(Text) # JSON字符串
created_at = Column(DateTime, default=func.now())
updated_at = Column(DateTime, default=func.now(), onupdate=func.now())
# 关系定义
user = relationship("User", back_populates="profile")
def __repr__(self):
return f"<UserProfile(id={self.id}, user_id={self.user_id})>"
def get_preferences(self) -> Dict[str, Any]:
"""获取偏好设置"""
if self.preferences:
try:
return json.loads(self.preferences)
except json.JSONDecodeError:
return {}
return {}
def set_preferences(self, preferences: Dict[str, Any]):
"""设置偏好设置"""
self.preferences = json.dumps(preferences, ensure_ascii=False)
class Role(Base):
"""角色模型"""
__tablename__ = 'roles'
id = Column(Integer, primary_key=True, autoincrement=True)
name = Column(String(50), unique=True, nullable=False)
description = Column(Text)
permissions = Column(Text) # JSON字符串存储权限列表
is_active = Column(Boolean, default=True)
created_at = Column(DateTime, default=func.now())
# 关系定义
users = relationship("User", secondary=user_roles, back_populates="roles")
def __repr__(self):
return f"<Role(id={self.id}, name='{self.name}')>"
def get_permissions(self) -> List[str]:
"""获取权限列表"""
if self.permissions:
try:
return json.loads(self.permissions)
except json.JSONDecodeError:
return []
return []
def set_permissions(self, permissions: List[str]):
"""设置权限列表"""
self.permissions = json.dumps(permissions)
class Category(Base):
"""商品分类模型"""
__tablename__ = 'categories'
id = Column(Integer, primary_key=True, autoincrement=True)
name = Column(String(100), nullable=False)
description = Column(Text)
parent_id = Column(Integer, ForeignKey('categories.id'))
is_active = Column(Boolean, default=True)
created_at = Column(DateTime, default=func.now())
# 自引用关系
parent = relationship("Category", remote_side=[id], backref="children")
# 关系定义
products = relationship("Product", back_populates="category")
def __repr__(self):
return f"<Category(id={self.id}, name='{self.name}')>"
class Product(Base):
"""商品模型"""
__tablename__ = 'products'
id = Column(Integer, primary_key=True, autoincrement=True)
name = Column(String(200), nullable=False, index=True)
description = Column(Text)
price = Column(Numeric(10, 2), nullable=False)
stock_quantity = Column(Integer, default=0)
category_id = Column(Integer, ForeignKey('categories.id'))
sku = Column(String(50), unique=True, nullable=False)
is_active = Column(Boolean, default=True)
created_at = Column(DateTime, default=func.now())
updated_at = Column(DateTime, default=func.now(), onupdate=func.now())
# 关系定义
category = relationship("Category", back_populates="products")
order_items = relationship("OrderItem", back_populates="product")
def __repr__(self):
return f"<Product(id={self.id}, name='{self.name}', price={self.price})>"
def is_in_stock(self) -> bool:
"""检查是否有库存"""
return self.stock_quantity > 0
def reduce_stock(self, quantity: int) -> bool:
"""减少库存"""
if self.stock_quantity >= quantity:
self.stock_quantity -= quantity
return True
return False
class Order(Base):
"""订单模型"""
__tablename__ = 'orders'
id = Column(Integer, primary_key=True, autoincrement=True)
user_id = Column(Integer, ForeignKey('users.id'), nullable=False)
order_number = Column(String(50), unique=True, nullable=False)
status = Column(String(20), default='pending') # pending, confirmed, shipped, delivered, cancelled
total_amount = Column(Numeric(10, 2), default=0)
shipping_address = Column(Text)
notes = Column(Text)
created_at = Column(DateTime, default=func.now())
updated_at = Column(DateTime, default=func.now(), onupdate=func.now())
# 关系定义
user = relationship("User", back_populates="orders")
items = relationship("OrderItem", back_populates="order", cascade="all, delete-orphan")
def __repr__(self):
return f"<Order(id={self.id}, order_number='{self.order_number}', status='{self.status}')>"
def calculate_total(self) -> Numeric:
"""计算订单总金额"""
total = sum(item.subtotal for item in self.items)
self.total_amount = total
return total
def add_item(self, product: Product, quantity: int, unit_price: Optional[Numeric] = None):
"""添加订单项"""
if unit_price is None:
unit_price = product.price
# 检查是否已存在该商品的订单项
existing_item = next((item for item in self.items if item.product_id == product.id), None)
if existing_item:
existing_item.quantity += quantity
existing_item.subtotal = existing_item.quantity * existing_item.unit_price
else:
order_item = OrderItem(
order=self,
product=product,
quantity=quantity,
unit_price=unit_price,
subtotal=quantity * unit_price
)
self.items.append(order_item)
self.calculate_total()
class OrderItem(Base):
"""订单项模型"""
__tablename__ = 'order_items'
id = Column(Integer, primary_key=True, autoincrement=True)
order_id = Column(Integer, ForeignKey('orders.id'), nullable=False)
product_id = Column(Integer, ForeignKey('products.id'), nullable=False)
quantity = Column(Integer, nullable=False)
unit_price = Column(Numeric(10, 2), nullable=False)
subtotal = Column(Numeric(10, 2), nullable=False)
created_at = Column(DateTime, default=func.now())
# 关系定义
order = relationship("Order", back_populates="items")
product = relationship("Product", back_populates="order_items")
def __repr__(self):
return f"<OrderItem(id={self.id}, order_id={self.order_id}, product_id={self.product_id}, quantity={self.quantity})>"
# ORM数据库管理器
class ORMDatabaseManager:
"""ORM数据库管理器"""
def __init__(self, database_url: str = "sqlite:///orm_demo.db"):
self.database_url = database_url
self.engine = None
self.SessionLocal = None
self.session = None
def initialize(self):
"""初始化数据库"""
# 创建引擎
self.engine = create_engine(
self.database_url,
poolclass=StaticPool,
connect_args={"check_same_thread": False} if "sqlite" in self.database_url else {},
echo=False # 设置为True可以看到SQL语句
)
# 创建会话工厂
self.SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=self.engine)
# 创建所有表
Base.metadata.create_all(bind=self.engine)
print(f"数据库初始化完成: {self.database_url}")
def get_session(self) -> Session:
"""获取数据库会话"""
if self.SessionLocal is None:
raise Exception("数据库未初始化")
return self.SessionLocal()
def close(self):
"""关闭数据库连接"""
if self.session:
self.session.close()
if self.engine:
self.engine.dispose()
print("数据库连接已关闭")
# ORM操作演示
def run_orm_demo():
"""运行ORM操作演示"""
print("\n--- ORM操作演示 ---")
# 初始化数据库管理器
db_manager = ORMDatabaseManager()
db_manager.initialize()
try:
# 获取数据库会话
session = db_manager.get_session()
# 1. 创建角色
print("\n1. 创建角色:")
admin_role = Role(
name="admin",
description="系统管理员",
)
admin_role.set_permissions(["read", "write", "delete", "admin"])
user_role = Role(
name="user",
description="普通用户"
)
user_role.set_permissions(["read"])
session.add_all([admin_role, user_role])
session.commit()
print(f" 创建角色: {admin_role.name}, {user_role.name}")
# 2. 创建用户
print("\n2. 创建用户:")
users_data = [
{
"username": "admin",
"email": "admin@example.com",
"password_hash": "hashed_password_1",
"full_name": "系统管理员",
"age": 30
},
{
"username": "alice",
"email": "alice@example.com",
"password_hash": "hashed_password_2",
"full_name": "Alice Smith",
"age": 25
},
{
"username": "bob",
"email": "bob@example.com",
"password_hash": "hashed_password_3",
"full_name": "Bob Johnson",
"age": 28
}
]
created_users = []
for user_data in users_data:
user = User(**user_data)
session.add(user)
created_users.append(user)
session.commit()
# 分配角色
created_users[0].roles.append(admin_role) # admin用户分配admin角色
created_users[1].roles.append(user_role) # alice分配user角色
created_users[2].roles.append(user_role) # bob分配user角色
session.commit()
for user in created_users:
roles = [role.name for role in user.roles]
print(f" 创建用户: {user.username} ({user.full_name}) - 角色: {roles}")
# 3. 创建用户资料
print("\n3. 创建用户资料:")
for user in created_users[1:]: # 跳过admin用户
profile = UserProfile(
user_id=user.id,
bio=f"这是{user.full_name}的个人简介",
phone=f"138{user.id:08d}",
address=f"北京市朝阳区{user.username}街道"
)
profile.set_preferences({
"theme": "dark" if user.username == "alice" else "light",
"language": "zh-CN",
"notifications": True
})
session.add(profile)
session.commit()
print(" 用户资料创建完成")
# 4. 创建商品分类
print("\n4. 创建商品分类:")
electronics = Category(name="电子产品", description="各种电子设备")
computers = Category(name="计算机", description="计算机及配件", parent=electronics)
phones = Category(name="手机", description="智能手机及配件", parent=electronics)
session.add_all([electronics, computers, phones])
session.commit()
print(f" 创建分类: {electronics.name} -> {computers.name}, {phones.name}")
# 5. 创建商品
print("\n5. 创建商品:")
products_data = [
{
"name": "MacBook Pro",
"description": "苹果笔记本电脑",
"price": 12999.00,
"stock_quantity": 10,
"category": computers,
"sku": "MBP-001"
},
{
"name": "iPhone 15",
"description": "苹果智能手机",
"price": 5999.00,
"stock_quantity": 20,
"category": phones,
"sku": "IP15-001"
},
{
"name": "iPad Air",
"description": "苹果平板电脑",
"price": 4399.00,
"stock_quantity": 15,
"category": computers,
"sku": "IPA-001"
}
]
created_products = []
for product_data in products_data:
product = Product(**product_data)
session.add(product)
created_products.append(product)
session.commit()
for product in created_products:
print(f" 创建商品: {product.name} - 价格: ¥{product.price} - 库存: {product.stock_quantity}")
# 6. 创建订单
print("\n6. 创建订单:")
import uuid
# Alice的订单
alice = session.query(User).filter_by(username="alice").first()
alice_order = Order(
user=alice,
order_number=f"ORD-{uuid.uuid4().hex[:8].upper()}",
shipping_address="北京市朝阳区alice街道123号",
notes="请在工作日送达"
)
# 添加订单项
alice_order.add_item(created_products[1], 1) # iPhone 15
alice_order.add_item(created_products[2], 1) # iPad Air
session.add(alice_order)
# Bob的订单
bob = session.query(User).filter_by(username="bob").first()
bob_order = Order(
user=bob,
order_number=f"ORD-{uuid.uuid4().hex[:8].upper()}",
shipping_address="北京市朝阳区bob街道456号"
)
bob_order.add_item(created_products[0], 1) # MacBook Pro
session.add(bob_order)
session.commit()
print(f" 创建订单: {alice_order.order_number} (用户: {alice.username}) - 总金额: ¥{alice_order.total_amount}")
print(f" 创建订单: {bob_order.order_number} (用户: {bob.username}) - 总金额: ¥{bob_order.total_amount}")
# 7. 查询操作演示
print("\n7. 查询操作演示:")
# 基本查询
print("\n 所有用户:")
all_users = session.query(User).all()
for user in all_users:
print(f" {user.username} - {user.email} - 角色: {[role.name for role in user.roles]}")
# 条件查询
print("\n 年龄大于25的用户:")
adult_users = session.query(User).filter(User.age > 25).all()
for user in adult_users:
print(f" {user.username} - 年龄: {user.age}")
# 关联查询
print("\n 用户及其订单信息:")
users_with_orders = session.query(User).join(Order).all()
for user in users_with_orders:
print(f" 用户: {user.username}")
for order in user.orders:
print(f" 订单: {order.order_number} - 状态: {order.status} - 金额: ¥{order.total_amount}")
for item in order.items:
print(f" 商品: {item.product.name} x {item.quantity} = ¥{item.subtotal}")
# 聚合查询
print("\n 订单统计:")
from sqlalchemy import func
order_stats = session.query(
func.count(Order.id).label('total_orders'),
func.sum(Order.total_amount).label('total_revenue'),
func.avg(Order.total_amount).label('avg_order_value')
).first()
print(f" 总订单数: {order_stats.total_orders}")
print(f" 总收入: ¥{order_stats.total_revenue}")
print(f" 平均订单金额: ¥{order_stats.avg_order_value:.2f}")
# 分组查询
print("\n 按分类统计商品:")
category_stats = session.query(
Category.name,
func.count(Product.id).label('product_count'),
func.sum(Product.stock_quantity).label('total_stock')
).join(Product).group_by(Category.id).all()
for stat in category_stats:
print(f" 分类: {stat.name} - 商品数: {stat.product_count} - 总库存: {stat.total_stock}")
# 8. 更新操作
print("\n8. 更新操作:")
# 更新用户信息
alice.age = 26
alice.updated_at = func.now()
session.commit()
print(f" 更新用户 {alice.username} 的年龄为 {alice.age}")
# 批量更新
session.query(Order).filter(Order.status == 'pending').update({
Order.status: 'confirmed',
Order.updated_at: func.now()
})
session.commit()
print(" 批量更新订单状态为已确认")
# 9. 删除操作(演示,不实际执行)
print("\n9. 删除操作演示(不实际执行):")
print(" 删除单个记录: session.delete(object)")
print(" 批量删除: session.query(Model).filter(...).delete()")
except Exception as e:
print(f"ORM操作失败: {e}")
session.rollback()
finally:
session.close()
db_manager.close()
return db_manager
# 运行ORM演示
print("运行ORM操作演示...")
orm_demo = run_orm_demo()
数据库连接池
连接池管理
# 数据库连接池演示
print("\n=== 数据库连接池 ===")
from sqlalchemy import create_engine, pool
from sqlalchemy.pool import QueuePool, StaticPool, NullPool
from contextlib import contextmanager
import threading
import time
from concurrent.futures import ThreadPoolExecutor, as_completed
import queue
from typing import Generator
class ConnectionPoolManager:
"""连接池管理器"""
def __init__(self, database_url: str, pool_config: Optional[Dict[str, Any]] = None):
self.database_url = database_url
self.pool_config = pool_config or {}
self.engine = None
self.SessionLocal = None
self._lock = threading.Lock()
self._stats = {
'total_connections': 0,
'active_connections': 0,
'pool_hits': 0,
'pool_misses': 0,
'connection_errors': 0
}
def initialize(self):
"""初始化连接池"""
with self._lock:
if self.engine is not None:
return
# 默认连接池配置
default_config = {
'poolclass': QueuePool,
'pool_size': 10, # 连接池大小
'max_overflow': 20, # 最大溢出连接数
'pool_timeout': 30, # 获取连接超时时间
'pool_recycle': 3600, # 连接回收时间(秒)
'pool_pre_ping': True, # 连接前ping检查
'echo': False # 是否打印SQL
}
# 合并用户配置
config = {**default_config, **self.pool_config}
# 创建引擎
self.engine = create_engine(
self.database_url,
**config
)
# 创建会话工厂
self.SessionLocal = sessionmaker(
autocommit=False,
autoflush=False,
bind=self.engine
)
print(f"连接池初始化完成: {self.database_url}")
print(f"连接池配置: {config}")
@contextmanager
def get_session(self) -> Generator[Session, None, None]:
"""获取数据库会话(上下文管理器)"""
if self.SessionLocal is None:
raise Exception("连接池未初始化")
session = None
try:
session = self.SessionLocal()
self._stats['active_connections'] += 1
self._stats['pool_hits'] += 1
yield session
except Exception as e:
self._stats['connection_errors'] += 1
if session:
session.rollback()
raise e
finally:
if session:
session.close()
self._stats['active_connections'] -= 1
def get_pool_status(self) -> Dict[str, Any]:
"""获取连接池状态"""
if self.engine is None:
return {'status': 'not_initialized'}
pool = self.engine.pool
return {
'pool_size': pool.size(),
'checked_in': pool.checkedin(),
'checked_out': pool.checkedout(),
'overflow': pool.overflow(),
'invalid': pool.invalid(),
'stats': self._stats.copy()
}
def close(self):
"""关闭连接池"""
if self.engine:
self.engine.dispose()
print("连接池已关闭")
class DatabaseService:
"""数据库服务类"""
def __init__(self, pool_manager: ConnectionPoolManager):
self.pool_manager = pool_manager
def create_user(self, user_data: Dict[str, Any]) -> Optional[User]:
"""创建用户"""
with self.pool_manager.get_session() as session:
try:
user = User(**user_data)
session.add(user)
session.commit()
session.refresh(user)
return user
except Exception as e:
session.rollback()
print(f"创建用户失败: {e}")
return None
def get_user_by_id(self, user_id: int) -> Optional[User]:
"""根据ID获取用户"""
with self.pool_manager.get_session() as session:
return session.query(User).filter(User.id == user_id).first()
def get_users_by_age_range(self, min_age: int, max_age: int) -> List[User]:
"""根据年龄范围获取用户"""
with self.pool_manager.get_session() as session:
return session.query(User).filter(
User.age >= min_age,
User.age <= max_age
).all()
def update_user(self, user_id: int, update_data: Dict[str, Any]) -> bool:
"""更新用户信息"""
with self.pool_manager.get_session() as session:
try:
user = session.query(User).filter(User.id == user_id).first()
if user:
for key, value in update_data.items():
if hasattr(user, key):
setattr(user, key, value)
user.updated_at = func.now()
session.commit()
return True
return False
except Exception as e:
session.rollback()
print(f"更新用户失败: {e}")
return False
def delete_user(self, user_id: int) -> bool:
"""删除用户"""
with self.pool_manager.get_session() as session:
try:
user = session.query(User).filter(User.id == user_id).first()
if user:
session.delete(user)
session.commit()
return True
return False
except Exception as e:
session.rollback()
print(f"删除用户失败: {e}")
return False
def get_user_statistics(self) -> Dict[str, Any]:
"""获取用户统计信息"""
with self.pool_manager.get_session() as session:
total_users = session.query(func.count(User.id)).scalar()
active_users = session.query(func.count(User.id)).filter(User.is_active == True).scalar()
avg_age = session.query(func.avg(User.age)).scalar()
return {
'total_users': total_users,
'active_users': active_users,
'inactive_users': total_users - active_users,
'average_age': float(avg_age) if avg_age else 0
}
# 连接池演示
def run_connection_pool_demo():
"""运行连接池演示"""
print("\n--- 连接池演示 ---")
# 创建连接池管理器
pool_config = {
'pool_size': 5,
'max_overflow': 10,
'pool_timeout': 10,
'pool_recycle': 1800
}
pool_manager = ConnectionPoolManager(
database_url="sqlite:///pool_demo.db",
pool_config=pool_config
)
try:
# 初始化连接池
pool_manager.initialize()
# 创建数据库服务
db_service = DatabaseService(pool_manager)
# 创建表(如果不存在)
with pool_manager.get_session() as session:
Base.metadata.create_all(bind=session.bind)
print("\n1. 单线程操作测试:")
# 创建测试用户
test_users = [
{"username": f"user_{i}", "email": f"user_{i}@example.com",
"password_hash": f"hash_{i}", "age": 20 + i}
for i in range(5)
]
for user_data in test_users:
user = db_service.create_user(user_data)
if user:
print(f" 创建用户: {user.username} (ID: {user.id})")
# 查询用户
users = db_service.get_users_by_age_range(20, 25)
print(f" 年龄在20-25之间的用户数: {len(users)}")
# 获取统计信息
stats = db_service.get_user_statistics()
print(f" 用户统计: {stats}")
# 查看连接池状态
pool_status = pool_manager.get_pool_status()
print(f" 连接池状态: {pool_status}")
print("\n2. 多线程并发测试:")
def worker_task(worker_id: int, num_operations: int) -> Dict[str, Any]:
"""工作线程任务"""
results = {
'worker_id': worker_id,
'operations': 0,
'errors': 0,
'start_time': time.time()
}
for i in range(num_operations):
try:
# 创建用户
user_data = {
"username": f"worker_{worker_id}_user_{i}",
"email": f"worker_{worker_id}_user_{i}@example.com",
"password_hash": f"hash_{worker_id}_{i}",
"age": 20 + (i % 30)
}
user = db_service.create_user(user_data)
if user:
# 查询用户
found_user = db_service.get_user_by_id(user.id)
if found_user:
# 更新用户
db_service.update_user(user.id, {'age': found_user.age + 1})
results['operations'] += 1
# 模拟一些处理时间
time.sleep(0.01)
except Exception as e:
results['errors'] += 1
print(f" 工作线程 {worker_id} 操作失败: {e}")
results['end_time'] = time.time()
results['duration'] = results['end_time'] - results['start_time']
return results
# 启动多个工作线程
num_workers = 8
operations_per_worker = 10
with ThreadPoolExecutor(max_workers=num_workers) as executor:
# 提交任务
futures = [
executor.submit(worker_task, i, operations_per_worker)
for i in range(num_workers)
]
# 收集结果
worker_results = []
for future in as_completed(futures):
try:
result = future.result()
worker_results.append(result)
print(f" 工作线程 {result['worker_id']} 完成: "
f"操作数={result['operations']}, "
f"错误数={result['errors']}, "
f"耗时={result['duration']:.2f}秒")
except Exception as e:
print(f" 工作线程执行失败: {e}")
# 汇总结果
total_operations = sum(r['operations'] for r in worker_results)
total_errors = sum(r['errors'] for r in worker_results)
total_duration = max(r['duration'] for r in worker_results)
print(f"\n 并发测试结果:")
print(f" 总操作数: {total_operations}")
print(f" 总错误数: {total_errors}")
print(f" 总耗时: {total_duration:.2f}秒")
print(f" 平均TPS: {total_operations / total_duration:.2f}")
# 最终连接池状态
final_pool_status = pool_manager.get_pool_status()
print(f"\n 最终连接池状态: {final_pool_status}")
# 最终用户统计
final_stats = db_service.get_user_statistics()
print(f" 最终用户统计: {final_stats}")
except Exception as e:
print(f"连接池演示失败: {e}")
finally:
pool_manager.close()
return pool_manager
# 运行连接池演示
print("运行连接池演示...")
pool_demo = run_connection_pool_demo()
事务管理
事务控制和回滚
# 事务管理演示
print("\n=== 事务管理 ===")
from sqlalchemy.exc import IntegrityError, SQLAlchemyError
from contextlib import contextmanager
from typing import Callable, Any
import functools
class TransactionManager:
"""事务管理器"""
def __init__(self, session_factory: Callable[[], Session]):
self.session_factory = session_factory
@contextmanager
def transaction(self, auto_commit: bool = True) -> Generator[Session, None, None]:
"""事务上下文管理器"""
session = self.session_factory()
try:
yield session
if auto_commit:
session.commit()
except Exception as e:
session.rollback()
raise e
finally:
session.close()
@contextmanager
def nested_transaction(self, session: Session) -> Generator[Session, None, None]:
"""嵌套事务(保存点)"""
savepoint = session.begin_nested()
try:
yield session
savepoint.commit()
except Exception as e:
savepoint.rollback()
raise e
def with_transaction(self, auto_commit: bool = True):
"""事务装饰器"""
def decorator(func: Callable) -> Callable:
@functools.wraps(func)
def wrapper(*args, **kwargs):
with self.transaction(auto_commit=auto_commit) as session:
# 将session作为第一个参数传递给函数
return func(session, *args, **kwargs)
return wrapper
return decorator
def with_retry(self, max_retries: int = 3, retry_delay: float = 0.1):
"""重试装饰器"""
def decorator(func: Callable) -> Callable:
@functools.wraps(func)
def wrapper(*args, **kwargs):
last_exception = None
for attempt in range(max_retries + 1):
try:
return func(*args, **kwargs)
except (IntegrityError, SQLAlchemyError) as e:
last_exception = e
if attempt < max_retries:
print(f" 事务失败,第 {attempt + 1} 次重试...")
time.sleep(retry_delay * (2 ** attempt)) # 指数退避
else:
print(f" 事务重试 {max_retries} 次后仍然失败")
raise e
except Exception as e:
# 非数据库相关异常不重试
raise e
raise last_exception
return wrapper
return decorator
class OrderService:
"""订单服务类(演示事务管理)"""
def __init__(self, transaction_manager: TransactionManager):
self.tx_manager = transaction_manager
@property
def with_transaction(self):
return self.tx_manager.with_transaction()
@property
def with_retry(self):
return self.tx_manager.with_retry()
@with_transaction
def create_order_simple(self, session: Session, user_id: int, items: List[Dict[str, Any]]) -> Optional[Order]:
"""创建订单(简单事务)"""
try:
# 获取用户
user = session.query(User).filter(User.id == user_id).first()
if not user:
raise ValueError(f"用户不存在: {user_id}")
# 创建订单
import uuid
order = Order(
user=user,
order_number=f"ORD-{uuid.uuid4().hex[:8].upper()}",
status='pending'
)
session.add(order)
session.flush() # 获取订单ID
# 添加订单项
total_amount = 0
for item_data in items:
product = session.query(Product).filter(
Product.id == item_data['product_id']
).first()
if not product:
raise ValueError(f"商品不存在: {item_data['product_id']}")
if not product.is_in_stock() or product.stock_quantity < item_data['quantity']:
raise ValueError(f"商品库存不足: {product.name}")
# 减少库存
product.reduce_stock(item_data['quantity'])
# 创建订单项
order_item = OrderItem(
order=order,
product=product,
quantity=item_data['quantity'],
unit_price=product.price,
subtotal=product.price * item_data['quantity']
)
session.add(order_item)
total_amount += order_item.subtotal
# 更新订单总金额
order.total_amount = total_amount
print(f" 创建订单成功: {order.order_number} - 总金额: ¥{total_amount}")
return order
except Exception as e:
print(f" 创建订单失败: {e}")
raise e
def create_order_with_nested_transactions(self, user_id: int, items: List[Dict[str, Any]]) -> Optional[Order]:
"""使用嵌套事务创建订单"""
with self.tx_manager.transaction() as session:
try:
# 获取用户
user = session.query(User).filter(User.id == user_id).first()
if not user:
raise ValueError(f"用户不存在: {user_id}")
# 创建订单(主事务)
import uuid
order = Order(
user=user,
order_number=f"ORD-{uuid.uuid4().hex[:8].upper()}",
status='pending'
)
session.add(order)
session.flush()
total_amount = 0
# 处理每个订单项(嵌套事务)
for item_data in items:
try:
with self.tx_manager.nested_transaction(session):
product = session.query(Product).filter(
Product.id == item_data['product_id']
).first()
if not product:
raise ValueError(f"商品不存在: {item_data['product_id']}")
if not product.is_in_stock() or product.stock_quantity < item_data['quantity']:
raise ValueError(f"商品库存不足: {product.name}")
# 减少库存
product.reduce_stock(item_data['quantity'])
# 创建订单项
order_item = OrderItem(
order=order,
product=product,
quantity=item_data['quantity'],
unit_price=product.price,
subtotal=product.price * item_data['quantity']
)
session.add(order_item)
total_amount += order_item.subtotal
print(f" 添加订单项: {product.name} x {item_data['quantity']}")
except Exception as e:
print(f" 订单项处理失败,跳过: {e}")
# 嵌套事务失败,但主事务继续
continue
if total_amount == 0:
raise ValueError("没有有效的订单项")
# 更新订单总金额
order.total_amount = total_amount
print(f" 使用嵌套事务创建订单成功: {order.order_number} - 总金额: ¥{total_amount}")
return order
except Exception as e:
print(f" 创建订单失败: {e}")
raise e
@with_retry(max_retries=3)
@with_transaction
def create_order_with_retry(self, session: Session, user_id: int, items: List[Dict[str, Any]]) -> Optional[Order]:
"""带重试机制的订单创建"""
# 模拟可能的并发冲突
import random
if random.random() < 0.3: # 30%概率模拟失败
raise IntegrityError("模拟的并发冲突", None, None)
return self.create_order_simple(session, user_id, items)
@with_transaction
def cancel_order(self, session: Session, order_id: int) -> bool:
"""取消订单(恢复库存)"""
try:
order = session.query(Order).filter(Order.id == order_id).first()
if not order:
raise ValueError(f"订单不存在: {order_id}")
if order.status != 'pending':
raise ValueError(f"订单状态不允许取消: {order.status}")
# 恢复库存
for item in order.items:
product = item.product
product.stock_quantity += item.quantity
print(f" 恢复库存: {product.name} +{item.quantity}")
# 更新订单状态
order.status = 'cancelled'
order.updated_at = func.now()
print(f" 订单取消成功: {order.order_number}")
return True
except Exception as e:
print(f" 取消订单失败: {e}")
raise e
def batch_process_orders(self, order_operations: List[Dict[str, Any]]) -> Dict[str, Any]:
"""批量处理订单"""
results = {
'success_count': 0,
'error_count': 0,
'errors': []
}
with self.tx_manager.transaction() as session:
for operation in order_operations:
try:
with self.tx_manager.nested_transaction(session):
op_type = operation['type']
if op_type == 'create':
self.create_order_simple(
session,
operation['user_id'],
operation['items']
)
elif op_type == 'cancel':
self.cancel_order(session, operation['order_id'])
else:
raise ValueError(f"不支持的操作类型: {op_type}")
results['success_count'] += 1
except Exception as e:
results['error_count'] += 1
results['errors'].append({
'operation': operation,
'error': str(e)
})
print(f" 批量操作失败: {e}")
return results
# 事务管理演示
def run_transaction_demo():
"""运行事务管理演示"""
print("\n--- 事务管理演示 ---")
# 创建数据库管理器
db_manager = ORMDatabaseManager("sqlite:///transaction_demo.db")
db_manager.initialize()
# 创建事务管理器
tx_manager = TransactionManager(db_manager.get_session)
# 创建订单服务
order_service = OrderService(tx_manager)
try:
# 准备测试数据
with tx_manager.transaction() as session:
# 创建测试用户
test_user = User(
username="test_user",
email="test@example.com",
password_hash="test_hash",
full_name="测试用户",
age=25
)
session.add(test_user)
session.flush()
# 创建测试商品
test_products = [
Product(
name="测试商品1",
description="测试商品1描述",
price=100.00,
stock_quantity=10,
sku="TEST-001"
),
Product(
name="测试商品2",
description="测试商品2描述",
price=200.00,
stock_quantity=5,
sku="TEST-002"
),
Product(
name="测试商品3",
description="测试商品3描述",
price=50.00,
stock_quantity=0, # 无库存
sku="TEST-003"
)
]
for product in test_products:
session.add(product)
session.commit()
user_id = test_user.id
product_ids = [p.id for p in test_products]
print("\n1. 简单事务测试:")
# 成功的订单创建
order_items = [
{'product_id': product_ids[0], 'quantity': 2},
{'product_id': product_ids[1], 'quantity': 1}
]
order1 = order_service.create_order_simple(user_id, order_items)
if order1:
print(f" 订单创建成功: {order1.order_number}")
print("\n2. 事务回滚测试:")
# 失败的订单创建(库存不足)
try:
order_items_fail = [
{'product_id': product_ids[0], 'quantity': 20}, # 库存不足
{'product_id': product_ids[2], 'quantity': 1} # 无库存
]
order_service.create_order_simple(user_id, order_items_fail)
except Exception as e:
print(f" 预期的事务回滚: {e}")
print("\n3. 嵌套事务测试:")
# 部分成功的订单创建
order_items_mixed = [
{'product_id': product_ids[0], 'quantity': 1}, # 成功
{'product_id': product_ids[2], 'quantity': 1}, # 失败(无库存)
{'product_id': product_ids[1], 'quantity': 1} # 成功
]
order2 = order_service.create_order_with_nested_transactions(user_id, order_items_mixed)
if order2:
print(f" 部分成功订单: {order2.order_number}")
print("\n4. 重试机制测试:")
# 带重试的订单创建
order_items_retry = [
{'product_id': product_ids[0], 'quantity': 1}
]
try:
order3 = order_service.create_order_with_retry(user_id, order_items_retry)
if order3:
print(f" 重试成功订单: {order3.order_number}")
except Exception as e:
print(f" 重试最终失败: {e}")
print("\n5. 订单取消测试:")
if order1:
success = order_service.cancel_order(order1.id)
if success:
print(f" 订单取消成功: {order1.order_number}")
print("\n6. 批量操作测试:")
batch_operations = [
{
'type': 'create',
'user_id': user_id,
'items': [{'product_id': product_ids[0], 'quantity': 1}]
},
{
'type': 'create',
'user_id': user_id,
'items': [{'product_id': product_ids[2], 'quantity': 1}] # 会失败
},
{
'type': 'create',
'user_id': user_id,
'items': [{'product_id': product_ids[1], 'quantity': 1}]
}
]
batch_results = order_service.batch_process_orders(batch_operations)
print(f" 批量操作结果: 成功 {batch_results['success_count']}, 失败 {batch_results['error_count']}")
# 查看最终库存状态
print("\n7. 最终库存状态:")
with tx_manager.transaction() as session:
products = session.query(Product).filter(Product.id.in_(product_ids)).all()
for product in products:
print(f" {product.name}: 库存 {product.stock_quantity}")
except Exception as e:
print(f"事务管理演示失败: {e}")
finally:
db_manager.close()
return tx_manager
# 运行事务管理演示
print("运行事务管理演示...")
transaction_demo = run_transaction_demo()
高级查询技术
复杂查询构建
from sqlalchemy import func, case, exists, and_, or_, not_
from sqlalchemy.orm import aliased
from datetime import datetime, timedelta
class AdvancedQueryBuilder:
"""高级查询构建器"""
def __init__(self, session):
self.session = session
def subquery_example(self):
"""子查询示例"""
# 查找订单金额大于平均值的用户
avg_order_amount = self.session.query(
func.avg(Order.total_amount)
).scalar_subquery()
users_with_high_orders = self.session.query(User).join(Order).filter(
Order.total_amount > avg_order_amount
).distinct().all()
return users_with_high_orders
def window_function_example(self):
"""窗口函数示例"""
from sqlalchemy import text
# 使用窗口函数计算用户订单排名
result = self.session.query(
User.username,
Order.total_amount,
func.row_number().over(
partition_by=User.id,
order_by=Order.total_amount.desc()
).label('order_rank')
).join(Order).all()
return result
def conditional_aggregation(self):
"""条件聚合查询"""
# 统计不同状态的订单数量和金额
stats = self.session.query(
func.count(case([(Order.status == 'completed', 1)])).label('completed_orders'),
func.count(case([(Order.status == 'pending', 1)])).label('pending_orders'),
func.sum(case([(Order.status == 'completed', Order.total_amount)], else_=0)).label('completed_amount'),
func.sum(case([(Order.status == 'pending', Order.total_amount)], else_=0)).label('pending_amount')
).first()
return stats
def exists_query(self):
"""EXISTS查询示例"""
# 查找有订单的用户
users_with_orders = self.session.query(User).filter(
exists().where(Order.user_id == User.id)
).all()
# 查找没有订单的用户
users_without_orders = self.session.query(User).filter(
not_(exists().where(Order.user_id == User.id))
).all()
return users_with_orders, users_without_orders
def self_join_example(self):
"""自连接查询示例"""
# 查找同一类别下的其他产品
product_alias = aliased(Product)
related_products = self.session.query(
Product.name.label('product'),
product_alias.name.label('related_product')
).join(
product_alias,
and_(
Product.category_id == product_alias.category_id,
Product.id != product_alias.id
)
).all()
return related_products
def dynamic_filter_builder(self, filters):
"""动态过滤器构建"""
query = self.session.query(Product)
# 根据条件动态添加过滤器
if filters.get('name'):
query = query.filter(Product.name.ilike(f"%{filters['name']}%"))
if filters.get('min_price'):
query = query.filter(Product.price >= filters['min_price'])
if filters.get('max_price'):
query = query.filter(Product.price <= filters['max_price'])
if filters.get('category_ids'):
query = query.filter(Product.category_id.in_(filters['category_ids']))
if filters.get('in_stock'):
query = query.filter(Product.stock_quantity > 0)
# 排序
if filters.get('sort_by'):
sort_field = getattr(Product, filters['sort_by'], None)
if sort_field:
if filters.get('sort_desc'):
query = query.order_by(sort_field.desc())
else:
query = query.order_by(sort_field)
# 分页
if filters.get('page') and filters.get('per_page'):
offset = (filters['page'] - 1) * filters['per_page']
query = query.offset(offset).limit(filters['per_page'])
return query.all()
# 高级查询演示
def advanced_query_demo():
"""高级查询演示"""
db_manager = ORMDatabaseManager()
db_manager.init_database()
with db_manager.get_session() as session:
query_builder = AdvancedQueryBuilder(session)
print("=== 高级查询技术演示 ===")
# 子查询示例
print("\n1. 子查询 - 订单金额大于平均值的用户:")
high_value_users = query_builder.subquery_example()
for user in high_value_users:
print(f" 用户: {user.username}")
# 条件聚合
print("\n2. 条件聚合 - 订单统计:")
stats = query_builder.conditional_aggregation()
print(f" 已完成订单: {stats.completed_orders}个, 金额: ${stats.completed_amount}")
print(f" 待处理订单: {stats.pending_orders}个, 金额: ${stats.pending_amount}")
# EXISTS查询
print("\n3. EXISTS查询:")
with_orders, without_orders = query_builder.exists_query()
print(f" 有订单用户: {len(with_orders)}个")
print(f" 无订单用户: {len(without_orders)}个")
# 动态过滤器
print("\n4. 动态过滤器:")
filters = {
'min_price': 50,
'max_price': 200,
'in_stock': True,
'sort_by': 'price',
'sort_desc': False
}
filtered_products = query_builder.dynamic_filter_builder(filters)
for product in filtered_products:
print(f" 产品: {product.name}, 价格: ${product.price}")
if __name__ == "__main__":
advanced_query_demo()
查询优化技术
from sqlalchemy.orm import selectinload, joinedload, subqueryload
from sqlalchemy import event
import time
class QueryOptimizer:
"""查询优化器"""
def __init__(self, session):
self.session = session
self.query_count = 0
self.setup_query_logging()
def setup_query_logging(self):
"""设置查询日志"""
@event.listens_for(self.session.bind, "before_cursor_execute")
def receive_before_cursor_execute(conn, cursor, statement, parameters, context, executemany):
self.query_count += 1
context._query_start_time = time.time()
print(f"Query {self.query_count}: {statement[:100]}...")
@event.listens_for(self.session.bind, "after_cursor_execute")
def receive_after_cursor_execute(conn, cursor, statement, parameters, context, executemany):
total = time.time() - context._query_start_time
print(f"Query completed in {total:.4f}s\n")
def n_plus_one_problem_demo(self):
"""N+1查询问题演示"""
print("=== N+1查询问题演示 ===")
# 错误方式:会产生N+1查询
print("\n❌ 错误方式 (N+1查询):")
self.query_count = 0
users = self.session.query(User).all()
for user in users:
orders = user.orders # 每个用户都会触发一次查询
print(f"用户 {user.username} 有 {len(orders)} 个订单")
print(f"总查询次数: {self.query_count}")
# 正确方式:使用预加载
print("\n✅ 正确方式 (预加载):")
self.query_count = 0
users = self.session.query(User).options(
selectinload(User.orders)
).all()
for user in users:
orders = user.orders # 不会触发额外查询
print(f"用户 {user.username} 有 {len(orders)} 个订单")
print(f"总查询次数: {self.query_count}")
def eager_loading_strategies(self):
"""预加载策略对比"""
print("\n=== 预加载策略对比 ===")
# 1. selectinload - 适合一对多关系
print("\n1. selectinload策略:")
self.query_count = 0
users = self.session.query(User).options(
selectinload(User.orders).selectinload(Order.items)
).all()
print(f"查询次数: {self.query_count}")
# 2. joinedload - 适合一对一或少量一对多
print("\n2. joinedload策略:")
self.query_count = 0
users = self.session.query(User).options(
joinedload(User.profile)
).all()
print(f"查询次数: {self.query_count}")
# 3. subqueryload - 适合复杂关系
print("\n3. subqueryload策略:")
self.query_count = 0
users = self.session.query(User).options(
subqueryload(User.orders)
).all()
print(f"查询次数: {self.query_count}")
def batch_operations(self):
"""批量操作优化"""
print("\n=== 批量操作优化 ===")
# 批量插入
print("\n1. 批量插入:")
start_time = time.time()
# 准备批量数据
batch_products = []
for i in range(100):
batch_products.append({
'name': f'批量产品{i}',
'price': 10.0 + i,
'stock_quantity': 100,
'category_id': 1
})
# 使用bulk_insert_mappings进行批量插入
self.session.bulk_insert_mappings(Product, batch_products)
self.session.commit()
print(f"批量插入100个产品耗时: {time.time() - start_time:.4f}s")
# 批量更新
print("\n2. 批量更新:")
start_time = time.time()
# 批量更新价格
self.session.query(Product).filter(
Product.name.like('批量产品%')
).update({
Product.price: Product.price * 1.1
}, synchronize_session=False)
self.session.commit()
print(f"批量更新耗时: {time.time() - start_time:.4f}s")
def query_caching_demo(self):
"""查询缓存演示"""
from functools import lru_cache
print("\n=== 查询缓存演示 ===")
@lru_cache(maxsize=128)
def get_user_by_id(user_id):
"""缓存用户查询"""
return self.session.query(User).filter(User.id == user_id).first()
# 第一次查询
print("\n第一次查询:")
self.query_count = 0
user1 = get_user_by_id(1)
print(f"查询次数: {self.query_count}")
# 第二次查询(从缓存获取)
print("\n第二次查询(缓存):")
self.query_count = 0
user2 = get_user_by_id(1)
print(f"查询次数: {self.query_count}")
print(f"是否为同一对象: {user1 is user2}")
# 查询优化演示
def query_optimization_demo():
"""查询优化演示"""
db_manager = ORMDatabaseManager()
db_manager.init_database()
with db_manager.get_session() as session:
optimizer = QueryOptimizer(session)
# N+1查询问题演示
optimizer.n_plus_one_problem_demo()
# 预加载策略对比
optimizer.eager_loading_strategies()
# 批量操作优化
optimizer.batch_operations()
# 查询缓存演示
optimizer.query_caching_demo()
if __name__ == "__main__":
query_optimization_demo()
数据库迁移
Alembic迁移工具
from alembic import command
from alembic.config import Config
from alembic.script import ScriptDirectory
from alembic.runtime.migration import MigrationContext
from alembic.runtime.environment import EnvironmentContext
import os
from pathlib import Path
class DatabaseMigration:
"""数据库迁移管理器"""
def __init__(self, database_url: str, migrations_dir: str = "migrations"):
self.database_url = database_url
self.migrations_dir = Path(migrations_dir)
self.alembic_cfg = self._setup_alembic_config()
def _setup_alembic_config(self) -> Config:
"""设置Alembic配置"""
# 创建迁移目录
self.migrations_dir.mkdir(exist_ok=True)
# 创建alembic.ini配置文件
alembic_ini_path = self.migrations_dir / "alembic.ini"
if not alembic_ini_path.exists():
self._create_alembic_ini(alembic_ini_path)
# 配置Alembic
alembic_cfg = Config(str(alembic_ini_path))
alembic_cfg.set_main_option("script_location", str(self.migrations_dir))
alembic_cfg.set_main_option("sqlalchemy.url", self.database_url)
return alembic_cfg
def _create_alembic_ini(self, ini_path: Path):
"""创建Alembic配置文件"""
ini_content = """
[alembic]
script_location = migrations
sqlalchemy.url = driver://user:pass@localhost/dbname
[post_write_hooks]
[loggers]
keys = root,sqlalchemy,alembic
[handlers]
keys = console
[formatters]
keys = generic
[logger_root]
level = WARN
handlers = console
qualname =
[logger_sqlalchemy]
level = WARN
handlers =
qualname = sqlalchemy.engine
[logger_alembic]
level = INFO
handlers =
qualname = alembic
[handler_console]
class = StreamHandler
args = (sys.stderr,)
level = NOTSET
formatter = generic
[formatter_generic]
format = %(levelname)-5.5s [%(name)s] %(message)s
datefmt = %H:%M:%S
"""
ini_path.write_text(ini_content.strip())
def init_migrations(self):
"""初始化迁移环境"""
try:
command.init(self.alembic_cfg, str(self.migrations_dir))
print(f"✅ 迁移环境初始化成功: {self.migrations_dir}")
except Exception as e:
print(f"❌ 迁移环境初始化失败: {e}")
def create_migration(self, message: str, auto_generate: bool = True):
"""创建新的迁移文件"""
try:
if auto_generate:
command.revision(self.alembic_cfg, message=message, autogenerate=True)
else:
command.revision(self.alembic_cfg, message=message)
print(f"✅ 迁移文件创建成功: {message}")
except Exception as e:
print(f"❌ 迁移文件创建失败: {e}")
def upgrade_database(self, revision: str = "head"):
"""升级数据库到指定版本"""
try:
command.upgrade(self.alembic_cfg, revision)
print(f"✅ 数据库升级成功到版本: {revision}")
except Exception as e:
print(f"❌ 数据库升级失败: {e}")
def downgrade_database(self, revision: str):
"""降级数据库到指定版本"""
try:
command.downgrade(self.alembic_cfg, revision)
print(f"✅ 数据库降级成功到版本: {revision}")
except Exception as e:
print(f"❌ 数据库降级失败: {e}")
def get_current_revision(self) -> str:
"""获取当前数据库版本"""
try:
from sqlalchemy import create_engine
engine = create_engine(self.database_url)
with engine.connect() as connection:
context = MigrationContext.configure(connection)
current_rev = context.get_current_revision()
return current_rev or "无版本"
except Exception as e:
print(f"❌ 获取当前版本失败: {e}")
return "未知"
def get_migration_history(self) -> list:
"""获取迁移历史"""
try:
script_dir = ScriptDirectory.from_config(self.alembic_cfg)
revisions = []
for revision in script_dir.walk_revisions():
revisions.append({
'revision': revision.revision,
'down_revision': revision.down_revision,
'message': revision.doc,
'create_date': getattr(revision, 'create_date', None)
})
return revisions
except Exception as e:
print(f"❌ 获取迁移历史失败: {e}")
return []
def show_migration_status(self):
"""显示迁移状态"""
print("=== 数据库迁移状态 ===")
print(f"当前版本: {self.get_current_revision()}")
print("\n迁移历史:")
history = self.get_migration_history()
for migration in history[:5]: # 显示最近5个迁移
print(f" {migration['revision'][:8]} - {migration['message']}")
# 迁移演示
def migration_demo():
"""数据库迁移演示"""
print("=== 数据库迁移演示 ===")
# 初始化迁移管理器
migration_manager = DatabaseMigration(
database_url="sqlite:///example.db",
migrations_dir="migrations"
)
# 显示迁移状态
migration_manager.show_migration_status()
# 模拟创建迁移(实际使用时需要先设置模型)
print("\n创建示例迁移...")
# migration_manager.create_migration("添加用户表")
# 模拟升级数据库
print("\n升级数据库...")
# migration_manager.upgrade_database()
if __name__ == "__main__":
migration_demo()
数据库版本控制
from dataclasses import dataclass
from datetime import datetime
from typing import List, Optional
import json
@dataclass
class SchemaVersion:
"""数据库模式版本"""
version: str
description: str
created_at: datetime
applied_at: Optional[datetime] = None
rollback_sql: Optional[str] = None
def to_dict(self) -> dict:
return {
'version': self.version,
'description': self.description,
'created_at': self.created_at.isoformat(),
'applied_at': self.applied_at.isoformat() if self.applied_at else None,
'rollback_sql': self.rollback_sql
}
@classmethod
def from_dict(cls, data: dict) -> 'SchemaVersion':
return cls(
version=data['version'],
description=data['description'],
created_at=datetime.fromisoformat(data['created_at']),
applied_at=datetime.fromisoformat(data['applied_at']) if data['applied_at'] else None,
rollback_sql=data.get('rollback_sql')
)
class SchemaVersionManager:
"""数据库模式版本管理器"""
def __init__(self, session):
self.session = session
self._ensure_version_table()
def _ensure_version_table(self):
"""确保版本表存在"""
create_table_sql = """
CREATE TABLE IF NOT EXISTS schema_versions (
version VARCHAR(50) PRIMARY KEY,
description TEXT NOT NULL,
created_at TIMESTAMP NOT NULL,
applied_at TIMESTAMP,
rollback_sql TEXT
)
"""
self.session.execute(text(create_table_sql))
self.session.commit()
def add_version(self, version: SchemaVersion):
"""添加新版本"""
insert_sql = """
INSERT INTO schema_versions
(version, description, created_at, applied_at, rollback_sql)
VALUES (:version, :description, :created_at, :applied_at, :rollback_sql)
"""
self.session.execute(text(insert_sql), {
'version': version.version,
'description': version.description,
'created_at': version.created_at,
'applied_at': version.applied_at,
'rollback_sql': version.rollback_sql
})
self.session.commit()
def mark_applied(self, version: str):
"""标记版本为已应用"""
update_sql = """
UPDATE schema_versions
SET applied_at = :applied_at
WHERE version = :version
"""
self.session.execute(text(update_sql), {
'version': version,
'applied_at': datetime.now()
})
self.session.commit()
def get_current_version(self) -> Optional[str]:
"""获取当前版本"""
query_sql = """
SELECT version FROM schema_versions
WHERE applied_at IS NOT NULL
ORDER BY applied_at DESC
LIMIT 1
"""
result = self.session.execute(text(query_sql)).fetchone()
return result[0] if result else None
def get_pending_versions(self) -> List[SchemaVersion]:
"""获取待应用的版本"""
query_sql = """
SELECT version, description, created_at, applied_at, rollback_sql
FROM schema_versions
WHERE applied_at IS NULL
ORDER BY created_at ASC
"""
results = self.session.execute(text(query_sql)).fetchall()
return [SchemaVersion(
version=row[0],
description=row[1],
created_at=row[2],
applied_at=row[3],
rollback_sql=row[4]
) for row in results]
def get_version_history(self) -> List[SchemaVersion]:
"""获取版本历史"""
query_sql = """
SELECT version, description, created_at, applied_at, rollback_sql
FROM schema_versions
ORDER BY created_at DESC
"""
results = self.session.execute(text(query_sql)).fetchall()
return [SchemaVersion(
version=row[0],
description=row[1],
created_at=row[2],
applied_at=row[3],
rollback_sql=row[4]
) for row in results]
# 版本控制演示
def version_control_demo():
"""版本控制演示"""
db_manager = ORMDatabaseManager()
db_manager.init_database()
with db_manager.get_session() as session:
version_manager = SchemaVersionManager(session)
print("=== 数据库版本控制演示 ===")
# 添加版本记录
versions = [
SchemaVersion(
version="1.0.0",
description="初始数据库结构",
created_at=datetime.now(),
applied_at=datetime.now()
),
SchemaVersion(
version="1.1.0",
description="添加用户角色表",
created_at=datetime.now()
),
SchemaVersion(
version="1.2.0",
description="添加订单索引",
created_at=datetime.now()
)
]
for version in versions:
try:
version_manager.add_version(version)
print(f"✅ 添加版本: {version.version}")
except Exception as e:
print(f"⚠️ 版本已存在: {version.version}")
# 显示当前版本
current = version_manager.get_current_version()
print(f"\n当前版本: {current}")
# 显示待应用版本
pending = version_manager.get_pending_versions()
print(f"\n待应用版本: {len(pending)}个")
for version in pending:
print(f" {version.version} - {version.description}")
# 显示版本历史
history = version_manager.get_version_history()
print(f"\n版本历史:")
for version in history:
status = "✅ 已应用" if version.applied_at else "⏳ 待应用"
print(f" {version.version} - {version.description} ({status})")
if __name__ == "__main__":
version_control_demo()
性能优化
数据库索引优化
from sqlalchemy import Index, text
from sqlalchemy.sql import func
import time
from typing import Dict, List
class DatabaseIndexOptimizer:
"""数据库索引优化器"""
def __init__(self, session):
self.session = session
def analyze_query_performance(self, query_sql: str, params: dict = None) -> Dict:
"""分析查询性能"""
# 执行EXPLAIN QUERY PLAN (SQLite)
explain_sql = f"EXPLAIN QUERY PLAN {query_sql}"
start_time = time.time()
result = self.session.execute(text(query_sql), params or {})
execution_time = time.time() - start_time
# 获取查询计划
plan_result = self.session.execute(text(explain_sql), params or {})
query_plan = [dict(row._mapping) for row in plan_result]
return {
'execution_time': execution_time,
'query_plan': query_plan,
'row_count': result.rowcount if hasattr(result, 'rowcount') else 0
}
def suggest_indexes(self, table_name: str, frequent_queries: List[str]) -> List[str]:
"""建议索引"""
suggestions = []
# 分析WHERE子句中的列
where_columns = set()
for query in frequent_queries:
# 简单的WHERE子句解析(实际应用中需要更复杂的SQL解析)
if 'WHERE' in query.upper():
where_part = query.upper().split('WHERE')[1].split('ORDER')[0]
# 提取列名(简化版本)
import re
columns = re.findall(r'(\w+)\s*[=<>]', where_part)
where_columns.update(columns)
# 建议为WHERE子句中的列创建索引
for column in where_columns:
suggestions.append(f"CREATE INDEX idx_{table_name}_{column.lower()} ON {table_name}({column.lower()})")
return suggestions
def create_composite_index(self, table_name: str, columns: List[str], index_name: str = None):
"""创建复合索引"""
if not index_name:
index_name = f"idx_{table_name}_{'_'.join(columns)}"
columns_str = ', '.join(columns)
create_index_sql = f"CREATE INDEX {index_name} ON {table_name}({columns_str})"
try:
self.session.execute(text(create_index_sql))
self.session.commit()
print(f"✅ 创建索引成功: {index_name}")
except Exception as e:
print(f"❌ 创建索引失败: {e}")
def analyze_table_statistics(self, table_name: str) -> Dict:
"""分析表统计信息"""
stats = {}
# 获取表行数
count_sql = f"SELECT COUNT(*) as row_count FROM {table_name}"
result = self.session.execute(text(count_sql)).fetchone()
stats['row_count'] = result[0]
# 获取表大小(SQLite特定)
try:
size_sql = f"SELECT page_count * page_size as size FROM pragma_page_count(), pragma_page_size()"
result = self.session.execute(text(size_sql)).fetchone()
stats['table_size_bytes'] = result[0] if result else 0
except:
stats['table_size_bytes'] = 0
return stats
def benchmark_query(self, query_sql: str, params: dict = None, iterations: int = 10) -> Dict:
"""查询基准测试"""
execution_times = []
for _ in range(iterations):
start_time = time.time()
result = self.session.execute(text(query_sql), params or {})
# 确保结果被完全获取
list(result)
execution_time = time.time() - start_time
execution_times.append(execution_time)
return {
'min_time': min(execution_times),
'max_time': max(execution_times),
'avg_time': sum(execution_times) / len(execution_times),
'total_time': sum(execution_times),
'iterations': iterations
}
# 性能优化演示
def performance_optimization_demo():
"""性能优化演示"""
db_manager = ORMDatabaseManager()
db_manager.init_database()
with db_manager.get_session() as session:
optimizer = DatabaseIndexOptimizer(session)
print("=== 数据库性能优化演示 ===")
# 分析表统计信息
print("\n1. 表统计信息:")
for table in ['users', 'orders', 'products']:
try:
stats = optimizer.analyze_table_statistics(table)
print(f" {table}: {stats['row_count']}行, {stats['table_size_bytes']}字节")
except Exception as e:
print(f" {table}: 分析失败 - {e}")
# 查询性能分析
print("\n2. 查询性能分析:")
test_queries = [
"SELECT * FROM users WHERE email = 'test@example.com'",
"SELECT * FROM orders WHERE user_id = 1",
"SELECT * FROM products WHERE price > 100"
]
for query in test_queries:
try:
performance = optimizer.analyze_query_performance(query)
print(f" 查询: {query[:50]}...")
print(f" 执行时间: {performance['execution_time']:.4f}s")
except Exception as e:
print(f" 查询分析失败: {e}")
# 索引建议
print("\n3. 索引建议:")
suggestions = optimizer.suggest_indexes('users', [
"SELECT * FROM users WHERE email = ?",
"SELECT * FROM users WHERE username = ?"
])
for suggestion in suggestions:
print(f" {suggestion}")
# 基准测试
print("\n4. 查询基准测试:")
benchmark_query = "SELECT COUNT(*) FROM users"
try:
benchmark = optimizer.benchmark_query(benchmark_query, iterations=5)
print(f" 查询: {benchmark_query}")
print(f" 平均时间: {benchmark['avg_time']:.4f}s")
print(f" 最小时间: {benchmark['min_time']:.4f}s")
print(f" 最大时间: {benchmark['max_time']:.4f}s")
except Exception as e:
print(f" 基准测试失败: {e}")
if __name__ == "__main__":
performance_optimization_demo()
连接池优化
from sqlalchemy.pool import QueuePool, StaticPool, NullPool
from sqlalchemy import create_engine
import threading
import time
from concurrent.futures import ThreadPoolExecutor, as_completed
class ConnectionPoolOptimizer:
"""连接池优化器"""
def __init__(self):
self.test_results = {}
def create_optimized_engine(self, database_url: str, pool_config: dict):
"""创建优化的数据库引擎"""
return create_engine(
database_url,
poolclass=pool_config.get('poolclass', QueuePool),
pool_size=pool_config.get('pool_size', 5),
max_overflow=pool_config.get('max_overflow', 10),
pool_timeout=pool_config.get('pool_timeout', 30),
pool_recycle=pool_config.get('pool_recycle', 3600),
pool_pre_ping=pool_config.get('pool_pre_ping', True),
echo=pool_config.get('echo', False)
)
def test_pool_configuration(self, database_url: str, pool_configs: List[dict],
concurrent_users: int = 10, operations_per_user: int = 5):
"""测试不同连接池配置的性能"""
results = {}
for i, config in enumerate(pool_configs):
config_name = config.get('name', f'Config_{i+1}')
print(f"\n测试配置: {config_name}")
engine = self.create_optimized_engine(database_url, config)
# 并发测试
start_time = time.time()
def worker_task(worker_id):
worker_times = []
for op in range(operations_per_user):
op_start = time.time()
try:
with engine.connect() as conn:
result = conn.execute(text("SELECT 1"))
list(result) # 确保结果被获取
op_time = time.time() - op_start
worker_times.append(op_time)
except Exception as e:
print(f" Worker {worker_id} 操作失败: {e}")
worker_times.append(float('inf'))
return worker_times
# 使用线程池执行并发测试
with ThreadPoolExecutor(max_workers=concurrent_users) as executor:
futures = [executor.submit(worker_task, i) for i in range(concurrent_users)]
all_times = []
for future in as_completed(futures):
worker_times = future.result()
all_times.extend(worker_times)
total_time = time.time() - start_time
# 计算统计信息
valid_times = [t for t in all_times if t != float('inf')]
results[config_name] = {
'total_time': total_time,
'avg_operation_time': sum(valid_times) / len(valid_times) if valid_times else 0,
'min_operation_time': min(valid_times) if valid_times else 0,
'max_operation_time': max(valid_times) if valid_times else 0,
'success_rate': len(valid_times) / len(all_times) * 100,
'total_operations': len(all_times),
'successful_operations': len(valid_times)
}
print(f" 总时间: {total_time:.2f}s")
print(f" 平均操作时间: {results[config_name]['avg_operation_time']:.4f}s")
print(f" 成功率: {results[config_name]['success_rate']:.1f}%")
engine.dispose()
return results
def recommend_pool_settings(self, expected_concurrent_users: int,
avg_query_time: float) -> dict:
"""推荐连接池设置"""
# 基于并发用户数和查询时间的启发式规则
pool_size = min(max(expected_concurrent_users // 2, 5), 20)
max_overflow = min(expected_concurrent_users, 30)
pool_timeout = max(avg_query_time * 10, 30)
return {
'pool_size': pool_size,
'max_overflow': max_overflow,
'pool_timeout': pool_timeout,
'pool_recycle': 3600, # 1小时
'pool_pre_ping': True,
'poolclass': QueuePool
}
# 连接池优化演示
def connection_pool_optimization_demo():
"""连接池优化演示"""
print("=== 连接池优化演示 ===")
optimizer = ConnectionPoolOptimizer()
database_url = "sqlite:///test_pool.db"
# 不同的连接池配置
pool_configs = [
{
'name': '小连接池',
'pool_size': 2,
'max_overflow': 3,
'pool_timeout': 10
},
{
'name': '中等连接池',
'pool_size': 5,
'max_overflow': 10,
'pool_timeout': 30
},
{
'name': '大连接池',
'pool_size': 10,
'max_overflow': 20,
'pool_timeout': 60
}
]
# 测试不同配置
results = optimizer.test_pool_configuration(
database_url,
pool_configs,
concurrent_users=8,
operations_per_user=3
)
# 显示最佳配置
print("\n=== 配置对比 ===")
best_config = min(results.items(), key=lambda x: x[1]['avg_operation_time'])
print(f"最佳配置: {best_config[0]}")
print(f"平均操作时间: {best_config[1]['avg_operation_time']:.4f}s")
# 推荐设置
print("\n=== 推荐设置 ===")
recommended = optimizer.recommend_pool_settings(
expected_concurrent_users=20,
avg_query_time=0.01
)
print(f"推荐连接池大小: {recommended['pool_size']}")
print(f"推荐最大溢出: {recommended['max_overflow']}")
print(f"推荐超时时间: {recommended['pool_timeout']}s")
if __name__ == "__main__":
connection_pool_optimization_demo()
实践练习
练习1:博客系统数据库设计
from sqlalchemy import Column, Integer, String, Text, DateTime, Boolean, ForeignKey, Table
from sqlalchemy.orm import relationship
from datetime import datetime
from typing import List, Optional
# 多对多关系表:文章标签
article_tags = Table(
'article_tags',
Base.metadata,
Column('article_id', Integer, ForeignKey('articles.id'), primary_key=True),
Column('tag_id', Integer, ForeignKey('tags.id'), primary_key=True)
)
class Author(Base):
"""作者模型"""
__tablename__ = 'authors'
id = Column(Integer, primary_key=True)
username = Column(String(50), unique=True, nullable=False)
email = Column(String(100), unique=True, nullable=False)
full_name = Column(String(100), nullable=False)
bio = Column(Text)
avatar_url = Column(String(255))
is_active = Column(Boolean, default=True)
created_at = Column(DateTime, default=datetime.utcnow)
# 关系
articles = relationship('Article', back_populates='author', cascade='all, delete-orphan')
comments = relationship('Comment', back_populates='author')
def __repr__(self):
return f"<Author(username='{self.username}', email='{self.email}')>"
def get_article_count(self) -> int:
"""获取文章数量"""
return len(self.articles)
def get_published_articles(self) -> List['Article']:
"""获取已发布的文章"""
return [article for article in self.articles if article.is_published]
class Category(Base):
"""分类模型"""
__tablename__ = 'categories'
id = Column(Integer, primary_key=True)
name = Column(String(50), unique=True, nullable=False)
description = Column(Text)
slug = Column(String(50), unique=True, nullable=False)
created_at = Column(DateTime, default=datetime.utcnow)
# 关系
articles = relationship('Article', back_populates='category')
def __repr__(self):
return f"<Category(name='{self.name}', slug='{self.slug}')>"
class Tag(Base):
"""标签模型"""
__tablename__ = 'tags'
id = Column(Integer, primary_key=True)
name = Column(String(30), unique=True, nullable=False)
color = Column(String(7), default='#007bff') # 十六进制颜色
created_at = Column(DateTime, default=datetime.utcnow)
# 关系
articles = relationship('Article', secondary=article_tags, back_populates='tags')
def __repr__(self):
return f"<Tag(name='{self.name}', color='{self.color}')>"
class Article(Base):
"""文章模型"""
__tablename__ = 'articles'
id = Column(Integer, primary_key=True)
title = Column(String(200), nullable=False)
slug = Column(String(200), unique=True, nullable=False)
content = Column(Text, nullable=False)
excerpt = Column(Text) # 摘要
featured_image = Column(String(255))
is_published = Column(Boolean, default=False)
is_featured = Column(Boolean, default=False)
view_count = Column(Integer, default=0)
created_at = Column(DateTime, default=datetime.utcnow)
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
published_at = Column(DateTime)
# 外键
author_id = Column(Integer, ForeignKey('authors.id'), nullable=False)
category_id = Column(Integer, ForeignKey('categories.id'), nullable=False)
# 关系
author = relationship('Author', back_populates='articles')
category = relationship('Category', back_populates='articles')
tags = relationship('Tag', secondary=article_tags, back_populates='articles')
comments = relationship('Comment', back_populates='article', cascade='all, delete-orphan')
def __repr__(self):
return f"<Article(title='{self.title}', author='{self.author.username if self.author else None}')>"
def publish(self):
"""发布文章"""
self.is_published = True
self.published_at = datetime.utcnow()
def unpublish(self):
"""取消发布"""
self.is_published = False
self.published_at = None
def increment_view_count(self):
"""增加浏览次数"""
self.view_count += 1
def get_comment_count(self) -> int:
"""获取评论数量"""
return len([comment for comment in self.comments if comment.is_approved])
class Comment(Base):
"""评论模型"""
__tablename__ = 'comments'
id = Column(Integer, primary_key=True)
content = Column(Text, nullable=False)
author_name = Column(String(50), nullable=False)
author_email = Column(String(100), nullable=False)
author_website = Column(String(255))
is_approved = Column(Boolean, default=False)
created_at = Column(DateTime, default=datetime.utcnow)
# 外键
article_id = Column(Integer, ForeignKey('articles.id'), nullable=False)
author_id = Column(Integer, ForeignKey('authors.id')) # 可选,注册用户评论
parent_id = Column(Integer, ForeignKey('comments.id')) # 回复评论
# 关系
article = relationship('Article', back_populates='comments')
author = relationship('Author', back_populates='comments')
parent = relationship('Comment', remote_side=[id])
replies = relationship('Comment', cascade='all, delete-orphan')
def __repr__(self):
return f"<Comment(author_name='{self.author_name}', article_id={self.article_id})>"
def approve(self):
"""批准评论"""
self.is_approved = True
def reject(self):
"""拒绝评论"""
self.is_approved = False
class BlogService:
"""博客服务类"""
def __init__(self, session):
self.session = session
def create_author(self, username: str, email: str, full_name: str,
bio: str = None, avatar_url: str = None) -> Author:
"""创建作者"""
author = Author(
username=username,
email=email,
full_name=full_name,
bio=bio,
avatar_url=avatar_url
)
self.session.add(author)
self.session.commit()
return author
def create_category(self, name: str, slug: str, description: str = None) -> Category:
"""创建分类"""
category = Category(
name=name,
slug=slug,
description=description
)
self.session.add(category)
self.session.commit()
return category
def create_tag(self, name: str, color: str = '#007bff') -> Tag:
"""创建标签"""
tag = Tag(name=name, color=color)
self.session.add(tag)
self.session.commit()
return tag
def create_article(self, title: str, slug: str, content: str,
author_id: int, category_id: int,
excerpt: str = None, tag_names: List[str] = None) -> Article:
"""创建文章"""
article = Article(
title=title,
slug=slug,
content=content,
excerpt=excerpt,
author_id=author_id,
category_id=category_id
)
# 添加标签
if tag_names:
for tag_name in tag_names:
tag = self.session.query(Tag).filter(Tag.name == tag_name).first()
if not tag:
tag = self.create_tag(tag_name)
article.tags.append(tag)
self.session.add(article)
self.session.commit()
return article
def get_published_articles(self, limit: int = 10, offset: int = 0) -> List[Article]:
"""获取已发布的文章"""
return self.session.query(Article).filter(
Article.is_published == True
).order_by(
Article.published_at.desc()
).offset(offset).limit(limit).all()
def get_articles_by_category(self, category_slug: str) -> List[Article]:
"""根据分类获取文章"""
return self.session.query(Article).join(Category).filter(
Category.slug == category_slug,
Article.is_published == True
).order_by(Article.published_at.desc()).all()
def get_articles_by_tag(self, tag_name: str) -> List[Article]:
"""根据标签获取文章"""
return self.session.query(Article).join(Article.tags).filter(
Tag.name == tag_name,
Article.is_published == True
).order_by(Article.published_at.desc()).all()
def search_articles(self, keyword: str) -> List[Article]:
"""搜索文章"""
return self.session.query(Article).filter(
or_(
Article.title.ilike(f'%{keyword}%'),
Article.content.ilike(f'%{keyword}%')
),
Article.is_published == True
).order_by(Article.published_at.desc()).all()
def get_popular_articles(self, limit: int = 5) -> List[Article]:
"""获取热门文章"""
return self.session.query(Article).filter(
Article.is_published == True
).order_by(
Article.view_count.desc()
).limit(limit).all()
def add_comment(self, article_id: int, author_name: str, author_email: str,
content: str, author_website: str = None,
parent_id: int = None) -> Comment:
"""添加评论"""
comment = Comment(
article_id=article_id,
author_name=author_name,
author_email=author_email,
author_website=author_website,
content=content,
parent_id=parent_id
)
self.session.add(comment)
self.session.commit()
return comment
def get_blog_statistics(self) -> dict:
"""获取博客统计信息"""
total_articles = self.session.query(Article).count()
published_articles = self.session.query(Article).filter(
Article.is_published == True
).count()
total_authors = self.session.query(Author).count()
total_categories = self.session.query(Category).count()
total_tags = self.session.query(Tag).count()
total_comments = self.session.query(Comment).filter(
Comment.is_approved == True
).count()
return {
'total_articles': total_articles,
'published_articles': published_articles,
'total_authors': total_authors,
'total_categories': total_categories,
'total_tags': total_tags,
'total_comments': total_comments
}
# 博客系统演示
def blog_system_demo():
"""博客系统演示"""
db_manager = ORMDatabaseManager()
db_manager.init_database()
with db_manager.get_session() as session:
blog_service = BlogService(session)
print("=== 博客系统演示 ===")
# 创建作者
print("\n1. 创建作者:")
author1 = blog_service.create_author(
username="alice",
email="alice@example.com",
full_name="Alice Johnson",
bio="技术博客作者,专注于Python和Web开发"
)
print(f" 创建作者: {author1.full_name}")
# 创建分类
print("\n2. 创建分类:")
tech_category = blog_service.create_category(
name="技术",
slug="tech",
description="技术相关文章"
)
print(f" 创建分类: {tech_category.name}")
# 创建标签
print("\n3. 创建标签:")
python_tag = blog_service.create_tag("Python", "#3776ab")
web_tag = blog_service.create_tag("Web开发", "#61dafb")
print(f" 创建标签: {python_tag.name}, {web_tag.name}")
# 创建文章
print("\n4. 创建文章:")
article1 = blog_service.create_article(
title="Python数据库操作入门",
slug="python-database-intro",
content="这是一篇关于Python数据库操作的入门文章...",
excerpt="学习Python数据库操作的基础知识",
author_id=author1.id,
category_id=tech_category.id,
tag_names=["Python", "数据库"]
)
article1.publish()
session.commit()
print(f" 创建文章: {article1.title}")
# 添加评论
print("\n5. 添加评论:")
comment1 = blog_service.add_comment(
article_id=article1.id,
author_name="Bob",
author_email="bob@example.com",
content="很好的文章,学到了很多!"
)
comment1.approve()
session.commit()
print(f" 添加评论: {comment1.content[:20]}...")
# 获取文章列表
print("\n6. 获取已发布文章:")
published_articles = blog_service.get_published_articles()
for article in published_articles:
print(f" {article.title} - {article.author.full_name}")
# 获取统计信息
print("\n7. 博客统计:")
stats = blog_service.get_blog_statistics()
for key, value in stats.items():
print(f" {key}: {value}")
if __name__ == "__main__":
blog_system_demo()
练习2:电商系统库存管理
from sqlalchemy import Column, Integer, String, Decimal, DateTime, Boolean, ForeignKey
from sqlalchemy.orm import relationship
from decimal import Decimal as PyDecimal
from datetime import datetime
from typing import List, Optional
import logging
class Product(Base):
"""商品模型"""
__tablename__ = 'products'
id = Column(Integer, primary_key=True)
name = Column(String(100), nullable=False)
sku = Column(String(50), unique=True, nullable=False)
price = Column(Decimal(10, 2), nullable=False)
cost = Column(Decimal(10, 2), nullable=False)
description = Column(String(500))
is_active = Column(Boolean, default=True)
created_at = Column(DateTime, default=datetime.utcnow)
# 关系
inventory = relationship('Inventory', back_populates='product', uselist=False)
stock_movements = relationship('StockMovement', back_populates='product')
def __repr__(self):
return f"<Product(sku='{self.sku}', name='{self.name}')>"
def get_current_stock(self) -> int:
"""获取当前库存"""
return self.inventory.quantity if self.inventory else 0
def get_stock_value(self) -> PyDecimal:
"""获取库存价值"""
return self.cost * self.get_current_stock()
class Inventory(Base):
"""库存模型"""
__tablename__ = 'inventory'
id = Column(Integer, primary_key=True)
product_id = Column(Integer, ForeignKey('products.id'), unique=True, nullable=False)
quantity = Column(Integer, default=0, nullable=False)
reserved_quantity = Column(Integer, default=0, nullable=False) # 预留库存
min_stock_level = Column(Integer, default=10, nullable=False) # 最低库存警戒线
max_stock_level = Column(Integer, default=1000, nullable=False) # 最高库存
last_updated = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
# 关系
product = relationship('Product', back_populates='inventory')
def __repr__(self):
return f"<Inventory(product_id={self.product_id}, quantity={self.quantity})>"
@property
def available_quantity(self) -> int:
"""可用库存数量"""
return max(0, self.quantity - self.reserved_quantity)
@property
def is_low_stock(self) -> bool:
"""是否低库存"""
return self.quantity <= self.min_stock_level
@property
def is_out_of_stock(self) -> bool:
"""是否缺货"""
return self.quantity <= 0
def can_fulfill_order(self, quantity: int) -> bool:
"""检查是否能满足订单"""
return self.available_quantity >= quantity
class StockMovement(Base):
"""库存变动记录"""
__tablename__ = 'stock_movements'
id = Column(Integer, primary_key=True)
product_id = Column(Integer, ForeignKey('products.id'), nullable=False)
movement_type = Column(String(20), nullable=False) # IN, OUT, ADJUSTMENT
quantity = Column(Integer, nullable=False)
reference_id = Column(String(50)) # 关联单据号
reference_type = Column(String(20)) # 单据类型:PURCHASE, SALE, ADJUSTMENT
reason = Column(String(200))
created_at = Column(DateTime, default=datetime.utcnow)
created_by = Column(String(50))
# 关系
product = relationship('Product', back_populates='stock_movements')
def __repr__(self):
return f"<StockMovement(product_id={self.product_id}, type='{self.movement_type}', quantity={self.quantity})>"
class InventoryService:
"""库存管理服务"""
def __init__(self, session):
self.session = session
self.logger = logging.getLogger(__name__)
def create_product(self, name: str, sku: str, price: PyDecimal,
cost: PyDecimal, description: str = None,
initial_stock: int = 0, min_stock: int = 10) -> Product:
"""创建商品"""
try:
# 创建商品
product = Product(
name=name,
sku=sku,
price=price,
cost=cost,
description=description
)
self.session.add(product)
self.session.flush() # 获取product.id
# 创建库存记录
inventory = Inventory(
product_id=product.id,
quantity=initial_stock,
min_stock_level=min_stock
)
self.session.add(inventory)
# 记录初始库存
if initial_stock > 0:
self._record_stock_movement(
product_id=product.id,
movement_type='IN',
quantity=initial_stock,
reference_type='INITIAL',
reason='初始库存'
)
self.session.commit()
self.logger.info(f"创建商品: {sku}, 初始库存: {initial_stock}")
return product
except Exception as e:
self.session.rollback()
self.logger.error(f"创建商品失败: {e}")
raise
def stock_in(self, product_id: int, quantity: int,
reference_id: str = None, reason: str = None) -> bool:
"""入库"""
try:
inventory = self.session.query(Inventory).filter(
Inventory.product_id == product_id
).first()
if not inventory:
raise ValueError(f"商品 {product_id} 的库存记录不存在")
# 更新库存
inventory.quantity += quantity
# 记录库存变动
self._record_stock_movement(
product_id=product_id,
movement_type='IN',
quantity=quantity,
reference_id=reference_id,
reference_type='PURCHASE',
reason=reason or '采购入库'
)
self.session.commit()
self.logger.info(f"商品 {product_id} 入库 {quantity} 件")
return True
except Exception as e:
self.session.rollback()
self.logger.error(f"入库失败: {e}")
raise
def stock_out(self, product_id: int, quantity: int,
reference_id: str = None, reason: str = None) -> bool:
"""出库"""
try:
inventory = self.session.query(Inventory).filter(
Inventory.product_id == product_id
).first()
if not inventory:
raise ValueError(f"商品 {product_id} 的库存记录不存在")
if not inventory.can_fulfill_order(quantity):
raise ValueError(f"库存不足,可用库存: {inventory.available_quantity}, 需要: {quantity}")
# 更新库存
inventory.quantity -= quantity
# 记录库存变动
self._record_stock_movement(
product_id=product_id,
movement_type='OUT',
quantity=quantity,
reference_id=reference_id,
reference_type='SALE',
reason=reason or '销售出库'
)
self.session.commit()
self.logger.info(f"商品 {product_id} 出库 {quantity} 件")
return True
except Exception as e:
self.session.rollback()
self.logger.error(f"出库失败: {e}")
raise
def reserve_stock(self, product_id: int, quantity: int) -> bool:
"""预留库存"""
try:
inventory = self.session.query(Inventory).filter(
Inventory.product_id == product_id
).first()
if not inventory:
raise ValueError(f"商品 {product_id} 的库存记录不存在")
if inventory.available_quantity < quantity:
raise ValueError(f"可用库存不足,可用: {inventory.available_quantity}, 需要: {quantity}")
inventory.reserved_quantity += quantity
self.session.commit()
self.logger.info(f"商品 {product_id} 预留库存 {quantity} 件")
return True
except Exception as e:
self.session.rollback()
self.logger.error(f"预留库存失败: {e}")
raise
def release_reserved_stock(self, product_id: int, quantity: int) -> bool:
"""释放预留库存"""
try:
inventory = self.session.query(Inventory).filter(
Inventory.product_id == product_id
).first()
if not inventory:
raise ValueError(f"商品 {product_id} 的库存记录不存在")
if inventory.reserved_quantity < quantity:
raise ValueError(f"预留库存不足,预留: {inventory.reserved_quantity}, 释放: {quantity}")
inventory.reserved_quantity -= quantity
self.session.commit()
self.logger.info(f"商品 {product_id} 释放预留库存 {quantity} 件")
return True
except Exception as e:
self.session.rollback()
self.logger.error(f"释放预留库存失败: {e}")
raise
def adjust_stock(self, product_id: int, new_quantity: int, reason: str) -> bool:
"""库存调整"""
try:
inventory = self.session.query(Inventory).filter(
Inventory.product_id == product_id
).first()
if not inventory:
raise ValueError(f"商品 {product_id} 的库存记录不存在")
old_quantity = inventory.quantity
adjustment = new_quantity - old_quantity
inventory.quantity = new_quantity
# 记录库存变动
self._record_stock_movement(
product_id=product_id,
movement_type='ADJUSTMENT',
quantity=adjustment,
reference_type='ADJUSTMENT',
reason=reason
)
self.session.commit()
self.logger.info(f"商品 {product_id} 库存调整: {old_quantity} -> {new_quantity}")
return True
except Exception as e:
self.session.rollback()
self.logger.error(f"库存调整失败: {e}")
raise
def get_low_stock_products(self) -> List[Product]:
"""获取低库存商品"""
return self.session.query(Product).join(Inventory).filter(
Inventory.quantity <= Inventory.min_stock_level,
Product.is_active == True
).all()
def get_out_of_stock_products(self) -> List[Product]:
"""获取缺货商品"""
return self.session.query(Product).join(Inventory).filter(
Inventory.quantity <= 0,
Product.is_active == True
).all()
def get_inventory_report(self) -> dict:
"""获取库存报告"""
total_products = self.session.query(Product).filter(
Product.is_active == True
).count()
low_stock_count = len(self.get_low_stock_products())
out_of_stock_count = len(self.get_out_of_stock_products())
total_inventory_value = self.session.query(
func.sum(Product.cost * Inventory.quantity)
).join(Inventory).filter(
Product.is_active == True
).scalar() or 0
return {
'total_products': total_products,
'low_stock_count': low_stock_count,
'out_of_stock_count': out_of_stock_count,
'total_inventory_value': float(total_inventory_value)
}
def get_stock_movements(self, product_id: int = None,
days: int = 30) -> List[StockMovement]:
"""获取库存变动记录"""
query = self.session.query(StockMovement)
if product_id:
query = query.filter(StockMovement.product_id == product_id)
if days:
from_date = datetime.utcnow() - timedelta(days=days)
query = query.filter(StockMovement.created_at >= from_date)
return query.order_by(StockMovement.created_at.desc()).all()
def _record_stock_movement(self, product_id: int, movement_type: str,
quantity: int, reference_id: str = None,
reference_type: str = None, reason: str = None):
"""记录库存变动"""
movement = StockMovement(
product_id=product_id,
movement_type=movement_type,
quantity=quantity,
reference_id=reference_id,
reference_type=reference_type,
reason=reason,
created_by='system'
)
self.session.add(movement)
# 库存管理演示
def inventory_management_demo():
"""库存管理演示"""
# 配置日志
logging.basicConfig(level=logging.INFO)
db_manager = ORMDatabaseManager()
db_manager.init_database()
with db_manager.get_session() as session:
inventory_service = InventoryService(session)
print("=== 库存管理系统演示 ===")
# 创建商品
print("\n1. 创建商品:")
product1 = inventory_service.create_product(
name="iPhone 15",
sku="IP15-001",
price=PyDecimal('6999.00'),
cost=PyDecimal('5000.00'),
description="苹果iPhone 15手机",
initial_stock=100,
min_stock=20
)
print(f" 创建商品: {product1.name}, SKU: {product1.sku}")
# 入库操作
print("\n2. 入库操作:")
inventory_service.stock_in(
product_id=product1.id,
quantity=50,
reference_id="PO-2024-001",
reason="采购入库"
)
print(f" 入库50件,当前库存: {product1.get_current_stock()}")
# 预留库存
print("\n3. 预留库存:")
inventory_service.reserve_stock(product1.id, 30)
inventory = session.query(Inventory).filter(
Inventory.product_id == product1.id
).first()
print(f" 预留30件,可用库存: {inventory.available_quantity}")
# 出库操作
print("\n4. 出库操作:")
inventory_service.stock_out(
product_id=product1.id,
quantity=20,
reference_id="SO-2024-001",
reason="销售出库"
)
print(f" 出库20件,当前库存: {product1.get_current_stock()}")
# 库存调整
print("\n5. 库存调整:")
inventory_service.adjust_stock(
product_id=product1.id,
new_quantity=120,
reason="盘点调整"
)
print(f" 调整后库存: {product1.get_current_stock()}")
# 获取库存报告
print("\n6. 库存报告:")
report = inventory_service.get_inventory_report()
for key, value in report.items():
print(f" {key}: {value}")
# 获取库存变动记录
print("\n7. 库存变动记录:")
movements = inventory_service.get_stock_movements(product1.id)
for movement in movements[:5]: # 显示最近5条记录
print(f" {movement.created_at.strftime('%Y-%m-%d %H:%M')} - "
f"{movement.movement_type}: {movement.quantity} - {movement.reason}")
if __name__ == "__main__":
inventory_management_demo()
总结
核心知识点回顾
-
数据库编程基础
- 数据库类型选择和配置管理
- 连接工厂模式的应用
- 数据库连接的生命周期管理
-
原生SQL操作
- SQL查询构建器的设计模式
- 动态SQL生成和参数绑定
- 复杂查询的构建技巧
-
SQLAlchemy ORM
- 模型定义和关系映射
- 查询API的高级用法
- 会话管理和对象生命周期
-
数据库连接池
- 连接池的配置和优化
- 并发环境下的连接管理
- 性能监控和调优
-
事务管理
- 事务的ACID特性
- 嵌套事务和保存点
- 事务装饰器和上下文管理器
-
数据库迁移
- Alembic迁移工具的使用
- 版本控制和回滚策略
- 生产环境迁移最佳实践
-
性能优化
- 索引策略和查询优化
- 连接池参数调优
- 查询性能分析和监控
最佳实践总结
1. 数据库设计原则
- 规范化设计: 避免数据冗余,保持数据一致性
- 索引策略: 合理创建索引,平衡查询性能和写入性能
- 约束定义: 使用数据库约束保证数据完整性
- 命名规范: 统一的表名、字段名和索引命名规范
2. ORM使用技巧
- 懒加载优化: 合理使用eager loading避免N+1查询问题
- 批量操作: 使用bulk操作提高大量数据处理性能
- 查询优化: 使用join和subquery优化复杂查询
- 缓存策略: 合理使用查询缓存和对象缓存
3. 事务管理策略
- 事务边界: 明确事务的开始和结束边界
- 异常处理: 完善的异常处理和回滚机制
- 隔离级别: 根据业务需求选择合适的隔离级别
- 死锁预防: 统一的锁获取顺序避免死锁
4. 性能优化建议
- 连接池配置: 根据应用负载调整连接池参数
- 查询分析: 定期分析慢查询并进行优化
- 索引维护: 定期检查和维护数据库索引
- 监控告警: 建立完善的数据库监控和告警机制
5. 安全考虑
- SQL注入防护: 始终使用参数化查询
- 权限控制: 最小权限原则,合理分配数据库权限
- 敏感数据: 对敏感数据进行加密存储
- 审计日志: 记录重要的数据库操作日志
进阶学习方向
-
分布式数据库
- 数据库分片和读写分离
- 分布式事务处理
- 数据一致性保证
-
NoSQL数据库
- MongoDB、Redis等NoSQL数据库的使用
- 多数据源整合
- 数据库选型策略
-
数据库运维
- 数据库备份和恢复
- 性能调优和故障排查
- 高可用架构设计
-
大数据处理
- 数据仓库和数据湖
- ETL流程设计
- 实时数据处理
学习建议
- 理论与实践结合: 在理解理论的基础上,多做实际项目练习
- 关注性能: 从一开始就养成关注性能的习惯
- 学习源码: 深入学习SQLAlchemy等框架的源码实现
- 持续学习: 关注数据库技术的最新发展和最佳实践
- 实际项目: 在真实项目中应用所学知识,积累经验
通过本教程的学习,你应该已经掌握了Python数据库操作的核心技能。继续实践和深入学习,你将能够构建出高性能、可靠的数据库应用系统。
文档版本: v1.0.0
创建时间: 2024-12-19
最后更新: 2024-12-19
作者: Python教程团队