def get_batches(X, y, num_seqs, num_steps):
per_batch = num_seqs*num_steps
num_batches = len(X)//per_batch
X = X.reset_index().values[:, 1:]
y = y.reset_index().values[:, 1:]
X, y = X[:num_batches*per_batch], y[:num_batches*per_batch]
dataX = []
dataY = []
for i in range(0, num_batches*per_batch, num_steps):
dataX.append(X[i:i+num_steps])
dataY.append(y[i:i + num_steps])
X = np.asarray(dataX)
y = np.asarray(dataY)
for i in range(0, (num_batches*per_batch)//num_steps, num_steps):
yield X[i:i+num_seqs, :, :], y[i:i+num_seqs, :, :]
for ii, (x, y) in enumerate(get_batches(X_train, y_train, batch_size, num_steps), 1):
这个枚举是如何输出的
enumerate 函数用于遍历序列中的元素以及它们的下标:
list1 = ["这", "是", "一个", "测试"]
for index, item in enumerate(list1):
print( index, item)