Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1import enum
2import types
3from pathlib import Path
4from typing import get_type_hints, List
5import torchvision.models as models
6from torch import nn
7from fastai.vision.learner import cnn_learner, unet_learner
9from .apps import TorchApp
10from .params import Param
13def torchvision_model_choices() -> List[str]:
14 """
15 Returns a list of function names in torchvision.models which can produce torch modules.
17 For more information see: https://pytorch.org/vision/stable/models.html
18 """
19 model_choices = [""] # Allow for blank option
20 for item in dir(models):
21 obj = getattr(models, item)
23 # Only accept functions
24 if isinstance(obj, types.FunctionType):
26 # Only accept if the return value is a pytorch module
27 hints = get_type_hints(obj)
28 return_value = hints.get("return", "")
29 try:
30 mro = return_value.mro()
31 if nn.Module in mro:
32 model_choices.append(item)
33 except TypeError:
34 pass
36 return model_choices
38TorchvisionModelEnum = enum.Enum(
39 "TorchvisionModelName",
40 {model_name if model_name else "default": model_name for model_name in torchvision_model_choices()},
41)
44class VisionApp(TorchApp):
45 """
46 A TorchApp which uses a model from torchvision.
48 The default base torchvision model is resnet18.
49 """
51 def default_model_name(self):
52 return "resnet18"
54 def model(
55 self,
56 model_name: TorchvisionModelEnum = Param(
57 default="",
58 help="The name of a model architecture in torchvision.models (https://pytorch.org/vision/stable/models.html). If not given, then it is given by `default_model_name`",
59 ),
60 ):
61 if not model_name:
62 model_name = self.default_model_name()
64 if not hasattr(models, model_name):
65 raise ValueError(f"Model '{model_name}' not recognized.")
67 return getattr(models, model_name)
69 def build_learner_func(self):
70 return cnn_learner
72 def learner_kwargs(
73 self,
74 output_dir: Path = Param("./outputs", help="The location of the output directory."),
75 pretrained: bool = Param(default=True, help="Whether or not to use the pretrained weights."),
76 weight_decay: float = Param(
77 None, help="The amount of weight decay. If None then it uses the default amount of weight decay in fastai."
78 ),
79 **kwargs,
80 ):
81 kwargs = super().learner_kwargs(output_dir=output_dir, weight_decay=weight_decay, **kwargs)
82 kwargs['pretrained'] = pretrained
83 self.fine_tune = pretrained
84 return kwargs
87class UNetApp(VisionApp):
88 """
89 A TorchApp which uses a base model from torchvision which is modified.
91 Useful for image segmentation, super-resolution or colorization.
92 The default base torchvision model is resnet18.
94 For more information see:
95 Olaf Ronneberger, Philipp Fischer, Thomas Brox,
96 U-Net: Convolutional Networks for Biomedical Image Segmentation,
97 https://arxiv.org/abs/1505.04597
98 https://github.com/fastai/fastbook/blob/master/15_arch_details.ipynb
99 """
101 def build_learner_func(self):
102 """
103 Returns unet_learner
105 For more information see: https://docs.fast.ai/vision.learner.html#unet_learner
106 """
107 return unet_learner