Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1import torch
2from torch import nn
3from torch import Tensor
4import numpy as np
5import math
7@torch.jit.script
8def autocrop(encoder_layer: torch.Tensor, decoder_layer: torch.Tensor):
9 """
10 Center-crops the encoder_layer to the size of the decoder_layer,
11 so that merging (concatenation) between levels/blocks is possible.
12 This is only necessary for input sizes != 2**n for 'same' padding and always required for 'valid' padding.
14 Taken from https://towardsdatascience.com/creating-and-training-a-u-net-model-with-pytorch-for-2d-3d-semantic-segmentation-model-building-6ab09d6a0862
15 """
16 if encoder_layer.shape[2:] != decoder_layer.shape[2:]:
17 ds = encoder_layer.shape[2:]
18 es = decoder_layer.shape[2:]
19 assert ds[0] >= es[0]
20 assert ds[1] >= es[1]
21 if encoder_layer.dim() == 4: # 2D
22 encoder_layer = encoder_layer[
23 :,
24 :,
25 ((ds[0] - es[0]) // 2):((ds[0] + es[0]) // 2),
26 ((ds[1] - es[1]) // 2):((ds[1] + es[1]) // 2)
27 ]
28 elif encoder_layer.dim() == 5: # 3D
29 assert ds[2] >= es[2]
30 encoder_layer = encoder_layer[
31 :,
32 :,
33 ((ds[0] - es[0]) // 2):((ds[0] + es[0]) // 2),
34 ((ds[1] - es[1]) // 2):((ds[1] + es[1]) // 2),
35 ((ds[2] - es[2]) // 2):((ds[2] + es[2]) // 2),
36 ]
37 return encoder_layer, decoder_layer
40def Conv(*args, dim:int, **kwargs):
41 if dim == 2:
42 return nn.Conv2d(*args, **kwargs)
43 if dim == 3:
44 return nn.Conv3d(*args, **kwargs)
45 raise ValueError(f"dimension {dim} not supported")
48def BatchNorm(*args, dim:int, **kwargs):
49 if dim == 2:
50 return nn.BatchNorm2d(*args, **kwargs)
51 if dim == 3:
52 return nn.BatchNorm3d(*args, **kwargs)
53 raise ValueError(f"dimension {dim} not supported")
56def ConvTranspose(*args, dim:int, **kwargs):
57 if dim == 2:
58 return nn.ConvTranspose2d(*args, **kwargs)
59 if dim == 3:
60 return nn.ConvTranspose3d(*args, **kwargs)
61 raise ValueError(f"dimension {dim} not supported")
64def AdaptiveAvgPool(*args, dim:int, **kwargs):
65 if dim == 2:
66 return nn.AdaptiveAvgPool2d(*args, **kwargs)
67 if dim == 3:
68 return nn.AdaptiveAvgPool3d(*args, **kwargs)
69 raise ValueError(f"dimension {dim} not supported")
72class PositionalEncoding(nn.Module):
73 """
74 Transforming time/noise values into embedding vectors.
76 Taken from: https://github.com/Janspiry/Image-Super-Resolution-via-Iterative-Refinement/blob/master/model/sr3_modules/unet.py#L18
77 """
79 def __init__(self, embedding_dim):
80 """
81 Arguments:
82 embedding_dim:
83 The dimension of the output positional embedding
84 """
85 super(PositionalEncoding, self).__init__()
86 self.embedding_dim = embedding_dim
88 def forward(self, position_level):
89 """
90 Arguments:
91 position_level:
92 The positional information to be encoded. Can be either time value or noise variance value.
93 """
94 count = self.embedding_dim // 2
95 step = torch.arange(count, dtype=position_level.dtype, device=position_level.device) / count
97 encoding = position_level.unsqueeze(1) * torch.exp(- math.log(1e4) * step.unsqueeze(0))
98 encoding = torch.cat([torch.sin(encoding), torch.cos(encoding)], dim=-1)
100 return encoding
102class FeatureWiseAffine(nn.Module):
103 """
104 FiLM layer that integrage noise/time information into the input image
106 Taken from: https://github.com/Janspiry/Image-Super-Resolution-via-Iterative-Refinement/blob/master/model/sr3_modules/unet.py#L34
107 Based on: https://distill.pub/2018/feature-wise-transformations/
108 """
110 def __init__(self, dim: int, embedding_dim: int, image_channels: int, use_affine: bool):
111 """
112 Arguments:
113 dim:
114 the dimension of the image. Value should be 2 or 3
115 embedding_dim:
116 the length of the noise embedding
117 image_channels:
118 the length of the noise embedding. This value will equal to the input image's channel size
119 use_affine:
120 Whether to use FeatureWiseAffine to integrate the noise information. If False, the noise_emb will
121 simply be projected and reshape to match the image size, then added to the image.
122 """
123 super(FeatureWiseAffine, self).__init__()
124 self.dim = dim # dimension of the image: 2D or 3D
125 self.use_affine = use_affine
126 self.noise_func = nn.Sequential(
127 nn.Linear(embedding_dim, image_channels * (1 + use_affine))
128 )
130 def forward(self, x, position_emb):
131 """
132 Arguments:
133 x:
134 the target image that the function is altering
135 position_emb:
136 the vector representation of the position level information.
137 Return:
138 x:
139 the image altered based on the position level information
140 """
141 batch = x.shape[0]
143 if self.dim == 2:
144 position_emb = self.noise_func(position_emb).view(batch, -1, 1, 1)
145 elif self.dim == 3:
146 position_emb = self.noise_func(position_emb).view(batch, -1, 1, 1, 1)
148 if self.use_affine:
149 gamma, beta = position_emb.chunk(2, dim=1)
150 x = gamma * x + beta
151 else:
152 x = x + position_emb
154 return x
157class SelfAttention(nn.Module):
158 def __init__(self, dim, in_channels, num_heads:int=1) -> None:
159 """
160 Arguments:
161 dim:
162 the dimension of the image. Value should be 2 or 3
163 in_channels:
164 the number of channel of the image the module is self-attented to
165 num_heads:
166 the number of heads used in the self attntion module
167 """
168 super(SelfAttention, self).__init__()
169 self.dim = dim
170 self.num_heads = num_heads
172 self.norm = BatchNorm(in_channels, dim=dim)
173 self.qkv_generator = Conv(in_channels, in_channels * 3, kernel_size=1, stride =1, dim=dim)
174 self.output = Conv(in_channels, in_channels, kernel_size=1, dim=dim)
176 if dim == 2:
177 self.attn_mask_eq = "bnchw, bncyx -> bnhwyx"
178 self.attn_value_eq = "bnhwyx, bncyx -> bnchw"
179 elif dim == 3:
180 self.attn_mask_eq = "bncdhw, bnczyx -> bndhwzyx"
181 self.attn_value_eq = "bndhwzyx, bnczyx -> bncdhw"
184 def forward(self, x):
186 head_dim = x.shape[1] // self.num_heads
188 normalised_x = self.norm(x)
190 # compute query key value vectors
191 qkv = self.qkv_generator(normalised_x).view(x.shape[0], self.num_heads, head_dim * 3, *x.shape[2:])
192 query, key, value = qkv.chunk(3, dim=2) # split qkv along the head_dim axis
194 # compute attention mask
195 attn_mask = torch.einsum(self.attn_mask_eq, query, key) / math.sqrt(x.shape[1])
196 attn_mask = attn_mask.view(x.shape[0], self.num_heads, *x.shape[2:], -1)
197 attn_mask = torch.softmax(attn_mask, -1)
198 attn_mask = attn_mask.view(x.shape[0], self.num_heads, *x.shape[2:], *x.shape[2:])
200 #compute attntion value
201 attn_value = torch.einsum(self.attn_value_eq, attn_mask, value)
202 attn_value = attn_value.view(*x.shape)
204 return x + self.output(attn_value)
207class ResBlock(nn.Module):
208 """
209 Based on
210 https://towardsdev.com/implement-resnet-with-pytorch-a9fb40a77448
211 https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py
212 """
213 def __init__(
214 self,
215 dim: int,
216 in_channels: int,
217 out_channels: int,
218 downsample: bool,
219 kernel_size: int = 3,
220 position_emb_dim: int = None,
221 use_affine: bool = False,
222 use_attn: bool=False
223 ):
224 super().__init__()
225 self.in_channels = in_channels
226 self.out_channels = out_channels
227 self.downsample = downsample
228 self.affine = use_affine
229 self.use_attn = use_attn
231 # calculate padding so that the output is the same as a kernel size of 1 with zero padding
232 # this is required to be calculated becaues padding="same" doesn't work with a stride
233 padding = (kernel_size - 1)//2
235 # position_emb_dim is used as an idicator for incorporating position information or not
236 self.position_emb_dim = position_emb_dim
237 if position_emb_dim is not None:
238 self.noise_func = FeatureWiseAffine(
239 dim=dim,
240 embedding_dim=position_emb_dim,
241 image_channels=out_channels,
242 use_affine=use_affine
243 )
245 if downsample:
246 self.conv1 = Conv(in_channels, out_channels, kernel_size=kernel_size, stride=2, padding=padding, dim=dim)
247 self.shortcut = nn.Sequential(
248 Conv(in_channels, out_channels, kernel_size=1, stride=2, dim=dim),
249 BatchNorm(out_channels, dim=dim)
250 )
251 else:
252 self.conv1 = Conv(in_channels, out_channels, kernel_size=kernel_size, stride=1, padding=padding, dim=dim)
253 self.shortcut = nn.Sequential()
255 self.conv2 = Conv(out_channels, out_channels, kernel_size=kernel_size, stride=1, padding=padding, dim=dim)
256 self.bn1 = BatchNorm(out_channels, dim=dim)
257 self.bn2 = BatchNorm(out_channels, dim=dim)
258 self.relu = nn.ReLU(inplace=True)
260 if use_attn:
261 self.attn = SelfAttention(dim=dim, in_channels=out_channels)
263 def forward(self, x: Tensor, position_emb: Tensor = None):
264 shortcut = self.shortcut(x)
265 x = self.relu(self.bn1(self.conv1(x)))
267 # incorporate position information only if position_emb is provided and noise_func exist
268 if position_emb is not None and self.position_emb_dim is not None:
269 x = self.noise_func (x, position_emb)
271 x = self.relu(self.bn2(self.conv2(x)))
272 x = self.relu(x + shortcut)
274 if self.use_attn:
275 x = self.attn(x)
277 return x
280class DownBlock(nn.Module):
281 def __init__(
282 self,
283 dim:int,
284 in_channels:int = 1,
285 downsample:bool = True,
286 growth_factor:float = 2.0,
287 kernel_size:int = 3,
288 position_emb_dim:int = None,
289 use_affine:bool = False,
290 use_attn:bool = False
291 ):
292 super().__init__()
293 self.in_channels = in_channels
294 self.out_channels = in_channels
295 self.position_emb_dim = position_emb_dim
296 self.use_affine = use_affine
297 self.use_attn = use_attn
299 if downsample:
300 self.out_channels = int(growth_factor*self.out_channels)
302 self.block1 = ResBlock(
303 dim=dim,
304 in_channels=in_channels,
305 out_channels=self.out_channels,
306 downsample=downsample,
307 kernel_size=kernel_size,
308 position_emb_dim=position_emb_dim,
309 use_affine=use_affine,
310 use_attn=use_attn,
311 )
312 self.block2 = ResBlock(
313 dim=dim,
314 in_channels=self.out_channels,
315 out_channels=self.out_channels,
316 downsample=False,
317 kernel_size=kernel_size,
318 position_emb_dim=position_emb_dim,
319 use_affine=use_affine,
320 use_attn=use_attn,
321 )
323 def forward(self, x: Tensor, position_emb: Tensor = None) -> Tensor:
324 x = self.block1(x, position_emb)
325 x = self.block2(x, position_emb)
326 return x
329class UpBlock(nn.Module):
330 def __init__(
331 self,
332 dim:int,
333 in_channels:int,
334 out_channels:int,
335 resblock_kernel_size:int = 3,
336 upsample_kernel_size:int = 2,
337 position_emb_dim: int = None,
338 use_affine: bool = False,
339 use_attn: bool = False
340 ):
341 super().__init__()
342 self.in_channels = in_channels
343 self.out_channels = out_channels
344 self.position_emb_dim = position_emb_dim
345 self.use_affine = use_affine
346 self.use_attn = use_attn
348 self.upsample = ConvTranspose(
349 in_channels=self.in_channels,
350 out_channels=self.out_channels,
351 kernel_size=upsample_kernel_size,
352 stride=2,
353 dim=dim)
355 self.block1 = ResBlock(
356 dim=dim,
357 in_channels=self.out_channels,
358 out_channels=self.out_channels,
359 downsample=False,
360 kernel_size=resblock_kernel_size,
361 position_emb_dim=position_emb_dim,
362 use_affine=use_affine,
363 use_attn=use_attn
364 )
365 # self.block2 = ResBlock(in_channels=self.out_channels, out_channels=self.out_channels, downsample=False, dim=dim, kernel_size=resblock_kernel_size)
367 def forward(self, x: Tensor, shortcut: Tensor, position_emb: Tensor = None) -> Tensor:
368 x = self.upsample(x)
369 # crop upsampled tensor in case the size is different from the shortcut connection
370 x, shortcut = autocrop(x, shortcut)
371 """ should be concatenation, is there a reason for this implementation """
372 x += shortcut
373 x = self.block1(x, position_emb)
374 # x = self.block2(x)
375 return x
378class ResNetBody(nn.Module):
379 def __init__(
380 self,
381 dim:int,
382 in_channels:int = 1,
383 initial_features:int = 64,
384 growth_factor:float = 2.0,
385 kernel_size:int = 3,
386 stub_kernel_size:int = 7,
387 layers:int = 4,
388 attn_layers=(3,),
389 position_emb_dim:int = None,
390 use_affine:bool = False,
391 ):
392 super().__init__()
394 self.dim = dim
395 self.in_channels = in_channels
396 self.initial_features = initial_features
397 self.growth_factor = growth_factor
398 self.kernel_size = kernel_size
399 self.stub_kernel_size = stub_kernel_size
400 self.layers = layers
401 self.attn_layers = attn_layers
402 self.position_emb_dim = position_emb_dim
403 self.use_affine = use_affine
405 current_num_features = initial_features
406 padding = (stub_kernel_size - 1)//2
408 self.stem = nn.Sequential(
409 Conv(in_channels=in_channels, out_channels=current_num_features, kernel_size=stub_kernel_size, stride=2, padding=padding, dim=dim),
410 BatchNorm(num_features=current_num_features, dim=dim),
411 nn.ReLU(inplace=True),
412 )
414 self.downblock_layers = nn.ModuleList()
415 for layer_idx in range(layers):
416 downblock = DownBlock(
417 dim=dim,
418 in_channels=current_num_features,
419 downsample=True,
420 growth_factor=growth_factor,
421 kernel_size=kernel_size,
422 position_emb_dim=position_emb_dim,
423 use_affine=use_affine,
424 use_attn = (layer_idx in attn_layers),
425 )
426 self.downblock_layers.append(downblock)
427 current_num_features = downblock.out_channels
429 self.output_features = current_num_features
431 def forward(self, x: Tensor, position_emb: Tensor = None) -> Tensor:
432 x = self.stem(x)
433 for layer in self.downblock_layers:
434 x = layer(x, position_emb)
435 return x
437 def macs(self):
438 return resnetbody_macs(
439 dim=self.dim,
440 growth_factor=self.growth_factor,
441 kernel_size=self.kernel_size,
442 stub_kernel_size=self.stub_kernel_size,
443 initial_features=self.initial_features,
444 downblock_layers=self.layers,
445 )
448class ResNet(nn.Module):
449 def __init__(
450 self,
451 dim:int,
452 num_classes:int=1,
453 body: ResNetBody = None,
454 in_channels:int = 1,
455 initial_features:int = 64,
456 growth_factor:float = 2.0,
457 layers:int = 4,
458 attn_layers=(),
459 position_emb_dim:int = None,
460 use_affine:bool = False,
461 ):
462 super().__init__()
464 self.position_emb_dim = position_emb_dim
466 if position_emb_dim is not None:
467 self.position_encoder = PositionalEncoding(position_emb_dim)
469 self.body = body if body is not None else ResNetBody(
470 dim=dim,
471 in_channels=in_channels,
472 initial_features=initial_features,
473 growth_factor=growth_factor,
474 layers=layers,
475 attn_layers=attn_layers,
476 position_emb_dim=position_emb_dim,
477 use_affine=use_affine
478 )
479 assert in_channels == self.body.in_channels
480 assert initial_features == self.body.initial_features
481 assert growth_factor == self.body.growth_factor
482 assert layers == self.body.layers
483 assert attn_layers == self.body.attn_layers
484 assert position_emb_dim == self.body.position_emb_dim
485 assert use_affine == self.body.use_affine
487 self.global_average_pool = AdaptiveAvgPool(1, dim=dim)
488 self.final_layer = torch.nn.Linear(self.body.output_features, num_classes)
490 def forward(self, x: Tensor, position: Tensor = None) -> Tensor:
491 if self.position_emb_dim is not None and position is not None:
492 position_emb = self.position_encoder(position)
493 else:
494 position_emb = None
496 x = self.body(x, position_emb)
498 # Final layer
499 x = self.global_average_pool(x)
500 x = torch.flatten(x, 1)
501 output = self.final_layer(x)
502 return output
505class ResidualUNet(nn.Module):
506 def __init__(
507 self,
508 dim:int,
509 body:ResNetBody = None,
510 in_channels:int = 1,
511 initial_features:int = 64,
512 out_channels: int = 1,
513 growth_factor:float = 2.0,
514 kernel_size:int = 3,
515 downblock_layers:int = 4,
516 attn_layers = (3,),
517 position_emb_dim:int = None,
518 use_affine:bool = False
519 ):
520 super().__init__()
521 self.dim = dim
522 self.attn_layers = attn_layers
523 self.position_emb_dim = position_emb_dim
524 self.use_affine = use_affine
526 if position_emb_dim is not None:
527 self.position_encoder = PositionalEncoding(position_emb_dim)
529 self.body = body if body is not None else ResNetBody(
530 dim=dim,
531 in_channels=in_channels,
532 initial_features=initial_features,
533 growth_factor=growth_factor,
534 kernel_size=kernel_size,
535 layers=downblock_layers,
536 attn_layers=attn_layers,
537 position_emb_dim=position_emb_dim,
538 use_affine = use_affine
539 )
540 assert in_channels == self.body.in_channels
541 assert initial_features == self.body.initial_features
542 assert growth_factor == self.body.growth_factor
543 assert kernel_size == self.body.kernel_size
544 assert downblock_layers == self.body.layers
545 assert attn_layers == self.body.attn_layers
546 assert position_emb_dim == self.body.position_emb_dim
547 assert use_affine == self.body.use_affine
549 self.upblock_layers = nn.ModuleList()
550 for downblock in reversed(self.body.downblock_layers):
551 upblock = UpBlock(
552 dim=dim,
553 in_channels=downblock.out_channels,
554 out_channels=downblock.in_channels,
555 resblock_kernel_size=kernel_size,
556 position_emb_dim=position_emb_dim,
557 use_affine=use_affine,
558 use_attn=downblock.use_attn
559 )
560 self.upblock_layers.append(upblock)
562 self.final_upsample_dims = self.upblock_layers[-1].out_channels//2
563 self.final_upsample = ConvTranspose(
564 in_channels=self.upblock_layers[-1].out_channels,
565 out_channels=self.final_upsample_dims,
566 kernel_size=2,
567 stride=2,
568 dim=dim,
569 )
571 self.final_layer = Conv(
572 in_channels=self.final_upsample_dims+in_channels,
573 out_channels=out_channels,
574 kernel_size=1,
575 stride=1,
576 dim=dim,
577 )
579 def forward(self, x: Tensor, position: Tensor = None) -> Tensor:
580 if self.position_emb_dim is not None and position is not None:
581 position_emb = self.position_encoder(position)
582 else:
583 position_emb = None
585 x = x.float()
586 input = x
587 encoded_list = []
588 x = self.body.stem(x)
589 for downblock in self.body.downblock_layers:
590 encoded_list.append(x)
591 x = downblock(x, position_emb)
593 for encoded, upblock in zip(reversed(encoded_list), self.upblock_layers):
594 x = upblock(x, encoded, position_emb)
596 x = self.final_upsample(x)
597 x = torch.cat([input,x], dim=1)
598 x = self.final_layer(x)
599 # activation?
600 return x
602 def macs(self) -> float:
603 return residualunet_macs(
604 dim=self.dim,
605 growth_factor=self.body.growth_factor,
606 kernel_size=self.body.kernel_size,
607 stub_kernel_size=self.body.stub_kernel_size,
608 initial_features=self.body.initial_features,
609 downblock_layers=self.body.layers,
610 )
613def residualunet_macs(
614 dim:int,
615 growth_factor:float,
616 kernel_size:int,
617 stub_kernel_size:int,
618 initial_features:int,
619 downblock_layers:int,
620) -> float:
621 """
622 M = L + \sum_{i=0}^n D_i + U_i
623 D_0 = \frac{\kappa ^ d}{2^d} f
624 D_i = \frac{1}{2^{d(i+1)}} g^{2i-1} ( k^d( 3g + 1 ) + 1) f^2
625 U_i = \frac{f^2}{2^{d i}} (2^d g^{2i-1} + 2 k^d g^{2i-2})
626 U_0 = 2^{d-1} f ^ 2
627 L = \frac{f}{2} + 1
628 """
629 stride = 2
630 U_0 = 2 ** (dim - 1) * initial_features **2
631 L = initial_features/2 + 1
632 body_macs = resnetbody_macs(
633 dim=dim,
634 growth_factor=growth_factor,
635 kernel_size=kernel_size,
636 stub_kernel_size=stub_kernel_size,
637 initial_features=initial_features,
638 downblock_layers=downblock_layers,
639 )
640 M = L + body_macs + U_0
642 for i in range(1, downblock_layers+1):
643 U_i = initial_features**2/(stride**(dim * i)) * (
644 2**dim * growth_factor ** (2 * i - 1) +
645 2 * kernel_size**dim * growth_factor ** ( 2 * i - 2 )
646 )
648 M += U_i
650 return M
653def resnetbody_macs(
654 dim:int,
655 growth_factor:float,
656 kernel_size:int,
657 stub_kernel_size:int,
658 initial_features:int,
659 downblock_layers:int,
660) -> float:
661 """
662 M = \sum_{i=0}^n D_i
663 D_0 = \frac{\kappa ^ d}{2^d} f
664 D_i = \frac{1}{2^{d(i+1)}} g^{2i-1} ( k^d( 3g + 1 ) + 1) f^2
665 """
666 stride = 2
667 D_0 = stub_kernel_size **dim * initial_features / (stride ** dim)
668 M = D_0
670 for i in range(1, downblock_layers+1):
671 D_i = initial_features**2 /(stride**(dim * (i+1))) * growth_factor ** (2 * i - 1) * (
672 kernel_size**dim * (3 * growth_factor + 1) + 1
673 )
674 M += D_i
676 return M
679def calc_initial_features_residualunet(
680 macc:int,
681 dim:int,
682 growth_factor:float,
683 kernel_size:int,
684 stub_kernel_size:int,
685 downblock_layers:int,
686) -> int:
687 """
688 """
689 stride = 2
690 a = 2 ** (dim - 1)
691 for i in range(1, downblock_layers+1):
692 D_i_over_f2 = 1 /(stride**(dim * (i+1))) * growth_factor ** (2 * i - 1) * (
693 kernel_size**dim * (3 * growth_factor + 1) + 1
694 )
695 U_i_over_f2 = 1/(stride**(dim * i)) * (
696 2**dim * growth_factor ** (2 * i - 1) +
697 2 * kernel_size**dim * growth_factor ** ( 2 * i - 2 )
698 )
700 a += D_i_over_f2 + U_i_over_f2
702 b = stub_kernel_size **dim / (stride ** dim) + 0.5
703 c = -macc + 1
705 initial_features = (-b + np.sqrt(b**2 - 4*a*c))/(2 * a)
707 return int(initial_features + 0.5)