from pyspark.sql import SparkSession, functions as F
from elasticsearch import Elasticsearch, helpers
from elasticsearch.exceptions import ConnectionError, RequestError, NotFoundError
from datetime import datetime
import base64
import hashlib
import sys
from pyspark import TaskContext
import uuid
# ====================== Spark 初始化 ======================
spark = SparkSession.builder \
.appName("BigDataToES") \
.config("spark.sql.shuffle.partitions", 800) \
.config("spark.default.parallelism", 800) \
.config("spark.driver.memory", "30g") \
.config("spark.speculation", "false") \
.getOrCreate()
success_count = spark.sparkContext.accumulator(0)
failed_count = spark.sparkContext.accumulator(0)
# ====================== 数据获取函数 ======================
def get_spark_sql_data(sql_query):
try:
spark.sql(f"use ${mainDb}") # 需确保 ${mainDb} 已定义
df = spark.sql(sql_query)
print(f"[INFO] 数据量统计: {df.count()}")
# 识别日期/时间戳字段(可选逻辑)
date_columns = [
col for col, datatype in df.dtypes
if datatype in ('date', 'timestamp', 'string')
]
columns = df.columns
print(f"[INFO] 字段列表: {columns}")
# 重分区优化
rows_rdd = df.rdd.map(lambda row: tuple(row)).repartition(800)
print(f"[INFO] RDD 分区数: {rows_rdd.getNumPartitions()}")
return columns, rows_rdd
except Exception as e:
print(f"[ERROR] Spark SQL 执行失败: {str(e)}")
sys.exit(1)
# ====================== ES 写入函数 ======================
def es_insert_data(es_host, es_user, es_pass, index_name, columns, rows_rdd,
mode='append', es_shards=5, es_replicas=1, failure_threshold=1000000000):
"""
改进版 ES 写入:支持覆盖/追加模式,含错误处理、类型自动映射、批量写入优化
"""
try:
# 1. 构建 ES 连接参数
es_kwargs = {
"hosts": [es_host],
"basic_auth": (es_user, es_pass),
"request_timeout": 300, # 超时设置
"max_retries": 3,
"retry_on_timeout": True
}
# TLS 配置(HTTPS 场景)
if es_host.startswith("https://"):
es_kwargs.update({
"verify_certs": False,
"ssl_assert_hostname": False
})
# 2. 初始化 ES 客户端
es = Elasticsearch(**es_kwargs)
# 3. 自动推断字段类型
def get_es_type(col_name, value):
if isinstance(value, datetime):
return {"type": "date", "format": "yyyy-MM-dd"}
elif isinstance(value, (int, float)):
return "double" if isinstance(value, float) else "long"
elif isinstance(value, bool):
return "boolean"
else:
return "text"
# 4. 处理索引(覆盖/追加逻辑)
if mode == 'overwrite':
# 尝试删除旧索引
try:
es.indices.delete(index=index_name)
print(f"[INFO] 覆盖模式:旧索引 {index_name} 已删除")
except NotFoundError:
print(f"[INFO] 覆盖模式:索引 {index_name} 不存在,直接创建")
except RequestError as e:
print(f"[ERROR] 删除索引失败: {str(e)},终止覆盖操作")
sys.exit(1)
# 采样推断字段类型
if rows_rdd:
sample_row = rows_rdd.take(1)[0]
properties = {
col: {"type": get_es_type(col, val)}
for col, val in zip(columns, sample_row)
}
else:
properties = {col: {"type": "text"} for col in columns}
# 创建新索引
es.indices.create(
index=index_name,
body={
"settings": {
"number_of_shards": es_shards,
"number_of_replicas": es_replicas
},
"mappings": {"properties": properties}
}
)
print(f"[INFO] 覆盖模式:新索引 {index_name} 创建成功(分片: {es_shards}, 副本: {es_replicas})")
elif mode == 'append':
# 索引不存在则创建
if not es.indices.exists(index=index_name):
if rows_rdd:
sample_row = rows_rdd.take(1)[0]
properties = {
col: {"type": get_es_type(col, val)}
for col, val in zip(columns, sample_row)
}
else:
properties = {col: {"type": "text"} for col in columns}
es.indices.create(
index=index_name,
body={"mappings": {"properties": properties}}
)
print(f"[INFO] 追加模式:新索引 {index_name} 创建成功(分片: {es_shards}, 副本: {es_replicas})")
else:
print(f"[INFO] 追加模式:索引 {index_name} 已存在,继续追加数据")
# 5. 分布式写入逻辑(按分区批量处理)
def write_partition_to_es(partition):
es_part = Elasticsearch(**es_kwargs) # 每个分区新建客户端
errors = []
local_success = 0
local_failed = 0
docs = []
for i, row_tuple in enumerate(partition):
# 生成唯一 ID(UUID 保证全局唯一)
unique_id = str(uuid.uuid4())
doc = dict(zip(columns, row_tuple))
# 类型转换(处理日期、二进制、特殊浮点值)
for col, val in doc.items():
if isinstance(val, datetime):
doc[col] = val.strftime("%Y-%m-%d")
elif isinstance(val, bytes):
doc[col] = base64.b64encode(val).decode('utf-8')
elif isinstance(val, float) and val in (float('inf'), float('-inf'), float('nan')):
doc[col] = None
# 构建批量写入文档
docs.append({
"_op_type": "create",
"_index": index_name,
"_id": unique_id,
"_source": doc
})
# 批量写入(每 500 条提交一次)
if len(docs) >= 500:
success, failed = helpers.bulk(
es_part, docs,
chunk_size=500,
raise_on_error=False,
refresh=False
)
local_success += success
local_failed += len(failed)
errors.extend(failed)
docs = [] # 清空缓存
# 处理剩余文档
if docs:
success, failed = helpers.bulk(
es_part, docs,
chunk_size=1000,
raise_on_error=False,
refresh=False
)
local_success += success
local_failed += len(failed)
errors.extend(failed)
# 更新全局统计
success_count.add(local_success)
failed_count.add(local_failed)
# 打印分区统计
print(f"[INFO] 分区写入:成功 {local_success} 条,失败 {local_failed} 条")
if errors:
print(f"[ERRORS] 前 10 条失败详情:")
for error in errors[:10]:
print(f" {error}")
es_part.close() # 关闭客户端
# 执行分区写入
rows_rdd.foreachPartition(write_partition_to_es)
# 6. 全局统计与校验
total_success = success_count.value
total_failed = failed_count.value
total_count = total_success + total_failed
failure_rate = total_failed / total_success if total_count > 0 else 0
# 刷新索引保证数据可见
es.indices.refresh(index=index_name)
# 验证实际写入数量
count_result = es.count(index=index_name)
print(f"[INFO] 全局统计:成功 {total_success} 条,失败 {total_failed} 条")
print(f"[INFO] 索引 {index_name} 实际文档数:{count_result['count']}")
# 失败率校验
if failure_rate > failure_threshold:
print(f"[ERROR] 失败率 {failure_rate:.2%} 超过阈值 {failure_threshold:.2%},任务终止")
spark.stop()
sys.exit(1)
except ConnectionError:
print("[ERROR] ES 连接失败,请检查地址/认证信息")
sys.exit(1)
except RequestError as e:
print(f"[ERROR] ES 请求错误: {str(e)}")
sys.exit(1)
except Exception as e:
print(f"[ERROR] 未知错误: {str(e)}")
sys.exit(1)
# ====================== 主程序 ======================
if __name__ == "__main__":
# 用户需自定义的配置(根据实际场景修改)
SPARK_SQL_CONFIG = {
"sql_query": "SELECT * FROM ${input1};" # 需确保 ${input1} 是有效表名
}
print("Spark SQL 配置:", SPARK_SQL_CONFIG)
ES_CONFIG = {
"es_host": "http://tias8es.jcfwpt.cmbc.com.cn:30004",
"es_user": "admin",
"es_pass": "Cmbc1#tias",
"index_name": "gffcm_pfs_indicator",
"mode": "append", # 可选 "overwrite"
"es_shards": 5,
"es_replicas": 1,
"failure_threshold": 0.1 # 示例:允许 10% 失败率
}
# 执行数据抽取与写入
columns, rows_rdd = get_spark_sql_data(**SPARK_SQL_CONFIG)
if columns and rows_rdd:
es_insert_data(columns=columns, rows_rdd=rows_rdd, **ES_CONFIG)
else:
print("[ERROR] 未获取到有效 Spark 数据,同步终止")
spark.stop()
sys.exit(1)
结合这个代码,规避掉division on by zero 这报错
最新发布