Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1#!/usr/bin/env python3
2from pathlib import Path
3import torch
4from fastai.data.external import untar_data, URLs
5import torchapp as ta
6from fastai.callback.core import Callback, CancelBatchException
7from fastcore.basics import store_attr
8from rich.progress import track
9from fastai.torch_core import Tensor, TensorImage, TensorImageBW
10from torch import nn
11from fastai.vision.augment import Resize
12from fastai.vision.data import ImageBlock
13from fastai.data.block import DataBlock, CategoryBlock
14from fastai.data.transforms import get_image_files, parent_label
15from fastai.data.load import DataLoader
16import torchvision.transforms as T
17from torchapp.vision import UNetApp
19from PIL import Image
20import torch
21import torch.nn as nn
22import torch.nn.functional as F
26def one_param(m):
27 "get model first parameter"
28 return next(iter(m.parameters()))
30class EMA:
31 def __init__(self, beta):
32 super().__init__()
33 self.beta = beta
34 self.step = 0
36 def update_model_average(self, ma_model, current_model):
37 for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()):
38 old_weight, up_weight = ma_params.data, current_params.data
39 ma_params.data = self.update_average(old_weight, up_weight)
41 def update_average(self, old, new):
42 if old is None:
43 return new
44 return old * self.beta + (1 - self.beta) * new
46 def step_ema(self, ema_model, model, step_start_ema=2000):
47 if self.step < step_start_ema:
48 self.reset_parameters(ema_model, model)
49 self.step += 1
50 return
51 self.update_model_average(ema_model, model)
52 self.step += 1
54 def reset_parameters(self, ema_model, model):
55 ema_model.load_state_dict(model.state_dict())
58class SelfAttention(nn.Module):
59 def __init__(self, channels):
60 super(SelfAttention, self).__init__()
61 self.channels = channels
62 self.mha = nn.MultiheadAttention(channels, 4, batch_first=True)
63 self.ln = nn.LayerNorm([channels])
64 self.ff_self = nn.Sequential(
65 nn.LayerNorm([channels]),
66 nn.Linear(channels, channels),
67 nn.GELU(),
68 nn.Linear(channels, channels),
69 )
71 def forward(self, x):
72 size = x.shape[-1]
73 x = x.view(-1, self.channels, size * size).swapaxes(1, 2)
74 x_ln = self.ln(x)
75 attention_value, _ = self.mha(x_ln, x_ln, x_ln)
76 attention_value = attention_value + x
77 attention_value = self.ff_self(attention_value) + attention_value
78 return attention_value.swapaxes(2, 1).view(-1, self.channels, size, size)
81class DoubleConv(nn.Module):
82 def __init__(self, in_channels, out_channels, mid_channels=None, residual=False):
83 super().__init__()
84 self.residual = residual
85 if not mid_channels:
86 mid_channels = out_channels
87 self.double_conv = nn.Sequential(
88 nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False),
89 nn.GroupNorm(1, mid_channels),
90 nn.GELU(),
91 nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False),
92 nn.GroupNorm(1, out_channels),
93 )
95 def forward(self, x):
96 if self.residual:
97 return F.gelu(x + self.double_conv(x))
98 else:
99 return self.double_conv(x)
102class Down(nn.Module):
103 def __init__(self, in_channels, out_channels, emb_dim=256):
104 super().__init__()
105 self.maxpool_conv = nn.Sequential(
106 nn.MaxPool2d(2),
107 DoubleConv(in_channels, in_channels, residual=True),
108 DoubleConv(in_channels, out_channels),
109 )
111 self.emb_layer = nn.Sequential(
112 nn.SiLU(),
113 nn.Linear(
114 emb_dim,
115 out_channels
116 ),
117 )
119 def forward(self, x, t):
120 x = self.maxpool_conv(x)
121 emb = self.emb_layer(t)[:, :, None, None].repeat(1, 1, x.shape[-2], x.shape[-1])
122 return x + emb
125class Up(nn.Module):
126 def __init__(self, in_channels, out_channels, emb_dim=256):
127 super().__init__()
129 self.up = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True)
130 self.conv = nn.Sequential(
131 DoubleConv(in_channels, in_channels, residual=True),
132 DoubleConv(in_channels, out_channels, in_channels // 2),
133 )
135 self.emb_layer = nn.Sequential(
136 nn.SiLU(),
137 nn.Linear(
138 emb_dim,
139 out_channels
140 ),
141 )
143 def forward(self, x, skip_x, t):
144 x = self.up(x)
145 x = torch.cat([skip_x, x], dim=1)
146 x = self.conv(x)
147 emb = self.emb_layer(t)[:, :, None, None].repeat(1, 1, x.shape[-2], x.shape[-1])
148 return x + emb
151class UNet(nn.Module):
152 def __init__(self, c_in=3, c_out=3, time_dim=256):
153 super().__init__()
154 self.c_in = c_in
155 self.c_out = c_out
156 self.time_dim = time_dim
157 self.inc = DoubleConv(c_in, 64)
158 self.down1 = Down(64, 128)
159 self.sa1 = SelfAttention(128)
160 self.down2 = Down(128, 256)
161 self.sa2 = SelfAttention(256)
162 self.down3 = Down(256, 256)
163 self.sa3 = SelfAttention(256)
165 # self.bot1 = DoubleConv(256, 512)
166 # self.bot2 = DoubleConv(512, 512)
167 # self.bot3 = DoubleConv(512, 256)
169 self.bot1 = DoubleConv(256, 256)
170 # self.bot2 = DoubleConv(512, )
171 self.bot3 = DoubleConv(256, 256)
173 self.up1 = Up(512, 128)
174 self.sa4 = SelfAttention(128)
175 self.up2 = Up(256, 64)
176 self.sa5 = SelfAttention(64)
177 self.up3 = Up(128, 64)
178 self.sa6 = SelfAttention(64)
179 self.outc = nn.Conv2d(64, c_out, kernel_size=1)
181 def pos_encoding(self, t, channels):
182 inv_freq = 1.0 / (
183 10000
184 ** (torch.arange(0, channels, 2, device=one_param(self).device).float() / channels)
185 )
186 pos_enc_a = torch.sin(t.repeat(1, channels // 2) * inv_freq)
187 pos_enc_b = torch.cos(t.repeat(1, channels // 2) * inv_freq)
188 pos_enc = torch.cat([pos_enc_a, pos_enc_b], dim=-1)
189 return pos_enc
191 def unet_forwad(self, x, t):
192 x1 = self.inc(x)
193 x2 = self.down1(x1, t)
194 x2 = self.sa1(x2)
195 x3 = self.down2(x2, t)
196 x3 = self.sa2(x3)
197 x4 = self.down3(x3, t)
198 x4 = self.sa3(x4)
200 x4 = self.bot1(x4)
201 # x4 = self.bot2(x4)
202 x4 = self.bot3(x4)
204 x = self.up1(x4, x3, t)
205 x = self.sa4(x)
206 x = self.up2(x, x2, t)
207 x = self.sa5(x)
208 x = self.up3(x, x1, t)
209 x = self.sa6(x)
210 output = self.outc(x)
211 return output
213 def forward(self, x, t):
214 t = t.unsqueeze(-1)
215 t = self.pos_encoding(t, self.time_dim)
216 return self.unet_forwad(x, t)
218# TODO change to UNetConditional before training next
219class UNetConditional(UNet):
220 def __init__(self, c_in=3, c_out=3, time_dim=256, num_classes=None):
221 super().__init__(c_in, c_out, time_dim)
222 if num_classes is not None:
223 self.label_emb = nn.Embedding(num_classes, time_dim)
225 def forward(self, x, t, y=None):
226 t = t.unsqueeze(-1)
227 t = self.pos_encoding(t, self.time_dim)
229 if y is not None:
230 t += self.label_emb(y)
232 return self.unet_forwad(x, t)
235class SampleDataloader(DataLoader):
236 def __init__(self, *args, category_index, **kwargs):
237 super().__init__(*args, **kwargs)
238 self.category_index = category_index
240 def __iter__(self):
241 y = torch.tensor([self.category_index] * self.bs, dtype=int)
242 x = [None]* self.bs
243 yield x,y
246class ConditionalDDPMCallback(Callback):
247 """
248 Derived from https://wandb.ai/capecape/train_sd/reports/How-To-Train-a-Conditional-Diffusion-Model-From-Scratch--VmlldzoyNzIzNTQ1#using-fastai-to-train-your-diffusion-model
249 """
250 def __init__(self, n_steps, beta_min, beta_max, tensor_type=TensorImage, size:int=32):
251 store_attr()
253 def before_fit(self):
254 self.beta = torch.linspace(self.beta_min, self.beta_max, self.n_steps).to(self.dls.device) # variance schedule, linearly increased with timestep
255 self.alpha = 1. - self.beta
256 self.alpha_bar = torch.cumprod(self.alpha, dim=0)
257 self.sigma = torch.sqrt(self.beta)
259 def before_batch_training(self):
260 x0 = self.xb[0] # original images, x_0
261 eps = self.tensor_type(torch.randn(x0.shape, device=x0.device)) # noise, x_T
262 batch_size = x0.shape[0]
263 t = torch.randint(0, self.n_steps, (batch_size,), device=x0.device, dtype=torch.long) # select random timesteps
264 alpha_bar_t = self.alpha_bar[t].reshape(-1, 1, 1, 1)
265 xt = torch.sqrt(alpha_bar_t)*x0 + torch.sqrt(1-alpha_bar_t)*eps #noisify the image
266 self.learn.xb = (xt, t, self.yb[0]) # input to our model is noisy image and timestep
267 self.learn.yb = (eps,) # ground truth is the noise
269 def before_batch_sampling(self):
270 # Use y to be the labels we want to generate
271 label = self.yb[0]
272 batch_size = label.shape[0]
274 # Generate a batch of random noise to start with
275 # We can ignore the self.xb[0] data and just generate random noise here.
276 xt = self.tensor_type(torch.randn((batch_size, self.model.c_out, self.size, self.size), device=label.device))
278 for t in track(reversed(range(self.n_steps)), total=self.n_steps, description="Performing diffusion steps for batch:"):
279 t_batch = torch.full((batch_size,), t, device=xt.device, dtype=torch.long)
280 z = torch.randn(xt.shape, device=xt.device) if t > 0 else torch.zeros(xt.shape, device=xt.device)
281 alpha_t = self.alpha[t] # get noise level at current timestep
282 alpha_bar_t = self.alpha_bar[t]
283 sigma_t = self.sigma[t]
284 xt = 1/torch.sqrt(alpha_t) * (xt - (1-alpha_t)/torch.sqrt(1-alpha_bar_t) * self.model(xt, t_batch, y=label)) + sigma_t*z # predict x_(t-1) in accordance to Algorithm 2 in paper
285 self.learn.pred = (xt,)
287 raise CancelBatchException
289 def before_batch(self):
290 if not hasattr(self, 'gather_preds'):
291 self.before_batch_training()
292 else:
293 self.before_batch_sampling()
296class DiffusionGeneratorCIFAR10(ta.TorchApp):
297 """
298 https://wandb.ai/capecape/train_sd/reports/How-To-Train-a-Conditional-Diffusion-Model-From-Scratch--VmlldzoyNzIzNTQ1#sampling-images
299 """
300 def dataloaders(
301 self,
302 batch_size: int = ta.Param(default=64, tune_min=8, tune_max=128, log=True, tune=True),
303 size:int = ta.Param(default=32),
304 ):
305 print("Getting CIFAR10")
306 path = untar_data(URLs.CIFAR)
307 dblock = DataBlock(
308 blocks=(ImageBlock(), CategoryBlock()),
309 get_items=get_image_files,
310 get_y=parent_label,
311 item_tfms=Resize(size)
312 )
313 self.size = size
315 return dblock.dataloaders(path, bs=batch_size)
317 def model(self):
318 return UNetConditional(c_out=3, num_classes=10)
320 def loss_func(self):
321 return nn.MSELoss()
323 def extra_callbacks(self):
324 return [ConditionalDDPMCallback(n_steps=1000, beta_min=0.0001, beta_max=0.02)]
326 def inference_callbacks(self):
327 return self.extra_callbacks()
329 def inference_dataloader(
330 self,
331 learner,
332 count:int = 1,
333 category:str = "",
334 **kwargs
335 ):
336 if category not in learner.dls.vocab:
337 raise ValueError(f"Please provide a category to generate from this list: {learner.dls.vocab}")
339 self.inference_category = category
340 return SampleDataloader(
341 bs=32,
342 category_index=learner.dls.vocab.o2i[category],
343 n=count,
344 )
346 def output_results(
347 self,
348 results,
349 output_dir: Path = ta.Param("./outputs", help="The location of the output directory."),
350 **kwargs,
351 ):
352 output_dir = Path(output_dir)
353 print(f"Saving {len(results)} generated {self.inference_category} images:")
355 transform = T.ToPILImage()
356 output_dir.mkdir(exist_ok=True, parents=True)
357 for index, image in enumerate(results[0]):
358 path = output_dir/f"{self.inference_category}.{index}.jpg"
359 print(f"\t{path}")
360 transform(image).save(path)
365if __name__ == "__main__":
366 DiffusionGeneratorCIFAR10.main()