Coverage for hierarchicalsoftmax/loss.py: 100.00%
50 statements
« prev ^ index » next coverage.py v7.8.0, created at 2025-07-02 01:49 +0000
« prev ^ index » next coverage.py v7.8.0, created at 2025-07-02 01:49 +0000
1import torch
2from torch import nn
3from torch import Tensor
4import torch.nn.functional as F
5from torch.autograd import Variable
8def focal_loss_with_smoothing(logits, label, weight=None, gamma=0.0, label_smoothing=0.0):
9 """
10 Adapted from https://github.com/Kageshimasu/focal-loss-with-smoothing
11 and https://github.com/clcarwin/focal_loss_pytorch
12 """
13 log_probabilities = F.log_softmax(logits, dim=-1)
14 label = label.view(-1,1)
15 log_probability = log_probabilities.gather(1,label).squeeze()
16 n_classes = logits.size(1)
17 uniform_probability = label_smoothing / n_classes
18 label_distribution = torch.full_like(logits, uniform_probability)
19 label_distribution.scatter_(1, label, 1.0 - label_smoothing + uniform_probability)
21 probability = Variable(log_probability.data.exp())
22 difficulty_level = (1-probability)** gamma
23 loss = -difficulty_level * torch.sum(log_probabilities * label_distribution, dim=1)
25 # Weights
26 if weight is not None:
27 weight = weight.to(logits.device)
28 loss *= torch.gather(weight, -1, label.squeeze())/weight.mean()
30 return loss.mean()
33class HierarchicalSoftmaxLoss(nn.Module):
34 """
35 A module which sums the loss for each level of a hiearchical tree.
36 """
37 def __init__(
38 self,
39 root,
40 **kwargs
41 ):
42 super().__init__(**kwargs)
43 self.root = root
45 # Set the indexes of the tree if necessary
46 self.root.set_indexes_if_unset()
48 assert len(self.root.node_list) > 0
50 def forward(self, batch_predictions: Tensor, targets: Tensor) -> Tensor:
51 target_nodes = (self.root.node_list[target] for target in targets)
53 total_loss = 0.0
54 device = targets.device
56 for prediction, target_node in zip(batch_predictions, target_nodes):
57 node = target_node
58 while node.parent:
59 # if this is the sole child, then skip it
60 if len(node.parent.children) == 1:
61 node = node.parent
62 continue
64 node.index_in_parent_tensor = node.index_in_parent_tensor.to(device) # can this be done elsewhere?
65 logits = torch.unsqueeze(prediction[node.parent.softmax_start_index:node.parent.softmax_end_index], dim=0)
66 label = node.index_in_parent_tensor
67 weight = node.parent.weight
68 label_smoothing = node.parent.label_smoothing
69 gamma = node.parent.gamma
70 if gamma is not None and gamma > 0.0:
71 loss = focal_loss_with_smoothing(
72 logits,
73 label,
74 weight=weight,
75 label_smoothing=label_smoothing,
76 gamma=gamma,
77 )
78 else:
79 loss = F.cross_entropy(
80 logits,
81 label,
82 weight=weight,
83 label_smoothing=label_smoothing,
84 )
86 total_loss += node.parent.alpha * loss
87 node = node.parent
89 batch_size = len(targets)
90 total_loss /= batch_size
91 return total_loss
96# class HierarchicalSoftmaxLoss(nn.Module):
97# """
98# A module which sums the loss for each level of a hiearchical tree.
99# """
100# def __init__(
101# self,
102# root,
103# **kwargs
104# ):
105# super().__init__(**kwargs)
106# self.root = root
108# # Set the indexes of the tree if necessary
109# self.root.set_indexes_if_unset()
111# assert len(self.root.node_list) > 0
113# def forward(self, batch_predictions: Tensor, targets: Tensor) -> Tensor:
114# target_nodes = (self.root.node_list[target] for target in targets)
116# loss = 0.0
117# device = targets.device
119# for prediction, target_node in zip(batch_predictions, target_nodes):
120# node = target_node
121# while node.parent:
122# node.index_in_parent_tensor = node.index_in_parent_tensor.to(device) # can this be done elsewhere?
123# print(node.index_in_parent_tensor)
124# loss += node.parent.alpha * F.cross_entropy(
125# torch.unsqueeze(prediction[node.parent.softmax_start_index:node.parent.softmax_end_index], dim=0),
126# node.index_in_parent_tensor,
127# weight=node.parent.weight,
128# label_smoothing=node.parent.label_smoothing,
129# )
130# node = node.parent
132# batch_size = len(targets)
133# loss /= batch_size
134# return loss