Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1from fastcore.transform import DisplayedTransform
2from fastai.data.block import TransformBlock
3import hdf5storage
4import torch
5import numpy as np
6from pathlib import Path
7from fastai.vision.data import TensorImage
8from skimage.transform import resize as skresize
9from skimage import io
10from fastai.vision.core import PILImageBW
12DEEPROCK_HDF5_KEY = "temp"
14def read3D(path:Path):
15 path = Path(path)
16 if path.suffix == ".mat":
17 data_dict = hdf5storage.loadmat(str(path))
18 if DEEPROCK_HDF5_KEY not in data_dict:
19 keys_found = ",".join(data_dict.keys())
20 raise Exception(f"expected key {DEEPROCK_HDF5_KEY} not found in '{path}'.\nCheck the following keys: {keys_found}")
22 result = np.float32(data_dict[DEEPROCK_HDF5_KEY]/255.0)
23 else:
24 result = np.float32(io.imread(path))
26 return result
29def write3D(path:Path, data):
30 path = Path(path)
32 if path.suffix == ".mat":
33 hdf5storage.savemat(str(path), {DEEPROCK_HDF5_KEY:data*255.0}, format='7.3', oned_as='column', store_python_metadata=True)
34 else:
35 io.imsave(path, data)
38def unsqueeze(inputs):
39 """Adds a dimension for the single channel."""
40 return inputs.unsqueeze(dim=1)
43def ImageBlock3D():
44 return TransformBlock(
45 type_tfms=read3D,
46 batch_tfms=unsqueeze,
47 )
50class InterpolateTransform(DisplayedTransform):
51 def __init__(self, width=None, *, depth=None, height=None, dim=2):
52 self.width = width
53 assert width != None
55 self.height = height or width
56 self.depth = depth or width
58 self.dim = dim
59 if dim == 3:
60 self.shape = (depth, height, width)
61 elif dim == 2:
62 self.shape = (height, width)
63 else:
64 raise ValueError("dim must be 2 or 3")
66 def encodes(self, data):
67 if len(data.shape) == self.dim + 1:
68 data = data.squeeze(0)
70 result = skresize(data, self.shape, order=3)
71 assert result.shape == self.shape
72 return np.expand_dims(result, 0)
75class RescaleImage(DisplayedTransform):
76 order = 20 #Need to run after IntToFloatTensor
78 def encodes(self, item):
79 if not isinstance(item, torch.Tensor):
80 item = torch.tensor(item)
82 return item.float()*2.0 - 1.0
85class RescaleImageMinMax(DisplayedTransform):
86 def __init__(self, rescaled_min=-0.95, rescaled_max=0.95):
87 self.extrema = []
88 self.rescaled_min = rescaled_min
89 self.rescaled_max = rescaled_max
90 self.factor = (self.rescaled_max - self.rescaled_min)
92 def encodes(self, item):
95 if isinstance(item, PILImageBW):
96 item = np.expand_dims(np.asarray(item), 0)
98 if not isinstance(item, torch.Tensor):
99 item = torch.tensor(item)
102 min, max = item.min(), item.max()
103 self.extrema.append( (min,max) )
104 transformed_item = (item.float() - min) / (max-min) * self.factor + self.rescaled_min
105 return transformed_item
107 def decodes(self, item, min, max):
108 return (item - self.rescaled_min)/self.factor * (max-min) + min
111class CropTransform(DisplayedTransform):
112 def __init__(self, start_x:int=None, end_x:int=None, start_y:int=None, end_y:int=None, start_z:int=None, end_z:int=None ):
113 self.start_x = start_x or None
114 self.end_x = end_x or None
115 self.start_y = start_y or None
116 self.end_y = end_y or None
117 self.start_z = start_z or None
118 self.end_z = end_z or None
120 def encodes(self, data):
121 if isinstance(data, PILImageBW):
122 data = np.expand_dims(np.asarray(data), 0)
124 if len(data.shape) == 3:
125 return data[self.start_z:self.end_z,self.start_y:self.end_y,self.start_x:self.end_x]
126 return data[self.start_y:self.end_y,self.start_x:self.end_x]