sklearn用 train_test_split对数据集划分后,调用LinearRegression函数时报错,如何解决?

求各位解答!

问题遇到的现象和发生背景

sklearn用 train_test_split对数据集划分后,调用LinearRegression函数时报错ValueError: Input X contains NaN.LinearRegression does not accept missing values encoded as NaN natively. For supervised learning, you might want to consider sklearn.ensemble.

用代码块功能插入代码,请勿粘贴截图
import pandas as pd
import seaborn as sns
from sklearn.linear_model import LinearRegression
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split

# 通过read_csv来读取我们的目的数据集
adv_data = pd.read_csv("H:/学习用品/ML/multivariable regression/sale_data.csv")
# 清洗不需要的数据
new_adv_data = adv_data.iloc[:, :]
# 得到我们所需要的数据集且查看其前几列以及数据形状
print('head:', new_adv_data.head(), '\nShape:', new_adv_data.shape)
print("-----------------------------------------------")

# 数据描述
print(new_adv_data.describe())
print("-----------------------------------------------")
# 缺失值检验
print(new_adv_data[new_adv_data.isnull() == True].count())
print("-----------------------------------------------")

new_adv_data.boxplot()
plt.savefig("boxplot.jpg")
plt.show()

##相关系数矩阵 r(相关系数) = x和y的协方差/(x的标准差*y的标准差) == cov(x,y)/σx*σy
# 相关系数0~0.3弱相关0.3~0.6中等程度相关0.6~1强相关
print(new_adv_data.corr())

# 建立散点图来查看数据集里的数据分布
# seaborn的pairplot函数绘制X的每一维度和对应Y的散点图。通过设置size和aspect参数来调节显示的大小和比例。
# 可以从图中看出,TV特征和销量是有比较强的线性关系的,而Radio和Sales线性关系弱一些,Newspaper和Sales线性关系更弱。
# 通过加入一个参数kind='reg',seaborn可以添加一条最佳拟合直线和95%的置信带。
sns.pairplot(new_adv_data, x_vars=['TV', 'radio', 'newspaper'], y_vars='sales', size=7, aspect=0.8, kind='reg')
plt.savefig("pairplot.jpg")
plt.show()

# 利用sklearn里面的包来对数据集进行划分,以此来创建训练集和测试集
# train_size表示训练集所占总数据集的比例
print(new_adv_data.iloc[:, :3].describe())
X_train, X_test, Y_train, Y_test = train_test_split(new_adv_data.iloc[:, :3], new_adv_data.sales, train_size=.80)
print(X_train)
print('++++++++++++++++++++++')
print(X_test)
print('++++++++++++++++++++')
print(X_train[X_train.isnull() == True].count())
print(X_test[X_test.isnull() == True].count())
print('++++++++++++++++++++')

print("原始数据特征:", new_adv_data.iloc[:, :3].shape,
      ",训练数据特征:", X_train.shape,
      ",测试数据特征:", X_test.shape)

print("原始数据标签:", new_adv_data.sales.shape,
      ",训练数据标签:", Y_train.shape,
      ",测试数据标签:", Y_test.shape)

model = LinearRegression()

model.fit(X_train, Y_train)

a = model.intercept_  # 截距

b = model.coef_  # 回归系数

print("最佳拟合线:截距", a, ",回归系数:", b)
# y=2.668+0.0448∗TV+0.187∗Radio-0.00242∗Newspaper

# R方检测
# 决定系数r平方
# 对于评估模型的精确度
# y误差平方和 = Σ(y实际值 - y预测值)^2
# y的总波动 = Σ(y实际值 - y平均值)^2
# 有多少百分比的y波动没有被回归拟合线所描述 = SSE/总波动
# 有多少百分比的y波动被回归线描述 = 1 - SSE/总波动 = 决定系数R平方
# 对于决定系数R平方来说1) 回归线拟合程度:有多少百分比的y波动刻印有回归线来描述(x的波动变化)
# 2)值大小:R平方越高,回归模型越精确(取值范围0~1),1无误差,0无法完成拟合
score = model.score(X_test, Y_test)

print(score)

# 对线性回归进行预测

Y_pred = model.predict(X_test)

print(Y_pred)

plt.plot(range(len(Y_pred)), Y_pred, 'b', label="predict")
# 显示图像
# plt.savefig("predict.jpg")
plt.show()

plt.figure()
plt.plot(range(len(Y_pred)), Y_pred, 'b', label="predict")
plt.plot(range(len(Y_pred)), Y_test, 'r', label="test")
plt.legend(loc="upper right")  # 显示图中的标签
plt.xlabel("the number of sales")
plt.ylabel('value of sales')
plt.savefig("ROC.jpg")
plt.show()

运行结果及报错内容
head:       TV  radio  newspaper  sales
0  230.1   37.8       69.2   22.1
1   44.5   39.3       45.1   10.4
2   17.2   45.9       69.3    9.3
3  151.5   41.3       58.5   18.5
4  180.8   10.8       58.4   12.9 
Shape: (201, 4)
-----------------------------------------------
               TV       radio   newspaper       sales
count  200.000000  200.000000  200.000000  200.000000
mean   147.042500   23.264000   30.554000   14.022500
std     85.854236   14.846809   21.778621    5.217457
min      0.700000    0.000000    0.300000    1.600000
25%     74.375000    9.975000   12.750000   10.375000
50%    149.750000   22.900000   25.750000   12.900000
75%    218.825000   36.525000   45.100000   17.400000
max    296.400000   49.600000  114.000000   27.000000
-----------------------------------------------
TV           0
radio        0
newspaper    0
sales        0
dtype: int64
-----------------------------------------------
H:\学习用品\ML\multivariable regression\venv\lib\site-packages\seaborn\axisgrid.py:2095: UserWarning: The `size` parameter has been renamed to `height`; please update your code.
  warnings.warn(msg, UserWarning)
                 TV     radio  newspaper     sales
TV         1.000000  0.054809   0.056648  0.782224
radio      0.054809  1.000000   0.354104  0.576223
newspaper  0.056648  0.354104   1.000000  0.228299
sales      0.782224  0.576223   0.228299  1.000000
               TV       radio   newspaper
count  200.000000  200.000000  200.000000
mean   147.042500   23.264000   30.554000
std     85.854236   14.846809   21.778621
min      0.700000    0.000000    0.300000
25%     74.375000    9.975000   12.750000
50%    149.750000   22.900000   25.750000
75%    218.825000   36.525000   45.100000
max    296.400000   49.600000  114.000000
        TV  radio  newspaper
80    76.4   26.7       22.3
42   293.6   27.7        1.8
117   76.4    0.8       14.8
123  123.1   34.6       12.4
156   93.9   43.5       50.5
..     ...    ...        ...
110  225.8    8.2       56.5
147  243.2   49.0       44.3
89   109.8   47.8       51.4
160  172.5   18.1       30.7
84   213.5   43.0       33.8

[160 rows x 3 columns]
++++++++++++++++++++++
        TV  radio  newspaper
96   197.6    3.5        5.9
197  177.0    9.3        6.4
181  218.5    5.4       27.4
38    43.1   26.7       35.1
113  209.6   20.6       10.7
130    0.7   39.6        8.7
183  287.6   43.0       71.8
189   18.7   12.1       23.4
127   80.2    0.0        9.2
31   112.9   17.4       38.6
155    4.1   11.6        5.7
46    89.7    9.9       35.7
193  166.8   42.0        3.6
93   250.9   36.5       72.3
59   210.7   29.5        9.3
101  296.4   36.3      100.9
106   25.0   11.0       29.7
26   142.9   29.3       12.6
66    31.5   24.6        2.2
177  170.2    7.8       35.2
108   13.1    0.4       25.6
41   177.0   33.4       38.7
134   36.9   38.6       65.6
170   50.0   11.6       18.4
111  241.7   38.0       23.2
118  125.7   36.9       79.2
25   262.9    3.5       19.5
100  222.4    4.3       49.8
82    75.3   20.3       32.5
43   206.9    8.4       26.4
154  187.8   21.1        9.5
18    69.2   20.5       18.3
105  137.9   46.4       59.0
19   147.3   23.9       19.1
36   266.9   43.8        5.0
90   134.3    4.9        9.3
87   110.7   40.6       63.2
76    27.5    1.6       20.7
94   107.4   14.0       10.9
2     17.2   45.9       69.3
69   216.8   43.9       27.2
++++++++++++++++++++
TV           0
radio        0
newspaper    0
dtype: int64
TV           0
radio        0
newspaper    0
dtype: int64
++++++++++++++++++++
原始数据特征: (201, 3) ,训练数据特征: (160, 3) ,测试数据特征: (41, 3)
原始数据标签: (201,) ,训练数据标签: (160,) ,测试数据标签: (41,)
Traceback (most recent call last):
  File "H:\学习用品\ML\multivariable regression\venv\MLR.py", line 60, in 
    model.fit(X_train, Y_train)
  File "H:\学习用品\ML\multivariable regression\venv\lib\site-packages\sklearn\linear_model\_base.py", line 684, in fit
    X, y = self._validate_data(
  File "H:\学习用品\ML\multivariable regression\venv\lib\site-packages\sklearn\base.py", line 596, in _validate_data
    X, y = check_X_y(X, y, **check_params)
  File "H:\学习用品\ML\multivariable regression\venv\lib\site-packages\sklearn\utils\validation.py", line 1074, in check_X_y
    X = check_array(
  File "H:\学习用品\ML\multivariable regression\venv\lib\site-packages\sklearn\utils\validation.py", line 899, in check_array
    _assert_all_finite(
  File "H:\学习用品\ML\multivariable regression\venv\lib\site-packages\sklearn\utils\validation.py", line 146, in _assert_all_finite
    raise ValueError(msg_err)
ValueError: Input X contains NaN.
LinearRegression does not accept missing values encoded as NaN natively. For supervised learning, you might want to consider sklearn.ensemble.HistGradientBoostingClassifier and Regressor which accept missing values encoded as NaNs natively. Alternatively, it is possible to preprocess the data, for instance by using an imputer transformer in a pipeline or drop samples with missing values. See https://scikit-learn.org/stable/modules/impute.html You can find a list of all estimators that handle NaN values at the following page: https://scikit-learn.org/stable/modules/impute.html#estimators-that-handle-nan-values

Process finished with exit code 1

我想要达到的结果

对相关方法了解不够,不清楚是使用方法问题还是什么?求各位解救!

输入的X变量存在空值,删掉空白的数据就好了