Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1from dataclasses import dataclass
2import pandas as pd
3import random
4import torch
5from torch import tensor
6import torch.nn as nn
8from fastcore.transform import Transform
9import numpy as np
10from Bio.SeqRecord import SeqRecord
11from scipy.stats import nbinom
13from .tensor import TensorDNA
16class SplitTransform(Transform):
17 """
18 Only performs the transform if the split index matches `only_split_index`
19 """
20 def __init__(self, only_split_index=None, **kwargs):
21 super().__init__(**kwargs)
22 self.current_split_idx = None
23 self.only_split_index = only_split_index
25 def __call__(self,
26 b,
27 split_idx:int=None, # Index of the train/valid dataset
28 **kwargs
29 ):
30 if self.only_split_index == split_idx or self.only_split_index is None:
31 return super().__call__(b, split_idx=split_idx, **kwargs)
32 return b
35def slice_tensor(tensor, size, start_index=None):
36 original_length = tensor.shape[0]
37 if start_index is None:
38 if original_length <= size:
39 start_index = 0
40 else:
41 start_index = random.randrange(0, original_length - size)
42 end_index = start_index + size
43 if end_index > original_length:
44 sliced = tensor[start_index:]
45 sliced = nn.ConstantPad1d((0, end_index - original_length), 0)(sliced)
46 else:
47 sliced = tensor[start_index:end_index]
48 return sliced
51class SliceTransform(Transform):
52 def __init__(self, size):
53 self.size = size
55 def encodes(self, tensor: TensorDNA):
56 return slice_tensor(tensor, self.size)
59VOCAB = "NACGT"
60CHAR_TO_INT = dict(zip(VOCAB, range(len(VOCAB))))
63def char_to_int(c):
64 return CHAR_TO_INT.get(c, 0)
67class CharsToTensorDNA(Transform):
68 def encodes(self, seq: list):
69 return TensorDNA([char_to_int(c) for c in seq])
71 def encodes(self, seq: SeqRecord):
72 seq_as_numpy = np.array(seq, "c")
73 seq_as_numpy = seq_as_numpy.view(np.uint8)
74 # Ignore any characters in sequence which are below an ascii value of 'A' i.e. 65
75 seq_as_numpy = seq_as_numpy[seq_as_numpy >= ord("A")]
76 for character, value in CHAR_TO_INT.items():
77 seq_as_numpy[seq_as_numpy == ord(character)] = value
78 seq_as_numpy = seq_as_numpy[seq_as_numpy < len(CHAR_TO_INT)]
80 return TensorDNA(seq_as_numpy)
83class RowToTensorDNA(Transform):
84 def __init__(self, categories, **kwargs):
85 super().__init__(**kwargs)
86 self.category_dict = {category.name: category for category in categories}
88 def encodes(self, row: pd.Series):
89 # print('row', row)
90 # print('type sequence', type(row['sequence']))
91 # print(' sequence', row['sequence'])
92 # import pdb; pdb.set_trace()
93 if 'sequence' in row:
94 return row['sequence'] # hack
95 return TensorDNA(self.category_dict[row['category']].get_seq(row["accession"]))
98class RandomSliceBatch(SplitTransform):
99 rand_generator = None
101 def __init__(self, rand_generator=None, distribution=None, minimum: int = 150, maximum: int = 3_000, **kwargs):
102 super().__init__(**kwargs)
104 self.rand_generator = rand_generator or self.default_rand_generator
105 if distribution is None:
106 from scipy.stats import skewnorm
108 distribution = skewnorm(5, loc=600, scale=1000)
109 self.distribution = distribution
110 self.minimum = minimum
111 self.maximum = maximum
113 def default_rand_generator(self):
114 # return random.randint(self.minimum, self.maximum)
116 seq_len = int(self.distribution.rvs())
117 seq_len = max(self.minimum, seq_len)
118 seq_len = min(self.maximum, seq_len)
119 return seq_len
121 def encodes(self, batch):
122 seq_len = self.rand_generator()
123 # seq_len = 150 # hack
125 def slice(tensor):
126 return (slice_tensor(tensor[0], seq_len),) + tensor[1:]
128 return list(map(slice, batch))
131class DeterministicSliceBatch(SplitTransform):
132 def __init__(self, seq_length, **kwargs):
133 super().__init__(**kwargs)
135 self.seq_length = seq_length
137 def slice(self, tensor):
138 return (slice_tensor(tensor[0], self.seq_length, start_index=0),) + tensor[1:]
140 def encodes(self, batch):
141 return list(map(self.slice, batch))
144class ShortRandomSliceBatch(RandomSliceBatch):
145 def __init__(self, rand_generator=None, distribution=None, minimum: int = 80, maximum: int = 150):
146 if distribution is None:
147 from scipy.stats import uniform
149 distribution = uniform(loc=minimum, scale=maximum - minimum)
150 super().__init__(rand_generator=rand_generator, distribution=distribution, minimum=minimum, maximum=maximum)
153class PadBatch(Transform):
154 def encodes(self, batch):
155 max_len = 0
156 for item in batch:
157 max_len = max(item[0].shape[0], max_len)
159 def pad(tensor):
160 return (slice_tensor(tensor[0], max_len),) + tensor[1:]
162 return list(map(pad, batch))
165class PadBatchX(Transform):
166 def encodes(self, batch):
167 max_len = 0
168 for item in batch:
169 max_len = max(item.shape[0], max_len)
171 def pad(tensor):
172 return slice_tensor(tensor, max_len).unsqueeze(dim=0)
174 return torch.cat(list(map(pad, batch))),
177class DeformBatch(SplitTransform):
178 def __init__(self, deform_lambda=4, **kwargs):
179 super().__init__(**kwargs)
180 self.deform_lambda = deform_lambda
181 self.distribution = torch.distributions.exponential.Exponential(deform_lambda)
182 self.states = len(VOCAB)
184 # def encodes_as_tensor(self, batch):
185 # print(type(batch), len(batch), type(batch[0]), len(batch[0]))
186 # assert 0
187 # times = self.distribution.sample()
188 # alt_states = self.states-1
189 # probability_same = 1.0/self.states + alt_states/self.states * torch.exp(-times)
190 # probability_same = probability_same.unsqueeze(1)
191 # weights = torch.cat( [probability_same, (1-probability_same).repeat(1,alt_states)/alt_states], dim=1)
192 # batch += torch.multinomial(weights, batch.shape[1], replacement=True)
193 # batch = batch % self.states
195 # print(weights)
196 # return batch
198 def deform(self, tensor):
199 times = self.distribution.sample()
200 alt_states = self.states-1
201 probability_same = 1.0/self.states + alt_states/self.states * torch.exp(-times)
202 # weights = torch.cat( [probability_same, (1-probability_same).repeat(1,alt_states)/alt_states], dim=1)
203 weights = torch.as_tensor([probability_same] + alt_states * [(1-probability_same)/alt_states])
204 x = tensor[0] + torch.multinomial(weights, len(tensor[0]), replacement=True)
205 x = x % self.states
207 return (x,) + tensor[1:]
209 def encodes(self, batch):
210 batch = list(map(self.deform, batch))
211 return batch