Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1import torch 

2from torch import nn 

3from torch import Tensor 

4import torch.nn.functional as F 

5from torch.autograd import Variable 

6 

7 

8def focal_loss_with_smoothing(logits, label, weight=None, gamma=0.0, label_smoothing=0.0): 

9 """  

10 Adapted from https://github.com/Kageshimasu/focal-loss-with-smoothing  

11 and https://github.com/clcarwin/focal_loss_pytorch 

12 """ 

13 log_probabilities = F.log_softmax(logits, dim=-1) 

14 label = label.view(-1,1) 

15 log_probability = log_probabilities.gather(1,label).squeeze() 

16 n_classes = logits.size(1) 

17 uniform_probability = label_smoothing / n_classes 

18 label_distribution = torch.full_like(logits, uniform_probability) 

19 label_distribution.scatter_(1, label, 1.0 - label_smoothing + uniform_probability) 

20 

21 probability = Variable(log_probability.data.exp()) 

22 difficulty_level = (1-probability)** gamma 

23 loss = -difficulty_level * torch.sum(log_probabilities * label_distribution, dim=1) 

24 

25 # Weights 

26 if weight is not None: 

27 weight = weight.to(logits.device) 

28 loss *= torch.gather(weight, -1, label.squeeze())/weight.mean() 

29 

30 return loss.mean() 

31 

32 

33class HierarchicalSoftmaxLoss(nn.Module): 

34 """ 

35 A module which sums the loss for each level of a hiearchical tree. 

36 """ 

37 def __init__( 

38 self, 

39 root, 

40 **kwargs 

41 ): 

42 super().__init__(**kwargs) 

43 self.root = root 

44 

45 # Set the indexes of the tree if necessary 

46 self.root.set_indexes_if_unset() 

47 

48 assert len(self.root.node_list) > 0 

49 

50 def forward(self, batch_predictions: Tensor, targets: Tensor) -> Tensor: 

51 target_nodes = (self.root.node_list[target] for target in targets) 

52 

53 total_loss = 0.0 

54 device = targets.device 

55 

56 for prediction, target_node in zip(batch_predictions, target_nodes): 

57 node = target_node 

58 while node.parent: 

59 node.index_in_parent_tensor = node.index_in_parent_tensor.to(device) # can this be done elsewhere? 

60 logits = torch.unsqueeze(prediction[node.parent.softmax_start_index:node.parent.softmax_end_index], dim=0) 

61 label = node.index_in_parent_tensor 

62 weight = node.parent.weight 

63 label_smoothing = node.parent.label_smoothing 

64 gamma = node.parent.gamma 

65 if gamma is not None and gamma > 0.0: 

66 loss = focal_loss_with_smoothing( 

67 logits, 

68 label, 

69 weight=weight, 

70 label_smoothing=label_smoothing, 

71 gamma=gamma, 

72 ) 

73 else: 

74 loss = F.cross_entropy( 

75 logits, 

76 label, 

77 weight=weight, 

78 label_smoothing=label_smoothing, 

79 ) 

80 

81 total_loss += node.parent.alpha * loss 

82 node = node.parent 

83 

84 batch_size = len(targets) 

85 total_loss /= batch_size 

86 return total_loss 

87 

88 

89 

90 

91# class HierarchicalSoftmaxLoss(nn.Module): 

92# """ 

93# A module which sums the loss for each level of a hiearchical tree. 

94# """ 

95# def __init__( 

96# self, 

97# root, 

98# **kwargs 

99# ): 

100# super().__init__(**kwargs) 

101# self.root = root 

102 

103# # Set the indexes of the tree if necessary 

104# self.root.set_indexes_if_unset() 

105 

106# assert len(self.root.node_list) > 0 

107 

108# def forward(self, batch_predictions: Tensor, targets: Tensor) -> Tensor: 

109# target_nodes = (self.root.node_list[target] for target in targets) 

110 

111# loss = 0.0 

112# device = targets.device 

113 

114# for prediction, target_node in zip(batch_predictions, target_nodes): 

115# node = target_node 

116# while node.parent: 

117# node.index_in_parent_tensor = node.index_in_parent_tensor.to(device) # can this be done elsewhere? 

118# print(node.index_in_parent_tensor) 

119# loss += node.parent.alpha * F.cross_entropy( 

120# torch.unsqueeze(prediction[node.parent.softmax_start_index:node.parent.softmax_end_index], dim=0), 

121# node.index_in_parent_tensor, 

122# weight=node.parent.weight, 

123# label_smoothing=node.parent.label_smoothing, 

124# ) 

125# node = node.parent 

126 

127# batch_size = len(targets) 

128# loss /= batch_size 

129# return loss