import time
import random
import threading
import queue
from typing import Callable, Any, Dict, List
from dataclasses import dataclass, field
from enum import Enum, auto
import json
import socket
from concurrent.futures import ThreadPoolExecutor
import logging
from datetime import datetime
class TaskStatus(Enum):
PENDING = auto()
RUNNING = auto()
COMPLETED = auto()
FAILED = auto()
@dataclass
class Task:
task_id: str
func: Callable[..., Any]
args: tuple = ()
kwargs: Dict[str, Any] = field(default_factory=dict)
status: TaskStatus = TaskStatus.PENDING
result: Any = None
error: str = ""
start_time: float = 0
end_time: float = 0
class DistributedTaskScheduler:
"""分布式任务调度与状态监控系统"""
def __init__(self, worker_nodes: List[str], max_workers_per_node: int = 4):
self.worker_nodes = worker_nodes
self.max_workers = max_workers_per_node
self.task_queue = queue.Queue()
self.task_registry: Dict[str, Task] = {}
self.worker_pools: Dict[str, ThreadPoolExecutor] = {}
self.heartbeat_interval = 5
self.logger = self._setup_logger()
self._initialize_workers()
def _setup_logger(self) -> logging.Logger:
logger = logging.getLogger('DistributedTaskScheduler')
logger.setLevel(logging.INFO)
formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
handler = logging.StreamHandler()
handler.setFormatter(formatter)
logger.addHandler(handler)
return logger
def _initialize_workers(self):
"""初始化工作节点线程池"""
for node in self.worker_nodes:
self.worker_pools[node] = ThreadPoolExecutor(max_workers=self.max_workers)
def _execute_task(self, task: Task, node: str):
"""实际执行任务的函数"""
try:
task.status = TaskStatus.RUNNING
task.start_time = time.time()
self.logger.info(f"Task {task.task_id} started on {node}")
result = task.func(*task.args, **task.kwargs)
task.status = TaskStatus.COMPLETED
task.result = result
self.logger.info(f"Task {task.task_id} completed on {node}")
except Exception as e:
task.status = TaskStatus.FAILED
task.error = str(e)
self.logger.error(f"Task {task.task_id} failed on {node}: {str(e)}")
finally:
task.end_time = time.time()
def _dispatch_tasks(self):
"""任务分发主循环"""
while True:
task = self.task_queue.get()
if task is None: # 终止信号
break
# 选择负载最低的节点
selected_node = min(
self.worker_nodes,
key=lambda node: self.worker_pools[node]._work_queue.qsize()
)
# 提交任务到选中的节点
self.worker_pools[selected_node].submit(
self._execute_task, task, selected_node
)
def _monitor_nodes(self):
"""节点健康状态监控"""
while True:
time.sleep(self.heartbeat_interval)
for node in self.worker_nodes:
try:
# 简单的心跳检测
socket.create_connection((node.split(':')[0], int(node.split(':')[1])), timeout=1)
self.logger.debug(f"Node {node} is alive")
except Exception as e:
self.logger.warning(f"Node {node} is down: {str(e)}")
# 这里可以添加节点故障处理逻辑
def submit_task(self, func: Callable[..., Any], *args, **kwargs) -> str:
"""提交新任务"""
task_id = f"task-{int(time.time()*1000)}-{random.randint(1000, 9999)}"
task = Task(task_id=task_id, func=func, args=args, kwargs=kwargs)
self.task_registry[task_id] = task
self.task_queue.put(task)
return task_id
def get_task_status(self, task_id: str) -> Dict[str, Any]:
"""获取任务状态"""
task = self.task_registry.get(task_id)
if not task:
return {"error": "Task not found"}
return {
"task_id": task.task_id,
"status": task.status.name,
"result": task.result,
"error": task.error,
"start_time": datetime.fromtimestamp(task.start_time).isoformat() if task.start_time else None,
"end_time": datetime.fromtimestamp(task.end_time).isoformat() if task.end_time else None,
"duration": task.end_time - task.start_time if task.end_time and task.start_time else None
}
def start(self):
"""启动调度系统"""
self.logger.info("Starting distributed task scheduler")
# 启动任务分发线程
self.dispatcher_thread = threading.Thread(target=self._dispatch_tasks, daemon=True)
self.dispatcher_thread.start()
# 启动监控线程
self.monitor_thread = threading.Thread(target=self._monitor_nodes, daemon=True)
self.monitor_thread.start()
def shutdown(self):
"""关闭调度系统"""
self.logger.info("Shutting down distributed task scheduler")
# 发送终止信号
self.task_queue.put(None)
# 关闭所有工作线程池
for pool in self.worker_pools.values():
pool.shutdown(wait=True)
# 示例任务函数
def sample_task(task_num: int, duration: float):
"""示例任务函数"""
print(f"Starting task {task_num}")
time.sleep(duration)
if random.random() < 0.1: # 10%概率失败
raise ValueError(f"Task {task_num} randomly failed")
return f"Result of task {task_num}"
if __name__ == "__main__":
# 示例工作节点 (格式: "host:port")
worker_nodes = ["localhost:8000", "localhost:8001", "localhost:8002"]
# 初始化调度器
scheduler = DistributedTaskScheduler(worker_nodes, max_workers_per_node=2)
scheduler.start()
# 提交一些示例任务
task_ids = []
for i in range(10):
task_id = scheduler.submit_task(sample_task, i, random.uniform(0.5, 2.0))
task_ids.append(task_id)
print(f"Submitted task {i} with ID: {task_id}")
# 监控任务状态
while True:
time.sleep(1)
print("\nCurrent task status:")
for tid in task_ids:
status = scheduler.get_task_status(tid)
print(f"{tid}: {status['status']}")
# 检查是否所有任务都已完成或失败
if all(scheduler.get_task_status(tid)['status'] in ['COMPLETED', 'FAILED'] for tid in task_ids):
break
# 关闭调度器
scheduler.shutdown()
使用说明
-
功能特点:
- 分布式任务调度与负载均衡
- 多节点任务执行
- 实时状态监控
- 故障检测与处理
- 任务优先级队列
- 详细任务日志记录
-
核心组件:
Task
: 封装任务信息的数据类DistributedTaskScheduler
: 主调度器类TaskStatus
: 任务状态枚举
-
使用方法:
# 1. 初始化调度器 (指定工作节点) scheduler = DistributedTaskScheduler( worker_nodes=["node1:8000", "node2:8000"], max_workers_per_node=4 ) # 2. 启动调度系统 scheduler.start() # 3. 提交任务 def my_task(arg1, arg2): # 你的任务逻辑 return result task_id = scheduler.submit_task(my_task, "arg1", arg2="value") # 4. 获取任务状态 status = scheduler.get_task_status(task_id) # 5. 关闭系统 (当所有任务完成后) scheduler.shutdown()
-
应用场景:
- 大规模数据处理
- 分布式计算
- 批处理作业
- 微服务任务调度
- CI/CD流水线
-
扩展建议:
- 添加任务优先级支持
- 实现任务依赖关系
- 添加任务超时机制
- 支持持久化任务队列
- 添加REST API接口
- 实现任务结果缓存
-
注意事项:
- 工作节点需要预先配置好
- 任务函数需要是可序列化的
- 生产环境需要更健壮的错误处理
- 考虑添加任务重试机制