-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathfunctions.py
More file actions
64 lines (47 loc) · 2 KB
/
functions.py
File metadata and controls
64 lines (47 loc) · 2 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
from torch.autograd import Function
class ReverseLayerF(Function):
@staticmethod
def forward(ctx,x,alpha):
ctx.alpha = alpha
return x.view_as(x)
@staticmethod
def backward(ctx,grad_output):
output = grad_output.neg() * ctx.alpha
return output, None
import torch
import torch.nn as nn
import torch.nn.functional as func
class difference_loss(nn.Module):
def __init__(self,reduce=True):
super(difference_loss,self).__init__()
self.reduce = reduce
return
def forward(self, private_samples,shared_samples):
batch_size = private_samples.size(0)
private_samples = private_samples.view(batch_size,-1)
shared_samples = shared_samples.view(batch_size,-1)
private_samples = private_samples - torch.mean(private_samples,0,keepdim=True)
shared_samples = shared_samples - torch.mean(shared_samples,0,keepdim=True)
pn = torch.norm(private_samples,p=2,dim=1, keepdim=True)
sn = torch.norm(shared_samples,p=2,dim=1, keepdim=True)
private_samples = private_samples.div(pn.expand_as(private_samples)+1e-10)
shared_samples = shared_samples.div(sn.expand_as(shared_samples)+1e-10)
diff_loss = torch.sum((shared_samples.t().mm(private_samples)).pow(2), 0)
if self.reduce:
diff_loss = torch.mean(diff_loss)
return diff_loss
class mean_pairwise_square_loss(nn.Module):
def __init__(self,reduce=True):
super(mean_pairwise_square_loss,self).__init__()
self.reduce = reduce
return
def forward(self,predictions,labels):
diff = predictions - labels
diff = diff.view(diff.size(0),-1)
sum_square_diff = torch.sum(diff.pow(2))
square_sum_diff = torch.sum(torch.pow(torch.sum(diff,1),2))
num_present =torch.numel(diff.data)
loss = sum_square_diff / num_present + square_sum_diff / num_present / num_present
if not self.reduce:
loss = loss * diff.size(0)
return loss