Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1#!/usr/bin/env python3 

2from pathlib import Path 

3import torch 

4from fastai.data.external import untar_data, URLs 

5import torchapp as ta 

6from fastai.callback.core import Callback, CancelBatchException 

7from fastcore.basics import store_attr 

8from rich.progress import track 

9from fastai.torch_core import Tensor, TensorImage, TensorImageBW 

10from torch import nn 

11from fastai.vision.augment import Resize 

12from fastai.vision.data import ImageBlock 

13from fastai.data.block import DataBlock, CategoryBlock 

14from fastai.data.transforms import get_image_files, parent_label 

15from fastai.data.load import DataLoader 

16import torchvision.transforms as T 

17from torchapp.vision import UNetApp 

18 

19from PIL import Image 

20import torch 

21import torch.nn as nn 

22import torch.nn.functional as F 

23 

24 

25 

26def one_param(m): 

27 "get model first parameter" 

28 return next(iter(m.parameters())) 

29 

30class EMA: 

31 def __init__(self, beta): 

32 super().__init__() 

33 self.beta = beta 

34 self.step = 0 

35 

36 def update_model_average(self, ma_model, current_model): 

37 for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()): 

38 old_weight, up_weight = ma_params.data, current_params.data 

39 ma_params.data = self.update_average(old_weight, up_weight) 

40 

41 def update_average(self, old, new): 

42 if old is None: 

43 return new 

44 return old * self.beta + (1 - self.beta) * new 

45 

46 def step_ema(self, ema_model, model, step_start_ema=2000): 

47 if self.step < step_start_ema: 

48 self.reset_parameters(ema_model, model) 

49 self.step += 1 

50 return 

51 self.update_model_average(ema_model, model) 

52 self.step += 1 

53 

54 def reset_parameters(self, ema_model, model): 

55 ema_model.load_state_dict(model.state_dict()) 

56 

57 

58class SelfAttention(nn.Module): 

59 def __init__(self, channels): 

60 super(SelfAttention, self).__init__() 

61 self.channels = channels 

62 self.mha = nn.MultiheadAttention(channels, 4, batch_first=True) 

63 self.ln = nn.LayerNorm([channels]) 

64 self.ff_self = nn.Sequential( 

65 nn.LayerNorm([channels]), 

66 nn.Linear(channels, channels), 

67 nn.GELU(), 

68 nn.Linear(channels, channels), 

69 ) 

70 

71 def forward(self, x): 

72 size = x.shape[-1] 

73 x = x.view(-1, self.channels, size * size).swapaxes(1, 2) 

74 x_ln = self.ln(x) 

75 attention_value, _ = self.mha(x_ln, x_ln, x_ln) 

76 attention_value = attention_value + x 

77 attention_value = self.ff_self(attention_value) + attention_value 

78 return attention_value.swapaxes(2, 1).view(-1, self.channels, size, size) 

79 

80 

81class DoubleConv(nn.Module): 

82 def __init__(self, in_channels, out_channels, mid_channels=None, residual=False): 

83 super().__init__() 

84 self.residual = residual 

85 if not mid_channels: 

86 mid_channels = out_channels 

87 self.double_conv = nn.Sequential( 

88 nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False), 

89 nn.GroupNorm(1, mid_channels), 

90 nn.GELU(), 

91 nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False), 

92 nn.GroupNorm(1, out_channels), 

93 ) 

94 

95 def forward(self, x): 

96 if self.residual: 

97 return F.gelu(x + self.double_conv(x)) 

98 else: 

99 return self.double_conv(x) 

100 

101 

102class Down(nn.Module): 

103 def __init__(self, in_channels, out_channels, emb_dim=256): 

104 super().__init__() 

105 self.maxpool_conv = nn.Sequential( 

106 nn.MaxPool2d(2), 

107 DoubleConv(in_channels, in_channels, residual=True), 

108 DoubleConv(in_channels, out_channels), 

109 ) 

110 

111 self.emb_layer = nn.Sequential( 

112 nn.SiLU(), 

113 nn.Linear( 

114 emb_dim, 

115 out_channels 

116 ), 

117 ) 

118 

119 def forward(self, x, t): 

120 x = self.maxpool_conv(x) 

121 emb = self.emb_layer(t)[:, :, None, None].repeat(1, 1, x.shape[-2], x.shape[-1]) 

122 return x + emb 

123 

124 

125class Up(nn.Module): 

126 def __init__(self, in_channels, out_channels, emb_dim=256): 

127 super().__init__() 

128 

129 self.up = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True) 

130 self.conv = nn.Sequential( 

131 DoubleConv(in_channels, in_channels, residual=True), 

132 DoubleConv(in_channels, out_channels, in_channels // 2), 

133 ) 

134 

135 self.emb_layer = nn.Sequential( 

136 nn.SiLU(), 

137 nn.Linear( 

138 emb_dim, 

139 out_channels 

140 ), 

141 ) 

142 

143 def forward(self, x, skip_x, t): 

144 x = self.up(x) 

145 x = torch.cat([skip_x, x], dim=1) 

146 x = self.conv(x) 

147 emb = self.emb_layer(t)[:, :, None, None].repeat(1, 1, x.shape[-2], x.shape[-1]) 

148 return x + emb 

149 

150 

151class UNet(nn.Module): 

152 def __init__(self, c_in=3, c_out=3, time_dim=256): 

153 super().__init__() 

154 self.c_in = c_in 

155 self.c_out = c_out 

156 self.time_dim = time_dim 

157 self.inc = DoubleConv(c_in, 64) 

158 self.down1 = Down(64, 128) 

159 self.sa1 = SelfAttention(128) 

160 self.down2 = Down(128, 256) 

161 self.sa2 = SelfAttention(256) 

162 self.down3 = Down(256, 256) 

163 self.sa3 = SelfAttention(256) 

164 

165 # self.bot1 = DoubleConv(256, 512) 

166 # self.bot2 = DoubleConv(512, 512) 

167 # self.bot3 = DoubleConv(512, 256) 

168 

169 self.bot1 = DoubleConv(256, 256) 

170 # self.bot2 = DoubleConv(512, ) 

171 self.bot3 = DoubleConv(256, 256) 

172 

173 self.up1 = Up(512, 128) 

174 self.sa4 = SelfAttention(128) 

175 self.up2 = Up(256, 64) 

176 self.sa5 = SelfAttention(64) 

177 self.up3 = Up(128, 64) 

178 self.sa6 = SelfAttention(64) 

179 self.outc = nn.Conv2d(64, c_out, kernel_size=1) 

180 

181 def pos_encoding(self, t, channels): 

182 inv_freq = 1.0 / ( 

183 10000 

184 ** (torch.arange(0, channels, 2, device=one_param(self).device).float() / channels) 

185 ) 

186 pos_enc_a = torch.sin(t.repeat(1, channels // 2) * inv_freq) 

187 pos_enc_b = torch.cos(t.repeat(1, channels // 2) * inv_freq) 

188 pos_enc = torch.cat([pos_enc_a, pos_enc_b], dim=-1) 

189 return pos_enc 

190 

191 def unet_forwad(self, x, t): 

192 x1 = self.inc(x) 

193 x2 = self.down1(x1, t) 

194 x2 = self.sa1(x2) 

195 x3 = self.down2(x2, t) 

196 x3 = self.sa2(x3) 

197 x4 = self.down3(x3, t) 

198 x4 = self.sa3(x4) 

199 

200 x4 = self.bot1(x4) 

201 # x4 = self.bot2(x4) 

202 x4 = self.bot3(x4) 

203 

204 x = self.up1(x4, x3, t) 

205 x = self.sa4(x) 

206 x = self.up2(x, x2, t) 

207 x = self.sa5(x) 

208 x = self.up3(x, x1, t) 

209 x = self.sa6(x) 

210 output = self.outc(x) 

211 return output 

212 

213 def forward(self, x, t): 

214 t = t.unsqueeze(-1) 

215 t = self.pos_encoding(t, self.time_dim) 

216 return self.unet_forwad(x, t) 

217 

218# TODO change to UNetConditional before training next 

219class UNetConditional(UNet): 

220 def __init__(self, c_in=3, c_out=3, time_dim=256, num_classes=None): 

221 super().__init__(c_in, c_out, time_dim) 

222 if num_classes is not None: 

223 self.label_emb = nn.Embedding(num_classes, time_dim) 

224 

225 def forward(self, x, t, y=None): 

226 t = t.unsqueeze(-1) 

227 t = self.pos_encoding(t, self.time_dim) 

228 

229 if y is not None: 

230 t += self.label_emb(y) 

231 

232 return self.unet_forwad(x, t) 

233 

234 

235class SampleDataloader(DataLoader): 

236 def __init__(self, *args, category_index, **kwargs): 

237 super().__init__(*args, **kwargs) 

238 self.category_index = category_index 

239 

240 def __iter__(self): 

241 y = torch.tensor([self.category_index] * self.bs, dtype=int) 

242 x = [None]* self.bs 

243 yield x,y 

244 

245 

246class ConditionalDDPMCallback(Callback): 

247 """ 

248 Derived from https://wandb.ai/capecape/train_sd/reports/How-To-Train-a-Conditional-Diffusion-Model-From-Scratch--VmlldzoyNzIzNTQ1#using-fastai-to-train-your-diffusion-model 

249 """ 

250 def __init__(self, n_steps, beta_min, beta_max, tensor_type=TensorImage, size:int=32): 

251 store_attr() 

252 

253 def before_fit(self): 

254 self.beta = torch.linspace(self.beta_min, self.beta_max, self.n_steps).to(self.dls.device) # variance schedule, linearly increased with timestep 

255 self.alpha = 1. - self.beta 

256 self.alpha_bar = torch.cumprod(self.alpha, dim=0) 

257 self.sigma = torch.sqrt(self.beta) 

258 

259 def before_batch_training(self): 

260 x0 = self.xb[0] # original images, x_0  

261 eps = self.tensor_type(torch.randn(x0.shape, device=x0.device)) # noise, x_T 

262 batch_size = x0.shape[0] 

263 t = torch.randint(0, self.n_steps, (batch_size,), device=x0.device, dtype=torch.long) # select random timesteps 

264 alpha_bar_t = self.alpha_bar[t].reshape(-1, 1, 1, 1) 

265 xt = torch.sqrt(alpha_bar_t)*x0 + torch.sqrt(1-alpha_bar_t)*eps #noisify the image 

266 self.learn.xb = (xt, t, self.yb[0]) # input to our model is noisy image and timestep 

267 self.learn.yb = (eps,) # ground truth is the noise  

268 

269 def before_batch_sampling(self): 

270 # Use y to be the labels we want to generate 

271 label = self.yb[0] 

272 batch_size = label.shape[0] 

273 

274 # Generate a batch of random noise to start with 

275 # We can ignore the self.xb[0] data and just generate random noise here. 

276 xt = self.tensor_type(torch.randn((batch_size, self.model.c_out, self.size, self.size), device=label.device)) 

277 

278 for t in track(reversed(range(self.n_steps)), total=self.n_steps, description="Performing diffusion steps for batch:"): 

279 t_batch = torch.full((batch_size,), t, device=xt.device, dtype=torch.long) 

280 z = torch.randn(xt.shape, device=xt.device) if t > 0 else torch.zeros(xt.shape, device=xt.device) 

281 alpha_t = self.alpha[t] # get noise level at current timestep 

282 alpha_bar_t = self.alpha_bar[t] 

283 sigma_t = self.sigma[t] 

284 xt = 1/torch.sqrt(alpha_t) * (xt - (1-alpha_t)/torch.sqrt(1-alpha_bar_t) * self.model(xt, t_batch, y=label)) + sigma_t*z # predict x_(t-1) in accordance to Algorithm 2 in paper 

285 self.learn.pred = (xt,) 

286 

287 raise CancelBatchException 

288 

289 def before_batch(self): 

290 if not hasattr(self, 'gather_preds'): 

291 self.before_batch_training() 

292 else: 

293 self.before_batch_sampling() 

294 

295 

296class DiffusionGeneratorCIFAR10(ta.TorchApp): 

297 """ 

298 https://wandb.ai/capecape/train_sd/reports/How-To-Train-a-Conditional-Diffusion-Model-From-Scratch--VmlldzoyNzIzNTQ1#sampling-images 

299 """ 

300 def dataloaders( 

301 self, 

302 batch_size: int = ta.Param(default=64, tune_min=8, tune_max=128, log=True, tune=True), 

303 size:int = ta.Param(default=32), 

304 ): 

305 print("Getting CIFAR10") 

306 path = untar_data(URLs.CIFAR) 

307 dblock = DataBlock( 

308 blocks=(ImageBlock(), CategoryBlock()), 

309 get_items=get_image_files, 

310 get_y=parent_label, 

311 item_tfms=Resize(size) 

312 ) 

313 self.size = size 

314 

315 return dblock.dataloaders(path, bs=batch_size) 

316 

317 def model(self): 

318 return UNetConditional(c_out=3, num_classes=10) 

319 

320 def loss_func(self): 

321 return nn.MSELoss() 

322 

323 def extra_callbacks(self): 

324 return [ConditionalDDPMCallback(n_steps=1000, beta_min=0.0001, beta_max=0.02)] 

325 

326 def inference_callbacks(self): 

327 return self.extra_callbacks() 

328 

329 def inference_dataloader( 

330 self, 

331 learner, 

332 count:int = 1, 

333 category:str = "", 

334 **kwargs 

335 ): 

336 if category not in learner.dls.vocab: 

337 raise ValueError(f"Please provide a category to generate from this list: {learner.dls.vocab}") 

338 

339 self.inference_category = category 

340 return SampleDataloader( 

341 bs=32, 

342 category_index=learner.dls.vocab.o2i[category], 

343 n=count, 

344 ) 

345 

346 def output_results( 

347 self, 

348 results, 

349 output_dir: Path = ta.Param("./outputs", help="The location of the output directory."), 

350 **kwargs, 

351 ): 

352 output_dir = Path(output_dir) 

353 print(f"Saving {len(results)} generated {self.inference_category} images:") 

354 

355 transform = T.ToPILImage() 

356 output_dir.mkdir(exist_ok=True, parents=True) 

357 for index, image in enumerate(results[0]): 

358 path = output_dir/f"{self.inference_category}.{index}.jpg" 

359 print(f"\t{path}") 

360 transform(image).save(path) 

361 

362 

363 

364 

365if __name__ == "__main__": 

366 DiffusionGeneratorCIFAR10.main()