Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1import torchapp as ta
2import torch
3from pathlib import Path
4from fastai.callback.core import Callback
5from fastai.data.block import DataBlock, TransformBlock
6from fastai.data.core import DataLoaders
7from fastai.data.transforms import IndexSplitter
8import torch.nn.functional as F
9import numpy as np
10import torchvision.transforms as T
12from supercat.models import ResidualUNet
13from supercat.diffusion import DDPMCallback, DDPMSamplerCallback, wandb_process
15from supercat.noise.fractal import FractalNoiseTensor
16from supercat.noise.worley import WorleyNoiseTensor
18class ShrinkCallBack(Callback):
19 def __init__(self, factor:int=4, **kwargs):
20 super().__init__(**kwargs)
21 self.factor = factor
23 def before_batch(self):
24 hr = self.xb[0]
25 lr_shape = tuple(s//self.factor for s in hr.shape[2:])
26 mode = "bilinear" if len(lr_shape) == 2 else "trilinear"
27 lr = F.interpolate(hr, lr_shape, mode=mode)
28 pseudo_hr = F.interpolate(lr, hr.shape[2:], mode=mode)
30 self.learn.xb = (pseudo_hr,)
31 self.learn.yb = (hr,)
34class NoiseTensorGenerator():
35 def __init__(self, shape, worley_density:int=0, fractal_proportion:float=0.5):
36 self.fractal = FractalNoiseTensor(shape)
38 if not worley_density:
39 worley_density = 200 if len(shape) == 2 else 40
41 self.worley = WorleyNoiseTensor(shape, density=worley_density)
42 self.shape = shape
43 self.fractal_proportion = fractal_proportion
45 def __call__(self, *args, **kwargs):
46 return self.fractal(*args, **kwargs) if np.random.rand() < self.fractal_proportion else self.worley(*args, **kwargs)
49class NoiseSR(ta.TorchApp):
50 def dataloaders(
51 self,
52 dim:int=2,
53 depth:int=500,
54 width:int=500,
55 height:int=500,
56 batch_size:int=16,
57 item_count:int=1024,
58 worley_density:int=0,
59 fractal_proportion:float=0.5,
60 ):
62 shape = (height, width) if dim == 2 else (depth, height, width)
63 self.shape = shape
64 self.dim = dim
66 datablock = DataBlock(
67 blocks=(TransformBlock),
68 get_x=NoiseTensorGenerator(shape, worley_density=worley_density, fractal_proportion=fractal_proportion),
69 splitter=IndexSplitter(list(range(batch_size))),
70 )
72 dataloaders = DataLoaders.from_dblock(
73 datablock,
74 source=range(item_count),
75 bs=batch_size,
76 )
78 return dataloaders
80 def extra_callbacks(self, diffusion:bool=True):
81 self.diffusion = diffusion
82 callbacks = [ShrinkCallBack(factor=4)]
83 if self.diffusion:
84 callbacks.append(DDPMCallback())
85 return callbacks
87 def inference_callbacks(self, diffusion:bool=True):
88 callbacks = [ShrinkCallBack(factor=4)]
89 if diffusion:
90 callbacks.append(DDPMSamplerCallback())
91 return callbacks
93 def model(self):
94 dim = getattr(self, "dim", 2)
95 diffusion = getattr(self, "diffusion", False)
96 return ResidualUNet(dim=dim, in_channels=2 if diffusion else 1)
98 def loss_func(self):
99 """
100 Returns the loss function to use with the model.
101 """
102 return F.smooth_l1_loss
104 def inference_dataloader(self, learner, **kwargs):
105 dataloader = learner.dls.test_dl([0], **kwargs) # output single test image
106 return dataloader
108 def output_results(
109 self,
110 results,
111 output_dir: Path = ta.Param("./outputs", help="The location of the output directory."),
112 fps:float=ta.Param(30.0, help="The frames per second to use when generating the gif."),
113 **kwargs,
114 ):
115 output_dir = Path(output_dir)
116 print(f"Saving {len(results)} generated images:")
118 transform = T.ToPILImage()
119 output_dir.mkdir(exist_ok=True, parents=True)
120 images = []
121 for index, image in enumerate(results[0]):
122 path = output_dir/f"image.{index}.jpg"
124 image = transform(torch.clip(image[0]/2.0 + 0.5, min=0.0, max=1.0))
125 image.save(path)
126 images.append(image)
127 print(f"\t{path}")
128 images[0].save(output_dir/f"image.gif", save_all=True, append_images=images[1:], fps=fps)
130 def monitor(self):
131 return "train_loss"
134if __name__ == "__main__":
135 NoiseSR.main()