Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1import re 

2import sys 

3import yaml 

4import importlib 

5import pytest 

6import torch 

7from typing import get_type_hints 

8from click.testing import CliRunner 

9from pathlib import Path 

10import difflib 

11from torch import nn 

12from collections import OrderedDict 

13from fastai.data.core import DataLoaders 

14from fastai.learner import Learner 

15from rich.console import Console 

16from fastai.torch_core import TensorBase 

17 

18from .apps import TorchApp 

19 

20console = Console() 

21 

22###################################################################### 

23## pytest fixtures 

24###################################################################### 

25 

26 

27@pytest.fixture 

28def interactive(request): 

29 return request.config.getoption("-s") == "no" 

30 

31 

32###################################################################### 

33## YAML functions from https://stackoverflow.com/a/8641732 

34###################################################################### 

35class quoted(str): 

36 pass 

37 

38 

39def quoted_presenter(dumper, data): 

40 return dumper.represent_scalar("tag:yaml.org,2002:str", data, style='"') 

41 

42 

43yaml.add_representer(quoted, quoted_presenter) 

44 

45 

46class literal(str): 

47 pass 

48 

49 

50def literal_presenter(dumper, data): 

51 return dumper.represent_scalar("tag:yaml.org,2002:str", data, style="|") 

52 

53 

54yaml.add_representer(literal, literal_presenter) 

55 

56 

57def ordered_dict_presenter(dumper, data): 

58 return dumper.represent_dict(data.items()) 

59 

60 

61yaml.add_representer(OrderedDict, ordered_dict_presenter) 

62 

63###################################################################### 

64## TorchApp Testing Utils 

65###################################################################### 

66 

67 

68class TorchAppTestCaseError(Exception): 

69 pass 

70 

71 

72def get_diff(a, b): 

73 a = str(a).splitlines(1) 

74 b = str(b).splitlines(1) 

75 

76 diff = difflib.unified_diff(a, b) 

77 

78 return "\n".join(diff).replace("\n\n", "\n") 

79 

80 

81def clean_output(output): 

82 if isinstance(output, (TensorBase, torch.Tensor)): 

83 output = f"{type(output)} {tuple(output.shape)}" 

84 output = str(output) 

85 output = re.sub(r"0[xX][0-9a-fA-F]+", "<HEX>", output) 

86 return output 

87 

88 

89def strip_whitespace_recursive(obj): 

90 if isinstance(obj, str): 

91 obj = obj.replace("\n", " ").strip() 

92 return re.sub(r"\s+", " ", obj) 

93 if isinstance(obj, dict): 

94 return {k:strip_whitespace_recursive(v) for k,v in obj.items()} 

95 

96 return obj 

97 

98 

99def assert_output(file: Path, interactive: bool, params: dict, output, expected, regenerate: bool = False, threshold:float=0.9): 

100 """ 

101 Tests to see if the output is the same as the expected data and allows for saving a new version of the expected files if needed. 

102 

103 Args: 

104 file (Path): The path to the expected file in yaml format. 

105 interactive (bool): Whether or not to prompt for replacing the expected file. 

106 params (dict): The dictionary of parameters to store in the expected file. 

107 output (str): The string representation of the output from the app. 

108 expected (str): The expected output from the yaml file. 

109 """ 

110 if expected == output: 

111 return 

112 

113 # if expected and output are both strings, check to see if they are equal when normalizing whitespace 

114 expected_cleaned = strip_whitespace_recursive(expected) 

115 output_cleaned = strip_whitespace_recursive(output) 

116 

117 if expected_cleaned == output_cleaned: 

118 return 

119 

120 if isinstance(expected, dict) and isinstance(output, dict): 

121 keys = set(expected.keys()) | set(output.keys()) 

122 diff = {} 

123 for key in keys: 

124 diff[key] = get_diff(expected.get(key, ""), output.get(key, "")) 

125 if diff[key]: 

126 console.print(diff[key]) 

127 else: 

128 diff = get_diff(str(expected), str(output)) 

129 console.print(diff) 

130 

131 if interactive or regenerate: 

132 # If we aren't automatically regenerating the expected files, then prompt the user 

133 if not regenerate: 

134 prompt_response = input( 

135 f"\nExpected file '{file.name}' does not match test output (see diff above).\n" 

136 "Should this file be replaced? (y/N) " 

137 ) 

138 regenerate = prompt_response.lower() == "y" 

139 

140 if regenerate: 

141 with open(file, "w") as f: 

142 output_for_yaml = literal(output) if isinstance(output, str) and "\n" in output else output 

143 # order the params dictionary if necessary 

144 if isinstance(params, dict): 

145 params = OrderedDict(params) 

146 

147 data = OrderedDict(params=params, output=output_for_yaml) 

148 yaml.dump(data, f) 

149 return 

150 

151 raise TorchAppTestCaseError(diff) 

152 

153 

154class TorchAppTestCase: 

155 """Automated tests for TorchApp classes""" 

156 

157 app_class = None 

158 expected_base = None 

159 

160 def get_expected_base(self) -> Path: 

161 if not self.expected_base: 

162 module = importlib.import_module(self.__module__) 

163 self.expected_base = Path(module.__file__).parent / "expected" 

164 

165 self.expected_base = Path(self.expected_base) 

166 return self.expected_base 

167 

168 def test_cli_version(self): 

169 app = self.get_app() 

170 runner = CliRunner() 

171 result = runner.invoke(app.cli(), "--version") 

172 assert result.exit_code == 0 

173 assert re.match(r"^(\d+\.)?(\d+\.)?(\*|\d+)$", result.stdout) 

174 

175 def get_expected_dir(self) -> Path: 

176 """ 

177 Returns the path to the directory where the expected files. 

178 

179 It creates the directory if it doesn't already exist. 

180 """ 

181 expected_dir = self.get_expected_base() / self.__class__.__name__ 

182 expected_dir.mkdir(exist_ok=True, parents=True) 

183 return expected_dir 

184 

185 def get_app(self) -> TorchApp: 

186 """ 

187 Returns an instance of the app for this test case. 

188 

189 It instantiates an object from `app_class`. 

190 Override `app_class` or this method so the correct app is returned from calling this method. 

191 """ 

192 # pdb.set_trace() 

193 assert self.app_class is not None 

194 app = self.app_class() 

195 

196 assert isinstance(app, TorchApp) 

197 return app 

198 

199 def subtest_dir(self, name: str): 

200 directory = self.get_expected_dir() / name 

201 directory.mkdir(exist_ok=True, parents=True) 

202 return directory 

203 

204 def subtest_files(self, name: str): 

205 directory = self.subtest_dir(name) 

206 files = list(directory.glob("*.yaml")) 

207 return files 

208 

209 def subtests(self, app, name: str): 

210 files = self.subtest_files(name) 

211 

212 if len(files) == 0: 

213 pytest.skip( 

214 f"Skipping test for '{name}' because no expected files were found in:\n" f"{self.subtest_dir(name)}." 

215 ) 

216 

217 for file in files: 

218 with open(file) as f: 

219 file_dict = yaml.safe_load(f) or {} 

220 params = file_dict.get("params", {}) 

221 output = file_dict.get("output", "") 

222 

223 yield params, output, file 

224 

225 def test_model(self, interactive: bool): 

226 """ 

227 Tests the method of a TorchApp to create a pytorch model. 

228 

229 The expected output is the string representation of the model created. 

230 

231 Args: 

232 interactive (bool): Whether or not failed tests should prompt the user to regenerate the expected files. 

233 """ 

234 app = self.get_app() 

235 name = sys._getframe().f_code.co_name 

236 method_name = name[5:] if name.startswith("test_") else name 

237 regenerate = False 

238 

239 if interactive: 

240 if not self.subtest_files(name): 

241 prompt_response = input( 

242 f"\nNo expected files for '{name}' when testing '{app}'.\n" 

243 "Should a default expected file be automatically generated? (y/N) " 

244 ) 

245 if prompt_response.lower() == "y": 

246 regenerate = True 

247 directory = self.subtest_dir(name) 

248 with open(directory / f"{method_name}_default.yaml", "w") as f: 

249 # The output will be autogenerated later 

250 data = OrderedDict(params={}, output="") 

251 yaml.dump(data, f) 

252 

253 for params, expected_output, file in self.subtests(app, name): 

254 model = app.model(**params) 

255 if model is None: 

256 model_summary = "None" 

257 else: 

258 assert isinstance(model, nn.Module) 

259 model_summary = str(model) 

260 

261 assert_output(file, interactive, params, model_summary, expected_output, regenerate=regenerate) 

262 

263 def test_dataloaders(self, interactive: bool): 

264 app = self.get_app() 

265 for params, expected_output, file in self.subtests(app, sys._getframe().f_code.co_name): 

266 # Make all paths relative to the result of get_expected_dir() 

267 modified_params = dict(params) 

268 hints = get_type_hints(app.dataloaders) 

269 for key, value in hints.items(): 

270 # if this is a union class, then loop over all options 

271 if not isinstance(value, type) and hasattr(value, "__args__"): # This is the case for unions 

272 values = value.__args__ 

273 else: 

274 values = [value] 

275 

276 for v in values: 

277 if key in params and Path in v.__mro__: 

278 relative_path = params[key] 

279 modified_params[key] = (self.get_expected_dir() / relative_path).resolve() 

280 break 

281 

282 dataloaders = app.dataloaders(**modified_params) 

283 

284 assert isinstance(dataloaders, DataLoaders) 

285 

286 batch = dataloaders.train.one_batch() 

287 dataloaders_summary = OrderedDict( 

288 type=type(dataloaders).__name__, 

289 train_size=len(dataloaders.train), 

290 validation_size=len(dataloaders.valid), 

291 batch_x_type=type(batch[0]).__name__, 

292 batch_y_type=type(batch[1]).__name__, 

293 batch_x_shape=str(batch[0].shape), 

294 batch_y_shape=str(batch[1].shape), 

295 ) 

296 

297 assert_output(file, interactive, params, dataloaders_summary, expected_output) 

298 

299 def perform_subtests(self, interactive: bool, name: str): 

300 """ 

301 Performs a number of subtests for a method on the app. 

302 

303 Args: 

304 interactive (bool): Whether or not the user should be prompted for creating or regenerating expected files. 

305 name (str): The name of the method to be tested with the string `test_` prepended to it. 

306 """ 

307 app = self.get_app() 

308 regenerate = False 

309 method_name = name[5:] if name.startswith("test_") else name 

310 method = getattr(app, method_name) 

311 

312 if interactive: 

313 if not self.subtest_files(name): 

314 prompt_response = input( 

315 f"\nNo expected files for '{name}' when testing '{app}'.\n" 

316 "Should a default expected file be automatically generated? (y/N) " 

317 ) 

318 if prompt_response.lower() == "y": 

319 regenerate = True 

320 directory = self.subtest_dir(name) 

321 with open(directory / f"{method_name}_default.yaml", "w") as f: 

322 # The output will be autogenerated later 

323 data = OrderedDict(params={}, output="") 

324 yaml.dump(data, f) 

325 

326 for params, expected_output, file in self.subtests(app, name): 

327 modified_params = dict(params) 

328 hints = get_type_hints(method) 

329 for key, value in hints.items(): 

330 # if this is a union class, then loop over all options 

331 if not isinstance(value, type) and hasattr(value, "__args__"): # This is the case for unions 

332 values = value.__args__ 

333 else: 

334 values = [value] 

335 

336 for v in values: 

337 if key in params and Path in v.__mro__: 

338 relative_path = params[key] 

339 modified_params[key] = (self.get_expected_dir() / relative_path).resolve() 

340 break 

341 

342 output = clean_output(method(**modified_params)) 

343 assert_output(file, interactive, params, output, expected_output, regenerate=regenerate) 

344 

345 def test_goal(self, interactive: bool): 

346 self.perform_subtests(interactive=interactive, name=sys._getframe().f_code.co_name) 

347 

348 def test_metrics(self, interactive: bool): 

349 self.perform_subtests(interactive=interactive, name=sys._getframe().f_code.co_name) 

350 

351 def test_loss_func(self, interactive: bool): 

352 self.perform_subtests(interactive=interactive, name=sys._getframe().f_code.co_name) 

353 

354 def test_monitor(self, interactive: bool): 

355 self.perform_subtests(interactive=interactive, name=sys._getframe().f_code.co_name) 

356 

357 def test_activation(self, interactive: bool): 

358 self.perform_subtests(interactive=interactive, name=sys._getframe().f_code.co_name) 

359 

360 def test_pretrained_location(self, interactive: bool): 

361 self.perform_subtests(interactive=interactive, name=sys._getframe().f_code.co_name) 

362 

363 def test_one_batch_size(self, interactive: bool): 

364 self.perform_subtests(interactive=interactive, name=sys._getframe().f_code.co_name) 

365 

366 def test_one_batch_output_size(self, interactive: bool): 

367 self.perform_subtests(interactive=interactive, name=sys._getframe().f_code.co_name) 

368 

369 def test_one_batch_loss(self, interactive: bool): 

370 self.perform_subtests(interactive=interactive, name=sys._getframe().f_code.co_name) 

371 

372 def test_bibliography(self, interactive: bool): 

373 self.perform_subtests(interactive=interactive, name=sys._getframe().f_code.co_name) 

374 

375 def test_bibtex(self, interactive: bool): 

376 self.perform_subtests(interactive=interactive, name=sys._getframe().f_code.co_name) 

377 

378 def cli_commands_to_test(self): 

379 return [ 

380 "--help", 

381 "train --help", 

382 "infer --help", 

383 "show-batch --help", 

384 "tune --help", 

385 "bibtex", 

386 "bibliography", 

387 ] 

388 

389 def test_cli(self): 

390 app = self.get_app() 

391 runner = CliRunner() 

392 for command in self.cli_commands_to_test(): 

393 print(command) 

394 result = runner.invoke(app.cli(), command.split()) 

395 assert result.exit_code == 0 

396 assert result.stdout 

397 

398 

399