Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1import torchapp as ta 

2import torch 

3from pathlib import Path 

4from fastai.callback.core import Callback 

5from fastai.data.block import DataBlock, TransformBlock 

6from fastai.data.core import DataLoaders 

7from fastai.data.transforms import IndexSplitter 

8import torch.nn.functional as F 

9import numpy as np 

10import torchvision.transforms as T 

11 

12from supercat.models import ResidualUNet 

13from supercat.diffusion import DDPMCallback, DDPMSamplerCallback, wandb_process 

14 

15from supercat.noise.fractal import FractalNoiseTensor 

16from supercat.noise.worley import WorleyNoiseTensor 

17 

18class ShrinkCallBack(Callback): 

19 def __init__(self, factor:int=4, **kwargs): 

20 super().__init__(**kwargs) 

21 self.factor = factor 

22 

23 def before_batch(self): 

24 hr = self.xb[0] 

25 lr_shape = tuple(s//self.factor for s in hr.shape[2:]) 

26 mode = "bilinear" if len(lr_shape) == 2 else "trilinear" 

27 lr = F.interpolate(hr, lr_shape, mode=mode) 

28 pseudo_hr = F.interpolate(lr, hr.shape[2:], mode=mode) 

29 

30 self.learn.xb = (pseudo_hr,) 

31 self.learn.yb = (hr,) 

32 

33 

34class NoiseTensorGenerator(): 

35 def __init__(self, shape, worley_density:int=0, fractal_proportion:float=0.5): 

36 self.fractal = FractalNoiseTensor(shape) 

37 

38 if not worley_density: 

39 worley_density = 200 if len(shape) == 2 else 40 

40 

41 self.worley = WorleyNoiseTensor(shape, density=worley_density) 

42 self.shape = shape 

43 self.fractal_proportion = fractal_proportion 

44 

45 def __call__(self, *args, **kwargs): 

46 return self.fractal(*args, **kwargs) if np.random.rand() < self.fractal_proportion else self.worley(*args, **kwargs) 

47 

48 

49class NoiseSR(ta.TorchApp): 

50 def dataloaders( 

51 self, 

52 dim:int=2, 

53 depth:int=500, 

54 width:int=500, 

55 height:int=500, 

56 batch_size:int=16, 

57 item_count:int=1024, 

58 worley_density:int=0, 

59 fractal_proportion:float=0.5, 

60 ): 

61 

62 shape = (height, width) if dim == 2 else (depth, height, width) 

63 self.shape = shape 

64 self.dim = dim 

65 

66 datablock = DataBlock( 

67 blocks=(TransformBlock), 

68 get_x=NoiseTensorGenerator(shape, worley_density=worley_density, fractal_proportion=fractal_proportion), 

69 splitter=IndexSplitter(list(range(batch_size))), 

70 ) 

71 

72 dataloaders = DataLoaders.from_dblock( 

73 datablock, 

74 source=range(item_count), 

75 bs=batch_size, 

76 ) 

77 

78 return dataloaders 

79 

80 def extra_callbacks(self, diffusion:bool=True): 

81 self.diffusion = diffusion 

82 callbacks = [ShrinkCallBack(factor=4)] 

83 if self.diffusion: 

84 callbacks.append(DDPMCallback()) 

85 return callbacks 

86 

87 def inference_callbacks(self, diffusion:bool=True): 

88 callbacks = [ShrinkCallBack(factor=4)] 

89 if diffusion: 

90 callbacks.append(DDPMSamplerCallback()) 

91 return callbacks 

92 

93 def model(self): 

94 dim = getattr(self, "dim", 2) 

95 diffusion = getattr(self, "diffusion", False) 

96 return ResidualUNet(dim=dim, in_channels=2 if diffusion else 1) 

97 

98 def loss_func(self): 

99 """ 

100 Returns the loss function to use with the model. 

101 """ 

102 return F.smooth_l1_loss 

103 

104 def inference_dataloader(self, learner, **kwargs): 

105 dataloader = learner.dls.test_dl([0], **kwargs) # output single test image 

106 return dataloader 

107 

108 def output_results( 

109 self, 

110 results, 

111 output_dir: Path = ta.Param("./outputs", help="The location of the output directory."), 

112 fps:float=ta.Param(30.0, help="The frames per second to use when generating the gif."), 

113 **kwargs, 

114 ): 

115 output_dir = Path(output_dir) 

116 print(f"Saving {len(results)} generated images:") 

117 

118 transform = T.ToPILImage() 

119 output_dir.mkdir(exist_ok=True, parents=True) 

120 images = [] 

121 for index, image in enumerate(results[0]): 

122 path = output_dir/f"image.{index}.jpg" 

123 

124 image = transform(torch.clip(image[0]/2.0 + 0.5, min=0.0, max=1.0)) 

125 image.save(path) 

126 images.append(image) 

127 print(f"\t{path}") 

128 images[0].save(output_dir/f"image.gif", save_all=True, append_images=images[1:], fps=fps) 

129 

130 def monitor(self): 

131 return "train_loss" 

132 

133 

134if __name__ == "__main__": 

135 NoiseSR.main()