python程序问题,传递参数

这个代码,为什么定义了 task_type = ‘multiclass', 但是最后一行打印出来的score 还是以
这个得出来的呢?这个分数是 task_type =’regression‘时得出的分数
else:
assert task_type == 'regression'
score = sklearn.metrics.mean_squared_error(target, prediction) ** 0.5 * y_std
return score

img

下面是源代码


task_type = 'multiclass'

@torch.no_grad()
def evaluate(part, task_type):
    model.eval()
    prediction = []
    for batch in delu.iter_batches(X[part], 1024):
        prediction.append(apply_model(batch))
    prediction = torch.cat(prediction).squeeze(1).cpu().numpy()
    target = y[part].cpu().numpy()

    if task_type == 'binclass':
        prediction = np.round(scipy.special.expit(prediction))
        score = sklearn.metrics.accuracy_score(target, prediction)
    elif task_type == 'multiclass':
        prediction = prediction.argmax(1)
        score = sklearn.metrics.accuracy_score(target, prediction)
    else:
        assert task_type == 'regression'
        score = sklearn.metrics.mean_squared_error(target, prediction) ** 0.5 * y_std
    return score


# Create a dataloader for batches of indices
# Docs: https://yura52.github.io/zero/reference/api/zero.data.IndexLoader.html
batch_size = 256
train_loader = delu.data.IndexLoader(len(X['train']), batch_size, device=device)

# Create a progress tracker for early stopping
# Docs: https://yura52.github.io/zero/reference/api/zero.ProgressTracker.html
progress = delu.ProgressTracker(patience=100)
print(f'Test score before training: {evaluate("test", task_type):.4f}')

你在13行前面输出一下task_type看看值是什么啊。

既然有源码,在里面(比如第12行的位置)print一下task_type,不就可以发现问题了吗?

我的建议是分成两行输出,不要想着偷懒一行搞定,你这个写在f“”里面,不能保证你的task_type传入的参数一定是正确的。


score=evaluate("test", task_type)
print(f'Test score before training: {score:.4f}')