Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1import re 

2import torchapp as ta 

3from typing import List 

4import torch 

5import random 

6from pathlib import Path 

7from fastai.callback.core import Callback, CancelBatchException 

8from fastai.data.block import DataBlock, TransformBlock 

9from fastai.data.core import DataLoaders, DisplayedTransform 

10import torch.nn.functional as F 

11from rich.progress import track 

12from fastai.data.transforms import get_image_files 

13import torchvision.transforms as T 

14from fastai.data.transforms import ToTensor 

15from fastcore.transform import Pipeline 

16from fastai.vision.augment import Resize 

17from fastai.data.transforms import FuncSplitter 

18from fastai.learner import load_learner 

19from PIL import Image 

20from functools import partial 

21from fastai.vision.data import ImageBlock, TensorImage 

22from fastai.vision.core import PILImageBW, TensorImageBW 

23from supercat.noise.apps import * # remove this 

24 

25from supercat.models import ResidualUNet, calc_initial_features_residualunet 

26from supercat.transforms import ImageBlock3D, RescaleImage, write3D, read3D, InterpolateTransform, RescaleImageMinMax, CropTransform 

27from supercat.enums import DownsampleScale, DownsampleMethod 

28from supercat.diffusion import DDPMCallback, DDPMSamplerCallback 

29from skimage.transform import resize as skresize 

30 

31from rich.console import Console 

32console = Console() 

33 

34 

35def is_validation_image(item:tuple): 

36 "Returns True if this image should be part of the validation set i.e. if the parent directory doesn't have the string `_train_` in it." 

37 return "_train_" not in item.parent.name 

38 

39 

40def get_y(item, pattern=r"_BI_.*"): 

41 dir_name = re.sub(pattern, "_HR", item.parent.name) 

42 return item.parent.parent/dir_name/item.name 

43 

44 

45class Supercat(ta.TorchApp): 

46 in_channels = 1 

47 

48 def get_items(self, directory): 

49 if self.dim == 2: 

50 return get_image_files(directory) 

51 

52 directory = Path(directory) 

53 return list(directory.glob("*.mat")) 

54 

55 def dataloaders( 

56 self, 

57 dim:int = ta.Param(default=2, help="The dimension of the dataset. 2 or 3."), 

58 deeprock:Path = ta.Param(help="The path to the DeepRockSR dataset."), 

59 downsample_scale:DownsampleScale = ta.Param(DownsampleScale.X4.value, help="Should it use the 2x or 4x downsampled images.", case_sensitive=False), 

60 downsample_method:DownsampleMethod = ta.Param(DownsampleMethod.UNKNOWN.value, help="Should it use the default method to downsample (bicubic) or a random kernel (UNKNOWN)."), 

61 batch_size:int = ta.Param(default=10, help="The batch size."), 

62 force:bool = ta.Param(default=False, help="Whether or not to force the conversion of the bicubic upscaling."), 

63 max_samples:int = ta.Param(default=None, help="If set, then the number of input samples for training/validation is truncated at this number."), 

64 include_sand:bool = ta.Param(default=False, help="Including DeepSand-SR dataset."), 

65 ) -> DataLoaders: 

66 """ 

67 Creates a FastAI DataLoaders object which Supercat uses in training and prediction. 

68 """ 

69 assert deeprock is not None 

70 

71 self.dim = dim 

72 deeprock = Path(deeprock) 

73 upscaled = [] 

74 highres = [] 

75 

76 # sources = ["shuffled2D"] 

77 sources = [f"carbonate{dim}D",f"coal{dim}D",f"sandstone{dim}D"] 

78 if include_sand: 

79 sources.append(f"sand{dim}D") 

80 

81 if isinstance(downsample_method, DownsampleMethod): 

82 downsample_method = downsample_method.value 

83 

84 if isinstance(downsample_scale, DownsampleScale): 

85 downsample_scale = downsample_scale.value 

86 

87 split_types = ["train","valid"] # There is also "test" 

88 # split_types = ["train","valid","test"] # hack 

89 

90 UP = "BI" if dim == 2 else "TRI" 

91 

92 for source in sources: 

93 for split_type in split_types: 

94 highres_dir = deeprock/source/f"{source}_{split_type}_HR" 

95 highres_split = self.get_items(highres_dir) 

96 highres.extend( highres_split ) 

97 

98 lowres_dir = deeprock/source/f"{source}_{split_type}_LR_{downsample_method}_{downsample_scale}" 

99 

100 # We will save upscaled images 

101 upscale_dir = deeprock/source/f"{source}_{split_type}_{UP}_{downsample_method}_{downsample_scale}" 

102 upscale_dir.mkdir(exist_ok=True) 

103 

104 for index, highres_path in enumerate(highres_split): 

105 upscale_path = upscale_dir/highres_path.name 

106 

107 if not upscale_path.exists() or force: 

108 components = highres_path.name.split(".") 

109 lowres_name = f'{components[0]}{downsample_scale.lower()}.{components[1]}' 

110 lowres_path = lowres_dir/lowres_name 

111 print(split_type, highres_path, upscale_path, lowres_path) 

112 

113 # upscale with upscale interpolation 

114 print("Upscaling") 

115 if dim == 2: 

116 highres_img = Image.open(highres_path) 

117 lowres_img = Image.open(lowres_path) 

118 

119 # Convert to single channel 

120 if lowres_img.mode == "RGB": 

121 lowres_img = lowres_img.getchannel('R') 

122 lowres_img.save(lowres_path) 

123 if highres_img.mode == "RGB": 

124 highres_img = highres_img.getchannel('R') 

125 highres_img.save(highres_path) 

126 

127 upscale_img = lowres_img.resize(highres_img.size,Image.upscale) 

128 if upscale_img.mode == "RGB": 

129 upscale_img = upscale_img.getchannel('R') 

130 

131 upscale_img.save(upscale_path) 

132 else: 

133 components = highres_path.name.split(".") 

134 lowres_name = f'{components[0]}{downsample_scale.lower()}.{components[1]}' 

135 lowres_path = lowres_dir/lowres_name 

136 print(split_type, highres_path, upscale_path, lowres_path) 

137 

138 # upscale with tricubic interpolation 

139 print("Upscaling with tricubic") 

140 highres_img = read3D(highres_path) 

141 lowres_img = read3D(lowres_path) 

142 

143 tricubic_img = skresize(lowres_img, highres_img.shape, order=3) 

144 write3D(upscale_path, tricubic_img) 

145 

146 upscaled.append(upscale_path) 

147 

148 if max_samples and index > max_samples: 

149 break 

150 

151 if len(upscaled) == 0: 

152 raise ValueError("No images found.") 

153 

154 if dim == 2: 

155 blocks = (ImageBlock(cls=PILImageBW), ImageBlock(cls=PILImageBW)) 

156 else: 

157 blocks = (ImageBlock3D, ImageBlock3D,) 

158 

159 datablock = DataBlock( 

160 blocks=blocks, 

161 splitter=FuncSplitter(is_validation_image), 

162 get_y=get_y if dim == 2 else partial(get_y, pattern=r"_TRI_.*"), 

163 batch_tfms=[RescaleImage], 

164 ) 

165 

166 dataloaders = DataLoaders.from_dblock( 

167 datablock, 

168 source=upscaled, 

169 bs=batch_size, 

170 ) 

171 

172 dataloaders.c = 1 

173 

174 return dataloaders 

175 

176 def model( 

177 self, 

178 pretrained:Path=None, 

179 initial_features:int = ta.Param( 

180 None, 

181 help="The number of features after the initial CNN layer. If not set then it is derived from the MACC." 

182 ), 

183 growth_factor:float = ta.Param( 

184 2.0, 

185 tune=True, 

186 tune_min=1.0, 

187 tune_max=4.0, 

188 tune_log=True, 

189 help="The factor to grow the number of convolutional filters each time the model downscales." 

190 ), 

191 kernel_size:int = ta.Param( 

192 3, 

193 tune=True, 

194 tune_choices=[3,5,7], 

195 help="The size of the kernel in the convolutional layers." 

196 ), 

197 stub_kernel_size:int = ta.Param( 

198 7, 

199 tune=True, 

200 tune_choices=[5,7,9], 

201 help="The size of the kernel in the initial stub convolutional layer." 

202 ), 

203 downblock_layers:int = ta.Param( 

204 4, 

205 tune=True, 

206 tune_min=2, 

207 tune_max=5, 

208 help="The number of layers to downscale (and upscale) in the UNet." 

209 ), 

210 macc:int = ta.Param( 

211 default=132_000, 

212 help=( 

213 "The approximate number of multiply or accumulate operations in the model per pixel/voxel. " + 

214 "Used to set initial_features if it is not provided explicitly." 

215 ), 

216 ), 

217 ): 

218 if pretrained: 

219 learner = load_learner(pretrained) 

220 return learner.model 

221 

222 dim = getattr(self, "dim", 3) 

223 

224 if not initial_features: 

225 assert macc 

226 

227 initial_features = calc_initial_features_residualunet( 

228 macc=macc, 

229 dim=dim, 

230 growth_factor=growth_factor, 

231 kernel_size=kernel_size, 

232 stub_kernel_size=stub_kernel_size, 

233 downblock_layers=downblock_layers, 

234 ) 

235 

236 return ResidualUNet( 

237 dim=dim, 

238 in_channels=self.in_channels, 

239 out_channels=1, 

240 initial_features=initial_features, 

241 # growth_factor=growth_factor,  

242 # kernel_size=kernel_size, 

243 # downblock_layers=downblock_layers, 

244 ) 

245 

246 

247 def loss_func(self): 

248 """ 

249 Returns the loss function to use with the model. 

250 """ 

251 return F.smooth_l1_loss 

252 

253 def inference_dataloader( 

254 self, 

255 learner, 

256 dim:int = ta.Param(default=2, help="The dimension of the dataset. 2 or 3."), 

257 items:List[Path] = None, 

258 item_dir: Path = ta.Param(None, help="A directory with images to upscale."), 

259 width:int = ta.Param(500, help="The width of the final image/volume."), 

260 height:int = ta.Param(None, help="The height of the final image/volume."), 

261 depth:int = ta.Param(None, help="The depth of the final image/volume."), 

262 start_x:int=None, 

263 end_x:int=None, 

264 start_y:int=None, 

265 end_y:int=None, 

266 start_z:int=None, 

267 end_z:int=None, 

268 **kwargs 

269 ): 

270 self.dim = dim 

271 

272 if not items: 

273 items = [] 

274 if isinstance(items, (Path, str)): 

275 items = [items] 

276 if item_dir: 

277 items += self.get_items(item_dir) 

278 

279 items = [Path(item) for item in items] 

280 self.items = items 

281 dataloader = learner.dls.test_dl(items, with_labels=True, **kwargs) 

282 dataloader.transform = dataloader.transform[:1] # ignore the get_y function 

283 height = height or width 

284 depth = depth or width 

285 

286 interpolation = InterpolateTransform(depth=depth, height=height, width=width, dim=dim) 

287 crop_transform = CropTransform( 

288 start_x=start_x, end_x=end_x, 

289 start_y=start_y, end_y=end_y, 

290 start_z=start_z, end_z=end_z, 

291 ) 

292 self.rescaling = RescaleImageMinMax() 

293 dataloader.after_item = Pipeline( [crop_transform, interpolation, self.rescaling, ToTensor] ) 

294 if isinstance(dataloader.after_batch[1], RescaleImage): 

295 dataloader.after_batch = Pipeline( *(dataloader.after_batch[:1] + dataloader.after_batch[2:]) ) 

296 

297 return dataloader 

298 

299 def output_results( 

300 self, 

301 results, 

302 return_data:bool=False, 

303 output_dir: Path = ta.Param(None, help="The location of the output directory. If not given then it uses the directory of the item."), 

304 suffix:str = ta.Param("", help="The file extension for the output file."), 

305 **kwargs, 

306 ): 

307 list_to_return = [] 

308 if output_dir: 

309 output_dir = Path(output_dir) 

310 output_dir.mkdir(exist_ok=True, parents=True) 

311 

312 for item, result in zip(self.items, results[0]): 

313 my_suffix = suffix or item.suffix 

314 if my_suffix[0] != ".": 

315 my_suffix = "." + my_suffix 

316 

317 new_name = item.with_suffix("").name + f".upscaled{my_suffix}" 

318 my_output_dir = output_dir or item.parent 

319 new_path = my_output_dir/new_name 

320 

321 dim = len(result.shape) - 1 

322 if dim == 2: 

323 # hack get extrema to rescale 

324 data = np.asarray(Image.open(item).convert('L')) 

325 min, max = Image.open(item).convert('L').getextrema() 

326 result[0] = self.rescaling.decodes(result[0], min, max) 

327 

328 pixels = torch.clip(result[0], min=0, max=255) 

329 im = Image.fromarray( pixels.cpu().detach().numpy().astype('uint8') ) 

330 im.save(new_path) 

331 else: 

332 # hack get extrema to rescale 

333 data = read3D(item) 

334 min, max = data.min(), data.max() 

335 result[0] = self.rescaling.decodes(result[0], min, max) 

336 

337 write3D(new_path, result[0].cpu().detach().numpy()) 

338 

339 list_to_return.append(result[0] if return_data else new_path) 

340 console.print(f"Upscaled '{item}' ⮕ '{new_path}'") 

341 

342 return list_to_return 

343 

344 def pretrained_location( 

345 self, 

346 dim:int = ta.Param(default=2, help="The dimension of the dataset. 2 or 3."), 

347 ) -> str: 

348 assert dim in [2,3] 

349 if dim == 2: 

350 return f"https://github.com/rbturnbull/supercat/releases/download/v0.2.1/supercat-{dim}D.0.2.pkl" 

351 return f"https://github.com/rbturnbull/supercat/releases/download/v0.3.0/supercat-{dim}D.0.3.pkl" 

352 

353 

354class SupercatDiffusion(Supercat): 

355 in_channels = 2 

356 

357 def extra_callbacks(self): 

358 return [DDPMCallback()] 

359 

360 def inference_callbacks(self): 

361 return [DDPMSamplerCallback()] 

362 

363 def pretrained_location( 

364 self, 

365 dim:int = ta.Param(default=2, help="The dimension of the dataset. 2 or 3."), 

366 ) -> str: 

367 assert dim in [2,3] 

368 if dim == 2: 

369 return f"https://github.com/rbturnbull/supercat/releases/download/v0.2.1/supercat-diffusion-{dim}D.0.2.pkl" 

370 return f"https://github.com/rbturnbull/supercat/releases/download/v0.3.0/supercat-diffusion-{dim}D.0.3.pkl" 

371 

372 # def output_results( 

373 # self,  

374 # results,  

375 # output_dir: Path = ta.Param("./outputs", help="The location of the output directory."), 

376 # diffusion_gif:bool=False,  

377 # diffusion_gif_fps:float=ta.Param(120.0, help="The frames per second to use when generating the gif."), 

378 # **kwargs, 

379 # ): 

380 # breakpoint() 

381 # # final_results = [[result[-1] for result in results[0][0]]] 

382 # to_return = super().output_results(results, output_dir=output_dir, **kwargs) 

383 

384 # if diffusion_gif: 

385 # assert self.dim == 2 

386 

387 # output_dir = Path(output_dir) 

388 # print(f"Saving {len(results[0])} generated images:") 

389 

390 # transform = T.ToPILImage() 

391 # output_dir.mkdir(exist_ok=True, parents=True) 

392 # images = [] 

393 # for index, image in enumerate(results[0][0]): 

394 # path = output_dir/f"image.{index}.png" 

395 

396 # image = transform(torch.clip(image[0]/2.0 + 0.5, min=0.0, max=1.0)) 

397 # images.append(image) 

398 # print(f"\t{path}") 

399 # images[0].save(output_dir/f"image.gif", save_all=True, append_images=images[1:], fps=diffusion_gif_fps) 

400 

401 # return to_return 

402 

403if __name__ == "__main__": 

404 SupercatDiffusion.main()