Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1from enum import Enum 

2from torch import nn 

3 

4 

5class ActivationError(Exception): 

6 """An exception used in the TorchApp Activation module.""" 

7 

8 pass 

9 

10 

11class Activation(str, Enum): 

12 """ 

13 Non-linear activation functions used in pytorch 

14 

15 See https://pytorch.org/docs/stable/nn.html#non-linear-activations-weighted-sum-nonlinearity 

16 

17 Excludes activation funtions that require arguments (i.e. MultiheadAttention and Threshold). 

18 """ 

19 

20 ELU = "ELU" 

21 Hardshrink = "Hardshrink" 

22 Hardsigmoid = "Hardsigmoid" 

23 Hardtanh = "Hardtanh" 

24 Hardswish = "Hardswish" 

25 LeakyReLU = "LeakyReLU" 

26 LogSigmoid = "LogSigmoid" 

27 PReLU = "PReLU" 

28 ReLU = "ReLU" 

29 ReLU6 = "ReLU6" 

30 RReLU = "RReLU" 

31 SELU = "SELU" 

32 CELU = "CELU" 

33 GELU = "GELU" 

34 Sigmoid = "Sigmoid" 

35 SiLU = "SiLU" 

36 Mish = "Mish" 

37 Softplus = "Softplus" 

38 Softshrink = "Softshrink" 

39 Softsign = "Softsign" 

40 Tanh = "Tanh" 

41 Tanhshrink = "Tanhshrink" 

42 GLU = "GLU" 

43 

44 def __str__(self): 

45 return self.value 

46 

47 def module(self, *args, **kwargs): 

48 """ 

49 Returns the pytorch module for this activation function. 

50 

51 Args: 

52 args: Arguments to pass to the function to create the module. 

53 kwargs: Keyword arguments to pass to the function to create the module. 

54 Raises: 

55 ActivationError: If the activation function is not available in pytorch 

56 

57 Returns: 

58 nn.Module: The pytorch module for this activation function. 

59 """ 

60 if not hasattr(nn, self.value): 

61 raise ActivationError(f"Activation function '{self.value}' not available.") 

62 

63 return getattr(nn, self.value)(*args, **kwargs) 

64 

65 @classmethod 

66 def default_tune_choices(cls): 

67 return [ 

68 cls.ReLU, 

69 cls.LeakyReLU, 

70 cls.RReLU, 

71 cls.ELU, 

72 cls.Hardsigmoid, 

73 ]