批度下降法化为分批度下降法,更改训练数据集大小,Python代码出现维度不统一的报错

问题遇到的现象和发生背景

背景:要求一个3600*5的矩阵,第一列作为Y标注,后四列作为x的四个输入的训练函数。
前3000作为训练集,后600作为测试集。输出迭代的效果图像。
现象:将训练集从3000整体训练,改成200分批训练的时候,代码出现维度不统一报错。经排查发现是赋值变量出现问题。

问题相关代码,请勿粘贴截图

#可正确运行的代码
import numpy as np
import matplotlib.pyplot as plt
num = 3000
train_x = data[0:3000,1:]
train_y = data[0:3000:,0].reshape(num,1)
val_x = data[3000:,1:]
val_y = data[3000:,0].reshape(len(data)-num,1)
w = np.zeros([4,1])
b = 0
w_g = np.zeros([4,1])
b_g = 0
learn_rate = 0.0001
epoch = 200
loss = np.zeros(epoch)
N = len(train_x)
for i in range(epoch):
y = np.dot(train_x,w)+b

print(y.shape)

loss[i] = np.sqrt(np.sum(np.power(y-train_y, 2))/N)
w_g = 2*np.dot(train_x.T,(y-train_y))/N

print((y-train_y).shape,w_g.shape)

b_g = np.sum(2*(y-train_y)/N)
w -= learn_rate*w_g
b -= learn_rate*b_g

plt.figure(figsize=(16.2,12.15))
plt.plot(loss)
plt.show()

##出现报错的代码
import pandas
import numpy as np
import matplotlib.pyplot as plt

load dataset

df = pandas.read_csv('temperature_dataset.csv')
data = np.array(df)
num = 3000
train_x = data[0:3000,1:]
train_y = data[0:3000:,0].reshape(num,1)
val_x = data[3000:,1:]
val_y = data[3000:,0].reshape(len(data)-num,1)
w = np.zeros([4,1])
b = 0
w_g = np.zeros([4,1])
b_g = 0
learn_rate = 0.0001
epoch = 200
loss = np.zeros([epoch,1])
temp_x=np.zeros([epoch,4])
temp_y=np.zeros([epoch,1])
y=np.zeros((epoch,1))
print(temp_x.shape,"temp_y",temp_y.shape)
N = len(temp_x)
#e=np.zeros(epoch,1)

for i in range(epoch):
temp_x=train_x[i*epoch:(i+1)*epoch,:]
y = np.dot(temp_x,w)+b

print(y.shape)

#e=y-temp_y
loss[i] = np.sqrt(np.sum(np.power(y-temp_y,2)/200))
w_g = 2*np.dot(temp_x.T,(y-temp_y))/N

print((y-train_y).shape,w_g.shape)

b_g = np.sum(2*(y-temp_y)/N)
w -= learn_rate*w_g
b -= learn_rate*b_g

plt.figure(figsize=(16.2,12.15))
plt.plot(loss)
plt.show()

#检测数组维数代码
#temp_x=train_x[i*epoch:(i+1)*epoch,:]
#y = np.dot(w.T,temp_x.T)+b
#print(loss.shape,temp_x.shape,temp_y.shape)
#print("temp_x.shape",temp_x.shape," shape.w=",w.shape," (np.dot(temp_x,w)).shape",(np.dot(w.T,temp_x.T)).shape," y.shape=",y.shape)

print("(temp_x.T.dot(y)).shape=",(temp_x.T.dot(y)).shape)

运行结果及报错内容

img

img

我的解答思路和尝试过的方法

逐句排查,分析每个矩阵,向量的维数进行判断。显示在temp_x初始化的时候出现问题,本应epoch=200,在本行突然变成0.不知为何。

我想要达到的结果