优化PyTorch代码:去除冗余变量,简洁计算准确率
优化PyTorch代码:去除冗余变量,简洁计算准确率
在之前的回答中,我们使用了一个名为 'correct' 的变量来统计模型预测正确的样本数。为了使代码更加简洁易懂,我们可以直接使用 'correct_total' 变量来完成相同的任务,从而无需再引入 'correct' 变量。
以下是优化后的代码:pythoncorrect_total = 0total = 0
for i, input_tensor in enumerate(train_tensors): optimizer.zero_grad()
output = network(input_tensor)
loss = custom_loss(output, tensor_list[i])
loss.backward() optimizer.step()
# 统计准确率 target_similarity = F.cosine_similarity(output, tensor_list[i].unsqueeze(0), dim=1) label_list = [torch.tensor([1, 0, 0, 0]), torch.tensor([0, 1, 0, 0]), torch.tensor([0, 0, 1, 0]), torch.tensor([1, 1, 1, 1])] other_list = [] for label_tensor in label_list: if not torch.all(torch.eq(tensor_list[i], label_tensor)): other_list.append(label_tensor)
if target_similarity > torch.max([F.cosine_similarity(output, other.unsqueeze(0), dim=1) for other in other_list]): correct_total += 1
total += 1
计算最终的正确率accuracy = correct_total / totalprint('Final Accuracy: %.2f%%' % (100 * accuracy))
在这个优化后的代码中,我们直接使用 'correct_total' 变量来累加预测正确的样本数,并在最后计算准确率时使用它。这样做不仅简化了代码,还减少了内存占用。
希望这次的优化能够帮助你更好地理解和编写PyTorch代码!如有任何疑问,请随时提问。
原文地址: https://gggwd.com/t/topic/MSB 著作权归作者所有。请勿转载和采集!