Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1from sklearn.metrics import f1_score 

2 

3 

4def logit_accuracy(predictions, target): 

5 """ 

6 Gives the accuracy when the output of the network is in logits and the target is binary. 

7 

8 For example, this can be used with BCEWithLogitsLoss. 

9 """ 

10 return ((predictions > 0.0) == (target > 0.5)).float().mean() 

11 

12 

13def logit_f1(logits, target): 

14 """ 

15 Gives the f1 score when the output of the network is in logits and the target is binary. 

16 

17 For example, this can be used with BCEWithLogitsLoss. 

18 """ 

19 predictions = logits > 0.0 

20 target_binary = target > 0.5 

21 return f1_score(target_binary.cpu(), predictions.cpu())