源代码地址:https://github.com/google-research/google-research/tree/master/tft
我想跑通自己的数据集,写了个按照readme提示的步骤写了个example.py,但是无法运行,一直报错
example.py代码如下:
import data_formatters.base
import libs.utils as utils
import pandas as pd
import sklearn.preprocessing
GenericDataFormatter = data_formatters.base.GenericDataFormatter
DataTypes = data_formatters.base.DataTypes
InputTypes = data_formatters.base.InputTypes
class ExampleFormatter(GenericDataFormatter):
"""Defines and formats data for dataset.
Attributes:
column_definition: Defines input and data type of column used in the
experiment.属性中使用的列的输入和数据类型
identifiers: Entity identifiers used in experiments.实验中使用的实体标识符
"""
_column_definition = [
# ('Symbol', DataTypes.CATEGORICAL, InputTypes.ID),
# 因为我自己的数据集是没有标签列数据集,使所以不使用这一行
('date', DataTypes.DATE, InputTypes.TIME),
# DATE:日期
('log_vol', DataTypes.REAL_VALUED, InputTypes.TARGET),
# TARGET:目标变量
('HUFL', DataTypes.REAL_VALUED, InputTypes.OBSERVED_INPUT),
('HULL', DataTypes.REAL_VALUED, InputTypes.OBSERVED_INPUT),
# OBSERVED_INPUT:输入的观测值
('MUFL', DataTypes.REAL_VALUED, InputTypes.KNOWN_INPUT),
('count_from_start', DataTypes.REAL_VALUED, InputTypes.KNOWN_INPUT),
# REAL_VALUED:真实值,KNOWN_INPUT:已知值,count_from_start;从第一条记录开始,多一条数据记录+1
('MULL', DataTypes.CATEGORICAL, InputTypes.KNOWN_INPUT),
('LUFL', DataTypes.CATEGORICAL, InputTypes.KNOWN_INPUT),
('LULL', DataTypes.CATEGORICAL, InputTypes.KNOWN_INPUT),
('month', DataTypes.CATEGORICAL, InputTypes.KNOWN_INPUT),
# ('Region', DataTypes.CATEGORICAL, InputTypes.STATIC_INPUT),
# CATEGORICAL:分类变量,STATIC_INPUT:静态输入变量
]
def __init__(self):
"""Initialises formatter."""
self.identifiers = None
self._real_scalers = None
self._cat_scalers = None
self._target_scaler = None
self._num_classes_per_cat_input = None
# 初始化,格式化
def split_data(self, df, valid_boundary=850, test_boundary=1250):
# valid_boundary和test_boundary的参数设置,
"""Splits data frame into training-validation-test data frames.
This also calibrates scaling object, and transforms data for each split.
这也校准缩放对象,并为每个分割转换数据
Args:
df: Source data frame to split.
valid_boundary: Starting counts for validation data
test_boundary: Starting counts for test data
Returns:
Tuple of transformed (train, valid, test) data.
"""
print('Formatting train-valid-test splits.')
index = df['count_from_start']
train = df.loc[index < valid_boundary]
valid = df.loc[(index >= valid_boundary) & (index < test_boundary)]
test = df.loc[(index >= test_boundary) & (df.index <= 1344)]
# 目前只有1344个数据,所以测试集最大为1344
self.set_scalers(train)
return (self.transform_inputs(data) for data in [train, valid, test])
def set_scalers(self, df):
"""Calibrates scalers using the data supplied.
Args:
df: Data to use to calibrate scalers.
"""
print('Setting scalers with training data...')
column_definitions = self.get_column_definition()
# id_column = utils.get_single_col_by_input_type(InputTypes.ID,
# column_definitions)
target_column = utils.get_single_col_by_input_type(InputTypes.TARGET,
column_definitions)
# Extract identifiers in case required 在需要时提取标识符
self.identifiers = list(df[id_column].unique())
# Format real scalers
real_inputs = utils.extract_cols_from_data_type(
DataTypes.REAL_VALUED, column_definitions, InputTypes.TIME)
# 标准化数据
data = df[real_inputs].values
self._real_scalers = sklearn.preprocessing.StandardScaler().fit(data)
self._target_scaler = sklearn.preprocessing.StandardScaler().fit(
df[[target_column]].values) # used for predictions 用于预测的值target
# Format categorical scalers
categorical_inputs = utils.extract_cols_from_data_type(
DataTypes.CATEGORICAL, column_definitions, InputTypes.TIME)
categorical_scalers = {}
num_classes = []
for col in categorical_inputs:
# Set all to str so that we don't have mixed integer/string columns
# 将所有数据设置为字符串,这样我们就不会有整数/字符串混合列
srs = df[col].apply(str)
categorical_scalers[col] = sklearn.preprocessing.LabelEncoder().fit(
srs.values)
num_classes.append(srs.nunique())
# Set categorical scaler outputs
self._cat_scalers = categorical_scalers
self._num_classes_per_cat_input = num_classes
def transform_inputs(self, df):
"""Performs feature transformations.
This includes both feature engineering, preprocessing and normalisation.
Args:
df: Data frame to transform.
Returns:
Transformed data frame.
"""
output = df.copy()
if self._real_scalers is None and self._cat_scalers is None:
raise ValueError('Scalers have not been set!')
column_definitions = self.get_column_definition()
real_inputs = utils.extract_cols_from_data_type(
DataTypes.REAL_VALUED, column_definitions, InputTypes.TIME)
categorical_inputs = utils.extract_cols_from_data_type(
DataTypes.CATEGORICAL, column_definitions, InputTypes.TIME)
# Format real inputs
output[real_inputs] = self._real_scalers.transform(df[real_inputs].values)
# Format categorical inputs
for col in categorical_inputs:
string_df = df[col].apply(str)
output[col] = self._cat_scalers[col].transform(string_df)
return output
def format_predictions(self, predictions):
"""Reverts any normalisation to give predictions in original scale.
Args:
predictions: Dataframe of model predictions.
Returns:
Data frame of unnormalised predictions.
"""
output = predictions.copy()
column_names = predictions.columns
for col in column_names:
if col not in {'forecast_time', 'identifier'}:
output[col] = self._target_scaler.inverse_transform(predictions[col])
return output
# Default params
def get_fixed_params(self):
"""Returns fixed model parameters for experiments."""
fixed_params = {
'total_time_steps': 1344,
'num_encoder_steps': 1344,
'num_epochs': 3,
'early_stopping_patience': 5,
'multiprocessing_workers': 5,
}
return fixed_params
def get_default_model_params(self):
"""Returns default optimised model parameters."""
model_params = {
'dropout_rate': 0.3,
'hidden_layer_size': 160,
# 隐藏层大小
'learning_rate': 0.01,
'minibatch_size': 64,
'max_gradient_norm': 0.01,
'num_heads': 1,
'stack_size': 1
}
return model_params
报错:
Traceback (most recent call last):
File "C:/Users/55459/tft/script_hyperparam_opt.py", line 224, in <module>
formatter = config.make_data_formatter()
File "C:\Users\55459\tft\expt_settings\configs.py", line 132, in make_data_formatter
'example': data_formatters.example.ExampleFormatter
AttributeError: module 'data_formatters' has no attribute 'example'
尝试了不是拼写错误
能跑通自己的数据集