Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1from fastcore.transform import DisplayedTransform 

2from fastai.data.block import TransformBlock 

3import hdf5storage 

4import torch 

5import numpy as np 

6from pathlib import Path 

7from fastai.vision.data import TensorImage 

8from skimage.transform import resize as skresize 

9from skimage import io 

10from fastai.vision.core import PILImageBW 

11 

12DEEPROCK_HDF5_KEY = "temp" 

13 

14def read3D(path:Path): 

15 path = Path(path) 

16 if path.suffix == ".mat": 

17 data_dict = hdf5storage.loadmat(str(path)) 

18 if DEEPROCK_HDF5_KEY not in data_dict: 

19 keys_found = ",".join(data_dict.keys()) 

20 raise Exception(f"expected key {DEEPROCK_HDF5_KEY} not found in '{path}'.\nCheck the following keys: {keys_found}") 

21 

22 result = np.float32(data_dict[DEEPROCK_HDF5_KEY]/255.0) 

23 else: 

24 result = np.float32(io.imread(path)) 

25 

26 return result 

27 

28 

29def write3D(path:Path, data): 

30 path = Path(path) 

31 

32 if path.suffix == ".mat": 

33 hdf5storage.savemat(str(path), {DEEPROCK_HDF5_KEY:data*255.0}, format='7.3', oned_as='column', store_python_metadata=True) 

34 else: 

35 io.imsave(path, data) 

36 

37 

38def unsqueeze(inputs): 

39 """Adds a dimension for the single channel.""" 

40 return inputs.unsqueeze(dim=1) 

41 

42 

43def ImageBlock3D(): 

44 return TransformBlock( 

45 type_tfms=read3D, 

46 batch_tfms=unsqueeze, 

47 ) 

48 

49 

50class InterpolateTransform(DisplayedTransform): 

51 def __init__(self, width=None, *, depth=None, height=None, dim=2): 

52 self.width = width 

53 assert width != None 

54 

55 self.height = height or width 

56 self.depth = depth or width 

57 

58 self.dim = dim 

59 if dim == 3: 

60 self.shape = (depth, height, width) 

61 elif dim == 2: 

62 self.shape = (height, width) 

63 else: 

64 raise ValueError("dim must be 2 or 3") 

65 

66 def encodes(self, data): 

67 if len(data.shape) == self.dim + 1: 

68 data = data.squeeze(0) 

69 

70 result = skresize(data, self.shape, order=3) 

71 assert result.shape == self.shape 

72 return np.expand_dims(result, 0) 

73 

74 

75class RescaleImage(DisplayedTransform): 

76 order = 20 #Need to run after IntToFloatTensor 

77 

78 def encodes(self, item): 

79 if not isinstance(item, torch.Tensor): 

80 item = torch.tensor(item) 

81 

82 return item.float()*2.0 - 1.0 

83 

84 

85class RescaleImageMinMax(DisplayedTransform): 

86 def __init__(self, rescaled_min=-0.95, rescaled_max=0.95): 

87 self.extrema = [] 

88 self.rescaled_min = rescaled_min 

89 self.rescaled_max = rescaled_max 

90 self.factor = (self.rescaled_max - self.rescaled_min) 

91 

92 def encodes(self, item): 

93 

94 

95 if isinstance(item, PILImageBW): 

96 item = np.expand_dims(np.asarray(item), 0) 

97 

98 if not isinstance(item, torch.Tensor): 

99 item = torch.tensor(item) 

100 

101 

102 min, max = item.min(), item.max() 

103 self.extrema.append( (min,max) ) 

104 transformed_item = (item.float() - min) / (max-min) * self.factor + self.rescaled_min 

105 return transformed_item 

106 

107 def decodes(self, item, min, max): 

108 return (item - self.rescaled_min)/self.factor * (max-min) + min 

109 

110 

111class CropTransform(DisplayedTransform): 

112 def __init__(self, start_x:int=None, end_x:int=None, start_y:int=None, end_y:int=None, start_z:int=None, end_z:int=None ): 

113 self.start_x = start_x or None 

114 self.end_x = end_x or None 

115 self.start_y = start_y or None 

116 self.end_y = end_y or None 

117 self.start_z = start_z or None 

118 self.end_z = end_z or None 

119 

120 def encodes(self, data): 

121 if isinstance(data, PILImageBW): 

122 data = np.expand_dims(np.asarray(data), 0) 

123 

124 if len(data.shape) == 3: 

125 return data[self.start_z:self.end_z,self.start_y:self.end_y,self.start_x:self.end_x] 

126 return data[self.start_y:self.end_y,self.start_x:self.end_x] 

127