Coverage for hierarchicalsoftmax/loss.py : 100.00%

Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1import torch
2from torch import nn
3from torch import Tensor
4import torch.nn.functional as F
5from torch.autograd import Variable
8def focal_loss_with_smoothing(logits, label, weight=None, gamma=0.0, label_smoothing=0.0):
9 """
10 Adapted from https://github.com/Kageshimasu/focal-loss-with-smoothing
11 and https://github.com/clcarwin/focal_loss_pytorch
12 """
13 log_probabilities = F.log_softmax(logits, dim=-1)
14 label = label.view(-1,1)
15 log_probability = log_probabilities.gather(1,label).squeeze()
16 n_classes = logits.size(1)
17 uniform_probability = label_smoothing / n_classes
18 label_distribution = torch.full_like(logits, uniform_probability)
19 label_distribution.scatter_(1, label, 1.0 - label_smoothing + uniform_probability)
21 probability = Variable(log_probability.data.exp())
22 difficulty_level = (1-probability)** gamma
23 loss = -difficulty_level * torch.sum(log_probabilities * label_distribution, dim=1)
25 # Weights
26 if weight is not None:
27 weight = weight.to(logits.device)
28 loss *= torch.gather(weight, -1, label.squeeze())/weight.mean()
30 return loss.mean()
33class HierarchicalSoftmaxLoss(nn.Module):
34 """
35 A module which sums the loss for each level of a hiearchical tree.
36 """
37 def __init__(
38 self,
39 root,
40 **kwargs
41 ):
42 super().__init__(**kwargs)
43 self.root = root
45 # Set the indexes of the tree if necessary
46 self.root.set_indexes_if_unset()
48 assert len(self.root.node_list) > 0
50 def forward(self, batch_predictions: Tensor, targets: Tensor) -> Tensor:
51 target_nodes = (self.root.node_list[target] for target in targets)
53 total_loss = 0.0
54 device = targets.device
56 for prediction, target_node in zip(batch_predictions, target_nodes):
57 node = target_node
58 while node.parent:
59 node.index_in_parent_tensor = node.index_in_parent_tensor.to(device) # can this be done elsewhere?
60 logits = torch.unsqueeze(prediction[node.parent.softmax_start_index:node.parent.softmax_end_index], dim=0)
61 label = node.index_in_parent_tensor
62 weight = node.parent.weight
63 label_smoothing = node.parent.label_smoothing
64 gamma = node.parent.gamma
65 if gamma is not None and gamma > 0.0:
66 loss = focal_loss_with_smoothing(
67 logits,
68 label,
69 weight=weight,
70 label_smoothing=label_smoothing,
71 gamma=gamma,
72 )
73 else:
74 loss = F.cross_entropy(
75 logits,
76 label,
77 weight=weight,
78 label_smoothing=label_smoothing,
79 )
81 total_loss += node.parent.alpha * loss
82 node = node.parent
84 batch_size = len(targets)
85 total_loss /= batch_size
86 return total_loss
91# class HierarchicalSoftmaxLoss(nn.Module):
92# """
93# A module which sums the loss for each level of a hiearchical tree.
94# """
95# def __init__(
96# self,
97# root,
98# **kwargs
99# ):
100# super().__init__(**kwargs)
101# self.root = root
103# # Set the indexes of the tree if necessary
104# self.root.set_indexes_if_unset()
106# assert len(self.root.node_list) > 0
108# def forward(self, batch_predictions: Tensor, targets: Tensor) -> Tensor:
109# target_nodes = (self.root.node_list[target] for target in targets)
111# loss = 0.0
112# device = targets.device
114# for prediction, target_node in zip(batch_predictions, target_nodes):
115# node = target_node
116# while node.parent:
117# node.index_in_parent_tensor = node.index_in_parent_tensor.to(device) # can this be done elsewhere?
118# print(node.index_in_parent_tensor)
119# loss += node.parent.alpha * F.cross_entropy(
120# torch.unsqueeze(prediction[node.parent.softmax_start_index:node.parent.softmax_end_index], dim=0),
121# node.index_in_parent_tensor,
122# weight=node.parent.weight,
123# label_smoothing=node.parent.label_smoothing,
124# )
125# node = node.parent
127# batch_size = len(targets)
128# loss /= batch_size
129# return loss