Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1from typing import Callable, Optional 

2import math 

3import numpy as np 

4import torch 

5import torch.nn as nn 

6import torch.nn.functional as F 

7from torch import Tensor 

8 

9from rich.console import Console 

10console = Console() 

11 

12 

13class PositionalEncoding(nn.Module): 

14 """ 

15 Adapted from https://pytorch.org/tutorials/beginner/transformer_tutorial.html 

16 """ 

17 

18 def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000): 

19 super().__init__() 

20 self.dropout = nn.Dropout(p=dropout) 

21 

22 position = torch.arange(max_len).unsqueeze(1) 

23 div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)) 

24 pe = torch.zeros(max_len, 1, d_model) 

25 pe[:, 0, 0::2] = torch.sin(position * div_term) 

26 pe[:, 0, 1::2] = torch.cos(position * div_term) 

27 self.register_buffer('pe', pe) 

28 

29 def forward(self, x: Tensor) -> Tensor: 

30 """ 

31 Arguments: 

32 x: Tensor, shape ``[seq_len, batch_size, embedding_dim]`` 

33 """ 

34 x = x + self.pe[:x.size(0)] 

35 return self.dropout(x) 

36 

37 

38 

39def conv3(in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1) -> nn.Conv1d: 

40 """convolution of width 3 with padding""" 

41 return nn.Conv1d( 

42 in_planes, 

43 out_planes, 

44 kernel_size=3, 

45 stride=stride, 

46 padding=dilation, 

47 groups=groups, 

48 bias=False, 

49 dilation=dilation, 

50 ) 

51 

52 

53class ResidualBlock1D(nn.Module): 

54 """Adapted from https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py""" 

55 

56 def __init__( 

57 self, 

58 inplanes: int, 

59 planes: int, 

60 stride: int = 1, 

61 downsample: Optional[nn.Module] = None, 

62 norm_layer: Optional[Callable[..., nn.Module]] = None, 

63 ) -> None: 

64 super().__init__() 

65 if norm_layer is None: 

66 norm_layer = nn.BatchNorm1d 

67 

68 if stride != 1: 

69 downsample = nn.Sequential( 

70 nn.Conv1d(inplanes, planes, kernel_size=1, stride=stride, bias=False), 

71 norm_layer(planes), 

72 ) 

73 

74 # Both self.conv1 and self.downsample layers downsample the input when stride != 1 

75 self.conv1 = conv3(inplanes, planes, stride) 

76 self.bn1 = norm_layer(planes) 

77 self.relu = nn.ReLU(inplace=True) 

78 self.conv2 = conv3(planes, planes) 

79 self.bn2 = norm_layer(planes) 

80 self.downsample = downsample 

81 self.stride = stride 

82 

83 def forward(self, x: Tensor) -> Tensor: 

84 identity = x 

85 

86 out = self.conv1(x) 

87 out = self.bn1(out) 

88 out = self.relu(out) 

89 

90 out = self.conv2(out) 

91 out = self.bn2(out) 

92 

93 if self.downsample is not None: 

94 identity = self.downsample(x) 

95 

96 out += identity 

97 out = self.relu(out) 

98 

99 return out 

100 

101 

102class ConvRecurrantClassifier(nn.Module): 

103 def __init__( 

104 self, 

105 num_classes, 

106 embedding_dim: int = 8, 

107 filters: int = 256, 

108 cnn_layers: int = 6, 

109 kernel_size_cnn: int = 9, 

110 lstm_dims: int = 256, 

111 final_layer_dims: int = 0, # If this is zero then it isn't used. 

112 dropout: float = 0.5, 

113 kernel_size_maxpool: int = 2, 

114 residual_blocks: bool = False, 

115 final_bias: bool = True, 

116 multi_kernel_sizes: bool = True, 

117 ): 

118 super().__init__() 

119 

120 num_embeddings = 5 # i.e. the size of the vocab which is N, A, C, G, T 

121 

122 self.num_classes = num_classes 

123 self.num_embeddings = num_embeddings 

124 

125 ######################## 

126 ## Embedding 

127 ######################## 

128 self.embedding_dim = embedding_dim 

129 self.embed = nn.Embedding( 

130 num_embeddings=num_embeddings, 

131 embedding_dim=embedding_dim, 

132 ) 

133 self.dropout = nn.Dropout(dropout) 

134 

135 ######################## 

136 ## Convolutional Layer 

137 ######################## 

138 

139 self.multi_kernel_sizes = multi_kernel_sizes 

140 if multi_kernel_sizes: 

141 kernel_size = 5 

142 convolutions = [] 

143 for _ in range(cnn_layers): 

144 convolutions.append( 

145 nn.Conv1d(in_channels=embedding_dim, out_channels=filters, kernel_size=kernel_size, padding='same') 

146 ) 

147 kernel_size += 2 

148 

149 self.convolutions = nn.ModuleList(convolutions) 

150 self.pool = nn.MaxPool1d(kernel_size=kernel_size_maxpool) 

151 current_dims = filters * cnn_layers 

152 else: 

153 self.filters = filters 

154 self.residual_blocks = residual_blocks 

155 self.intermediate_filters = 128 

156 if residual_blocks: 

157 self.cnn_layers = nn.Sequential( 

158 ResidualBlock1D(embedding_dim, embedding_dim), 

159 ResidualBlock1D(embedding_dim, self.intermediate_filters, 2), 

160 ResidualBlock1D(self.intermediate_filters, self.intermediate_filters), 

161 ResidualBlock1D(self.intermediate_filters, filters, 2), 

162 ResidualBlock1D(filters, filters), 

163 ) 

164 else: 

165 self.kernel_size_cnn = kernel_size_cnn 

166 self.cnn_layers = nn.Sequential( 

167 nn.Conv1d(in_channels=embedding_dim, out_channels=filters, kernel_size=kernel_size_cnn), 

168 nn.MaxPool1d(kernel_size=kernel_size_maxpool), 

169 ) 

170 current_dims = filters 

171 

172 ######################## 

173 ## Recurrent Layer 

174 ######################## 

175 self.lstm_dims = lstm_dims 

176 if lstm_dims: 

177 self.bi_lstm = nn.LSTM( 

178 input_size=current_dims, # Is this dimension? - this should receive output from maxpool 

179 hidden_size=lstm_dims, 

180 bidirectional=True, 

181 bias=True, 

182 batch_first=True, 

183 dropout=dropout, 

184 ) 

185 current_dims = lstm_dims * 2 

186 

187 if final_layer_dims: 

188 self.fc1 = nn.Linear( 

189 in_features=current_dims, 

190 out_features=final_layer_dims, 

191 ) 

192 current_dims = final_layer_dims 

193 

194 ################################# 

195 ## Linear Layer(s) to Predictions 

196 ################################# 

197 self.final_layer_dims = final_layer_dims 

198 self.logits = nn.Linear( 

199 in_features=current_dims, 

200 out_features=self.num_classes, 

201 bias=final_bias, 

202 ) 

203 

204 def forward(self, x): 

205 ######################## 

206 ## Embedding 

207 ######################## 

208 # Cast as pytorch tensor 

209 # x = Tensor(x) 

210 

211 # Convert to int because it may be simply a byte 

212 x = x.int() 

213 x = self.embed(x) 

214 

215 ######################## 

216 ## Convolutional Layer 

217 ######################## 

218 # Transpose seq_len with embedding dims to suit convention of pytorch CNNs (batch_size, input_size, seq_len) 

219 x = x.transpose(1, 2) 

220 

221 if self.multi_kernel_sizes: 

222 conv_results = [conv(x) for conv in self.convolutions] 

223 x = torch.cat(conv_results, dim=-2) 

224 

225 x = self.pool(x) 

226 else: 

227 x = self.cnn_layers(x) 

228 

229 # Current shape: batch, filters, seq_len 

230 # With batch_first=True, LSTM expects shape: batch, seq, feature 

231 x = x.transpose(2, 1) 

232 

233 ######################## 

234 ## Recurrent Layer 

235 ######################## 

236 

237 # BiLSTM 

238 if self.lstm_dims: 

239 output, (h_n, c_n) = self.bi_lstm(x) 

240 # h_n of shape (num_layers * num_directions, batch, hidden_size) 

241 # We are using a single layer with 2 directions so the two output vectors are 

242 # [0,:,:] and [1,:,:] 

243 # [0,:,:] -> considers the first index from the first dimension 

244 x = torch.cat((h_n[0, :, :], h_n[1, :, :]), dim=-1) 

245 else: 

246 # if there is no recurrent layer then simply sum over sequence dimension 

247 x = torch.sum(x, dim=1) 

248 

249 ################################# 

250 ## Linear Layer(s) to Predictions 

251 ################################# 

252 # Ignore if the final_layer_dims is empty 

253 if self.final_layer_dims: 

254 x = F.relu(self.fc1(x)) 

255 # Get logits. The cross-entropy loss optimisation function just takes in the logits and automatically does a softmax 

256 out = self.logits(x) 

257 

258 return out 

259 

260 

261class ConvClassifier(nn.Module): 

262 def __init__( 

263 self, 

264 embedding_dim=8, 

265 cnn_layers=6, 

266 num_classes=5, 

267 cnn_dims_start=64, 

268 kernel_size_maxpool=2, 

269 num_embeddings=5, # i.e. the size of the vocab which is N, A, C, G, T 

270 kernel_size=3, 

271 factor=2, 

272 padding="same", 

273 padding_mode="zeros", 

274 dropout=0.5, 

275 final_bias=True, 

276 lstm_dims: int = 0, 

277 penultimate_dims: int = 1028, 

278 include_length: bool = False, 

279 length_scaling:float = 3_000.0, 

280 transformer_heads: int = 8, 

281 transformer_layers: int = 6, 

282 ): 

283 super().__init__() 

284 

285 self.embedding_dim = embedding_dim 

286 self.cnn_layers = cnn_layers 

287 self.num_classes = num_classes 

288 self.kernel_size_maxpool = kernel_size_maxpool 

289 

290 self.num_embeddings = num_embeddings 

291 self.kernel_size = kernel_size 

292 self.factor = factor 

293 self.dropout = dropout 

294 self.include_length = include_length 

295 self.length_scaling = length_scaling 

296 self.transformer_layers = transformer_layers 

297 self.transformer_heads = transformer_heads 

298 

299 self.embedding = nn.Embedding( 

300 num_embeddings=num_embeddings, 

301 embedding_dim=embedding_dim, 

302 ) 

303 

304 in_channels = embedding_dim 

305 out_channels = cnn_dims_start 

306 conv_layers = [] 

307 for layer_index in range(cnn_layers): 

308 conv_layers += [ 

309 nn.Conv1d( 

310 in_channels=in_channels, 

311 out_channels=out_channels, 

312 kernel_size=kernel_size, 

313 padding=padding, 

314 padding_mode=padding_mode, 

315 ), 

316 nn.ReLU(), 

317 nn.Dropout(dropout), 

318 nn.MaxPool1d(kernel_size_maxpool), 

319 # nn.Conv1d( 

320 # in_channels=out_channels, 

321 # out_channels=out_channels, 

322 # kernel_size=kernel_size, 

323 # stride=kernel_size_maxpool, 

324 # ), 

325 ] 

326 in_channels = out_channels 

327 out_channels = int(out_channels * factor) 

328 

329 self.conv = nn.Sequential(*conv_layers) 

330 

331 if self.transformer_layers: 

332 self.positional_encoding = PositionalEncoding(d_model=in_channels) 

333 encoder_layer = nn.TransformerEncoderLayer(d_model=in_channels, nhead=self.transformer_heads, batch_first=True) 

334 self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=self.transformer_layers) 

335 else: 

336 self.transformer_encoder = None 

337 

338 self.lstm_dims = lstm_dims 

339 if lstm_dims: 

340 self.bi_lstm = nn.LSTM( 

341 input_size=in_channels, # Is this dimension? - this should receive output from maxpool 

342 hidden_size=lstm_dims, 

343 bidirectional=True, 

344 bias=True, 

345 batch_first=True, 

346 dropout=dropout, 

347 ) 

348 current_dims = lstm_dims * 2 

349 else: 

350 current_dims = in_channels 

351 

352 self.average_pool = nn.AdaptiveAvgPool1d(1) 

353 

354 current_dims += int(include_length) 

355 self.final = nn.Sequential( 

356 # nn.Linear(in_features=current_dims, out_features=current_dims, bias=True), 

357 # nn.ReLU(), 

358 nn.Linear(in_features=current_dims, out_features=penultimate_dims, bias=True), 

359 nn.ReLU(), 

360 nn.Linear(in_features=penultimate_dims, out_features=num_classes, bias=final_bias), 

361 ) 

362 

363 def forward(self, x): 

364 # Convert to int because it may be simply a byte 

365 x = x.int() 

366 length = x.shape[-1] 

367 x = self.embedding(x) 

368 

369 # Transpose seq_len with embedding dims to suit convention of pytorch CNNs (batch_size, input_size, seq_len) 

370 x = x.transpose(1, 2) 

371 x = self.conv(x) 

372 

373 if hasattr(self, 'transformer_encoder') and self.transformer_encoder: 

374 x = x.transpose(2, 1) 

375 x = self.positional_encoding(x) 

376 x = self.transformer_encoder(x) 

377 x = x.transpose(1, 2) 

378 

379 if self.lstm_dims: 

380 x = x.transpose(2, 1) 

381 output, (h_n, c_n) = self.bi_lstm(x) 

382 # h_n of shape (num_layers * num_directions, batch, hidden_size) 

383 # We are using a single layer with 2 directions so the two output vectors are 

384 # [0,:,:] and [1,:,:] 

385 # [0,:,:] -> considers the first index from the first dimension 

386 x = torch.cat((h_n[0, :, :], h_n[1, :, :]), dim=-1) 

387 elif hasattr(x, 'average_pool'): 

388 x = self.average_pool(x) 

389 x = torch.flatten(x, 1) 

390 else: 

391 x = torch.mean(x, axis=-1) 

392 

393 if getattr(self, 'include_length', False): 

394 length_tensor = torch.full( (x.shape[0], 1), length/self.length_scaling, device=x.device ) 

395 x = torch.cat([x, length_tensor], dim=1) 

396 

397 predictions = self.final(x) 

398 

399 return predictions 

400 

401 def new_final(self, output_size): 

402 final_in_features = list(self.final.modules())[1].in_features 

403 

404 self.final = nn.Sequential( 

405 nn.Linear(in_features=final_in_features, out_features=final_in_features, bias=True), 

406 nn.ReLU(), 

407 nn.Linear(in_features=final_in_features, out_features=output_size, bias=final_bias), 

408 ) 

409 

410 

411class SequentialDebug(nn.Sequential): 

412 def forward(self, input): 

413 macs_cummulative = 0 

414 from thop import profile 

415 

416 console.print(f"Input shape {input.shape}") 

417 for module in self: 

418 console.print(f"Module: {module} ({type(module)})") 

419 macs, _ = profile(module, inputs=(input, )) 

420 macs_cummulative += int(macs) 

421 console.print(f"MACs: {int(macs)} (cummulative {macs_cummulative})") 

422 

423 input = module(input) 

424 console.print(f"Output shape: {input.shape}") 

425 

426 return input 

427 

428 

429 

430class ConvProcessor(nn.Sequential): 

431 def __init__( 

432 self, 

433 in_channels=8, 

434 cnn_layers=6, 

435 cnn_dims_start=64, 

436 kernel_size_maxpool=2, 

437 kernel_size=3, 

438 factor=2, 

439 dropout=0.5, 

440 padding="same", 

441 padding_mode="zeros", 

442 ): 

443 out_channels = cnn_dims_start 

444 conv_layers = [] 

445 for layer_index in range(cnn_layers): 

446 conv_layers += [ 

447 nn.Conv1d( 

448 in_channels=in_channels, 

449 out_channels=out_channels, 

450 kernel_size=kernel_size, 

451 padding=padding, 

452 padding_mode=padding_mode, 

453 ), 

454 nn.ReLU(), 

455 nn.Dropout(dropout), 

456 nn.MaxPool1d(kernel_size_maxpool), 

457 # nn.Conv1d( 

458 # in_channels=out_channels, 

459 # out_channels=out_channels, 

460 # kernel_size=kernel_size, 

461 # stride=kernel_size_maxpool, 

462 # ), 

463 ] 

464 in_channels = out_channels 

465 out_channels = int(out_channels * factor) 

466 

467 super().__init__(*conv_layers) 

468 

469 

470def calc_cnn_dims_start( 

471 macc, 

472 seq_len:int, 

473 embedding_dim:int, 

474 cnn_layers:int, 

475 kernel_size:int, 

476 factor:float, 

477 penultimate_dims: int, 

478 num_classes: int, 

479): 

480 """ 

481 Solving equation M = s k e c + \sum_{l=1}^{L-1} \frac{s}{2^{l} } k c^2 f^{2l-1} + c f^{L-1} p + p o 

482 for c. 

483 

484 Args: 

485 macc_per_base (int): the number of multiply-accumulate operations per base pair in the sequence. 

486 embedding_dim (int): The size of the embedding. 

487 cnn_layers (int): The number of CNN layers. 

488 kernel_size (int): The size of the kernel in the CNN 

489 factor (float): The multiplying factor for the CNN output layers. 

490 """ 

491 b = kernel_size * embedding_dim * seq_len + factor ** (cnn_layers-1) * penultimate_dims 

492 c = penultimate_dims * num_classes - macc 

493 

494 if cnn_layers == 1: 

495 cnn_dims_start = -c/b 

496 else: 

497 a = 0.0 

498 for layer_index in range(1, cnn_layers): 

499 a += seq_len * kernel_size * (0.5**layer_index) * (factor**(2 * layer_index - 1)) 

500 

501 cnn_dims_start = (-b + np.sqrt(b**2 - 4*a*c))/(2 * a) 

502 

503 return int(cnn_dims_start + 0.5)