这个代码,为什么定义了 task_type = ‘multiclass', 但是最后一行打印出来的score 还是以
这个得出来的呢?这个分数是 task_type =’regression‘时得出的分数
else:
assert task_type == 'regression'
score = sklearn.metrics.mean_squared_error(target, prediction) ** 0.5 * y_std
return score
下面是源代码
task_type = 'multiclass'
@torch.no_grad()
def evaluate(part, task_type):
model.eval()
prediction = []
for batch in delu.iter_batches(X[part], 1024):
prediction.append(apply_model(batch))
prediction = torch.cat(prediction).squeeze(1).cpu().numpy()
target = y[part].cpu().numpy()
if task_type == 'binclass':
prediction = np.round(scipy.special.expit(prediction))
score = sklearn.metrics.accuracy_score(target, prediction)
elif task_type == 'multiclass':
prediction = prediction.argmax(1)
score = sklearn.metrics.accuracy_score(target, prediction)
else:
assert task_type == 'regression'
score = sklearn.metrics.mean_squared_error(target, prediction) ** 0.5 * y_std
return score
# Create a dataloader for batches of indices
# Docs: https://yura52.github.io/zero/reference/api/zero.data.IndexLoader.html
batch_size = 256
train_loader = delu.data.IndexLoader(len(X['train']), batch_size, device=device)
# Create a progress tracker for early stopping
# Docs: https://yura52.github.io/zero/reference/api/zero.ProgressTracker.html
progress = delu.ProgressTracker(patience=100)
print(f'Test score before training: {evaluate("test", task_type):.4f}')
你在13行前面输出一下task_type看看值是什么啊。
既然有源码,在里面(比如第12行的位置)print一下task_type,不就可以发现问题了吗?
我的建议是分成两行输出,不要想着偷懒一行搞定,你这个写在f“”里面,不能保证你的task_type传入的参数一定是正确的。
score=evaluate("test", task_type)
print(f'Test score before training: {score:.4f}')