Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1from enum import Enum
2from torch import nn
5class ActivationError(Exception):
6 """An exception used in the TorchApp Activation module."""
8 pass
11class Activation(str, Enum):
12 """
13 Non-linear activation functions used in pytorch
15 See https://pytorch.org/docs/stable/nn.html#non-linear-activations-weighted-sum-nonlinearity
17 Excludes activation funtions that require arguments (i.e. MultiheadAttention and Threshold).
18 """
20 ELU = "ELU"
21 Hardshrink = "Hardshrink"
22 Hardsigmoid = "Hardsigmoid"
23 Hardtanh = "Hardtanh"
24 Hardswish = "Hardswish"
25 LeakyReLU = "LeakyReLU"
26 LogSigmoid = "LogSigmoid"
27 PReLU = "PReLU"
28 ReLU = "ReLU"
29 ReLU6 = "ReLU6"
30 RReLU = "RReLU"
31 SELU = "SELU"
32 CELU = "CELU"
33 GELU = "GELU"
34 Sigmoid = "Sigmoid"
35 SiLU = "SiLU"
36 Mish = "Mish"
37 Softplus = "Softplus"
38 Softshrink = "Softshrink"
39 Softsign = "Softsign"
40 Tanh = "Tanh"
41 Tanhshrink = "Tanhshrink"
42 GLU = "GLU"
44 def __str__(self):
45 return self.value
47 def module(self, *args, **kwargs):
48 """
49 Returns the pytorch module for this activation function.
51 Args:
52 args: Arguments to pass to the function to create the module.
53 kwargs: Keyword arguments to pass to the function to create the module.
54 Raises:
55 ActivationError: If the activation function is not available in pytorch
57 Returns:
58 nn.Module: The pytorch module for this activation function.
59 """
60 if not hasattr(nn, self.value):
61 raise ActivationError(f"Activation function '{self.value}' not available.")
63 return getattr(nn, self.value)(*args, **kwargs)
65 @classmethod
66 def default_tune_choices(cls):
67 return [
68 cls.ReLU,
69 cls.LeakyReLU,
70 cls.RReLU,
71 cls.ELU,
72 cls.Hardsigmoid,
73 ]