Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1import torch
2from rich.progress import track
3from fastai.callback.core import Callback, CancelBatchException
4import wandb
5from plotly.subplots import make_subplots
6import plotly.graph_objects as go
7from fastcore.dispatch import typedispatch
9from supercat.visualization import add_volume_face_traces
11@typedispatch
12def wandb_process(x, y, samples, outs, preds):
13 table = wandb.Table(columns=["Sample"])
14 index = 0
15 for (sample_input, sample_target), prediction in zip(samples, outs):
16 xt = sample_input[0]
17 lr = sample_input[1]
18 alpha_bar_t = sample_input[2,0,0]
19 noise = sample_target[0]
20 dim = len(xt.shape)
22 if dim == 3:
23 alpha_bar_t = alpha_bar_t[0]
26 #xt = torch.sqrt(alpha_bar_t) * hr + torch.sqrt(1-alpha_bar_t) * noise
27 hr = (xt - torch.sqrt(1-alpha_bar_t) * noise)/torch.sqrt(alpha_bar_t)
28 hr_predicted = (xt - torch.sqrt(1-alpha_bar_t) * prediction[0][0])/torch.sqrt(alpha_bar_t)
31 specs = None
32 if dim == 3:
33 specs = [[{'type':"surface"}, {'type':"surface"}, {'type':"surface"}, {'type':"surface"}, ]]
35 fig = make_subplots(
36 rows=1,
37 cols=4,
38 subplot_titles=(
39 "Low Resolution",
40 f"Noise: {alpha_bar_t}",
41 "Prediction (Single Shot)",
42 "Ground Truth",
43 ),
44 vertical_spacing = 0.02,
45 horizontal_spacing = 0.02,
46 specs=specs,
47 )
49 if dim == 2:
50 def add_trace(z, col):
51 fig.add_trace( go.Heatmap(z=lr, zmin=-1.0, zmax=1.0, autocolorscale=False, showscale=False), row=1, col=col)
52 add_trace(lr, 1)
53 add_trace(noise, 2)
54 add_trace(hr_predicted, 3)
55 add_trace(hr, 4)
56 fig.update_traces(
57 zmax=-1.0,
58 zmin=1.0,
59 )
60 fig.update_xaxes(showticklabels=False)
61 fig.update_yaxes(showticklabels=False)
62 else:
63 add_volume_face_traces(fig, lr, row=1, col=1)
64 add_volume_face_traces(fig, noise, row=1, col=2)
65 add_volume_face_traces(fig, hr_predicted, row=1, col=3)
66 add_volume_face_traces(fig, hr, row=1, col=4)
67 fig.update_layout(coloraxis=dict(colorscale='gray', cmin=-1.0, cmax=1.0, showscale=False), showlegend=False)
68 axis = dict(showgrid=False, showticklabels=False, showaxeslabels=False, title="", showbackground=False)
69 scene = dict(
70 xaxis=axis,
71 yaxis=axis,
72 zaxis=axis,
73 )
74 fig.update_layout(
75 scene1=scene,
76 scene2=scene,
77 scene3=scene,
78 scene4=scene,
79 )
81 # fig.write_html("plotly.html", auto_play = False)
82 fig.update_layout(
83 height=400,
84 width=1200,
85 )
87 filename = f"plotly{index}.png"
88 fig.write_image(filename)
89 table.add_data(
90 wandb.Image(filename),
91 )
92 index += 1
94 return {"Predictions": table}
97class DDPMCallback(Callback):
98 """
99 Derived from https://wandb.ai/capecape/train_sd/reports/How-To-Train-a-Conditional-Diffusion-Model-From-Scratch--VmlldzoyNzIzNTQ1#using-fastai-to-train-your-diffusion-model
100 """
101 def __init__(self, n_steps:int=1000, s:float = 0.008):
102 self.n_steps = n_steps
103 self.s = s
105 t = torch.arange(self.n_steps)
106 self.alpha_bar = torch.cos((t/self.n_steps+self.s)/(1+self.s) * torch.pi * 0.5)**2
107 self.alpha = self.alpha_bar/torch.cat([torch.ones(1), self.alpha_bar[:-1]])
108 self.beta = 1.0 - self.alpha
109 self.sigma = torch.sqrt(self.beta)
111 def before_batch(self):
112 """
113 x: (batch_size, c, d, h, w)
114 """
115 lr = self.xb[0]
116 hr = self.yb[0]
118 noise = torch.randn_like(hr)
120 batch_size = hr.shape[0]
121 dim = len(hr.shape) - 2
123 # lookup noise schedule
124 t = torch.randint(0, self.n_steps, (batch_size,), dtype=torch.long) # select random timesteps
125 if dim == 2:
126 alpha_bar_t = self.alpha_bar[t, None, None, None]
127 else:
128 alpha_bar_t = self.alpha_bar[t, None, None, None, None]
129 alpha_bar_t = alpha_bar_t.to(self.dls.device)
131 # noisify the image
132 xt = torch.sqrt(alpha_bar_t) * hr + torch.sqrt(1-alpha_bar_t) * noise
134 # Stack input with low-resolution image (upscaled) at channel dim,
135 # then pass the stacked image along with the noise level as tuple to the model
136 self.learn.xb = (torch.cat([xt, lr], dim=1), alpha_bar_t.view((batch_size, 1)))
137 self.learn.yb = (noise,) # we are trying to predict the noise
140class DDPMSamplerCallback(DDPMCallback):
141 def before_batch(self):
142 lr = self.xb[0]
143 batch_size = lr.shape[0]
145 # Generate a batch of random noise to start with
146 xt = torch.randn_like(lr)
148 outputs = [xt]
149 for t in track(reversed(range(self.n_steps)), total=self.n_steps, description="Performing diffusion steps for batch:"):
150 z = torch.randn(xt.shape, device=xt.device) if t > 0 else torch.zeros(xt.shape, device=xt.device)
151 alpha_t = self.alpha[t] # get noise level at current timestep
152 alpha_bar_t = self.alpha_bar[t]
153 sigma_t = self.sigma[t]
155 predicted_noise = self.model(torch.cat([xt, lr], dim=1), alpha_bar_t.view(batch_size, 1))
157 # predict x_(t-1) in accordance to Algorithm 2 in paper
158 xt = 1/torch.sqrt(alpha_t) * (xt - (1-alpha_t)/torch.sqrt(1-alpha_bar_t) * predicted_noise) + sigma_t*z
159 outputs.append(xt)
161 # self.learn.pred = (torch.stack(outputs, dim=1),)
162 self.learn.pred = (xt,)
164 raise CancelBatchException