Coverage for hierarchicalsoftmax/loss.py: 100.00%

50 statements  

« prev     ^ index     » next       coverage.py v7.8.0, created at 2025-07-02 01:49 +0000

1import torch 

2from torch import nn 

3from torch import Tensor 

4import torch.nn.functional as F 

5from torch.autograd import Variable 

6 

7 

8def focal_loss_with_smoothing(logits, label, weight=None, gamma=0.0, label_smoothing=0.0): 

9 """  

10 Adapted from https://github.com/Kageshimasu/focal-loss-with-smoothing  

11 and https://github.com/clcarwin/focal_loss_pytorch 

12 """ 

13 log_probabilities = F.log_softmax(logits, dim=-1) 

14 label = label.view(-1,1) 

15 log_probability = log_probabilities.gather(1,label).squeeze() 

16 n_classes = logits.size(1) 

17 uniform_probability = label_smoothing / n_classes 

18 label_distribution = torch.full_like(logits, uniform_probability) 

19 label_distribution.scatter_(1, label, 1.0 - label_smoothing + uniform_probability) 

20 

21 probability = Variable(log_probability.data.exp()) 

22 difficulty_level = (1-probability)** gamma 

23 loss = -difficulty_level * torch.sum(log_probabilities * label_distribution, dim=1) 

24 

25 # Weights 

26 if weight is not None: 

27 weight = weight.to(logits.device) 

28 loss *= torch.gather(weight, -1, label.squeeze())/weight.mean() 

29 

30 return loss.mean() 

31 

32 

33class HierarchicalSoftmaxLoss(nn.Module): 

34 """ 

35 A module which sums the loss for each level of a hiearchical tree. 

36 """ 

37 def __init__( 

38 self, 

39 root, 

40 **kwargs 

41 ): 

42 super().__init__(**kwargs) 

43 self.root = root 

44 

45 # Set the indexes of the tree if necessary 

46 self.root.set_indexes_if_unset() 

47 

48 assert len(self.root.node_list) > 0 

49 

50 def forward(self, batch_predictions: Tensor, targets: Tensor) -> Tensor: 

51 target_nodes = (self.root.node_list[target] for target in targets) 

52 

53 total_loss = 0.0 

54 device = targets.device 

55 

56 for prediction, target_node in zip(batch_predictions, target_nodes): 

57 node = target_node 

58 while node.parent: 

59 # if this is the sole child, then skip it 

60 if len(node.parent.children) == 1: 

61 node = node.parent 

62 continue 

63 

64 node.index_in_parent_tensor = node.index_in_parent_tensor.to(device) # can this be done elsewhere? 

65 logits = torch.unsqueeze(prediction[node.parent.softmax_start_index:node.parent.softmax_end_index], dim=0) 

66 label = node.index_in_parent_tensor 

67 weight = node.parent.weight 

68 label_smoothing = node.parent.label_smoothing 

69 gamma = node.parent.gamma 

70 if gamma is not None and gamma > 0.0: 

71 loss = focal_loss_with_smoothing( 

72 logits, 

73 label, 

74 weight=weight, 

75 label_smoothing=label_smoothing, 

76 gamma=gamma, 

77 ) 

78 else: 

79 loss = F.cross_entropy( 

80 logits, 

81 label, 

82 weight=weight, 

83 label_smoothing=label_smoothing, 

84 ) 

85 

86 total_loss += node.parent.alpha * loss 

87 node = node.parent 

88 

89 batch_size = len(targets) 

90 total_loss /= batch_size 

91 return total_loss 

92 

93 

94 

95 

96# class HierarchicalSoftmaxLoss(nn.Module): 

97# """ 

98# A module which sums the loss for each level of a hiearchical tree. 

99# """ 

100# def __init__( 

101# self, 

102# root, 

103# **kwargs 

104# ): 

105# super().__init__(**kwargs) 

106# self.root = root 

107 

108# # Set the indexes of the tree if necessary 

109# self.root.set_indexes_if_unset() 

110 

111# assert len(self.root.node_list) > 0 

112 

113# def forward(self, batch_predictions: Tensor, targets: Tensor) -> Tensor: 

114# target_nodes = (self.root.node_list[target] for target in targets) 

115 

116# loss = 0.0 

117# device = targets.device 

118 

119# for prediction, target_node in zip(batch_predictions, target_nodes): 

120# node = target_node 

121# while node.parent: 

122# node.index_in_parent_tensor = node.index_in_parent_tensor.to(device) # can this be done elsewhere? 

123# print(node.index_in_parent_tensor) 

124# loss += node.parent.alpha * F.cross_entropy( 

125# torch.unsqueeze(prediction[node.parent.softmax_start_index:node.parent.softmax_end_index], dim=0), 

126# node.index_in_parent_tensor, 

127# weight=node.parent.weight, 

128# label_smoothing=node.parent.label_smoothing, 

129# ) 

130# node = node.parent 

131 

132# batch_size = len(targets) 

133# loss /= batch_size 

134# return loss