在量化投资领域,建立一个高效、稳定的数据处理和分析工具链是成功的基础。本项目旨在帮助学习者构建从数据采集、清洗、存储到特征提取、分析的完整流程,为后续策略开发奠定坚实基础。
学习目标:
class DataSource:
"""数据源基类,定义标准接口"""
def __init__(self, config):
"""
初始化数据源
参数:
config (dict): 配置信息,包含连接参数、认证信息等
"""
self.config = config
self.connection = None
def connect(self):
"""建立与数据源的连接"""
raise NotImplementedError
def disconnect(self):
"""断开与数据源的连接"""
raise NotImplementedError
def get_data(self, query_params):
"""
获取数据
参数:
query_params (dict): 查询参数
返回:
pd.DataFrame: 查询结果
"""
raise NotImplementedError
def check_health(self):
"""检查数据源连接状态"""
raise NotImplementedError
class ETLPipeline:
"""ETL (Extract-Transform-Load) 流程框架"""
def __init__(self, data_sources, transformers, data_store):
"""
初始化ETL流程
参数:
data_sources (list): 数据源列表
transformers (list): 数据转换器列表
data_store (DataStore): 数据存储对象
"""
self.data_sources = data_sources
self.transformers = transformers
self.data_store = data_store
self.logger = self._setup_logger()
def _setup_logger(self):
"""设置日志记录器"""
# 日志配置代码
pass
def extract(self, query_params):
"""
从数据源提取数据
参数:
query_params (dict): 查询参数
返回:
dict: 键为数据源ID,值为相应的DataFrame
"""
results = {}
for source in self.data_sources:
try:
source.connect()
data = source.get_data(query_params)
results[source.id] = data
self.logger.info(f"从数据源 {source.id} 成功提取 {len(data)} 条记录")
except Exception as e:
self.logger.error(f"从数据源 {source.id} 提取数据时发生错误: {str(e)}")
finally:
source.disconnect()
return results
def transform(self, data_dict):
"""
转换数据
参数:
data_dict (dict): 提取的原始数据
返回:
dict: 转换后的数据
"""
transformed = {}
for source_id, data in data_dict.items():
curr_data = data.copy()
for transformer in self.transformers:
try:
curr_data = transformer.transform(curr_data)
self.logger.info(f"使用转换器 {transformer.name} 成功转换数据源 {source_id} 的数据")
except Exception as e:
self.logger.error(f"转换数据源 {source_id} 的数据时发生错误: {str(e)}")
transformed[source_id] = curr_data
return transformed
def load(self, transformed_data):
"""
加载数据到存储系统
参数:
transformed_data (dict): 转换后的数据
返回:
bool: 操作是否成功
"""
try:
self.data_store.connect()
for source_id, data in transformed_data.items():
self.data_store.save(data, source_id)
self.logger.info(f"成功加载数据源 {source_id} 的 {len(data)} 条记录到存储系统")
return True
except Exception as e:
self.logger.error(f"加载数据到存储系统时发生错误: {str(e)}")
return False
finally:
self.data_store.disconnect()
def run(self, query_params):
"""
运行完整ETL流程
参数:
query_params (dict): 查询参数
返回:
bool: 操作是否成功
"""
extracted = self.extract(query_params)
if not extracted:
self.logger.warning("未提取到任何数据,ETL流程中止")
return False
transformed = self.transform(extracted)
if not transformed:
self.logger.warning("数据转换失败,ETL流程中止")
return False
loaded = self.load(transformed)
return loaded
def check_missing_values(df, threshold=0.05):
"""
检查数据框中的缺失值
参数:
df (pd.DataFrame): 待检查的数据框
threshold (float): 可接受的缺失值比例阈值
返回:
dict: 每列缺失值统计和整体评估结果
"""
missing_counts = df.isnull().sum()
missing_ratio = missing_counts / len(df)
columns_above_threshold = missing_ratio[missing_ratio > threshold].index.tolist()
return {
'missing_counts': missing_counts.to_dict(),
'missing_ratio': missing_ratio.to_dict(),
'columns_above_threshold': columns_above_threshold,
'passed': len(columns_above_threshold) == 0
}
def check_duplicates(df, subset=None):
"""
检查数据框中的重复记录
参数:
df (pd.DataFrame): 待检查的数据框
subset (list): 用于识别重复的列子集
返回:
dict: 重复记录统计和示例
"""
duplicates = df.duplicated(subset=subset, keep='first')
duplicate_indices = duplicates[duplicates].index.tolist()
return {
'duplicate_count': sum(duplicates),
'duplicate_ratio': sum(duplicates) / len(df) if len(df) > 0 else 0,
'duplicate_indices': duplicate_indices[:10], # 只返回前10个示例
'passed': sum(duplicates) == 0
}
def check_data_range(df, numeric_ranges=None, categorical_values=None):
"""
检查数据值是否在预期范围内
参数:
df (pd.DataFrame): 待检查的数据框
numeric_ranges (dict): 数值列的有效范围,格式为 {'column_name': (min, max)}
categorical_values (dict): 分类列的有效值集合,格式为 {'column_name': set(valid_values)}
返回:
dict: 范围检查结果
"""
results = {'numeric_columns': {}, 'categorical_columns': {}, 'passed': True}
# 检查数值范围
if numeric_ranges:
for column, (min_val, max_val) in numeric_ranges.items():
if column in df.columns:
out_of_range = df[(df[column] < min_val) | (df[column] > max_val)]
results['numeric_columns'][column] = {
'out_of_range_count': len(out_of_range),
'out_of_range_ratio': len(out_of_range) / len(df) if len(df) > 0 else 0,
'min_value': df[column].min(),
'max_value': df[column].max(),
'passed': len(out_of_range) == 0
}
if len(out_of_range) > 0:
results['passed'] = False
# 检查分类值
if categorical_values:
for column, valid_values in categorical_values.items():
if column in df.columns:
invalid_values = df[~df[column].isin(valid_values)]
results['categorical_columns'][column] = {
'invalid_count': len(invalid_values),
'invalid_ratio': len(invalid_values) / len(df) if len(df) > 0 else 0,
'unique_values': df[column].unique().tolist(),
'passed': len(invalid_values) == 0
}
if len(invalid_values) > 0:
results['passed'] = False
return results
def check_data_consistency(df, consistency_rules):
"""
检查数据一致性规则
参数:
df (pd.DataFrame): 待检查的数据框
consistency_rules (list): 一致性规则列表,每条规则是一个函数,接受df作为参数并返回布尔值
返回:
dict: 一致性检查结果
"""
results = {'rule_results': {}, 'passed': True}
for i, rule in enumerate(consistency_rules):
rule_name = getattr(rule, '__name__', f'rule_{i}')
try:
rule_passed = rule(df)
results['rule_results'][rule_name] = rule_passed
if not rule_passed:
results['passed'] = False
except Exception as e:
results['rule_results'][rule_name] = {
'error': str(e),
'passed': False
}
results['passed'] = False
return results
def run_data_quality_checks(df, checks_config):
"""
运行一组数据质量检查
参数:
df (pd.DataFrame): 待检查的数据框
checks_config (dict): 检查配置
返回:
dict: 所有质量检查的综合结果
"""
results = {}
overall_passed = True
# 缺失值检查
if 'missing_values' in checks_config:
threshold = checks_config['missing_values'].get('threshold', 0.05)
missing_check = check_missing_values(df, threshold)
results['missing_values'] = missing_check
overall_passed &= missing_check['passed']
# 重复记录检查
if 'duplicates' in checks_config:
subset = checks_config['duplicates'].get('subset', None)
duplicate_check = check_duplicates(df, subset)
results['duplicates'] = duplicate_check
overall_passed &= duplicate_check['passed']
# 数据范围检查
if 'data_range' in checks_config:
numeric_ranges = checks_config['data_range'].get('numeric_ranges', None)
categorical_values = checks_config['data_range'].get('categorical_values', None)
range_check = check_data_range(df, numeric_ranges, categorical_values)
results['data_range'] = range_check
overall_passed &= range_check['passed']
# 数据一致性检查
if 'consistency' in checks_config:
consistency_rules = checks_config['consistency'].get('rules', [])
consistency_check = check_data_consistency(df, consistency_rules)
results['consistency'] = consistency_check
overall_passed &= consistency_check['passed']
results['overall_passed'] = overall_passed
return results
class DataStore:
"""数据存储基类,定义标准接口"""
def __init__(self, config):
"""
初始化数据存储
参数:
config (dict): 配置信息,包含连接参数、存储路径等
"""
self.config = config
self.connection = None
def connect(self):
"""建立与存储系统的连接"""
raise NotImplementedError
def disconnect(self):
"""断开与存储系统的连接"""
raise NotImplementedError
def save(self, data, dataset_id):
"""
保存数据
参数:
data (pd.DataFrame): 待保存的数据
dataset_id (str): 数据集标识符
返回:
bool: 操作是否成功
"""
raise NotImplementedError
def load(self, dataset_id, query_params=None):
"""
加载数据
参数:
dataset_id (str): 数据集标识符
query_params (dict): 查询参数
返回:
pd.DataFrame: 加载的数据
"""
raise NotImplementedError
def delete(self, dataset_id, query_params=None):
"""
删除数据
参数:
dataset_id (str): 数据集标识符
query_params (dict): 查询参数,指定要删除的记录
返回:
bool: 操作是否成功
"""
raise NotImplementedError
def list_datasets(self):
"""
列出所有数据集
返回:
list: 数据集标识符列表
"""
raise NotImplementedError
# 示例分层数据存储结构
class TimeSeriesDataStore(DataStore):
"""
针对时间序列数据的分层存储实现
存储结构:
- 原始数据层 (raw)
- 清洗数据层 (cleaned)
- 特征数据层 (features)
- 分析结果层 (results)
"""
def __init__(self, config):
super().__init__(config)
self.base_path = config.get('base_path', './data')
self.layers = {
'raw': os.path.join(self.base_path, 'raw'),
'cleaned': os.path.join(self.base_path, 'cleaned'),
'features': os.path.join(self.base_path, 'features'),
'results': os.path.join(self.base_path, 'results')
}
self._ensure_directories()
def _ensure_directories(self):
"""确保存储目录存在"""
for path in self.layers.values():
os.makedirs(path, exist_ok=True)
def connect(self):
"""对于文件系统存储,连接操作简化为检查目录访问权限"""
for path in self.layers.values():
if not os.access(path, os.W_OK):
raise PermissionError(f"无法写入目录: {path}")
return True
def disconnect(self):
"""断开连接,对于文件系统无特殊操作"""
return True
def save(self, data, dataset_id, layer='raw'):
"""
保存数据到指定层
参数:
data (pd.DataFrame): 待保存的数据
dataset_id (str): 数据集标识符
layer (str): 目标存储层 ('raw', 'cleaned', 'features', 'results')
返回:
bool: 操作是否成功
"""
if layer not in self.layers:
raise ValueError(f"无效的存储层: {layer}")
file_path = os.path.join(self.layers[layer], f"{dataset_id}.parquet")
try:
data.to_parquet(file_path, index=True)
return True
except Exception as e:
print(f"保存数据失败: {str(e)}")
return False
def load(self, dataset_id, layer='raw', query_params=None):
"""
从指定层加载数据
参数:
dataset_id (str): 数据集标识符
layer (str): 目标存储层 ('raw', 'cleaned', 'features', 'results')
query_params (dict): 包含过滤条件
返回:
pd.DataFrame: 加载的数据
"""
if layer not in self.layers:
raise ValueError(f"无效的存储层: {layer}")
file_path = os.path.join(self.layers[layer], f"{dataset_id}.parquet")
if not os.path.exists(file_path):
raise FileNotFoundError(f"找不到数据集文件: {file_path}")
# 基本加载
data = pd.read_parquet(file_path)
# 应用查询过滤
if query_params:
# 日期范围过滤
if 'start_date' in query_params and 'date_column' in query_params:
date_col = query_params['date_column']
data = data[data[date_col] >= query_params['start_date']]
if 'end_date' in query_params and 'date_column' in query_params:
date_col = query_params['date_column']
data = data[data[date_col] <= query_params['end_date']]
# 列选择
if 'columns' in query_params:
columns = [col for col in query_params['columns'] if col in data.columns]
data = data[columns]
return data
def list_datasets(self, layer='raw'):
"""
列出指定层的所有数据集
参数:
layer (str): 目标存储层 ('raw', 'cleaned', 'features', 'results')
返回:
list: 数据集标识符列表
"""
if layer not in self.layers:
raise ValueError(f"无效的存储层: {layer}")
path = self.layers[layer]
files = [f for f in os.listdir(path) if f.endswith('.parquet')]
return [os.path.splitext(f)[0] for f in files]
要求:
参考实现框架:
import pandas as pd
# DataSource基类已提供,请勿修改
class DataSource:
"""数据源的抽象基类"""
def __init__(self, config):
"""
初始化数据源
参数:
config (dict): 配置参数
"""
self.config = config
def connect(self):
"""
连接到数据源
返回:
bool: 连接是否成功
"""
raise NotImplementedError("子类必须实现connect方法")
def disconnect(self):
"""
断开与数据源的连接
返回:
bool: 断开连接是否成功
"""
raise NotImplementedError("子类必须实现disconnect方法")
def get_data(self, query_params):
"""
从数据源获取数据
参数:
query_params: 查询参数
返回:
pd.DataFrame: 获取的数据
"""
raise NotImplementedError("子类必须实现get_data方法")
# 任务1: 实现SQLDataSource类
class SQLDataSource(DataSource):
"""从SQL数据库获取数据的数据源实现"""
def __init__(self, config):
# 在这里初始化SQL数据源
# 提示: 你需要从config中获取数据库连接信息
pass
def connect(self):
"""建立数据库连接"""
# 在这里实现数据库连接逻辑
pass
def disconnect(self):
"""关闭数据库连接"""
# 在这里实现关闭数据库连接的逻辑
pass
def get_data(self, query_params):
"""
执行SQL查询并获取数据
参数:
query_params (dict): 包含SQL查询的参数
返回:
pd.DataFrame: 查询结果
"""
# 在这里实现执行SQL查询并返回DataFrame的逻辑
pass
# 任务2: 实现CSVDataSource类
class CSVDataSource(DataSource):
"""从CSV文件获取数据的数据源实现"""
def __init__(self, config):
# 在这里初始化CSV数据源
# 提示: 你需要从config中获取文件路径等信息
pass
def connect(self):
"""验证CSV文件是否可访问"""
# 在这里实现验证CSV文件的逻辑
pass
def disconnect(self):
"""CSV数据源的断开连接操作"""
# 在这里实现断开连接的逻辑(如果需要的话)
pass
def get_data(self, query_params):
"""
从CSV文件读取数据
参数:
query_params (dict): 可能包含过滤、选择列等参数
返回:
pd.DataFrame: CSV数据
"""
# 在这里实现读取CSV并返回DataFrame的逻辑
pass
# 任务3: 实现WebScraperDataSource类
class WebScraperDataSource(DataSource):
"""通过网络爬虫获取数据的数据源实现"""
def __init__(self, config):
# 在这里初始化网络爬虫数据源
# 提示: 你需要从config中获取URL等信息
pass
def connect(self):
"""准备爬虫环境"""
# 在这里实现准备爬虫环境的逻辑
pass
def disconnect(self):
"""清理爬虫资源"""
# 在这里实现清理爬虫资源的逻辑
pass
def get_data(self, query_params):
"""
爬取网页并提取数据
参数:
query_params (dict): 爬虫参数,可能包含选择器等
返回:
pd.DataFrame: 爬取的数据
"""
# 在这里实现爬取网页并返回DataFrame的逻辑
pass
数据源连接池管理参考:
# 已经实现的DataSource类及其子类在此基础上进行操作
# 任务: 实现DataSourcePool类
class DataSourcePool:
"""数据源连接池,管理多个数据源的连接生命周期"""
def __init__(self, source_configs):
"""
初始化数据源连接池
参数:
source_configs (dict): 数据源配置,格式为 {'source_id': {'type': 'source_type', 'config': {...}}}
"""
# 在这里初始化数据源池
# 你需要:
# 1. 初始化sources字典用于存储数据源实例
# 2. 保存source_configs配置
# 3. 定义数据源类型映射
pass
def get_source(self, source_id):
"""
获取数据源实例,如果不存在则创建
参数:
source_id (str): 数据源ID
返回:
DataSource: 数据源实例
"""
# 在这里实现获取数据源的逻辑
# 需要:
# 1. 检查源是否已存在
# 2. 如果不存在,则创建新的实例
# 3. 处理错误情况
pass
def release_source(self, source_id):
"""
释放数据源实例
参数:
source_id (str): 数据源ID
"""
# 在这里实现释放数据源的逻辑
# 需要:
# 1. 断开连接
# 2. 从sources字典中删除
pass
def release_all(self):
"""释放所有数据源实例"""
# 在这里实现释放所有数据源的逻辑
pass
async def get_data_async(self, source_id, query_params):
"""
异步获取数据
参数:
source_id (str): 数据源ID
query_params (dict): 查询参数
返回:
pd.DataFrame: 查询结果
"""
# 在这里实现异步获取数据的逻辑
# 需要:
# 1. 获取数据源
# 2. 使用asyncio将同步操作包装为异步
# 3. 处理连接和断开
pass
# 使用示例 (取消注释后可以测试你的实现)
"""
if __name__ == "__main__":
# 创建数据源配置
configs = {
'sales_db': {
'type': 'sql',
'config': {
'host': 'localhost',
'database': 'sales',
'user': 'user',
'password': 'password'
}
},
'products_api': {
'type': 'api',
'config': {
'base_url': 'https://api.example.com/v1',
'api_key': 'your_api_key'
}
}
}
# 创建数据源池
pool = DataSourcePool(configs)
# 获取数据源
sales_source = pool.get_source('sales_db')
# 获取数据
sales_data = sales_source.get_data({'query': 'SELECT * FROM sales'})
# 释放所有数据源
pool.release_all()
"""
要求:
参考实现框架:
import pandas as pd
import numpy as np
class DataTransformer:
"""数据转换器基类"""
def __init__(self, name, config=None):
"""初始化转换器"""
self.name = name
self.config = config or {}
self.is_fitted = False
def transform(self, df):
"""转换数据"""
# 子类需要实现此方法
raise NotImplementedError
def fit(self, df):
"""从数据学习转换参数"""
self.is_fitted = True
return self
def fit_transform(self, df):
"""学习并应用转换"""
return self.fit(df).transform(df)
class MissingValueHandler(DataTransformer):
"""处理缺失值的转换器"""
def __init__(self, strategy='mean', columns=None, fill_value=None):
"""
初始化缺失值处理器
参数:
strategy: 填充策略 ('mean', 'median', 'mode', 'constant', 'ffill', 'bfill')
columns: 要处理的列,默认为所有列
fill_value: 当strategy='constant'时使用的填充值
"""
super().__init__('missing_value_handler')
self.strategy = strategy
self.columns = columns
self.fill_value = fill_value
self.fill_dict = {} # 存储各列的填充值
def fit(self, df):
"""学习各列的填充值"""
# TODO: 学习者实现此处逻辑
# 提示: 根据self.strategy为每列计算适当的填充值
self.is_fitted = True
return self
def transform(self, df):
"""应用填充值处理缺失值"""
# TODO: 学习者实现此处逻辑
# 提示: 使用self.fill_dict或其他方式填充缺失值
return df
class OutlierHandler(DataTransformer):
"""处理异常值的转换器"""
def __init__(self, method='z_score', threshold=3.0, columns=None):
"""
初始化异常值处理器
参数:
method: 检测方法 ('z_score', 'iqr', 'percentile')
threshold: 异常值阈值
columns: 要处理的列,默认为所有数值列
"""
super().__init__('outlier_handler')
self.method = method
self.threshold = threshold
self.columns = columns
self.bounds = {} # 存储各列的边界值
def fit(self, df):
"""学习异常值边界"""
# TODO: 学习者实现此处逻辑
return self
def transform(self, df):
"""应用异常值处理"""
# TODO: 学习者实现此处逻辑
return df
class DataNormalizer(DataTransformer):
"""数据归一化转换器"""
def __init__(self, method='min_max', columns=None):
"""
初始化归一化转换器
参数:
method: 归一化方法 ('min_max', 'max_abs')
columns: 要处理的列,默认为所有数值列
"""
super().__init__('data_normalizer')
self.method = method
self.columns = columns
self.params = {} # 存储归一化参数
def fit(self, df):
"""学习归一化参数"""
# TODO: 学习者实现此处逻辑
return self
def transform(self, df):
"""应用归一化转换"""
# TODO: 学习者实现此处逻辑
return df
class DataStandardizer(DataTransformer):
"""数据标准化转换器"""
def __init__(self, columns=None):
"""
初始化标准化转换器
参数:
columns: 要处理的列,默认为所有数值列
"""
super().__init__('data_standardizer')
self.columns = columns
self.means = {}
self.stds = {}
def fit(self, df):
"""学习标准化参数"""
# TODO: 学习者实现此处逻辑
return self
def transform(self, df):
"""应用标准化转换"""
# TODO: 学习者实现此处逻辑
return df
class DataValidator(DataTransformer):
"""数据验证器"""
def __init__(self, rules=None):
"""
初始化数据验证器
参数:
rules: 验证规则字典,键为列名,值为验证函数或条件
"""
super().__init__('data_validator')
self.rules = rules or {}
self.validation_results = {}
def transform(self, df):
"""验证数据并返回原始数据框"""
# TODO: 学习者实现此处逻辑
return df
class Pipeline:
"""转换器流水线"""
def __init__(self, transformers=None):
"""
初始化转换器流水线
参数:
transformers: 转换器列表
"""
self.transformers = transformers or []
def add_transformer(self, transformer):
"""添加转换器到流水线"""
self.transformers.append(transformer)
return self
def fit(self, df):
"""拟合所有转换器"""
data = df.copy()
for transformer in self.transformers:
transformer.fit(data)
return self
def transform(self, df):
"""应用所有转换器"""
data = df.copy()
for transformer in self.transformers:
data = transformer.transform(data)
return data
def fit_transform(self, df):
"""拟合并应用所有转换器"""
return self.fit(df).transform(df)
# 使用示例
def example():
# 创建转换器
missing_handler = MissingValueHandler(strategy='mean', columns=['age', 'salary'])
outlier_handler = OutlierHandler(method='z_score')
normalizer = DataNormalizer(method='min_max')
# 创建流水线
pipeline = Pipeline()
pipeline.add_transformer(missing_handler)
pipeline.add_transformer(outlier_handler)
pipeline.add_transformer(normalizer)
# 应用流水线
# df = pd.read_csv('data.csv')
# result = pipeline.fit_transform(df)
要求:
参考实现框架:
import pandas as pd
import numpy as np
class FeatureGenerator:
"""特征生成器基类"""
def __init__(self, name, config=None):
"""初始化特征生成器"""
self.name = name
self.config = config or {}
def generate(self, df):
"""生成特征"""
# 子类需要实现此方法
raise NotImplementedError
class TimeSeriesFeatures(FeatureGenerator):
"""时间序列特征提取器"""
def __init__(self, date_column, features=None):
"""
初始化时间序列特征提取器
参数:
date_column: 日期列名
features: 要生成的特征列表,可选项包括
['year', 'month', 'day', 'dayofweek', 'quarter',
'is_month_start', 'is_month_end', 'is_quarter_start',
'is_quarter_end', 'is_year_start', 'is_year_end']
"""
super().__init__('time_series_features')
self.date_column = date_column
self.features = features or ['year', 'month', 'day', 'dayofweek']
def generate(self, df):
"""生成时间特征"""
# TODO: 学习者实现此处逻辑
# 提示: 确保日期列为datetime类型,然后提取需要的时间特征
result = df.copy()
return result
class TechnicalIndicators(FeatureGenerator):
"""金融技术指标计算器"""
def __init__(self, price_column='close', volume_column=None, indicators=None):
"""
初始化技术指标计算器
参数:
price_column: 价格列名
volume_column: 成交量列名
indicators: 要计算的指标列表,如['sma', 'ema', 'rsi', 'macd', 'bbands']
"""
super().__init__('technical_indicators')
self.price_column = price_column
self.volume_column = volume_column
self.indicators = indicators or ['sma', 'ema', 'rsi']
def generate(self, df):
"""计算技术指标"""
# TODO: 学习者实现此处逻辑
# 提示: 使用pandas_ta或自行实现各类技术指标的计算
result = df.copy()
return result
class CrossFeatures(FeatureGenerator):
"""交叉特征生成器"""
def __init__(self, feature_pairs=None, operations=None):
"""
初始化交叉特征生成器
参数:
feature_pairs: 要交叉的特征对列表,如[('f1', 'f2'), ...]
operations: 要执行的操作列表,如['add', 'subtract', 'multiply', 'divide']
"""
super().__init__('cross_features')
self.feature_pairs = feature_pairs or []
self.operations = operations or ['add', 'subtract', 'multiply', 'divide']
def generate(self, df):
"""生成交叉特征"""
# TODO: 学习者实现此处逻辑
# 提示: 对每对特征执行指定的数学运算,生成新特征
result = df.copy()
return result
class FeatureSelector:
"""特征选择工具"""
def __init__(self, target_column, method='importance'):
"""
初始化特征选择器
参数:
target_column: 目标列名
method: 选择方法,可选项为'correlation', 'importance', 'mutual_info'
"""
self.target_column = target_column
self.method = method
self.selected_features = None
def select(self, df, n_features=None, threshold=None):
"""
选择特征
参数:
df: 输入数据框
n_features: 要选择的特征数量
threshold: 选择特征的阈值
返回:
选择的特征列表
"""
# TODO: 学习者实现此处逻辑
# 提示: 根据选择方法计算特征的重要性并选择最重要的特征
if self.method == 'correlation':
# 实现基于相关性的特征选择
pass
elif self.method == 'importance':
# 实现基于特征重要性的选择(如随机森林)
pass
elif self.method == 'mutual_info':
# 实现基于互信息的特征选择
pass
# 设置选择的特征
self.selected_features = []
return self.selected_features
def transform(self, df):
"""应用特征选择"""
if self.selected_features is None:
raise ValueError("必须先调用select方法")
# 返回选择的特征和目标变量(如果存在)
columns = self.selected_features.copy()
if self.target_column in df.columns:
columns.append(self.target_column)
return df[columns]
def fit_transform(self, df, n_features=None, threshold=None):
"""选择特征并应用"""
self.select(df, n_features, threshold)
return self.transform(df)
class LagFeatures(FeatureGenerator):
"""滞后特征生成器"""
def __init__(self, columns, lag_periods=None):
"""
初始化滞后特征生成器
参数:
columns: 要创建滞后特征的列列表
lag_periods: 滞后期数列表,如[1, 2, 3, 5, 7]
"""
super().__init__('lag_features')
self.columns = columns if isinstance(columns, list) else [columns]
self.lag_periods = lag_periods or [1, 2, 3]
def generate(self, df):
"""生成滞后特征"""
# TODO: 学习者实现此处逻辑
# 提示: 为指定的列创建指定滞后期数的特征
result = df.copy()
return result
class WindowFeatures(FeatureGenerator):
"""窗口统计特征生成器"""
def __init__(self, columns, window_sizes=None, functions=None):
"""
初始化窗口统计特征生成器
参数:
columns: 要创建窗口特征的列列表
window_sizes: 窗口大小列表,如[3, 5, 7]
functions: 窗口函数列表,如['mean', 'std', 'min', 'max']
"""
super().__init__('window_features')
self.columns = columns if isinstance(columns, list) else [columns]
self.window_sizes = window_sizes or [3, 5, 7]
self.functions = functions or ['mean', 'std']
def generate(self, df):
"""生成窗口统计特征"""
# TODO: 学习者实现此处逻辑
# 提示: 对每列计算窗口统计特征
result = df.copy()
return result
class FeaturePipeline:
"""特征生成流水线"""
def __init__(self, generators=None):
"""
初始化特征生成流水线
参数:
generators: 特征生成器列表
"""
self.generators = generators or []
def add_generator(self, generator):
"""添加特征生成器到流水线"""
self.generators.append(generator)
return self
def generate(self, df):
"""依次应用所有特征生成器"""
result = df.copy()
for generator in self.generators:
result = generator.generate(result)
return result
# 使用示例
def example():
# 创建特征生成器
time_features = TimeSeriesFeatures(date_column='date')
tech_indicators = TechnicalIndicators(price_column='close', volume_column='volume')
cross_features = CrossFeatures(feature_pairs=[('close', 'volume'), ('high', 'low')])
lag_features = LagFeatures(columns=['close', 'volume'], lag_periods=[1, 2, 3])
# 创建流水线
pipeline = FeaturePipeline()
pipeline.add_generator(time_features)
pipeline.add_generator(tech_indicators)
pipeline.add_generator(cross_features)
pipeline.add_generator(lag_features)
# 应用流水线
# df = pd.read_csv('data.csv')
# enhanced_df = pipeline.generate(df)
# 特征选择
# selector = FeatureSelector(target_column='target', method='importance')
# selected_df = selector.fit_transform(enhanced_df, n_features=20)
要求:
参考实现框架:
class PipelineNode:
"""数据分析流水线的节点"""
def __init__(self, name, processor, input_keys=None, output_keys=None):
"""
初始化流水线节点
参数:
name (str): 节点名称
processor (callable): 处理函数或对象,必须实现__call__(data)方法
input_keys (list): 输入数据的键列表
output_keys (list): 输出数据的键列表
"""
# TODO: 实现节点初始化
pass
def process(self, data_dict):
"""
处理输入数据
参数:
data_dict (dict): 输入数据字典
返回:
dict: 输出数据字典
"""
# TODO: 实现数据处理逻辑
# 1. 从data_dict中获取输入数据
# 2. 调用processor处理数据
# 3. 整理输出数据并返回
pass
class Pipeline:
"""数据分析流水线"""
def __init__(self, name='DataPipeline'):
"""
初始化流水线
参数:
name (str): 流水线名称
"""
# TODO: 实现流水线初始化
pass
def _setup_logger(self):
"""设置日志记录器"""
# TODO: 配置并返回日志记录器
pass
def add_node(self, node):
"""
添加节点到流水线
参数:
node (PipelineNode): 要添加的节点
返回:
自身,支持链式调用
"""
# TODO: 实现添加节点功能
pass
def run(self, initial_data=None):
"""
运行流水线
参数:
initial_data (dict): 初始输入数据
返回:
dict: 流水线处理结果
"""
# TODO: 实现流水线运行逻辑
# 1. 按顺序执行每个节点
# 2. 处理异常情况
# 3. 返回最终结果
pass
def visualize(self, format='text'):
"""
可视化流水线结构
参数:
format (str): 输出格式 ('text', 'html', 'graph')
返回:
输出结果,格式取决于format参数
"""
# TODO: 实现流水线可视化功能
# 根据不同format参数生成不同形式的可视化
pass
def generate_report(self, results, format='html'):
"""
生成分析报告
参数:
results (dict): 流水线运行结果
format (str): 报告格式 ('html', 'markdown', 'json')
返回:
str: 生成的报告
"""
# TODO: 实现报告生成功能
# 根据不同format参数生成不同格式的报告
pass
要求:
参考实现框架:
class IncrementalDataProcessor:
"""增量数据处理器"""
def __init__(self, data_store, versioning=True):
"""
初始化增量数据处理器
参数:
data_store (DataStore): 数据存储对象
versioning (bool): 是否启用版本控制
"""
# TODO: 初始化增量数据处理器
# 1. 存储data_store引用
# 2. 设置versioning标志
# 3. 初始化版本历史记录字典
pass
def _get_latest_version(self, dataset_id):
"""获取最新版本号"""
# TODO: 获取指定数据集的最新版本号
# 如果数据集不存在于版本历史中,返回0
pass
def _generate_version_id(self, dataset_id, version):
"""生成带版本的数据集ID"""
# TODO: 生成格式为"{dataset_id}_v{version}"的版本化数据集ID
pass
def _save_with_version(self, data, dataset_id, metadata=None):
"""保存数据并管理版本"""
# TODO: 实现带版本控制的数据保存功能
# 1. 如果不启用版本控制,直接保存数据
# 2. 如果启用版本控制,获取新版本号并生成版本化ID
# 3. 保存数据并更新版本历史记录
# 4. 用最新数据更新当前版本
pass
def detect_changes(self, current_data, new_data, key_columns, change_detection_columns=None):
"""
检测两个数据集之间的变更
参数:
current_data (pd.DataFrame): 当前数据集
new_data (pd.DataFrame): 新数据集
key_columns (list): 用于识别记录的键列
change_detection_columns (list): 用于检测变更的列,默认为除键列外的所有列
返回:
dict: 变更摘要
"""
# TODO: 实现数据变更检测功能
# 1. 确定用于检测变更的列
# 2. 设置键列作为索引以便比较
# 3. 找出新增、删除和共有的键
# 4. 对于共有的键,检测列值的变更
# 5. 返回变更统计信息
pass
def merge_incremental(self, current_data, new_data, key_columns, update_existing=True):
"""
合并增量数据
参数:
current_data (pd.DataFrame): 当前数据集
new_data (pd.DataFrame): 新增量数据
key_columns (list): 用于识别记录的键列
update_existing (bool): 是否更新已存在的记录
返回:
pd.DataFrame: 合并后的数据集
"""
# TODO: 实现增量数据合并功能
# 1. 验证输入数据类型和键列的存在
# 2. 设置键列为索引以便合并
# 3. 根据update_existing参数决定如何合并数据
# 4. 重置索引,恢复键列
pass
def update_incremental(self, dataset_id, new_data, key_columns, update_existing=True, metadata=None):
"""
根据增量数据更新数据集
参数:
dataset_id (str): 数据集ID
new_data (pd.DataFrame): 新增量数据
key_columns (list): 用于识别记录的键列
update_existing (bool): 是否更新已存在的记录
metadata (dict): 更新相关的元数据
返回:
dict: 更新摘要
"""
# TODO: 实现增量数据更新流程
# 1. 尝试加载当前数据,如果不存在则创建新的
# 2. 检测当前数据和新数据之间的变更
# 3. 如果有变更,合并数据并保存新版本
# 4. 返回更新摘要
pass
def get_version_history(self, dataset_id):
"""
获取数据集的版本历史
参数:
dataset_id (str): 数据集ID
返回:
dict: 版本历史
"""
# TODO: 获取数据集的版本历史
# 1. 如果不启用版本控制,返回相应信息
# 2. 如果数据集不存在于版本历史中,返回空版本历史
# 3. 返回完整的版本历史和最新版本号
pass
def load_version(self, dataset_id, version=None):
"""
加载特定版本的数据集
参数:
dataset_id (str): 数据集ID
version (int): 要加载的版本号,默认为最新版本
返回:
pd.DataFrame: 数据集
"""
# TODO: 加载特定版本的数据集
# 1. 如果不启用版本控制,直接加载最新数据
# 2. 验证数据集存在于版本历史中
# 3. 确定要加载的版本号(如果未指定则使用最新版本)
# 4. 验证版本号有效
# 5. 生成版本化ID并加载对应数据
pass
在量化分析中,常见的金融数据结构包括:
rolling().apply()
优于自定义循环)