From 396111f60f0abb4c4116dd033bf27bab8cc85e31 Mon Sep 17 00:00:00 2001 From: JunYi <38555734+wujunyi627@users.noreply.github.com> Date: Sun, 10 Mar 2019 10:03:26 +0800 Subject: [PATCH] Update triplet.py --- reid/loss/triplet.py | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/reid/loss/triplet.py b/reid/loss/triplet.py index a019eee..663ad2e 100644 --- a/reid/loss/triplet.py +++ b/reid/loss/triplet.py @@ -3,7 +3,7 @@ import torch from torch import nn from torch.autograd import Variable - +version = torch.__version__ class TripletLoss(nn.Module): def __init__(self, margin=0): @@ -21,11 +21,18 @@ def forward(self, inputs, targets): # For each anchor, find the hardest positive and negative mask = targets.expand(n, n).eq(targets.expand(n, n).t()) dist_ap, dist_an = [], [] - for i in range(n): - dist_ap.append(dist[i][mask[i]].max()) - dist_an.append(dist[i][mask[i] == 0].min()) - dist_ap = torch.cat(dist_ap) - dist_an = torch.cat(dist_an) + if int(version[2]) ==3: + for i in range(n): + dist_ap.append(dist[i][mask[i]].max()) + dist_an.append(dist[i][mask[i] == 0].min()) + dist_ap = torch.cat(dist_ap) + dist_an = torch.cat(dist_an) + else: + for i in range(n): + dist_ap.append(dist[i][mask[i]].max().unsqueeze(0)) + dist_an.append(dist[i][mask[i] == 0].min().unsqueeze(0)) + dist_ap = torch.cat(dist_ap) + dist_an = torch.cat(dist_an) # Compute ranking hinge loss y = dist_an.data.new() y.resize_as_(dist_an.data)