Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1import re
2import sys
3import yaml
4import importlib
5import pytest
6import torch
7from typing import get_type_hints
8from click.testing import CliRunner
9from pathlib import Path
10import difflib
11from torch import nn
12from collections import OrderedDict
13from fastai.data.core import DataLoaders
14from fastai.learner import Learner
15from rich.console import Console
16from fastai.torch_core import TensorBase
18from .apps import TorchApp
20console = Console()
22######################################################################
23## pytest fixtures
24######################################################################
27@pytest.fixture
28def interactive(request):
29 return request.config.getoption("-s") == "no"
32######################################################################
33## YAML functions from https://stackoverflow.com/a/8641732
34######################################################################
35class quoted(str):
36 pass
39def quoted_presenter(dumper, data):
40 return dumper.represent_scalar("tag:yaml.org,2002:str", data, style='"')
43yaml.add_representer(quoted, quoted_presenter)
46class literal(str):
47 pass
50def literal_presenter(dumper, data):
51 return dumper.represent_scalar("tag:yaml.org,2002:str", data, style="|")
54yaml.add_representer(literal, literal_presenter)
57def ordered_dict_presenter(dumper, data):
58 return dumper.represent_dict(data.items())
61yaml.add_representer(OrderedDict, ordered_dict_presenter)
63######################################################################
64## TorchApp Testing Utils
65######################################################################
68class TorchAppTestCaseError(Exception):
69 pass
72def get_diff(a, b):
73 a = str(a).splitlines(1)
74 b = str(b).splitlines(1)
76 diff = difflib.unified_diff(a, b)
78 return "\n".join(diff).replace("\n\n", "\n")
81def clean_output(output):
82 if isinstance(output, (TensorBase, torch.Tensor)):
83 output = f"{type(output)} {tuple(output.shape)}"
84 output = str(output)
85 output = re.sub(r"0[xX][0-9a-fA-F]+", "<HEX>", output)
86 return output
89def strip_whitespace_recursive(obj):
90 if isinstance(obj, str):
91 obj = obj.replace("\n", " ").strip()
92 return re.sub(r"\s+", " ", obj)
93 if isinstance(obj, dict):
94 return {k:strip_whitespace_recursive(v) for k,v in obj.items()}
96 return obj
99def assert_output(file: Path, interactive: bool, params: dict, output, expected, regenerate: bool = False, threshold:float=0.9):
100 """
101 Tests to see if the output is the same as the expected data and allows for saving a new version of the expected files if needed.
103 Args:
104 file (Path): The path to the expected file in yaml format.
105 interactive (bool): Whether or not to prompt for replacing the expected file.
106 params (dict): The dictionary of parameters to store in the expected file.
107 output (str): The string representation of the output from the app.
108 expected (str): The expected output from the yaml file.
109 """
110 if expected == output:
111 return
113 # if expected and output are both strings, check to see if they are equal when normalizing whitespace
114 expected_cleaned = strip_whitespace_recursive(expected)
115 output_cleaned = strip_whitespace_recursive(output)
117 if expected_cleaned == output_cleaned:
118 return
120 if isinstance(expected, dict) and isinstance(output, dict):
121 keys = set(expected.keys()) | set(output.keys())
122 diff = {}
123 for key in keys:
124 diff[key] = get_diff(expected.get(key, ""), output.get(key, ""))
125 if diff[key]:
126 console.print(diff[key])
127 else:
128 diff = get_diff(str(expected), str(output))
129 console.print(diff)
131 if interactive or regenerate:
132 # If we aren't automatically regenerating the expected files, then prompt the user
133 if not regenerate:
134 prompt_response = input(
135 f"\nExpected file '{file.name}' does not match test output (see diff above).\n"
136 "Should this file be replaced? (y/N) "
137 )
138 regenerate = prompt_response.lower() == "y"
140 if regenerate:
141 with open(file, "w") as f:
142 output_for_yaml = literal(output) if isinstance(output, str) and "\n" in output else output
143 # order the params dictionary if necessary
144 if isinstance(params, dict):
145 params = OrderedDict(params)
147 data = OrderedDict(params=params, output=output_for_yaml)
148 yaml.dump(data, f)
149 return
151 raise TorchAppTestCaseError(diff)
154class TorchAppTestCase:
155 """Automated tests for TorchApp classes"""
157 app_class = None
158 expected_base = None
160 def get_expected_base(self) -> Path:
161 if not self.expected_base:
162 module = importlib.import_module(self.__module__)
163 self.expected_base = Path(module.__file__).parent / "expected"
165 self.expected_base = Path(self.expected_base)
166 return self.expected_base
168 def test_cli_version(self):
169 app = self.get_app()
170 runner = CliRunner()
171 result = runner.invoke(app.cli(), "--version")
172 assert result.exit_code == 0
173 assert re.match(r"^(\d+\.)?(\d+\.)?(\*|\d+)$", result.stdout)
175 def get_expected_dir(self) -> Path:
176 """
177 Returns the path to the directory where the expected files.
179 It creates the directory if it doesn't already exist.
180 """
181 expected_dir = self.get_expected_base() / self.__class__.__name__
182 expected_dir.mkdir(exist_ok=True, parents=True)
183 return expected_dir
185 def get_app(self) -> TorchApp:
186 """
187 Returns an instance of the app for this test case.
189 It instantiates an object from `app_class`.
190 Override `app_class` or this method so the correct app is returned from calling this method.
191 """
192 # pdb.set_trace()
193 assert self.app_class is not None
194 app = self.app_class()
196 assert isinstance(app, TorchApp)
197 return app
199 def subtest_dir(self, name: str):
200 directory = self.get_expected_dir() / name
201 directory.mkdir(exist_ok=True, parents=True)
202 return directory
204 def subtest_files(self, name: str):
205 directory = self.subtest_dir(name)
206 files = list(directory.glob("*.yaml"))
207 return files
209 def subtests(self, app, name: str):
210 files = self.subtest_files(name)
212 if len(files) == 0:
213 pytest.skip(
214 f"Skipping test for '{name}' because no expected files were found in:\n" f"{self.subtest_dir(name)}."
215 )
217 for file in files:
218 with open(file) as f:
219 file_dict = yaml.safe_load(f) or {}
220 params = file_dict.get("params", {})
221 output = file_dict.get("output", "")
223 yield params, output, file
225 def test_model(self, interactive: bool):
226 """
227 Tests the method of a TorchApp to create a pytorch model.
229 The expected output is the string representation of the model created.
231 Args:
232 interactive (bool): Whether or not failed tests should prompt the user to regenerate the expected files.
233 """
234 app = self.get_app()
235 name = sys._getframe().f_code.co_name
236 method_name = name[5:] if name.startswith("test_") else name
237 regenerate = False
239 if interactive:
240 if not self.subtest_files(name):
241 prompt_response = input(
242 f"\nNo expected files for '{name}' when testing '{app}'.\n"
243 "Should a default expected file be automatically generated? (y/N) "
244 )
245 if prompt_response.lower() == "y":
246 regenerate = True
247 directory = self.subtest_dir(name)
248 with open(directory / f"{method_name}_default.yaml", "w") as f:
249 # The output will be autogenerated later
250 data = OrderedDict(params={}, output="")
251 yaml.dump(data, f)
253 for params, expected_output, file in self.subtests(app, name):
254 model = app.model(**params)
255 if model is None:
256 model_summary = "None"
257 else:
258 assert isinstance(model, nn.Module)
259 model_summary = str(model)
261 assert_output(file, interactive, params, model_summary, expected_output, regenerate=regenerate)
263 def test_dataloaders(self, interactive: bool):
264 app = self.get_app()
265 for params, expected_output, file in self.subtests(app, sys._getframe().f_code.co_name):
266 # Make all paths relative to the result of get_expected_dir()
267 modified_params = dict(params)
268 hints = get_type_hints(app.dataloaders)
269 for key, value in hints.items():
270 # if this is a union class, then loop over all options
271 if not isinstance(value, type) and hasattr(value, "__args__"): # This is the case for unions
272 values = value.__args__
273 else:
274 values = [value]
276 for v in values:
277 if key in params and Path in v.__mro__:
278 relative_path = params[key]
279 modified_params[key] = (self.get_expected_dir() / relative_path).resolve()
280 break
282 dataloaders = app.dataloaders(**modified_params)
284 assert isinstance(dataloaders, DataLoaders)
286 batch = dataloaders.train.one_batch()
287 dataloaders_summary = OrderedDict(
288 type=type(dataloaders).__name__,
289 train_size=len(dataloaders.train),
290 validation_size=len(dataloaders.valid),
291 batch_x_type=type(batch[0]).__name__,
292 batch_y_type=type(batch[1]).__name__,
293 batch_x_shape=str(batch[0].shape),
294 batch_y_shape=str(batch[1].shape),
295 )
297 assert_output(file, interactive, params, dataloaders_summary, expected_output)
299 def perform_subtests(self, interactive: bool, name: str):
300 """
301 Performs a number of subtests for a method on the app.
303 Args:
304 interactive (bool): Whether or not the user should be prompted for creating or regenerating expected files.
305 name (str): The name of the method to be tested with the string `test_` prepended to it.
306 """
307 app = self.get_app()
308 regenerate = False
309 method_name = name[5:] if name.startswith("test_") else name
310 method = getattr(app, method_name)
312 if interactive:
313 if not self.subtest_files(name):
314 prompt_response = input(
315 f"\nNo expected files for '{name}' when testing '{app}'.\n"
316 "Should a default expected file be automatically generated? (y/N) "
317 )
318 if prompt_response.lower() == "y":
319 regenerate = True
320 directory = self.subtest_dir(name)
321 with open(directory / f"{method_name}_default.yaml", "w") as f:
322 # The output will be autogenerated later
323 data = OrderedDict(params={}, output="")
324 yaml.dump(data, f)
326 for params, expected_output, file in self.subtests(app, name):
327 modified_params = dict(params)
328 hints = get_type_hints(method)
329 for key, value in hints.items():
330 # if this is a union class, then loop over all options
331 if not isinstance(value, type) and hasattr(value, "__args__"): # This is the case for unions
332 values = value.__args__
333 else:
334 values = [value]
336 for v in values:
337 if key in params and Path in v.__mro__:
338 relative_path = params[key]
339 modified_params[key] = (self.get_expected_dir() / relative_path).resolve()
340 break
342 output = clean_output(method(**modified_params))
343 assert_output(file, interactive, params, output, expected_output, regenerate=regenerate)
345 def test_goal(self, interactive: bool):
346 self.perform_subtests(interactive=interactive, name=sys._getframe().f_code.co_name)
348 def test_metrics(self, interactive: bool):
349 self.perform_subtests(interactive=interactive, name=sys._getframe().f_code.co_name)
351 def test_loss_func(self, interactive: bool):
352 self.perform_subtests(interactive=interactive, name=sys._getframe().f_code.co_name)
354 def test_monitor(self, interactive: bool):
355 self.perform_subtests(interactive=interactive, name=sys._getframe().f_code.co_name)
357 def test_activation(self, interactive: bool):
358 self.perform_subtests(interactive=interactive, name=sys._getframe().f_code.co_name)
360 def test_pretrained_location(self, interactive: bool):
361 self.perform_subtests(interactive=interactive, name=sys._getframe().f_code.co_name)
363 def test_one_batch_size(self, interactive: bool):
364 self.perform_subtests(interactive=interactive, name=sys._getframe().f_code.co_name)
366 def test_one_batch_output_size(self, interactive: bool):
367 self.perform_subtests(interactive=interactive, name=sys._getframe().f_code.co_name)
369 def test_one_batch_loss(self, interactive: bool):
370 self.perform_subtests(interactive=interactive, name=sys._getframe().f_code.co_name)
372 def test_bibliography(self, interactive: bool):
373 self.perform_subtests(interactive=interactive, name=sys._getframe().f_code.co_name)
375 def test_bibtex(self, interactive: bool):
376 self.perform_subtests(interactive=interactive, name=sys._getframe().f_code.co_name)
378 def cli_commands_to_test(self):
379 return [
380 "--help",
381 "train --help",
382 "infer --help",
383 "show-batch --help",
384 "tune --help",
385 "bibtex",
386 "bibliography",
387 ]
389 def test_cli(self):
390 app = self.get_app()
391 runner = CliRunner()
392 for command in self.cli_commands_to_test():
393 print(command)
394 result = runner.invoke(app.cli(), command.split())
395 assert result.exit_code == 0
396 assert result.stdout