Coverage for rdgai/main.py: 100.00%

77 statements  

« prev     ^ index     » next       coverage.py v7.6.4, created at 2025-01-03 01:37 +0000

1from pathlib import Path 

2import typer 

3from rich.console import Console 

4import pandas as pd 

5 

6from .apparatus import Doc 

7from .export import export_variants_to_excel, import_classifications_from_dataframe 

8from .classification import classify as classify_fn 

9from .evaluation import evaluate_docs 

10from .classification import DEFAULT_MODEL_ID 

11from .validation import validate as validate_fn 

12from .prompts import build_preamble 

13 

14console = Console() 

15error_console = Console(stderr=True, style="bold red") 

16 

17 

18app = typer.Typer(pretty_exceptions_enable=False) 

19 

20 

21 

22def get_output_path(doc:Path, output:Path, inplace:bool) -> Path: 

23 """ Checks if the output path should be replaced with the input doc. """ 

24 if output and inplace: 

25 raise typer.BadParameter("You cannot use both an output path and --inplace/-i at the same time.") 

26 if not output and not inplace: 

27 raise typer.BadParameter("You must provide either an output path or use --inplace/-i.") 

28 

29 if inplace: 

30 output = doc 

31 

32 return output 

33 

34 

35@app.command() 

36def classify( 

37 doc:Path=typer.Argument(..., help="The path to the TEI XML document to classify."), 

38 output:Path=typer.Argument(None, help="The path to the output TEI XML file."), 

39 inplace: bool = typer.Option(False, "--inplace", "-i", help="Overwrite the input file."), 

40 verbose:bool=typer.Option(False, help="Print verbose output."), 

41 api_key:str=typer.Option("", help="API key for the LLM."), 

42 llm:str=typer.Option(DEFAULT_MODEL_ID, help="ID of the language model to use."), 

43 temperature:float=typer.Option(0.1, help="Temperature for sampling from the language model."), 

44 prompt_only:bool=typer.Option(False, help="Only print the prompt and not classify."), 

45 examples:int=typer.Option(10, help="Number of examples to include in the prompt."), 

46): 

47 """ 

48 Classifies relations in TEI documents. 

49 """ 

50 doc = Doc(doc) 

51 output = get_output_path(doc, output, inplace) 

52 

53 return classify_fn( 

54 doc=doc, 

55 output=output, 

56 verbose=verbose, 

57 api_key=api_key, 

58 llm=llm, 

59 temperature=temperature, 

60 prompt_only=prompt_only, 

61 examples=examples, 

62 console=console, 

63 ) 

64 

65 

66@app.command() 

67def classified_pairs( 

68 doc:Path=typer.Argument(..., help="The path to the TEI XML document with the classifications."), 

69): 

70 """ Print classified pairs in a document. """ 

71 doc = Doc(doc) 

72 doc.print_classified_pairs(console) 

73 

74 

75@app.command() 

76def html( 

77 doc:Path=typer.Argument(..., help="The path to the TEI XML document to render as HTML."), 

78 output:Path=typer.Argument(..., help="The path to the output HTML file."), 

79 all_apps:bool=typer.Option(False, help="Whether or not to use all variation unit `app` elements. By default it shows only non-redundant pairs of readings."), 

80): 

81 """ Renders the variation units of a TEI document as HTML. """ 

82 doc = Doc(doc) 

83 doc.render_html(output, all_apps=all_apps) 

84 

85 

86@app.command() 

87def gui( 

88 doc:Path=typer.Argument(..., help="The path to the TEI XML document to classify."), 

89 output:Path=typer.Argument(None, help="The path to the output TEI XML file."), 

90 inplace: bool = typer.Option(False, "--inplace", "-i", help="Overwrite the input file."), 

91 debug:bool=True, 

92 use_reloader:bool=False, 

93 all_apps:bool=typer.Option(False, help="Whether or not to use all variation unit `app` elements. By default it shows only non-redundant pairs of readings."), 

94): 

95 """ Starts a Flask app to view and classify a TEI document. """ 

96 doc = Doc(doc) 

97 output = get_output_path(doc, output, inplace) 

98 flask_app = doc.flask_app(output, all_apps=all_apps) 

99 flask_app.run(debug=debug, use_reloader=use_reloader) 

100 

101 

102@app.command() 

103def evaluate( 

104 predicted:Path=typer.Argument(..., help="The path to the TEI XML document with predictions from Rdgai to evaluate."), 

105 ground_truth:Path=typer.Argument(..., help="The path to the input TEI XML document to use as the ground truth for evaluation."), 

106 confusion_matrix:Path=typer.Option(None, help="Path to write the confusion matrix plot as a CSV file."), 

107 confusion_matrix_plot:Path=typer.Option(None, help="Path to write the confusion matrix plot as an HTML file."), 

108 report:Path=typer.Option(None, help="Path to write the report."), 

109): 

110 """ Evaluates the classifications in a predicted document against a ground truth document. """ 

111 predicted = Doc(predicted) 

112 ground_truth = Doc(ground_truth) 

113 

114 evaluate_docs(predicted, ground_truth, confusion_matrix=confusion_matrix, confusion_matrix_plot=confusion_matrix_plot, report=report) 

115 

116 

117@app.command() 

118def validate( 

119 ground_truth:Path=typer.Argument(..., help="The path to the input TEI XML document to use as the ground truth for evaluation."), 

120 output:Path=typer.Argument(..., help="The path to the output TEI XML file."), 

121 proportion:float=typer.Option(0.5, help="Proportion of classified pairs to use for validation."), 

122 api_key:str=typer.Option("", help="API key for the LLM."), 

123 llm:str=typer.Option(DEFAULT_MODEL_ID, help="ID of the language model to use."), 

124 temperature:float=typer.Option(0.1, help="Temperature for sampling from the language model."), 

125 examples:int=typer.Option(10, help="Number of examples to include in the prompt."), 

126 confusion_matrix:Path=typer.Option(None, help="Path to write the confusion matrix plot as a CSV file."), 

127 confusion_matrix_plot:Path=typer.Option(None, help="Path to write the confusion matrix plot as an HTML file."), 

128 seed:int=typer.Option(42, help="Seed for random sampling of validation pairs."), 

129 report:Path=typer.Option(None, help="Path to write the report."), 

130): 

131 """ Takes a ground truth document, chooses a proportion of classified pairs to validate against and outputs a report. """ 

132 ground_truth = Doc(ground_truth) 

133 

134 validate_fn( 

135 ground_truth, 

136 output, 

137 llm=llm, 

138 api_key=api_key, 

139 examples=examples, 

140 seed=seed, 

141 temperature=temperature, 

142 proportion=proportion, 

143 confusion_matrix=confusion_matrix, 

144 confusion_matrix_plot=confusion_matrix_plot, 

145 report=report, 

146 ) 

147 

148 

149@app.command() 

150def clean( 

151 doc:Path=typer.Argument(..., help="The path to the TEI XML document to clean."), 

152 output:Path=typer.Argument(None, help="The path to the output TEI XML file."), 

153 inplace: bool = typer.Option(False, "--inplace", "-i", help="Overwrite the input file."), 

154): 

155 """ Cleans a TEI XML file for common errors. """ 

156 doc = Doc(doc) 

157 output = get_output_path(doc, output, inplace) 

158 doc.clean(output=output) 

159 

160 

161@app.command() 

162def export( 

163 doc:Path=typer.Argument(..., help="The path to the TEI XML document to export."), 

164 output:Path=typer.Argument(..., help="The path to the output Excel file."), 

165): 

166 """ Exports pairs of readings with classifications from a TEI document to an Excel spreadsheet. """ 

167 doc = Doc(doc) 

168 export_variants_to_excel(doc, output) 

169 

170 

171@app.command() 

172def import_classifications( 

173 doc:Path=typer.Argument(..., help="The path to the base TEI XML document to use for importing the classifications from Excel."), 

174 spreadsheet:Path=typer.Argument(..., help="The path to the Excel file to import."), 

175 output:Path=typer.Argument(None, help="The path to the output TEI XML file."), 

176 inplace: bool = typer.Option(False, "--inplace", "-i", help="Overwrite the input file."), 

177 responsible:str=typer.Option("", help="The responsible party for the classifications. By default it is the name of the spreadsheet."), 

178): 

179 """ Imports classifications from a spreadsheet into a TEI document. """ 

180 doc = Doc(doc) 

181 output = get_output_path(doc, output, inplace) 

182 

183 if spreadsheet.suffix == ".xlsx": 

184 variants_df = pd.read_excel(spreadsheet, sheet_name="Variants", keep_default_na=False) 

185 elif spreadsheet.suffix == ".csv": 

186 variants_df = pd.read_csv(spreadsheet, keep_default_na=False) 

187 

188 # TODO add responsible to TEI header 

189 responsible = responsible or spreadsheet.stem 

190 responsible = responsible.replace(" ", "_") 

191 if not responsible.startswith("#"): 

192 responsible = "#" + responsible 

193 

194 import_classifications_from_dataframe(doc, variants_df, output, responsible=responsible) 

195 

196 

197@app.command() 

198def prompt_preamble( 

199 doc:Path=typer.Argument(..., help="The path to the TEI XML document to classify."), 

200 examples:int=typer.Option(10, help="Number of examples to include in the prompt."), 

201): 

202 """ Prints the prompt preamble for a TEI document for a given number of examples. """ 

203 doc = Doc(doc) 

204 template = build_preamble(doc, examples) 

205 print(template)