Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1import enum 

2import types 

3from pathlib import Path 

4from typing import get_type_hints, List 

5import torchvision.models as models 

6from torch import nn 

7from fastai.vision.learner import cnn_learner, unet_learner 

8 

9from .apps import TorchApp 

10from .params import Param 

11 

12 

13def torchvision_model_choices() -> List[str]: 

14 """ 

15 Returns a list of function names in torchvision.models which can produce torch modules. 

16 

17 For more information see: https://pytorch.org/vision/stable/models.html 

18 """ 

19 model_choices = [""] # Allow for blank option 

20 for item in dir(models): 

21 obj = getattr(models, item) 

22 

23 # Only accept functions 

24 if isinstance(obj, types.FunctionType): 

25 

26 # Only accept if the return value is a pytorch module 

27 hints = get_type_hints(obj) 

28 return_value = hints.get("return", "") 

29 try: 

30 mro = return_value.mro() 

31 if nn.Module in mro: 

32 model_choices.append(item) 

33 except TypeError: 

34 pass 

35 

36 return model_choices 

37 

38TorchvisionModelEnum = enum.Enum( 

39 "TorchvisionModelName", 

40 {model_name if model_name else "default": model_name for model_name in torchvision_model_choices()}, 

41) 

42 

43 

44class VisionApp(TorchApp): 

45 """ 

46 A TorchApp which uses a model from torchvision. 

47 

48 The default base torchvision model is resnet18. 

49 """ 

50 

51 def default_model_name(self): 

52 return "resnet18" 

53 

54 def model( 

55 self, 

56 model_name: TorchvisionModelEnum = Param( 

57 default="", 

58 help="The name of a model architecture in torchvision.models (https://pytorch.org/vision/stable/models.html). If not given, then it is given by `default_model_name`", 

59 ), 

60 ): 

61 if not model_name: 

62 model_name = self.default_model_name() 

63 

64 if not hasattr(models, model_name): 

65 raise ValueError(f"Model '{model_name}' not recognized.") 

66 

67 return getattr(models, model_name) 

68 

69 def build_learner_func(self): 

70 return cnn_learner 

71 

72 def learner_kwargs( 

73 self, 

74 output_dir: Path = Param("./outputs", help="The location of the output directory."), 

75 pretrained: bool = Param(default=True, help="Whether or not to use the pretrained weights."), 

76 weight_decay: float = Param( 

77 None, help="The amount of weight decay. If None then it uses the default amount of weight decay in fastai." 

78 ), 

79 **kwargs, 

80 ): 

81 kwargs = super().learner_kwargs(output_dir=output_dir, weight_decay=weight_decay, **kwargs) 

82 kwargs['pretrained'] = pretrained 

83 self.fine_tune = pretrained 

84 return kwargs 

85 

86 

87class UNetApp(VisionApp): 

88 """ 

89 A TorchApp which uses a base model from torchvision which is modified. 

90 

91 Useful for image segmentation, super-resolution or colorization. 

92 The default base torchvision model is resnet18. 

93 

94 For more information see: 

95 Olaf Ronneberger, Philipp Fischer, Thomas Brox, 

96 U-Net: Convolutional Networks for Biomedical Image Segmentation, 

97 https://arxiv.org/abs/1505.04597 

98 https://github.com/fastai/fastbook/blob/master/15_arch_details.ipynb 

99 """ 

100 

101 def build_learner_func(self): 

102 """ 

103 Returns unet_learner 

104 

105 For more information see: https://docs.fast.ai/vision.learner.html#unet_learner 

106 """ 

107 return unet_learner