优化PyTorch代码:去除冗余变量,简洁计算准确率

在之前的回答中,我们使用了一个名为 'correct' 的变量来统计模型预测正确的样本数。为了使代码更加简洁易懂,我们可以直接使用 'correct_total' 变量来完成相同的任务,从而无需再引入 'correct' 变量。

以下是优化后的代码:pythoncorrect_total = 0total = 0

for i, input_tensor in enumerate(train_tensors): optimizer.zero_grad()

output = network(input_tensor)

loss = custom_loss(output, tensor_list[i])

loss.backward()    optimizer.step()

# 统计准确率    target_similarity = F.cosine_similarity(output, tensor_list[i].unsqueeze(0), dim=1)    label_list = [torch.tensor([1, 0, 0, 0]), torch.tensor([0, 1, 0, 0]), torch.tensor([0, 0, 1, 0]), torch.tensor([1, 1, 1, 1])]    other_list = []    for label_tensor in label_list:        if not torch.all(torch.eq(tensor_list[i], label_tensor)):            other_list.append(label_tensor)

if target_similarity > torch.max([F.cosine_similarity(output, other.unsqueeze(0), dim=1) for other in other_list]):        correct_total += 1

total += 1

计算最终的正确率accuracy = correct_total / totalprint('Final Accuracy: %.2f%%' % (100 * accuracy))

在这个优化后的代码中,我们直接使用 'correct_total' 变量来累加预测正确的样本数,并在最后计算准确率时使用它。这样做不仅简化了代码,还减少了内存占用。

希望这次的优化能够帮助你更好地理解和编写PyTorch代码!如有任何疑问,请随时提问。

标签: 常规


原文地址: https://gggwd.com/t/topic/MSB 著作权归作者所有。请勿转载和采集!