待调整涨跌停标记后完成高质量数据加载及异常数据标记
# -*- coding: utf-8 -*-
import jqdata
import pandas as pd
import numpy as np
import logging
from jqdata import *
from datetime import datetime
from scipy.stats import iqr
from sklearn.ensemble import IsolationForest
import arch
from scipy.stats import gaussian_kde
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
class StockDataLoader:
def __init__(self):
self.market_cap_data = {}
self.special_events = self._create_special_events_db()
self.security_info_cache = {}
def _create_special_events_db(self):
events = {
'circuit_breaker': ['2016-01-04', '2016-01-07'],
'major_events': [
'2015-06-12', '2015-07-08', '2016-01-04', '2016-01-07',
'2018-03-23', '2019-05-06', '2020-01-23', '2020-02-03',
'2020-03-09', '2020-03-12', '2020-03-16', '2020-03-18',
'2020-07-06', '2021-01-04', '2021-02-18', '2021-07-26',
'2022-02-24', '2022-03-07', '2022-03-08', '2022-03-09',
'2022-04-25', '2022-10-24', '2022-10-28', '2023-01-30'
],
'black_swan': [
'2015-06-12', '2018-03-23', '2020-01-23',
'2020-03-09', '2021-07-26', '2022-02-24'
],
'extreme_market': [
'2015-06-19', '2015-06-26', '2015-07-03', '2015-07-09',
'2015-08-24', '2015-08-25', '2016-01-04', '2016-01-07',
'2018-10-11', '2019-05-06', '2020-02-03', '2020-07-06',
'2020-07-16', '2021-02-18', '2021-07-27', '2022-03-15',
'2022-04-25', '2022-10-24'
],
'gem_reform': [
'2020-08-24', '2020-08-27', '2020-09-09',
'2021-02-18', '2021-04-19', '2021-08-24'
]
}
events_date = {}
for key, date_list in events.items():
events_date[key] = [pd.Timestamp(date) for date in date_list]
return events_date
def get_security_info_cached(self, stock_code):
if stock_code not in self.security_info_cache:
self.security_info_cache[stock_code] = get_security_info(stock_code)
return self.security_info_cache[stock_code]
def is_special_event(self, date, event_type):
return date in self.special_events.get(event_type, [])
def get_all_stocks(self):
all_stocks = get_all_securities(types=['stock'], date='2012-12-31').index.tolist()
stocks = []
for stock in all_stocks:
stock_info = get_security_info(stock)
if stock_info:
start_date = pd.Timestamp(stock_info.start_date)
end_date = pd.Timestamp(stock_info.end_date) if stock_info.end_date else None
if start_date < pd.Timestamp('2013-01-01'):
if end_date is None or end_date > pd.Timestamp('2023-12-31'):
stocks.append(stock)
else:
logger.debug(f"排除退市股票: {stock}, 退市日期: {stock_info.end_date}")
return stocks
def filter_st_stocks(self, stocks, start_date, end_date):
logger.info(f"开始过滤ST股票 ({len(stocks)}只)")
non_st_stocks = []
batch_size = 100
trade_days = get_trade_days(start_date=start_date, end_date=end_date)
for i in range(0, len(stocks), batch_size):
batch = stocks[i:i+batch_size]
logger.info(f"正在检查ST状态批次 {i//batch_size+1}/{(len(stocks)-1)//batch_size+1}")
for stock in batch:
try:
st_status = get_extras('is_st', stock, start_date=start_date, end_date=end_date)
if st_status is None or st_status.dropna().empty:
non_st_stocks.append(stock)
else:
if st_status.iloc[:, 0].sum() == 0:
non_st_stocks.append(stock)
else:
st_dates = st_status[st_status.iloc[:, 0]].index.tolist()
logger.debug(f"过滤ST股票: {stock}, ST日期: {st_dates}")
except Exception as e:
logger.error(f"检查{stock}的ST状态失败: {str(e)}")
non_st_stocks.append(stock)
logger.info(f"过滤后剩余股票数量: {len(non_st_stocks)}")
return non_st_stocks
def get_market_cap(self, stock, date='2012-12-31'):
if stock not in self.market_cap_data:
try:
q = query(valuation).filter(valuation.code == stock)
df = get_fundamentals(q, date=date)
if not df.empty:
self.market_cap_data[stock] = df['market_cap'].iloc[0]
else:
self.market_cap_data[stock] = np.nan
except Exception as e:
logger.warning(f"获取{stock}市值失败: {str(e)}")
self.market_cap_data[stock] = np.nan
return self.market_cap_data[stock]
def categorize_stocks(self, stocks):
market_caps = []
for stock in stocks:
cap = self.get_market_cap(stock)
if not np.isnan(cap):
market_caps.append((stock, cap))
sorted_stocks = sorted(market_caps, key=lambda x: x[1], reverse=True)
total = len(sorted_stocks)
large_cap = [s[0] for s in sorted_stocks[:total//3]]
mid_cap = [s[0] for s in sorted_stocks[total//3:2*total//3]]
small_cap = [s[0] for s in sorted_stocks[2*total//3:]]
return large_cap, mid_cap, small_cap
def sample_stocks(self, large_cap, mid_cap, small_cap, n=100):
large_sample = np.random.choice(large_cap, min(n, len(large_cap)), replace=False) if large_cap else []
mid_sample = np.random.choice(mid_cap, min(n, len(mid_cap)), replace=False) if mid_cap else []
small_sample = np.random.choice(small_cap, min(n, len(small_cap)), replace=False) if small_cap else []
return list(large_sample) + list(mid_sample) + list(small_sample)
def calculate_price_limits(self, price_data):
price_data = price_data.copy()
unique_codes = price_data['code'].unique()
security_types = {
code: self.get_security_info_cached(code).type if self.get_security_info_cached(code) else 'normal'
for code in unique_codes
}
price_data['security_type'] = price_data['code'].map(security_types)
price_data['price_limit_threshold'] = 0.10
gem_mask = (price_data['security_type'] == 'gem') & (price_data['date'] >= '2020-08-24')
price_data.loc[gem_mask, 'price_limit_threshold'] = 0.20
ks_mask = price_data['security_type'] == 'ks'
price_data.loc[ks_mask, 'price_limit_threshold'] = 0.20
bj_mask = price_data['security_type'] == 'bj'
price_data.loc[bj_mask, 'price_limit_threshold'] = 0.30
price_data['up_limit'] = np.round(price_data['pre_close'] * (1 + price_data['price_limit_threshold']), 2)
price_data['down_limit'] = np.round(price_data['pre_close'] * (1 - price_data['price_limit_threshold']), 2)
price_data['up_limit_hit'] = (
(price_data['high'] >= price_data['up_limit'] - 0.015) &
(price_data['low'] <= price_data['up_limit'] + 0.015)
).astype(int)
price_data['down_limit_hit'] = (
(price_data['low'] <= price_data['down_limit'] + 0.015) &
(price_data['high'] >= price_data['down_limit'] - 0.015)
).astype(int)
price_data['limit_one_way'] = (
(price_data['up_limit_hit'] == 1) &
(price_data['low'] == price_data['high']) &
(price_data['open'] == price_data['close'])
).astype(int)
price_data.drop(['security_type', 'up_limit', 'down_limit'], axis=1, inplace=True)
return price_data
def mark_special_events_vectorized(self, price_data):
price_data = price_data.copy()
price_data['special_events'] = ''
for event_type in self.special_events.keys():
event_mask = price_data['date'].isin(self.special_events[event_type])
price_data.loc[event_mask, 'special_events'] = price_data.loc[event_mask, 'special_events'] + event_type + ','
price_data['special_events'] = price_data['special_events'].str.rstrip(',')
price_data['special_events'] = price_data['special_events'].replace('', np.nan)
return price_data
def mark_anomalies(self, price_data):
"""异常标记层:集成MAD、KDE、非参数波动率检测"""
if price_data.empty:
return price_data
# 仅分析正常交易日
valid_mask = (
(price_data['suspended'] == 0) &
(price_data['up_limit_hit'] == 0) &
(price_data['down_limit_hit'] == 0) &
price_data['special_events'].isna()
)
valid_data = price_data[valid_mask].copy()
if valid_data.empty:
return price_data
valid_data['return'] = np.log(valid_data['close'] / valid_data['pre_close'])
# 初始化异常标记列
price_data['mad_anomaly'] = 0
price_data['kde_anomaly'] = 0
price_data['vol_anomaly'] = 0
# MAD异常检测
for stock, group in valid_data.groupby('code'):
returns = group['return']
if len(returns) < 10:
continue
median = returns.median()
mad = np.median(np.abs(returns - median))
threshold = 5 * 1.4826 * mad
anomaly_mask = np.abs(returns - median) > threshold
anomaly_indices = group[anomaly_mask].index
price_data.loc[anomaly_indices, 'mad_anomaly'] = 1
# KDE异常检测
for stock, group in valid_data.groupby('code'):
X = group[['return', 'volume']].values
if len(X) < 20:
continue
X_norm = (X - X.mean(axis=0)) / X.std(axis=0)
return_kde = gaussian_kde(X_norm[:, 0])
volume_kde = gaussian_kde(X_norm[:, 1])
densities = return_kde(X_norm[:, 0]) * volume_kde(X_norm[:, 1])
threshold = np.percentile(densities, 1)
anomaly_mask = densities < threshold
anomaly_indices = group[anomaly_mask].index
price_data.loc[anomaly_indices, 'kde_anomaly'] = 1
# 非参数波动率异常检测
for stock, group in valid_data.groupby('code'):
returns = group['return']
if len(returns) < 20:
continue
realized_vol = returns.rolling(5).std()
realized_vol_no_na = realized_vol.dropna()
if len(realized_vol_no_na) == 0:
continue
med_vol = realized_vol_no_na.median()
mad_vol = np.median(np.abs(realized_vol_no_na - med_vol))
threshold = med_vol + 3 * 1.4826 * mad_vol
anomaly_mask = realized_vol > threshold
anomaly_indices = group[anomaly_mask].index
price_data.loc[anomaly_indices, 'vol_anomaly'] = 1
return price_data
def load_price_data(self, stocks, start_date, end_date):
trade_days = get_trade_days(start_date=start_date, end_date=end_date)
logger.info(f"交易日数量: {len(trade_days)}")
data_frames = []
batch_size = 100
total = len(stocks)
for i in range(0, total, batch_size):
batch = stocks[i:i+batch_size]
logger.info(f"加载股票批次 {i//batch_size+1}/{(total-1)//batch_size+1} ({len(batch)}只股票)")
try:
batch_data = get_price(
batch,
start_date=trade_days[0],
end_date=trade_days[-1],
fields=['open', 'close', 'high', 'low', 'volume', 'pre_close'],
frequency='daily',
panel=False,
skip_paused=False,
fq='pre',
fill_paused=True
)
if batch_data is None or batch_data.empty:
logger.warning(f"批次加载失败,跳过此批次")
continue
if 'time' in batch_data.columns:
batch_data.rename(columns={'time': 'date'}, inplace=True)
batch_data['date'] = pd.to_datetime(batch_data['date'])
batch_data['suspended'] = (batch_data['volume'] == 0).astype(int)
batch_data.sort_values(['code', 'date'], inplace=True)
batch_data['prev_suspended'] = batch_data.groupby('code')['suspended'].shift(1)
batch_data['resumption_first_day'] = ((batch_data['suspended'] == 0) &
(batch_data['prev_suspended'] == 1)).astype(int)
batch_data.drop('prev_suspended', axis=1, inplace=True)
logger.info("向量化计算涨跌停标记...")
batch_data = self.calculate_price_limits(batch_data)
logger.info("向量化标记特殊事件...")
batch_data = self.mark_special_events_vectorized(batch_data)
logger.info("执行异常标记层...")
batch_data = self.mark_anomalies(batch_data)
data_frames.append(batch_data)
except Exception as e:
logger.error(f"批次加载失败: {str(e)},跳过此批次")
if data_frames:
combined = pd.concat(data_frames)
return combined.set_index(['date', 'code']).sort_index()
return pd.DataFrame()
def run_data_validation(data):
"""重构后的数据验证函数"""
logger.info("\n" + "="*60)
logger.info("开始运行数据验证测试")
logger.info("="*60)
data_reset = data.reset_index()
results = {
'missing_values': {},
'limit_issues': 0,
'event_mark_issues': {},
'anomaly_mark_stats': {'mad': 0, 'kde': 0, 'vol': 0},
'special_cases': {'zero_volume': 0}
}
# 1. 检查缺失值
null_counts = data_reset.isnull().sum()
results['missing_values'] = {col: count for col, count in null_counts.items() if count > 0}
# 2. 检查涨跌停标记
# 获取涨停阈值
data_reset['calculated_threshold'] = np.round(
data_reset['pre_close'] * (1 + data_reset['price_limit_threshold']), 2
)
false_negatives = data_reset[
(data_reset['high'] >= data_reset['calculated_threshold'] - 0.015) &
(data_reset['up_limit_hit'] == 0) &
(data_reset['suspended'] == 0)
]
false_positives = data_reset[
(data_reset['high'] < data_reset['calculated_threshold'] - 0.015) &
(data_reset['up_limit_hit'] == 1) &
(data_reset['suspended'] == 0)
]
results['limit_issues'] = len(false_negatives) + len(false_positives)
# 3. 检查特殊事件标记
special_events_db = {
'circuit_breaker': ['2016-01-04', '2016-01-07'],
'major_events': [
'2015-06-12', '2015-07-08', '2016-01-04', '2016-01-07',
'2018-03-23', '2019-05-06', '2020-01-23', '2020-02-03',
'2020-03-09', '2020-03-12', '2016-03-16', '2020-03-18',
'2020-07-06', '2021-01-04', '2021-02-18', '2021-07-26',
'2022-02-24', '2022-03-07', '2022-03-08', '2022-03-09',
'2022-04-25', '2022-10-24', '2022-10-28', '2023-01-30'
],
'black_swan': [
'2015-06-12', '2018-03-23', '2020-01-23',
'2020-03-09', '2021-07-26', '2022-02-24'
],
'extreme_market': [
'2015-06-19', '2015-06-26', '2015-07-03', '2015-07-09',
'2015-08-24', '2015-08-25', '2016-01-04', '2016-01-07',
'2018-10-11', '2019-05-06', '2020-02-03', '2020-07-06',
'2020-07-16', '2021-02-18', '2021-07-27', '2022-03-15',
'2022-04-25', '2022-10-24'
],
'gem_reform': [
'2020-08-24', '2020-08-27', '2020-09-09',
'2021-02-18', '2021-04-19', '2021-08-24'
]
}
for event_type, date_list in special_events_db.items():
dates = [pd.Timestamp(date) for date in date_list]
marked = data_reset[data_reset['special_events'].str.contains(event_type, na=False)]
expected = len(dates)
actual = marked['date'].nunique()
results['event_mark_issues'][event_type] = abs(actual - expected)
# 4. 统计异常标记
results['anomaly_mark_stats']['mad'] = data_reset['mad_anomaly'].sum()
results['anomaly_mark_stats']['kde'] = data_reset['kde_anomaly'].sum()
results['anomaly_mark_stats']['vol'] = data_reset['vol_anomaly'].sum()
# 5. 检查特殊情况
# 非停牌日零值
results['special_cases']['zero_volume'] = data_reset[
(data_reset['volume'] == 0) &
(data_reset['suspended'] == 0)
].shape[0]
# 输出验证结果
logger.info("验证结果:")
if results['missing_values']:
logger.warning(f"⚠️ 缺失值: {results['missing_values']}")
else:
logger.info("✅ 无缺失值")
if results['limit_issues'] > 0:
logger.warning(f"⚠️ 涨跌停标记问题: {results['limit_issues']}处")
else:
logger.info("✅ 涨跌停标记正确")
for event_type, issues in results['event_mark_issues'].items():
if issues > 0:
logger.warning(f"⚠️ {event_type}事件标记不匹配: 差异{issues}处")
else:
logger.info(f"✅ {event_type}事件标记正确")
logger.info(f"异常标记统计 - MAD: {results['anomaly_mark_stats']['mad']}, "
f"KDE: {results['anomaly_mark_stats']['kde']}, "
f"波动率: {results['anomaly_mark_stats']['vol']}")
if results['special_cases']['zero_volume'] > 0:
logger.warning(f"⚠️ 非停牌日零成交量: {results['special_cases']['zero_volume']}处")
else:
logger.info("✅ 无非停牌日零成交量问题")
logger.info("="*60)
logger.info("数据验证测试完成")
logger.info("="*60)
return results
def main():
logger.info("="*60)
logger.info("开始执行数据加载")
logger.info("="*60)
loader = StockDataLoader()
logger.info("获取2013年前上市且在2013~2023年间未退市的股票...")
all_stocks = loader.get_all_stocks()
logger.info(f"共找到{len(all_stocks)}只符合条件的股票")
logger.info("过滤2014~2023年间被ST的股票...")
non_st_stocks = loader.filter_st_stocks(all_stocks, '2014-01-01', '2023-12-31')
logger.info(f"过滤后剩余股票: {len(non_st_stocks)}")
logger.info("按市值分组...")
large_cap, mid_cap, small_cap = loader.categorize_stocks(non_st_stocks)
logger.info(f"分组完成: 大盘股({len(large_cap)}), 中盘股({len(mid_cap)}), 微盘股({len(small_cap)})")
logger.info("随机抽取股票...")
sampled_stocks = loader.sample_stocks(large_cap, mid_cap, small_cap, n=100)
logger.info(f"抽样完成: 共选取{len(sampled_stocks)}只股票")
logger.info("开始加载2014-2023年价格数据(前复权)...")
price_data = loader.load_price_data(sampled_stocks, '2014-01-01', '2023-12-31')
if price_data.empty:
logger.error("数据加载失败,无有效数据")
return
logger.info(f"数据加载完成,共{len(price_data)}条记录")
run_data_validation(price_data)
logger.info("="*60)
logger.info("数据加载和测试完成")
logger.info("="*60)
if __name__ == "__main__":
main()
当前代码有什么问题?
最新发布