Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1from dataclasses import dataclass 

2import pandas as pd 

3import random 

4import torch 

5from torch import tensor 

6import torch.nn as nn 

7 

8from fastcore.transform import Transform 

9import numpy as np 

10from Bio.SeqRecord import SeqRecord 

11from scipy.stats import nbinom 

12 

13from .tensor import TensorDNA 

14 

15 

16class SplitTransform(Transform): 

17 """  

18 Only performs the transform if the split index matches `only_split_index` 

19 """ 

20 def __init__(self, only_split_index=None, **kwargs): 

21 super().__init__(**kwargs) 

22 self.current_split_idx = None 

23 self.only_split_index = only_split_index 

24 

25 def __call__(self, 

26 b, 

27 split_idx:int=None, # Index of the train/valid dataset 

28 **kwargs 

29 ): 

30 if self.only_split_index == split_idx or self.only_split_index is None: 

31 return super().__call__(b, split_idx=split_idx, **kwargs) 

32 return b 

33 

34 

35def slice_tensor(tensor, size, start_index=None): 

36 original_length = tensor.shape[0] 

37 if start_index is None: 

38 if original_length <= size: 

39 start_index = 0 

40 else: 

41 start_index = random.randrange(0, original_length - size) 

42 end_index = start_index + size 

43 if end_index > original_length: 

44 sliced = tensor[start_index:] 

45 sliced = nn.ConstantPad1d((0, end_index - original_length), 0)(sliced) 

46 else: 

47 sliced = tensor[start_index:end_index] 

48 return sliced 

49 

50 

51class SliceTransform(Transform): 

52 def __init__(self, size): 

53 self.size = size 

54 

55 def encodes(self, tensor: TensorDNA): 

56 return slice_tensor(tensor, self.size) 

57 

58 

59VOCAB = "NACGT" 

60CHAR_TO_INT = dict(zip(VOCAB, range(len(VOCAB)))) 

61 

62 

63def char_to_int(c): 

64 return CHAR_TO_INT.get(c, 0) 

65 

66 

67class CharsToTensorDNA(Transform): 

68 def encodes(self, seq: list): 

69 return TensorDNA([char_to_int(c) for c in seq]) 

70 

71 def encodes(self, seq: SeqRecord): 

72 seq_as_numpy = np.array(seq, "c") 

73 seq_as_numpy = seq_as_numpy.view(np.uint8) 

74 # Ignore any characters in sequence which are below an ascii value of 'A' i.e. 65 

75 seq_as_numpy = seq_as_numpy[seq_as_numpy >= ord("A")] 

76 for character, value in CHAR_TO_INT.items(): 

77 seq_as_numpy[seq_as_numpy == ord(character)] = value 

78 seq_as_numpy = seq_as_numpy[seq_as_numpy < len(CHAR_TO_INT)] 

79 

80 return TensorDNA(seq_as_numpy) 

81 

82 

83class RowToTensorDNA(Transform): 

84 def __init__(self, categories, **kwargs): 

85 super().__init__(**kwargs) 

86 self.category_dict = {category.name: category for category in categories} 

87 

88 def encodes(self, row: pd.Series): 

89 # print('row', row) 

90 # print('type sequence', type(row['sequence'])) 

91 # print(' sequence', row['sequence']) 

92 # import pdb; pdb.set_trace() 

93 if 'sequence' in row: 

94 return row['sequence'] # hack 

95 return TensorDNA(self.category_dict[row['category']].get_seq(row["accession"])) 

96 

97 

98class RandomSliceBatch(SplitTransform): 

99 rand_generator = None 

100 

101 def __init__(self, rand_generator=None, distribution=None, minimum: int = 150, maximum: int = 3_000, **kwargs): 

102 super().__init__(**kwargs) 

103 

104 self.rand_generator = rand_generator or self.default_rand_generator 

105 if distribution is None: 

106 from scipy.stats import skewnorm 

107 

108 distribution = skewnorm(5, loc=600, scale=1000) 

109 self.distribution = distribution 

110 self.minimum = minimum 

111 self.maximum = maximum 

112 

113 def default_rand_generator(self): 

114 # return random.randint(self.minimum, self.maximum) 

115 

116 seq_len = int(self.distribution.rvs()) 

117 seq_len = max(self.minimum, seq_len) 

118 seq_len = min(self.maximum, seq_len) 

119 return seq_len 

120 

121 def encodes(self, batch): 

122 seq_len = self.rand_generator() 

123 # seq_len = 150 # hack 

124 

125 def slice(tensor): 

126 return (slice_tensor(tensor[0], seq_len),) + tensor[1:] 

127 

128 return list(map(slice, batch)) 

129 

130 

131class DeterministicSliceBatch(SplitTransform): 

132 def __init__(self, seq_length, **kwargs): 

133 super().__init__(**kwargs) 

134 

135 self.seq_length = seq_length 

136 

137 def slice(self, tensor): 

138 return (slice_tensor(tensor[0], self.seq_length, start_index=0),) + tensor[1:] 

139 

140 def encodes(self, batch): 

141 return list(map(self.slice, batch)) 

142 

143 

144class ShortRandomSliceBatch(RandomSliceBatch): 

145 def __init__(self, rand_generator=None, distribution=None, minimum: int = 80, maximum: int = 150): 

146 if distribution is None: 

147 from scipy.stats import uniform 

148 

149 distribution = uniform(loc=minimum, scale=maximum - minimum) 

150 super().__init__(rand_generator=rand_generator, distribution=distribution, minimum=minimum, maximum=maximum) 

151 

152 

153class PadBatch(Transform): 

154 def encodes(self, batch): 

155 max_len = 0 

156 for item in batch: 

157 max_len = max(item[0].shape[0], max_len) 

158 

159 def pad(tensor): 

160 return (slice_tensor(tensor[0], max_len),) + tensor[1:] 

161 

162 return list(map(pad, batch)) 

163 

164 

165class PadBatchX(Transform): 

166 def encodes(self, batch): 

167 max_len = 0 

168 for item in batch: 

169 max_len = max(item.shape[0], max_len) 

170 

171 def pad(tensor): 

172 return slice_tensor(tensor, max_len).unsqueeze(dim=0) 

173 

174 return torch.cat(list(map(pad, batch))), 

175 

176 

177class DeformBatch(SplitTransform): 

178 def __init__(self, deform_lambda=4, **kwargs): 

179 super().__init__(**kwargs) 

180 self.deform_lambda = deform_lambda 

181 self.distribution = torch.distributions.exponential.Exponential(deform_lambda) 

182 self.states = len(VOCAB) 

183 

184 # def encodes_as_tensor(self, batch): 

185 # print(type(batch), len(batch), type(batch[0]), len(batch[0])) 

186 # assert 0 

187 # times = self.distribution.sample() 

188 # alt_states = self.states-1 

189 # probability_same = 1.0/self.states + alt_states/self.states * torch.exp(-times) 

190 # probability_same = probability_same.unsqueeze(1) 

191 # weights = torch.cat( [probability_same, (1-probability_same).repeat(1,alt_states)/alt_states], dim=1) 

192 # batch += torch.multinomial(weights, batch.shape[1], replacement=True) 

193 # batch = batch % self.states 

194 

195 # print(weights) 

196 # return batch 

197 

198 def deform(self, tensor): 

199 times = self.distribution.sample() 

200 alt_states = self.states-1 

201 probability_same = 1.0/self.states + alt_states/self.states * torch.exp(-times) 

202 # weights = torch.cat( [probability_same, (1-probability_same).repeat(1,alt_states)/alt_states], dim=1) 

203 weights = torch.as_tensor([probability_same] + alt_states * [(1-probability_same)/alt_states]) 

204 x = tensor[0] + torch.multinomial(weights, len(tensor[0]), replacement=True) 

205 x = x % self.states 

206 

207 return (x,) + tensor[1:] 

208 

209 def encodes(self, batch): 

210 batch = list(map(self.deform, batch)) 

211 return batch