Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1import torch 

2from rich.progress import track 

3from fastai.callback.core import Callback, CancelBatchException 

4import wandb 

5from plotly.subplots import make_subplots 

6import plotly.graph_objects as go 

7from fastcore.dispatch import typedispatch 

8 

9from supercat.visualization import add_volume_face_traces 

10 

11@typedispatch 

12def wandb_process(x, y, samples, outs, preds): 

13 table = wandb.Table(columns=["Sample"]) 

14 index = 0 

15 for (sample_input, sample_target), prediction in zip(samples, outs): 

16 xt = sample_input[0] 

17 lr = sample_input[1] 

18 alpha_bar_t = sample_input[2,0,0] 

19 noise = sample_target[0] 

20 dim = len(xt.shape) 

21 

22 if dim == 3: 

23 alpha_bar_t = alpha_bar_t[0] 

24 

25 

26 #xt = torch.sqrt(alpha_bar_t) * hr + torch.sqrt(1-alpha_bar_t) * noise  

27 hr = (xt - torch.sqrt(1-alpha_bar_t) * noise)/torch.sqrt(alpha_bar_t) 

28 hr_predicted = (xt - torch.sqrt(1-alpha_bar_t) * prediction[0][0])/torch.sqrt(alpha_bar_t) 

29 

30 

31 specs = None 

32 if dim == 3: 

33 specs = [[{'type':"surface"}, {'type':"surface"}, {'type':"surface"}, {'type':"surface"}, ]] 

34 

35 fig = make_subplots( 

36 rows=1, 

37 cols=4, 

38 subplot_titles=( 

39 "Low Resolution", 

40 f"Noise: {alpha_bar_t}", 

41 "Prediction (Single Shot)", 

42 "Ground Truth", 

43 ), 

44 vertical_spacing = 0.02, 

45 horizontal_spacing = 0.02, 

46 specs=specs, 

47 ) 

48 

49 if dim == 2: 

50 def add_trace(z, col): 

51 fig.add_trace( go.Heatmap(z=lr, zmin=-1.0, zmax=1.0, autocolorscale=False, showscale=False), row=1, col=col) 

52 add_trace(lr, 1) 

53 add_trace(noise, 2) 

54 add_trace(hr_predicted, 3) 

55 add_trace(hr, 4) 

56 fig.update_traces( 

57 zmax=-1.0, 

58 zmin=1.0, 

59 ) 

60 fig.update_xaxes(showticklabels=False) 

61 fig.update_yaxes(showticklabels=False) 

62 else: 

63 add_volume_face_traces(fig, lr, row=1, col=1) 

64 add_volume_face_traces(fig, noise, row=1, col=2) 

65 add_volume_face_traces(fig, hr_predicted, row=1, col=3) 

66 add_volume_face_traces(fig, hr, row=1, col=4) 

67 fig.update_layout(coloraxis=dict(colorscale='gray', cmin=-1.0, cmax=1.0, showscale=False), showlegend=False) 

68 axis = dict(showgrid=False, showticklabels=False, showaxeslabels=False, title="", showbackground=False) 

69 scene = dict( 

70 xaxis=axis, 

71 yaxis=axis, 

72 zaxis=axis, 

73 ) 

74 fig.update_layout( 

75 scene1=scene, 

76 scene2=scene, 

77 scene3=scene, 

78 scene4=scene, 

79 ) 

80 

81 # fig.write_html("plotly.html", auto_play = False)  

82 fig.update_layout( 

83 height=400, 

84 width=1200, 

85 ) 

86 

87 filename = f"plotly{index}.png" 

88 fig.write_image(filename) 

89 table.add_data( 

90 wandb.Image(filename), 

91 ) 

92 index += 1 

93 

94 return {"Predictions": table} 

95 

96 

97class DDPMCallback(Callback): 

98 """ 

99 Derived from https://wandb.ai/capecape/train_sd/reports/How-To-Train-a-Conditional-Diffusion-Model-From-Scratch--VmlldzoyNzIzNTQ1#using-fastai-to-train-your-diffusion-model 

100 """ 

101 def __init__(self, n_steps:int=1000, s:float = 0.008): 

102 self.n_steps = n_steps 

103 self.s = s 

104 

105 t = torch.arange(self.n_steps) 

106 self.alpha_bar = torch.cos((t/self.n_steps+self.s)/(1+self.s) * torch.pi * 0.5)**2 

107 self.alpha = self.alpha_bar/torch.cat([torch.ones(1), self.alpha_bar[:-1]]) 

108 self.beta = 1.0 - self.alpha 

109 self.sigma = torch.sqrt(self.beta) 

110 

111 def before_batch(self): 

112 """ 

113 x: (batch_size, c, d, h, w) 

114 """ 

115 lr = self.xb[0] 

116 hr = self.yb[0] 

117 

118 noise = torch.randn_like(hr) 

119 

120 batch_size = hr.shape[0] 

121 dim = len(hr.shape) - 2 

122 

123 # lookup noise schedule 

124 t = torch.randint(0, self.n_steps, (batch_size,), dtype=torch.long) # select random timesteps 

125 if dim == 2: 

126 alpha_bar_t = self.alpha_bar[t, None, None, None] 

127 else: 

128 alpha_bar_t = self.alpha_bar[t, None, None, None, None] 

129 alpha_bar_t = alpha_bar_t.to(self.dls.device) 

130 

131 # noisify the image 

132 xt = torch.sqrt(alpha_bar_t) * hr + torch.sqrt(1-alpha_bar_t) * noise 

133 

134 # Stack input with low-resolution image (upscaled) at channel dim, 

135 # then pass the stacked image along with the noise level as tuple to the model 

136 self.learn.xb = (torch.cat([xt, lr], dim=1), alpha_bar_t.view((batch_size, 1))) 

137 self.learn.yb = (noise,) # we are trying to predict the noise 

138 

139 

140class DDPMSamplerCallback(DDPMCallback): 

141 def before_batch(self): 

142 lr = self.xb[0] 

143 batch_size = lr.shape[0] 

144 

145 # Generate a batch of random noise to start with 

146 xt = torch.randn_like(lr) 

147 

148 outputs = [xt] 

149 for t in track(reversed(range(self.n_steps)), total=self.n_steps, description="Performing diffusion steps for batch:"): 

150 z = torch.randn(xt.shape, device=xt.device) if t > 0 else torch.zeros(xt.shape, device=xt.device) 

151 alpha_t = self.alpha[t] # get noise level at current timestep 

152 alpha_bar_t = self.alpha_bar[t] 

153 sigma_t = self.sigma[t] 

154 

155 predicted_noise = self.model(torch.cat([xt, lr], dim=1), alpha_bar_t.view(batch_size, 1)) 

156 

157 # predict x_(t-1) in accordance to Algorithm 2 in paper 

158 xt = 1/torch.sqrt(alpha_t) * (xt - (1-alpha_t)/torch.sqrt(1-alpha_bar_t) * predicted_noise) + sigma_t*z 

159 outputs.append(xt) 

160 

161 # self.learn.pred = (torch.stack(outputs, dim=1),) 

162 self.learn.pred = (xt,) 

163 

164 raise CancelBatchException 

165