Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1import re
2import torchapp as ta
3from typing import List
4import torch
5import random
6from pathlib import Path
7from fastai.callback.core import Callback, CancelBatchException
8from fastai.data.block import DataBlock, TransformBlock
9from fastai.data.core import DataLoaders, DisplayedTransform
10import torch.nn.functional as F
11from rich.progress import track
12from fastai.data.transforms import get_image_files
13import torchvision.transforms as T
14from fastai.data.transforms import ToTensor
15from fastcore.transform import Pipeline
16from fastai.vision.augment import Resize
17from fastai.data.transforms import FuncSplitter
18from fastai.learner import load_learner
19from PIL import Image
20from functools import partial
21from fastai.vision.data import ImageBlock, TensorImage
22from fastai.vision.core import PILImageBW, TensorImageBW
23from supercat.noise.apps import * # remove this
25from supercat.models import ResidualUNet, calc_initial_features_residualunet
26from supercat.transforms import ImageBlock3D, RescaleImage, write3D, read3D, InterpolateTransform, RescaleImageMinMax, CropTransform
27from supercat.enums import DownsampleScale, DownsampleMethod
28from supercat.diffusion import DDPMCallback, DDPMSamplerCallback
29from skimage.transform import resize as skresize
31from rich.console import Console
32console = Console()
35def is_validation_image(item:tuple):
36 "Returns True if this image should be part of the validation set i.e. if the parent directory doesn't have the string `_train_` in it."
37 return "_train_" not in item.parent.name
40def get_y(item, pattern=r"_BI_.*"):
41 dir_name = re.sub(pattern, "_HR", item.parent.name)
42 return item.parent.parent/dir_name/item.name
45class Supercat(ta.TorchApp):
46 in_channels = 1
48 def get_items(self, directory):
49 if self.dim == 2:
50 return get_image_files(directory)
52 directory = Path(directory)
53 return list(directory.glob("*.mat"))
55 def dataloaders(
56 self,
57 dim:int = ta.Param(default=2, help="The dimension of the dataset. 2 or 3."),
58 deeprock:Path = ta.Param(help="The path to the DeepRockSR dataset."),
59 downsample_scale:DownsampleScale = ta.Param(DownsampleScale.X4.value, help="Should it use the 2x or 4x downsampled images.", case_sensitive=False),
60 downsample_method:DownsampleMethod = ta.Param(DownsampleMethod.UNKNOWN.value, help="Should it use the default method to downsample (bicubic) or a random kernel (UNKNOWN)."),
61 batch_size:int = ta.Param(default=10, help="The batch size."),
62 force:bool = ta.Param(default=False, help="Whether or not to force the conversion of the bicubic upscaling."),
63 max_samples:int = ta.Param(default=None, help="If set, then the number of input samples for training/validation is truncated at this number."),
64 include_sand:bool = ta.Param(default=False, help="Including DeepSand-SR dataset."),
65 ) -> DataLoaders:
66 """
67 Creates a FastAI DataLoaders object which Supercat uses in training and prediction.
68 """
69 assert deeprock is not None
71 self.dim = dim
72 deeprock = Path(deeprock)
73 upscaled = []
74 highres = []
76 # sources = ["shuffled2D"]
77 sources = [f"carbonate{dim}D",f"coal{dim}D",f"sandstone{dim}D"]
78 if include_sand:
79 sources.append(f"sand{dim}D")
81 if isinstance(downsample_method, DownsampleMethod):
82 downsample_method = downsample_method.value
84 if isinstance(downsample_scale, DownsampleScale):
85 downsample_scale = downsample_scale.value
87 split_types = ["train","valid"] # There is also "test"
88 # split_types = ["train","valid","test"] # hack
90 UP = "BI" if dim == 2 else "TRI"
92 for source in sources:
93 for split_type in split_types:
94 highres_dir = deeprock/source/f"{source}_{split_type}_HR"
95 highres_split = self.get_items(highres_dir)
96 highres.extend( highres_split )
98 lowres_dir = deeprock/source/f"{source}_{split_type}_LR_{downsample_method}_{downsample_scale}"
100 # We will save upscaled images
101 upscale_dir = deeprock/source/f"{source}_{split_type}_{UP}_{downsample_method}_{downsample_scale}"
102 upscale_dir.mkdir(exist_ok=True)
104 for index, highres_path in enumerate(highres_split):
105 upscale_path = upscale_dir/highres_path.name
107 if not upscale_path.exists() or force:
108 components = highres_path.name.split(".")
109 lowres_name = f'{components[0]}{downsample_scale.lower()}.{components[1]}'
110 lowres_path = lowres_dir/lowres_name
111 print(split_type, highres_path, upscale_path, lowres_path)
113 # upscale with upscale interpolation
114 print("Upscaling")
115 if dim == 2:
116 highres_img = Image.open(highres_path)
117 lowres_img = Image.open(lowres_path)
119 # Convert to single channel
120 if lowres_img.mode == "RGB":
121 lowres_img = lowres_img.getchannel('R')
122 lowres_img.save(lowres_path)
123 if highres_img.mode == "RGB":
124 highres_img = highres_img.getchannel('R')
125 highres_img.save(highres_path)
127 upscale_img = lowres_img.resize(highres_img.size,Image.upscale)
128 if upscale_img.mode == "RGB":
129 upscale_img = upscale_img.getchannel('R')
131 upscale_img.save(upscale_path)
132 else:
133 components = highres_path.name.split(".")
134 lowres_name = f'{components[0]}{downsample_scale.lower()}.{components[1]}'
135 lowres_path = lowres_dir/lowres_name
136 print(split_type, highres_path, upscale_path, lowres_path)
138 # upscale with tricubic interpolation
139 print("Upscaling with tricubic")
140 highres_img = read3D(highres_path)
141 lowres_img = read3D(lowres_path)
143 tricubic_img = skresize(lowres_img, highres_img.shape, order=3)
144 write3D(upscale_path, tricubic_img)
146 upscaled.append(upscale_path)
148 if max_samples and index > max_samples:
149 break
151 if len(upscaled) == 0:
152 raise ValueError("No images found.")
154 if dim == 2:
155 blocks = (ImageBlock(cls=PILImageBW), ImageBlock(cls=PILImageBW))
156 else:
157 blocks = (ImageBlock3D, ImageBlock3D,)
159 datablock = DataBlock(
160 blocks=blocks,
161 splitter=FuncSplitter(is_validation_image),
162 get_y=get_y if dim == 2 else partial(get_y, pattern=r"_TRI_.*"),
163 batch_tfms=[RescaleImage],
164 )
166 dataloaders = DataLoaders.from_dblock(
167 datablock,
168 source=upscaled,
169 bs=batch_size,
170 )
172 dataloaders.c = 1
174 return dataloaders
176 def model(
177 self,
178 pretrained:Path=None,
179 initial_features:int = ta.Param(
180 None,
181 help="The number of features after the initial CNN layer. If not set then it is derived from the MACC."
182 ),
183 growth_factor:float = ta.Param(
184 2.0,
185 tune=True,
186 tune_min=1.0,
187 tune_max=4.0,
188 tune_log=True,
189 help="The factor to grow the number of convolutional filters each time the model downscales."
190 ),
191 kernel_size:int = ta.Param(
192 3,
193 tune=True,
194 tune_choices=[3,5,7],
195 help="The size of the kernel in the convolutional layers."
196 ),
197 stub_kernel_size:int = ta.Param(
198 7,
199 tune=True,
200 tune_choices=[5,7,9],
201 help="The size of the kernel in the initial stub convolutional layer."
202 ),
203 downblock_layers:int = ta.Param(
204 4,
205 tune=True,
206 tune_min=2,
207 tune_max=5,
208 help="The number of layers to downscale (and upscale) in the UNet."
209 ),
210 macc:int = ta.Param(
211 default=132_000,
212 help=(
213 "The approximate number of multiply or accumulate operations in the model per pixel/voxel. " +
214 "Used to set initial_features if it is not provided explicitly."
215 ),
216 ),
217 ):
218 if pretrained:
219 learner = load_learner(pretrained)
220 return learner.model
222 dim = getattr(self, "dim", 3)
224 if not initial_features:
225 assert macc
227 initial_features = calc_initial_features_residualunet(
228 macc=macc,
229 dim=dim,
230 growth_factor=growth_factor,
231 kernel_size=kernel_size,
232 stub_kernel_size=stub_kernel_size,
233 downblock_layers=downblock_layers,
234 )
236 return ResidualUNet(
237 dim=dim,
238 in_channels=self.in_channels,
239 out_channels=1,
240 initial_features=initial_features,
241 # growth_factor=growth_factor,
242 # kernel_size=kernel_size,
243 # downblock_layers=downblock_layers,
244 )
247 def loss_func(self):
248 """
249 Returns the loss function to use with the model.
250 """
251 return F.smooth_l1_loss
253 def inference_dataloader(
254 self,
255 learner,
256 dim:int = ta.Param(default=2, help="The dimension of the dataset. 2 or 3."),
257 items:List[Path] = None,
258 item_dir: Path = ta.Param(None, help="A directory with images to upscale."),
259 width:int = ta.Param(500, help="The width of the final image/volume."),
260 height:int = ta.Param(None, help="The height of the final image/volume."),
261 depth:int = ta.Param(None, help="The depth of the final image/volume."),
262 start_x:int=None,
263 end_x:int=None,
264 start_y:int=None,
265 end_y:int=None,
266 start_z:int=None,
267 end_z:int=None,
268 **kwargs
269 ):
270 self.dim = dim
272 if not items:
273 items = []
274 if isinstance(items, (Path, str)):
275 items = [items]
276 if item_dir:
277 items += self.get_items(item_dir)
279 items = [Path(item) for item in items]
280 self.items = items
281 dataloader = learner.dls.test_dl(items, with_labels=True, **kwargs)
282 dataloader.transform = dataloader.transform[:1] # ignore the get_y function
283 height = height or width
284 depth = depth or width
286 interpolation = InterpolateTransform(depth=depth, height=height, width=width, dim=dim)
287 crop_transform = CropTransform(
288 start_x=start_x, end_x=end_x,
289 start_y=start_y, end_y=end_y,
290 start_z=start_z, end_z=end_z,
291 )
292 self.rescaling = RescaleImageMinMax()
293 dataloader.after_item = Pipeline( [crop_transform, interpolation, self.rescaling, ToTensor] )
294 if isinstance(dataloader.after_batch[1], RescaleImage):
295 dataloader.after_batch = Pipeline( *(dataloader.after_batch[:1] + dataloader.after_batch[2:]) )
297 return dataloader
299 def output_results(
300 self,
301 results,
302 return_data:bool=False,
303 output_dir: Path = ta.Param(None, help="The location of the output directory. If not given then it uses the directory of the item."),
304 suffix:str = ta.Param("", help="The file extension for the output file."),
305 **kwargs,
306 ):
307 list_to_return = []
308 if output_dir:
309 output_dir = Path(output_dir)
310 output_dir.mkdir(exist_ok=True, parents=True)
312 for item, result in zip(self.items, results[0]):
313 my_suffix = suffix or item.suffix
314 if my_suffix[0] != ".":
315 my_suffix = "." + my_suffix
317 new_name = item.with_suffix("").name + f".upscaled{my_suffix}"
318 my_output_dir = output_dir or item.parent
319 new_path = my_output_dir/new_name
321 dim = len(result.shape) - 1
322 if dim == 2:
323 # hack get extrema to rescale
324 data = np.asarray(Image.open(item).convert('L'))
325 min, max = Image.open(item).convert('L').getextrema()
326 result[0] = self.rescaling.decodes(result[0], min, max)
328 pixels = torch.clip(result[0], min=0, max=255)
329 im = Image.fromarray( pixels.cpu().detach().numpy().astype('uint8') )
330 im.save(new_path)
331 else:
332 # hack get extrema to rescale
333 data = read3D(item)
334 min, max = data.min(), data.max()
335 result[0] = self.rescaling.decodes(result[0], min, max)
337 write3D(new_path, result[0].cpu().detach().numpy())
339 list_to_return.append(result[0] if return_data else new_path)
340 console.print(f"Upscaled '{item}' ⮕ '{new_path}'")
342 return list_to_return
344 def pretrained_location(
345 self,
346 dim:int = ta.Param(default=2, help="The dimension of the dataset. 2 or 3."),
347 ) -> str:
348 assert dim in [2,3]
349 if dim == 2:
350 return f"https://github.com/rbturnbull/supercat/releases/download/v0.2.1/supercat-{dim}D.0.2.pkl"
351 return f"https://github.com/rbturnbull/supercat/releases/download/v0.3.0/supercat-{dim}D.0.3.pkl"
354class SupercatDiffusion(Supercat):
355 in_channels = 2
357 def extra_callbacks(self):
358 return [DDPMCallback()]
360 def inference_callbacks(self):
361 return [DDPMSamplerCallback()]
363 def pretrained_location(
364 self,
365 dim:int = ta.Param(default=2, help="The dimension of the dataset. 2 or 3."),
366 ) -> str:
367 assert dim in [2,3]
368 if dim == 2:
369 return f"https://github.com/rbturnbull/supercat/releases/download/v0.2.1/supercat-diffusion-{dim}D.0.2.pkl"
370 return f"https://github.com/rbturnbull/supercat/releases/download/v0.3.0/supercat-diffusion-{dim}D.0.3.pkl"
372 # def output_results(
373 # self,
374 # results,
375 # output_dir: Path = ta.Param("./outputs", help="The location of the output directory."),
376 # diffusion_gif:bool=False,
377 # diffusion_gif_fps:float=ta.Param(120.0, help="The frames per second to use when generating the gif."),
378 # **kwargs,
379 # ):
380 # breakpoint()
381 # # final_results = [[result[-1] for result in results[0][0]]]
382 # to_return = super().output_results(results, output_dir=output_dir, **kwargs)
384 # if diffusion_gif:
385 # assert self.dim == 2
387 # output_dir = Path(output_dir)
388 # print(f"Saving {len(results[0])} generated images:")
390 # transform = T.ToPILImage()
391 # output_dir.mkdir(exist_ok=True, parents=True)
392 # images = []
393 # for index, image in enumerate(results[0][0]):
394 # path = output_dir/f"image.{index}.png"
396 # image = transform(torch.clip(image[0]/2.0 + 0.5, min=0.0, max=1.0))
397 # images.append(image)
398 # print(f"\t{path}")
399 # images[0].save(output_dir/f"image.gif", save_all=True, append_images=images[1:], fps=diffusion_gif_fps)
401 # return to_return
403if __name__ == "__main__":
404 SupercatDiffusion.main()