我想求出degree!!!但是我的图跑出来一直不对!!!!可以帮助指导一下吗!!!应该是4,但是我做出来的是2 !!可以帮忙看一下吗!!!谢谢大佬
from sklearn.linear_model import LinearRegression
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error as mse
import matplotlib.pyplot as plt
import pandas as pd
x = mod_dum['days_after_open']
X = x #np.hstack((x,x**2,x**3,x**4,x**5,x**6,x**7,x**8,x**9,x**10,x**11,x**12,x**13,x**14,x**15))
y = mod_y.to_numpy()
# ---------------------
# Train-vali-test split
# ---------------------
X_tv, X_test, y_tv, y_test = train_test_split(X, y, test_size=0.25, random_state = 1)
X_train, X_vali, y_train, y_vali = train_test_split(X_tv, y_tv, test_size=1/3, random_state = 1)
mse_train = []
mse_vali = []
max_deg = 15
degrees = range(1, max_deg+1)
for deg in degrees:
X_train_df = pd.DataFrame(X_train[:deg])
y_train_df = pd.DataFrame(y_train[:deg])
X_vali_df = pd.DataFrame(X_vali[:deg])
# -----------------------------------
# Build and fit your regression model
# -----------------------------------
linear_reg = LinearRegression()
linear_reg.fit(X_train_df, y_train_df)
# 因为是要进行deg的循环,所以要对x进行处理
# ----------------------------------
# Predict with your regression model
# ----------------------------------
train_pre = linear_reg.predict(X_train_df)
vali_pre = linear_reg.predict(X_vali_df)
# -----------------
# Calculate and save the MSE
# -----------------
train_pre_df = pd.DataFrame(train_pre)
vali_pre_df = pd.DataFrame(vali_pre)
y_vali_df = pd.DataFrame(y_vali[:deg])
print(f"for degree {deg}:", mse(train_pre_df, y_train_df), mse(vali_pre_df, y_vali_df))
mse_train = mse_train + [mse(train_pre_df, y_train_df)]
mse_vali = mse_vali + [mse(vali_pre_df, y_vali_df)]
# ------------------------
# Plot the holdout results (Do not modify)
# ------------------------
mse_train_df = pd.DataFrame(mse_train)
mse_vali_df = pd.DataFrame(mse_vali)
plt.plot(mse_train_df, color = 'green', label = 'Training data')
plt.plot(mse_vali_df, color = 'orange', label = 'Validation data')
plt.xlabel('Polynomial degree')
plt.ylabel('Mean squared error')
plt.title('Holdout')
plt.xticks(degrees)
plt.yscale('log')
# plt.ylim([5e3, 5e11])
plt.legend()
plt.savefig('plot.png')
plt.show()
2和4分别是啥?