Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1from typing import Callable, Optional
2import math
3import numpy as np
4import torch
5import torch.nn as nn
6import torch.nn.functional as F
7from torch import Tensor
9from rich.console import Console
10console = Console()
13class PositionalEncoding(nn.Module):
14 """
15 Adapted from https://pytorch.org/tutorials/beginner/transformer_tutorial.html
16 """
18 def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000):
19 super().__init__()
20 self.dropout = nn.Dropout(p=dropout)
22 position = torch.arange(max_len).unsqueeze(1)
23 div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
24 pe = torch.zeros(max_len, 1, d_model)
25 pe[:, 0, 0::2] = torch.sin(position * div_term)
26 pe[:, 0, 1::2] = torch.cos(position * div_term)
27 self.register_buffer('pe', pe)
29 def forward(self, x: Tensor) -> Tensor:
30 """
31 Arguments:
32 x: Tensor, shape ``[seq_len, batch_size, embedding_dim]``
33 """
34 x = x + self.pe[:x.size(0)]
35 return self.dropout(x)
39def conv3(in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1) -> nn.Conv1d:
40 """convolution of width 3 with padding"""
41 return nn.Conv1d(
42 in_planes,
43 out_planes,
44 kernel_size=3,
45 stride=stride,
46 padding=dilation,
47 groups=groups,
48 bias=False,
49 dilation=dilation,
50 )
53class ResidualBlock1D(nn.Module):
54 """Adapted from https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py"""
56 def __init__(
57 self,
58 inplanes: int,
59 planes: int,
60 stride: int = 1,
61 downsample: Optional[nn.Module] = None,
62 norm_layer: Optional[Callable[..., nn.Module]] = None,
63 ) -> None:
64 super().__init__()
65 if norm_layer is None:
66 norm_layer = nn.BatchNorm1d
68 if stride != 1:
69 downsample = nn.Sequential(
70 nn.Conv1d(inplanes, planes, kernel_size=1, stride=stride, bias=False),
71 norm_layer(planes),
72 )
74 # Both self.conv1 and self.downsample layers downsample the input when stride != 1
75 self.conv1 = conv3(inplanes, planes, stride)
76 self.bn1 = norm_layer(planes)
77 self.relu = nn.ReLU(inplace=True)
78 self.conv2 = conv3(planes, planes)
79 self.bn2 = norm_layer(planes)
80 self.downsample = downsample
81 self.stride = stride
83 def forward(self, x: Tensor) -> Tensor:
84 identity = x
86 out = self.conv1(x)
87 out = self.bn1(out)
88 out = self.relu(out)
90 out = self.conv2(out)
91 out = self.bn2(out)
93 if self.downsample is not None:
94 identity = self.downsample(x)
96 out += identity
97 out = self.relu(out)
99 return out
102class ConvRecurrantClassifier(nn.Module):
103 def __init__(
104 self,
105 num_classes,
106 embedding_dim: int = 8,
107 filters: int = 256,
108 cnn_layers: int = 6,
109 kernel_size_cnn: int = 9,
110 lstm_dims: int = 256,
111 final_layer_dims: int = 0, # If this is zero then it isn't used.
112 dropout: float = 0.5,
113 kernel_size_maxpool: int = 2,
114 residual_blocks: bool = False,
115 final_bias: bool = True,
116 multi_kernel_sizes: bool = True,
117 ):
118 super().__init__()
120 num_embeddings = 5 # i.e. the size of the vocab which is N, A, C, G, T
122 self.num_classes = num_classes
123 self.num_embeddings = num_embeddings
125 ########################
126 ## Embedding
127 ########################
128 self.embedding_dim = embedding_dim
129 self.embed = nn.Embedding(
130 num_embeddings=num_embeddings,
131 embedding_dim=embedding_dim,
132 )
133 self.dropout = nn.Dropout(dropout)
135 ########################
136 ## Convolutional Layer
137 ########################
139 self.multi_kernel_sizes = multi_kernel_sizes
140 if multi_kernel_sizes:
141 kernel_size = 5
142 convolutions = []
143 for _ in range(cnn_layers):
144 convolutions.append(
145 nn.Conv1d(in_channels=embedding_dim, out_channels=filters, kernel_size=kernel_size, padding='same')
146 )
147 kernel_size += 2
149 self.convolutions = nn.ModuleList(convolutions)
150 self.pool = nn.MaxPool1d(kernel_size=kernel_size_maxpool)
151 current_dims = filters * cnn_layers
152 else:
153 self.filters = filters
154 self.residual_blocks = residual_blocks
155 self.intermediate_filters = 128
156 if residual_blocks:
157 self.cnn_layers = nn.Sequential(
158 ResidualBlock1D(embedding_dim, embedding_dim),
159 ResidualBlock1D(embedding_dim, self.intermediate_filters, 2),
160 ResidualBlock1D(self.intermediate_filters, self.intermediate_filters),
161 ResidualBlock1D(self.intermediate_filters, filters, 2),
162 ResidualBlock1D(filters, filters),
163 )
164 else:
165 self.kernel_size_cnn = kernel_size_cnn
166 self.cnn_layers = nn.Sequential(
167 nn.Conv1d(in_channels=embedding_dim, out_channels=filters, kernel_size=kernel_size_cnn),
168 nn.MaxPool1d(kernel_size=kernel_size_maxpool),
169 )
170 current_dims = filters
172 ########################
173 ## Recurrent Layer
174 ########################
175 self.lstm_dims = lstm_dims
176 if lstm_dims:
177 self.bi_lstm = nn.LSTM(
178 input_size=current_dims, # Is this dimension? - this should receive output from maxpool
179 hidden_size=lstm_dims,
180 bidirectional=True,
181 bias=True,
182 batch_first=True,
183 dropout=dropout,
184 )
185 current_dims = lstm_dims * 2
187 if final_layer_dims:
188 self.fc1 = nn.Linear(
189 in_features=current_dims,
190 out_features=final_layer_dims,
191 )
192 current_dims = final_layer_dims
194 #################################
195 ## Linear Layer(s) to Predictions
196 #################################
197 self.final_layer_dims = final_layer_dims
198 self.logits = nn.Linear(
199 in_features=current_dims,
200 out_features=self.num_classes,
201 bias=final_bias,
202 )
204 def forward(self, x):
205 ########################
206 ## Embedding
207 ########################
208 # Cast as pytorch tensor
209 # x = Tensor(x)
211 # Convert to int because it may be simply a byte
212 x = x.int()
213 x = self.embed(x)
215 ########################
216 ## Convolutional Layer
217 ########################
218 # Transpose seq_len with embedding dims to suit convention of pytorch CNNs (batch_size, input_size, seq_len)
219 x = x.transpose(1, 2)
221 if self.multi_kernel_sizes:
222 conv_results = [conv(x) for conv in self.convolutions]
223 x = torch.cat(conv_results, dim=-2)
225 x = self.pool(x)
226 else:
227 x = self.cnn_layers(x)
229 # Current shape: batch, filters, seq_len
230 # With batch_first=True, LSTM expects shape: batch, seq, feature
231 x = x.transpose(2, 1)
233 ########################
234 ## Recurrent Layer
235 ########################
237 # BiLSTM
238 if self.lstm_dims:
239 output, (h_n, c_n) = self.bi_lstm(x)
240 # h_n of shape (num_layers * num_directions, batch, hidden_size)
241 # We are using a single layer with 2 directions so the two output vectors are
242 # [0,:,:] and [1,:,:]
243 # [0,:,:] -> considers the first index from the first dimension
244 x = torch.cat((h_n[0, :, :], h_n[1, :, :]), dim=-1)
245 else:
246 # if there is no recurrent layer then simply sum over sequence dimension
247 x = torch.sum(x, dim=1)
249 #################################
250 ## Linear Layer(s) to Predictions
251 #################################
252 # Ignore if the final_layer_dims is empty
253 if self.final_layer_dims:
254 x = F.relu(self.fc1(x))
255 # Get logits. The cross-entropy loss optimisation function just takes in the logits and automatically does a softmax
256 out = self.logits(x)
258 return out
261class ConvClassifier(nn.Module):
262 def __init__(
263 self,
264 embedding_dim=8,
265 cnn_layers=6,
266 num_classes=5,
267 cnn_dims_start=64,
268 kernel_size_maxpool=2,
269 num_embeddings=5, # i.e. the size of the vocab which is N, A, C, G, T
270 kernel_size=3,
271 factor=2,
272 padding="same",
273 padding_mode="zeros",
274 dropout=0.5,
275 final_bias=True,
276 lstm_dims: int = 0,
277 penultimate_dims: int = 1028,
278 include_length: bool = False,
279 length_scaling:float = 3_000.0,
280 transformer_heads: int = 8,
281 transformer_layers: int = 6,
282 ):
283 super().__init__()
285 self.embedding_dim = embedding_dim
286 self.cnn_layers = cnn_layers
287 self.num_classes = num_classes
288 self.kernel_size_maxpool = kernel_size_maxpool
290 self.num_embeddings = num_embeddings
291 self.kernel_size = kernel_size
292 self.factor = factor
293 self.dropout = dropout
294 self.include_length = include_length
295 self.length_scaling = length_scaling
296 self.transformer_layers = transformer_layers
297 self.transformer_heads = transformer_heads
299 self.embedding = nn.Embedding(
300 num_embeddings=num_embeddings,
301 embedding_dim=embedding_dim,
302 )
304 in_channels = embedding_dim
305 out_channels = cnn_dims_start
306 conv_layers = []
307 for layer_index in range(cnn_layers):
308 conv_layers += [
309 nn.Conv1d(
310 in_channels=in_channels,
311 out_channels=out_channels,
312 kernel_size=kernel_size,
313 padding=padding,
314 padding_mode=padding_mode,
315 ),
316 nn.ReLU(),
317 nn.Dropout(dropout),
318 nn.MaxPool1d(kernel_size_maxpool),
319 # nn.Conv1d(
320 # in_channels=out_channels,
321 # out_channels=out_channels,
322 # kernel_size=kernel_size,
323 # stride=kernel_size_maxpool,
324 # ),
325 ]
326 in_channels = out_channels
327 out_channels = int(out_channels * factor)
329 self.conv = nn.Sequential(*conv_layers)
331 if self.transformer_layers:
332 self.positional_encoding = PositionalEncoding(d_model=in_channels)
333 encoder_layer = nn.TransformerEncoderLayer(d_model=in_channels, nhead=self.transformer_heads, batch_first=True)
334 self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=self.transformer_layers)
335 else:
336 self.transformer_encoder = None
338 self.lstm_dims = lstm_dims
339 if lstm_dims:
340 self.bi_lstm = nn.LSTM(
341 input_size=in_channels, # Is this dimension? - this should receive output from maxpool
342 hidden_size=lstm_dims,
343 bidirectional=True,
344 bias=True,
345 batch_first=True,
346 dropout=dropout,
347 )
348 current_dims = lstm_dims * 2
349 else:
350 current_dims = in_channels
352 self.average_pool = nn.AdaptiveAvgPool1d(1)
354 current_dims += int(include_length)
355 self.final = nn.Sequential(
356 # nn.Linear(in_features=current_dims, out_features=current_dims, bias=True),
357 # nn.ReLU(),
358 nn.Linear(in_features=current_dims, out_features=penultimate_dims, bias=True),
359 nn.ReLU(),
360 nn.Linear(in_features=penultimate_dims, out_features=num_classes, bias=final_bias),
361 )
363 def forward(self, x):
364 # Convert to int because it may be simply a byte
365 x = x.int()
366 length = x.shape[-1]
367 x = self.embedding(x)
369 # Transpose seq_len with embedding dims to suit convention of pytorch CNNs (batch_size, input_size, seq_len)
370 x = x.transpose(1, 2)
371 x = self.conv(x)
373 if hasattr(self, 'transformer_encoder') and self.transformer_encoder:
374 x = x.transpose(2, 1)
375 x = self.positional_encoding(x)
376 x = self.transformer_encoder(x)
377 x = x.transpose(1, 2)
379 if self.lstm_dims:
380 x = x.transpose(2, 1)
381 output, (h_n, c_n) = self.bi_lstm(x)
382 # h_n of shape (num_layers * num_directions, batch, hidden_size)
383 # We are using a single layer with 2 directions so the two output vectors are
384 # [0,:,:] and [1,:,:]
385 # [0,:,:] -> considers the first index from the first dimension
386 x = torch.cat((h_n[0, :, :], h_n[1, :, :]), dim=-1)
387 elif hasattr(x, 'average_pool'):
388 x = self.average_pool(x)
389 x = torch.flatten(x, 1)
390 else:
391 x = torch.mean(x, axis=-1)
393 if getattr(self, 'include_length', False):
394 length_tensor = torch.full( (x.shape[0], 1), length/self.length_scaling, device=x.device )
395 x = torch.cat([x, length_tensor], dim=1)
397 predictions = self.final(x)
399 return predictions
401 def new_final(self, output_size):
402 final_in_features = list(self.final.modules())[1].in_features
404 self.final = nn.Sequential(
405 nn.Linear(in_features=final_in_features, out_features=final_in_features, bias=True),
406 nn.ReLU(),
407 nn.Linear(in_features=final_in_features, out_features=output_size, bias=final_bias),
408 )
411class SequentialDebug(nn.Sequential):
412 def forward(self, input):
413 macs_cummulative = 0
414 from thop import profile
416 console.print(f"Input shape {input.shape}")
417 for module in self:
418 console.print(f"Module: {module} ({type(module)})")
419 macs, _ = profile(module, inputs=(input, ))
420 macs_cummulative += int(macs)
421 console.print(f"MACs: {int(macs)} (cummulative {macs_cummulative})")
423 input = module(input)
424 console.print(f"Output shape: {input.shape}")
426 return input
430class ConvProcessor(nn.Sequential):
431 def __init__(
432 self,
433 in_channels=8,
434 cnn_layers=6,
435 cnn_dims_start=64,
436 kernel_size_maxpool=2,
437 kernel_size=3,
438 factor=2,
439 dropout=0.5,
440 padding="same",
441 padding_mode="zeros",
442 ):
443 out_channels = cnn_dims_start
444 conv_layers = []
445 for layer_index in range(cnn_layers):
446 conv_layers += [
447 nn.Conv1d(
448 in_channels=in_channels,
449 out_channels=out_channels,
450 kernel_size=kernel_size,
451 padding=padding,
452 padding_mode=padding_mode,
453 ),
454 nn.ReLU(),
455 nn.Dropout(dropout),
456 nn.MaxPool1d(kernel_size_maxpool),
457 # nn.Conv1d(
458 # in_channels=out_channels,
459 # out_channels=out_channels,
460 # kernel_size=kernel_size,
461 # stride=kernel_size_maxpool,
462 # ),
463 ]
464 in_channels = out_channels
465 out_channels = int(out_channels * factor)
467 super().__init__(*conv_layers)
470def calc_cnn_dims_start(
471 macc,
472 seq_len:int,
473 embedding_dim:int,
474 cnn_layers:int,
475 kernel_size:int,
476 factor:float,
477 penultimate_dims: int,
478 num_classes: int,
479):
480 """
481 Solving equation M = s k e c + \sum_{l=1}^{L-1} \frac{s}{2^{l} } k c^2 f^{2l-1} + c f^{L-1} p + p o
482 for c.
484 Args:
485 macc_per_base (int): the number of multiply-accumulate operations per base pair in the sequence.
486 embedding_dim (int): The size of the embedding.
487 cnn_layers (int): The number of CNN layers.
488 kernel_size (int): The size of the kernel in the CNN
489 factor (float): The multiplying factor for the CNN output layers.
490 """
491 b = kernel_size * embedding_dim * seq_len + factor ** (cnn_layers-1) * penultimate_dims
492 c = penultimate_dims * num_classes - macc
494 if cnn_layers == 1:
495 cnn_dims_start = -c/b
496 else:
497 a = 0.0
498 for layer_index in range(1, cnn_layers):
499 a += seq_len * kernel_size * (0.5**layer_index) * (factor**(2 * layer_index - 1))
501 cnn_dims_start = (-b + np.sqrt(b**2 - 4*a*c))/(2 * a)
503 return int(cnn_dims_start + 0.5)