Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1import torch 

2from torch import nn 

3from torch import Tensor 

4import numpy as np 

5import math 

6 

7@torch.jit.script 

8def autocrop(encoder_layer: torch.Tensor, decoder_layer: torch.Tensor): 

9 """ 

10 Center-crops the encoder_layer to the size of the decoder_layer, 

11 so that merging (concatenation) between levels/blocks is possible. 

12 This is only necessary for input sizes != 2**n for 'same' padding and always required for 'valid' padding. 

13 

14 Taken from https://towardsdatascience.com/creating-and-training-a-u-net-model-with-pytorch-for-2d-3d-semantic-segmentation-model-building-6ab09d6a0862 

15 """ 

16 if encoder_layer.shape[2:] != decoder_layer.shape[2:]: 

17 ds = encoder_layer.shape[2:] 

18 es = decoder_layer.shape[2:] 

19 assert ds[0] >= es[0] 

20 assert ds[1] >= es[1] 

21 if encoder_layer.dim() == 4: # 2D 

22 encoder_layer = encoder_layer[ 

23 :, 

24 :, 

25 ((ds[0] - es[0]) // 2):((ds[0] + es[0]) // 2), 

26 ((ds[1] - es[1]) // 2):((ds[1] + es[1]) // 2) 

27 ] 

28 elif encoder_layer.dim() == 5: # 3D 

29 assert ds[2] >= es[2] 

30 encoder_layer = encoder_layer[ 

31 :, 

32 :, 

33 ((ds[0] - es[0]) // 2):((ds[0] + es[0]) // 2), 

34 ((ds[1] - es[1]) // 2):((ds[1] + es[1]) // 2), 

35 ((ds[2] - es[2]) // 2):((ds[2] + es[2]) // 2), 

36 ] 

37 return encoder_layer, decoder_layer 

38 

39 

40def Conv(*args, dim:int, **kwargs): 

41 if dim == 2: 

42 return nn.Conv2d(*args, **kwargs) 

43 if dim == 3: 

44 return nn.Conv3d(*args, **kwargs) 

45 raise ValueError(f"dimension {dim} not supported") 

46 

47 

48def BatchNorm(*args, dim:int, **kwargs): 

49 if dim == 2: 

50 return nn.BatchNorm2d(*args, **kwargs) 

51 if dim == 3: 

52 return nn.BatchNorm3d(*args, **kwargs) 

53 raise ValueError(f"dimension {dim} not supported") 

54 

55 

56def ConvTranspose(*args, dim:int, **kwargs): 

57 if dim == 2: 

58 return nn.ConvTranspose2d(*args, **kwargs) 

59 if dim == 3: 

60 return nn.ConvTranspose3d(*args, **kwargs) 

61 raise ValueError(f"dimension {dim} not supported") 

62 

63 

64def AdaptiveAvgPool(*args, dim:int, **kwargs): 

65 if dim == 2: 

66 return nn.AdaptiveAvgPool2d(*args, **kwargs) 

67 if dim == 3: 

68 return nn.AdaptiveAvgPool3d(*args, **kwargs) 

69 raise ValueError(f"dimension {dim} not supported") 

70 

71 

72class PositionalEncoding(nn.Module): 

73 """ 

74 Transforming time/noise values into embedding vectors. 

75 

76 Taken from: https://github.com/Janspiry/Image-Super-Resolution-via-Iterative-Refinement/blob/master/model/sr3_modules/unet.py#L18 

77 """ 

78 

79 def __init__(self, embedding_dim): 

80 """ 

81 Arguments: 

82 embedding_dim: 

83 The dimension of the output positional embedding 

84 """ 

85 super(PositionalEncoding, self).__init__() 

86 self.embedding_dim = embedding_dim 

87 

88 def forward(self, position_level): 

89 """ 

90 Arguments: 

91 position_level: 

92 The positional information to be encoded. Can be either time value or noise variance value. 

93 """ 

94 count = self.embedding_dim // 2 

95 step = torch.arange(count, dtype=position_level.dtype, device=position_level.device) / count 

96 

97 encoding = position_level.unsqueeze(1) * torch.exp(- math.log(1e4) * step.unsqueeze(0)) 

98 encoding = torch.cat([torch.sin(encoding), torch.cos(encoding)], dim=-1) 

99 

100 return encoding 

101 

102class FeatureWiseAffine(nn.Module): 

103 """ 

104 FiLM layer that integrage noise/time information into the input image 

105  

106 Taken from: https://github.com/Janspiry/Image-Super-Resolution-via-Iterative-Refinement/blob/master/model/sr3_modules/unet.py#L34 

107 Based on: https://distill.pub/2018/feature-wise-transformations/ 

108 """ 

109 

110 def __init__(self, dim: int, embedding_dim: int, image_channels: int, use_affine: bool): 

111 """ 

112 Arguments:  

113 dim: 

114 the dimension of the image. Value should be 2 or 3 

115 embedding_dim: 

116 the length of the noise embedding 

117 image_channels: 

118 the length of the noise embedding. This value will equal to the input image's channel size 

119 use_affine: 

120 Whether to use FeatureWiseAffine to integrate the noise information. If False, the noise_emb will 

121 simply be projected and reshape to match the image size, then added to the image. 

122 """ 

123 super(FeatureWiseAffine, self).__init__() 

124 self.dim = dim # dimension of the image: 2D or 3D 

125 self.use_affine = use_affine 

126 self.noise_func = nn.Sequential( 

127 nn.Linear(embedding_dim, image_channels * (1 + use_affine)) 

128 ) 

129 

130 def forward(self, x, position_emb): 

131 """ 

132 Arguments: 

133 x: 

134 the target image that the function is altering 

135 position_emb: 

136 the vector representation of the position level information. 

137 Return: 

138 x: 

139 the image altered based on the position level information 

140 """ 

141 batch = x.shape[0] 

142 

143 if self.dim == 2: 

144 position_emb = self.noise_func(position_emb).view(batch, -1, 1, 1) 

145 elif self.dim == 3: 

146 position_emb = self.noise_func(position_emb).view(batch, -1, 1, 1, 1) 

147 

148 if self.use_affine: 

149 gamma, beta = position_emb.chunk(2, dim=1) 

150 x = gamma * x + beta 

151 else: 

152 x = x + position_emb 

153 

154 return x 

155 

156 

157class SelfAttention(nn.Module): 

158 def __init__(self, dim, in_channels, num_heads:int=1) -> None: 

159 """ 

160 Arguments: 

161 dim: 

162 the dimension of the image. Value should be 2 or 3 

163 in_channels: 

164 the number of channel of the image the module is self-attented to 

165 num_heads: 

166 the number of heads used in the self attntion module 

167 """ 

168 super(SelfAttention, self).__init__() 

169 self.dim = dim 

170 self.num_heads = num_heads 

171 

172 self.norm = BatchNorm(in_channels, dim=dim) 

173 self.qkv_generator = Conv(in_channels, in_channels * 3, kernel_size=1, stride =1, dim=dim) 

174 self.output = Conv(in_channels, in_channels, kernel_size=1, dim=dim) 

175 

176 if dim == 2: 

177 self.attn_mask_eq = "bnchw, bncyx -> bnhwyx" 

178 self.attn_value_eq = "bnhwyx, bncyx -> bnchw" 

179 elif dim == 3: 

180 self.attn_mask_eq = "bncdhw, bnczyx -> bndhwzyx" 

181 self.attn_value_eq = "bndhwzyx, bnczyx -> bncdhw" 

182 

183 

184 def forward(self, x): 

185 

186 head_dim = x.shape[1] // self.num_heads 

187 

188 normalised_x = self.norm(x) 

189 

190 # compute query key value vectors 

191 qkv = self.qkv_generator(normalised_x).view(x.shape[0], self.num_heads, head_dim * 3, *x.shape[2:]) 

192 query, key, value = qkv.chunk(3, dim=2) # split qkv along the head_dim axis 

193 

194 # compute attention mask 

195 attn_mask = torch.einsum(self.attn_mask_eq, query, key) / math.sqrt(x.shape[1]) 

196 attn_mask = attn_mask.view(x.shape[0], self.num_heads, *x.shape[2:], -1) 

197 attn_mask = torch.softmax(attn_mask, -1) 

198 attn_mask = attn_mask.view(x.shape[0], self.num_heads, *x.shape[2:], *x.shape[2:]) 

199 

200 #compute attntion value 

201 attn_value = torch.einsum(self.attn_value_eq, attn_mask, value) 

202 attn_value = attn_value.view(*x.shape) 

203 

204 return x + self.output(attn_value) 

205 

206 

207class ResBlock(nn.Module): 

208 """  

209 Based on 

210 https://towardsdev.com/implement-resnet-with-pytorch-a9fb40a77448  

211 https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py 

212 """ 

213 def __init__( 

214 self, 

215 dim: int, 

216 in_channels: int, 

217 out_channels: int, 

218 downsample: bool, 

219 kernel_size: int = 3, 

220 position_emb_dim: int = None, 

221 use_affine: bool = False, 

222 use_attn: bool=False 

223 ): 

224 super().__init__() 

225 self.in_channels = in_channels 

226 self.out_channels = out_channels 

227 self.downsample = downsample 

228 self.affine = use_affine 

229 self.use_attn = use_attn 

230 

231 # calculate padding so that the output is the same as a kernel size of 1 with zero padding 

232 # this is required to be calculated becaues padding="same" doesn't work with a stride 

233 padding = (kernel_size - 1)//2 

234 

235 # position_emb_dim is used as an idicator for incorporating position information or not 

236 self.position_emb_dim = position_emb_dim 

237 if position_emb_dim is not None: 

238 self.noise_func = FeatureWiseAffine( 

239 dim=dim, 

240 embedding_dim=position_emb_dim, 

241 image_channels=out_channels, 

242 use_affine=use_affine 

243 ) 

244 

245 if downsample: 

246 self.conv1 = Conv(in_channels, out_channels, kernel_size=kernel_size, stride=2, padding=padding, dim=dim) 

247 self.shortcut = nn.Sequential( 

248 Conv(in_channels, out_channels, kernel_size=1, stride=2, dim=dim), 

249 BatchNorm(out_channels, dim=dim) 

250 ) 

251 else: 

252 self.conv1 = Conv(in_channels, out_channels, kernel_size=kernel_size, stride=1, padding=padding, dim=dim) 

253 self.shortcut = nn.Sequential() 

254 

255 self.conv2 = Conv(out_channels, out_channels, kernel_size=kernel_size, stride=1, padding=padding, dim=dim) 

256 self.bn1 = BatchNorm(out_channels, dim=dim) 

257 self.bn2 = BatchNorm(out_channels, dim=dim) 

258 self.relu = nn.ReLU(inplace=True) 

259 

260 if use_attn: 

261 self.attn = SelfAttention(dim=dim, in_channels=out_channels) 

262 

263 def forward(self, x: Tensor, position_emb: Tensor = None): 

264 shortcut = self.shortcut(x) 

265 x = self.relu(self.bn1(self.conv1(x))) 

266 

267 # incorporate position information only if position_emb is provided and noise_func exist 

268 if position_emb is not None and self.position_emb_dim is not None: 

269 x = self.noise_func (x, position_emb) 

270 

271 x = self.relu(self.bn2(self.conv2(x))) 

272 x = self.relu(x + shortcut) 

273 

274 if self.use_attn: 

275 x = self.attn(x) 

276 

277 return x 

278 

279 

280class DownBlock(nn.Module): 

281 def __init__( 

282 self, 

283 dim:int, 

284 in_channels:int = 1, 

285 downsample:bool = True, 

286 growth_factor:float = 2.0, 

287 kernel_size:int = 3, 

288 position_emb_dim:int = None, 

289 use_affine:bool = False, 

290 use_attn:bool = False 

291 ): 

292 super().__init__() 

293 self.in_channels = in_channels 

294 self.out_channels = in_channels 

295 self.position_emb_dim = position_emb_dim 

296 self.use_affine = use_affine 

297 self.use_attn = use_attn 

298 

299 if downsample: 

300 self.out_channels = int(growth_factor*self.out_channels) 

301 

302 self.block1 = ResBlock( 

303 dim=dim, 

304 in_channels=in_channels, 

305 out_channels=self.out_channels, 

306 downsample=downsample, 

307 kernel_size=kernel_size, 

308 position_emb_dim=position_emb_dim, 

309 use_affine=use_affine, 

310 use_attn=use_attn, 

311 ) 

312 self.block2 = ResBlock( 

313 dim=dim, 

314 in_channels=self.out_channels, 

315 out_channels=self.out_channels, 

316 downsample=False, 

317 kernel_size=kernel_size, 

318 position_emb_dim=position_emb_dim, 

319 use_affine=use_affine, 

320 use_attn=use_attn, 

321 ) 

322 

323 def forward(self, x: Tensor, position_emb: Tensor = None) -> Tensor: 

324 x = self.block1(x, position_emb) 

325 x = self.block2(x, position_emb) 

326 return x 

327 

328 

329class UpBlock(nn.Module): 

330 def __init__( 

331 self, 

332 dim:int, 

333 in_channels:int, 

334 out_channels:int, 

335 resblock_kernel_size:int = 3, 

336 upsample_kernel_size:int = 2, 

337 position_emb_dim: int = None, 

338 use_affine: bool = False, 

339 use_attn: bool = False 

340 ): 

341 super().__init__() 

342 self.in_channels = in_channels 

343 self.out_channels = out_channels 

344 self.position_emb_dim = position_emb_dim 

345 self.use_affine = use_affine 

346 self.use_attn = use_attn 

347 

348 self.upsample = ConvTranspose( 

349 in_channels=self.in_channels, 

350 out_channels=self.out_channels, 

351 kernel_size=upsample_kernel_size, 

352 stride=2, 

353 dim=dim) 

354 

355 self.block1 = ResBlock( 

356 dim=dim, 

357 in_channels=self.out_channels, 

358 out_channels=self.out_channels, 

359 downsample=False, 

360 kernel_size=resblock_kernel_size, 

361 position_emb_dim=position_emb_dim, 

362 use_affine=use_affine, 

363 use_attn=use_attn 

364 ) 

365 # self.block2 = ResBlock(in_channels=self.out_channels, out_channels=self.out_channels, downsample=False, dim=dim, kernel_size=resblock_kernel_size) 

366 

367 def forward(self, x: Tensor, shortcut: Tensor, position_emb: Tensor = None) -> Tensor: 

368 x = self.upsample(x) 

369 # crop upsampled tensor in case the size is different from the shortcut connection 

370 x, shortcut = autocrop(x, shortcut) 

371 """ should be concatenation, is there a reason for this implementation """ 

372 x += shortcut 

373 x = self.block1(x, position_emb) 

374 # x = self.block2(x) 

375 return x 

376 

377 

378class ResNetBody(nn.Module): 

379 def __init__( 

380 self, 

381 dim:int, 

382 in_channels:int = 1, 

383 initial_features:int = 64, 

384 growth_factor:float = 2.0, 

385 kernel_size:int = 3, 

386 stub_kernel_size:int = 7, 

387 layers:int = 4, 

388 attn_layers=(3,), 

389 position_emb_dim:int = None, 

390 use_affine:bool = False, 

391 ): 

392 super().__init__() 

393 

394 self.dim = dim 

395 self.in_channels = in_channels 

396 self.initial_features = initial_features 

397 self.growth_factor = growth_factor 

398 self.kernel_size = kernel_size 

399 self.stub_kernel_size = stub_kernel_size 

400 self.layers = layers 

401 self.attn_layers = attn_layers 

402 self.position_emb_dim = position_emb_dim 

403 self.use_affine = use_affine 

404 

405 current_num_features = initial_features 

406 padding = (stub_kernel_size - 1)//2 

407 

408 self.stem = nn.Sequential( 

409 Conv(in_channels=in_channels, out_channels=current_num_features, kernel_size=stub_kernel_size, stride=2, padding=padding, dim=dim), 

410 BatchNorm(num_features=current_num_features, dim=dim), 

411 nn.ReLU(inplace=True), 

412 ) 

413 

414 self.downblock_layers = nn.ModuleList() 

415 for layer_idx in range(layers): 

416 downblock = DownBlock( 

417 dim=dim, 

418 in_channels=current_num_features, 

419 downsample=True, 

420 growth_factor=growth_factor, 

421 kernel_size=kernel_size, 

422 position_emb_dim=position_emb_dim, 

423 use_affine=use_affine, 

424 use_attn = (layer_idx in attn_layers), 

425 ) 

426 self.downblock_layers.append(downblock) 

427 current_num_features = downblock.out_channels 

428 

429 self.output_features = current_num_features 

430 

431 def forward(self, x: Tensor, position_emb: Tensor = None) -> Tensor: 

432 x = self.stem(x) 

433 for layer in self.downblock_layers: 

434 x = layer(x, position_emb) 

435 return x 

436 

437 def macs(self): 

438 return resnetbody_macs( 

439 dim=self.dim, 

440 growth_factor=self.growth_factor, 

441 kernel_size=self.kernel_size, 

442 stub_kernel_size=self.stub_kernel_size, 

443 initial_features=self.initial_features, 

444 downblock_layers=self.layers, 

445 ) 

446 

447 

448class ResNet(nn.Module): 

449 def __init__( 

450 self, 

451 dim:int, 

452 num_classes:int=1, 

453 body: ResNetBody = None, 

454 in_channels:int = 1, 

455 initial_features:int = 64, 

456 growth_factor:float = 2.0, 

457 layers:int = 4, 

458 attn_layers=(), 

459 position_emb_dim:int = None, 

460 use_affine:bool = False, 

461 ): 

462 super().__init__() 

463 

464 self.position_emb_dim = position_emb_dim 

465 

466 if position_emb_dim is not None: 

467 self.position_encoder = PositionalEncoding(position_emb_dim) 

468 

469 self.body = body if body is not None else ResNetBody( 

470 dim=dim, 

471 in_channels=in_channels, 

472 initial_features=initial_features, 

473 growth_factor=growth_factor, 

474 layers=layers, 

475 attn_layers=attn_layers, 

476 position_emb_dim=position_emb_dim, 

477 use_affine=use_affine 

478 ) 

479 assert in_channels == self.body.in_channels 

480 assert initial_features == self.body.initial_features 

481 assert growth_factor == self.body.growth_factor 

482 assert layers == self.body.layers 

483 assert attn_layers == self.body.attn_layers 

484 assert position_emb_dim == self.body.position_emb_dim 

485 assert use_affine == self.body.use_affine 

486 

487 self.global_average_pool = AdaptiveAvgPool(1, dim=dim) 

488 self.final_layer = torch.nn.Linear(self.body.output_features, num_classes) 

489 

490 def forward(self, x: Tensor, position: Tensor = None) -> Tensor: 

491 if self.position_emb_dim is not None and position is not None: 

492 position_emb = self.position_encoder(position) 

493 else: 

494 position_emb = None 

495 

496 x = self.body(x, position_emb) 

497 

498 # Final layer 

499 x = self.global_average_pool(x) 

500 x = torch.flatten(x, 1) 

501 output = self.final_layer(x) 

502 return output 

503 

504 

505class ResidualUNet(nn.Module): 

506 def __init__( 

507 self, 

508 dim:int, 

509 body:ResNetBody = None, 

510 in_channels:int = 1, 

511 initial_features:int = 64, 

512 out_channels: int = 1, 

513 growth_factor:float = 2.0, 

514 kernel_size:int = 3, 

515 downblock_layers:int = 4, 

516 attn_layers = (3,), 

517 position_emb_dim:int = None, 

518 use_affine:bool = False 

519 ): 

520 super().__init__() 

521 self.dim = dim 

522 self.attn_layers = attn_layers 

523 self.position_emb_dim = position_emb_dim 

524 self.use_affine = use_affine 

525 

526 if position_emb_dim is not None: 

527 self.position_encoder = PositionalEncoding(position_emb_dim) 

528 

529 self.body = body if body is not None else ResNetBody( 

530 dim=dim, 

531 in_channels=in_channels, 

532 initial_features=initial_features, 

533 growth_factor=growth_factor, 

534 kernel_size=kernel_size, 

535 layers=downblock_layers, 

536 attn_layers=attn_layers, 

537 position_emb_dim=position_emb_dim, 

538 use_affine = use_affine 

539 ) 

540 assert in_channels == self.body.in_channels 

541 assert initial_features == self.body.initial_features 

542 assert growth_factor == self.body.growth_factor 

543 assert kernel_size == self.body.kernel_size 

544 assert downblock_layers == self.body.layers 

545 assert attn_layers == self.body.attn_layers 

546 assert position_emb_dim == self.body.position_emb_dim 

547 assert use_affine == self.body.use_affine 

548 

549 self.upblock_layers = nn.ModuleList() 

550 for downblock in reversed(self.body.downblock_layers): 

551 upblock = UpBlock( 

552 dim=dim, 

553 in_channels=downblock.out_channels, 

554 out_channels=downblock.in_channels, 

555 resblock_kernel_size=kernel_size, 

556 position_emb_dim=position_emb_dim, 

557 use_affine=use_affine, 

558 use_attn=downblock.use_attn 

559 ) 

560 self.upblock_layers.append(upblock) 

561 

562 self.final_upsample_dims = self.upblock_layers[-1].out_channels//2 

563 self.final_upsample = ConvTranspose( 

564 in_channels=self.upblock_layers[-1].out_channels, 

565 out_channels=self.final_upsample_dims, 

566 kernel_size=2, 

567 stride=2, 

568 dim=dim, 

569 ) 

570 

571 self.final_layer = Conv( 

572 in_channels=self.final_upsample_dims+in_channels, 

573 out_channels=out_channels, 

574 kernel_size=1, 

575 stride=1, 

576 dim=dim, 

577 ) 

578 

579 def forward(self, x: Tensor, position: Tensor = None) -> Tensor: 

580 if self.position_emb_dim is not None and position is not None: 

581 position_emb = self.position_encoder(position) 

582 else: 

583 position_emb = None 

584 

585 x = x.float() 

586 input = x 

587 encoded_list = [] 

588 x = self.body.stem(x) 

589 for downblock in self.body.downblock_layers: 

590 encoded_list.append(x) 

591 x = downblock(x, position_emb) 

592 

593 for encoded, upblock in zip(reversed(encoded_list), self.upblock_layers): 

594 x = upblock(x, encoded, position_emb) 

595 

596 x = self.final_upsample(x) 

597 x = torch.cat([input,x], dim=1) 

598 x = self.final_layer(x) 

599 # activation? 

600 return x 

601 

602 def macs(self) -> float: 

603 return residualunet_macs( 

604 dim=self.dim, 

605 growth_factor=self.body.growth_factor, 

606 kernel_size=self.body.kernel_size, 

607 stub_kernel_size=self.body.stub_kernel_size, 

608 initial_features=self.body.initial_features, 

609 downblock_layers=self.body.layers, 

610 ) 

611 

612 

613def residualunet_macs( 

614 dim:int, 

615 growth_factor:float, 

616 kernel_size:int, 

617 stub_kernel_size:int, 

618 initial_features:int, 

619 downblock_layers:int, 

620) -> float: 

621 """ 

622 M = L + \sum_{i=0}^n D_i + U_i 

623 D_0 = \frac{\kappa ^ d}{2^d} f 

624 D_i = \frac{1}{2^{d(i+1)}} g^{2i-1} ( k^d( 3g + 1 ) + 1) f^2  

625 U_i = \frac{f^2}{2^{d i}} (2^d g^{2i-1} + 2 k^d g^{2i-2})  

626 U_0 = 2^{d-1} f ^ 2 

627 L = \frac{f}{2} + 1 

628 """ 

629 stride = 2 

630 U_0 = 2 ** (dim - 1) * initial_features **2 

631 L = initial_features/2 + 1 

632 body_macs = resnetbody_macs( 

633 dim=dim, 

634 growth_factor=growth_factor, 

635 kernel_size=kernel_size, 

636 stub_kernel_size=stub_kernel_size, 

637 initial_features=initial_features, 

638 downblock_layers=downblock_layers, 

639 ) 

640 M = L + body_macs + U_0 

641 

642 for i in range(1, downblock_layers+1): 

643 U_i = initial_features**2/(stride**(dim * i)) * ( 

644 2**dim * growth_factor ** (2 * i - 1) + 

645 2 * kernel_size**dim * growth_factor ** ( 2 * i - 2 ) 

646 ) 

647 

648 M += U_i 

649 

650 return M 

651 

652 

653def resnetbody_macs( 

654 dim:int, 

655 growth_factor:float, 

656 kernel_size:int, 

657 stub_kernel_size:int, 

658 initial_features:int, 

659 downblock_layers:int, 

660) -> float: 

661 """ 

662 M = \sum_{i=0}^n D_i 

663 D_0 = \frac{\kappa ^ d}{2^d} f 

664 D_i = \frac{1}{2^{d(i+1)}} g^{2i-1} ( k^d( 3g + 1 ) + 1) f^2  

665 """ 

666 stride = 2 

667 D_0 = stub_kernel_size **dim * initial_features / (stride ** dim) 

668 M = D_0 

669 

670 for i in range(1, downblock_layers+1): 

671 D_i = initial_features**2 /(stride**(dim * (i+1))) * growth_factor ** (2 * i - 1) * ( 

672 kernel_size**dim * (3 * growth_factor + 1) + 1 

673 ) 

674 M += D_i 

675 

676 return M 

677 

678 

679def calc_initial_features_residualunet( 

680 macc:int, 

681 dim:int, 

682 growth_factor:float, 

683 kernel_size:int, 

684 stub_kernel_size:int, 

685 downblock_layers:int, 

686) -> int: 

687 """ 

688 """ 

689 stride = 2 

690 a = 2 ** (dim - 1) 

691 for i in range(1, downblock_layers+1): 

692 D_i_over_f2 = 1 /(stride**(dim * (i+1))) * growth_factor ** (2 * i - 1) * ( 

693 kernel_size**dim * (3 * growth_factor + 1) + 1 

694 ) 

695 U_i_over_f2 = 1/(stride**(dim * i)) * ( 

696 2**dim * growth_factor ** (2 * i - 1) + 

697 2 * kernel_size**dim * growth_factor ** ( 2 * i - 2 ) 

698 ) 

699 

700 a += D_i_over_f2 + U_i_over_f2 

701 

702 b = stub_kernel_size **dim / (stride ** dim) + 0.5 

703 c = -macc + 1 

704 

705 initial_features = (-b + np.sqrt(b**2 - 4*a*c))/(2 * a) 

706 

707 return int(initial_features + 0.5) 

708