import pandas as pd
from sklearn.tree import DecisionTreeRegressor
from sklearn import preprocessing
from sklearn.metrics import mean_squared_error
from sklearn.model_selection import GridSearchCV
# 读取含有热误差数据的CSV文件
train_file = open('data1.csv', encoding='utf-8')
train_df = pd.read_csv(train_file)
# 读取测试集的含有热误差数据的CSV文件
test_file = open('data.csv', encoding='utf-8')
test_df = pd.read_csv(test_file)
# 对训练集进行数据预处理
X_train = train_df.iloc[:, :-1]
Y_train = train_df.iloc[:, -1]
X_train_scaled = preprocessing.scale(X_train)
# 对测试集进行数据预处理
X_test = test_df.iloc[:, :-1]
Y_test = test_df.iloc[:, -1]
X_test_scaled = preprocessing.scale(X_test, with_mean=X_train_scaled.mean(axis=0)[0], with_std=X_train_scaled.mean(axis=0)[0])
X = pd.concat([test_df.iloc[:, :-1], train_df.iloc[:, :-1]], axis=1)
Y = pd.concat([test_df.iloc[:, -1], train_df.iloc[:, -1]], axis=1)
X_test.columns = X_train.columns
# 创建决策树模型
dt = DecisionTreeRegressor()
# 定义网格搜索参数
param_grid = {
'max_depth': [1,2,3,4,5,6,7,8,9],
'min_samples_split': [2, 4, 6],
'min_samples_leaf': [1, 2, 3]
}
# 进行网格搜索优化
grid = GridSearchCV(dt, param_grid, cv=5)
grid.fit(X, Y)
grid_search = GridSearchCV(DecisionTreeRegressor(), param_grid, cv=5)
grid_search.fit(X_train, Y_train)
best_model = grid_search.best_estimator_
# 输出最优参数和模型得分
print('Best Parameters:', grid.best_params_)
# 定义新的温度数据
# 输出预测结果
Y_pred = best_model.predict(X_test)
mse = mean_squared_error(Y_test, Y_pred)
print(f"MSE: {mse:.4f}")
print(Y_pred)
代码运行预测值与真实值的方差有点高,怎么改善
可以尝试以下几个方法改善模型预测效果:
调节模型参数:适当地调节决策树模型的超参数可以提高模型预测精度。可以进一步调节参数范围以及尝试其他的机器学习模型。
特征工程:对原始特征进行处理可以提取更有用的信息,以提高模型的准确性。例如,可以进行特征选择、降维、离散化等。
改变模型方法:可以考虑使用其他的回归模型,例如线性回归、支持向量回归、随机森林等,以比较不同模型的预测效果,并选择预测效果更好的模型。
数据增强:可以通过数据扩充等方式增加样本数量,以提高模型的准确性和鲁棒性。
调整目标值:可以通过调整目标值的分布形态和缩放来改善模型的预测效果。例如,对目标值进行归一化或者对数变换。
增加训练数据的数量
n_samples=10
X=np.sort(np.random.uniform(-5,10,n_samples)).reshape(-1,1)
y=1.5*X**2-5*X-10
from sklearn.preprocessing import PolynomialFeatures
from sklearn.linear_model import LinearRegression
poly=PolynomialFeatures(2)
poly.fit(X) #数据预处理之拟合多项式:bais+x^1+x^2
X2=poly.transform(X)
print(X,'\n',X2)
显示结果