Coverage for rdgai/main.py: 100.00%
77 statements
« prev ^ index » next coverage.py v7.6.4, created at 2025-01-03 01:37 +0000
« prev ^ index » next coverage.py v7.6.4, created at 2025-01-03 01:37 +0000
1from pathlib import Path
2import typer
3from rich.console import Console
4import pandas as pd
6from .apparatus import Doc
7from .export import export_variants_to_excel, import_classifications_from_dataframe
8from .classification import classify as classify_fn
9from .evaluation import evaluate_docs
10from .classification import DEFAULT_MODEL_ID
11from .validation import validate as validate_fn
12from .prompts import build_preamble
14console = Console()
15error_console = Console(stderr=True, style="bold red")
18app = typer.Typer(pretty_exceptions_enable=False)
22def get_output_path(doc:Path, output:Path, inplace:bool) -> Path:
23 """ Checks if the output path should be replaced with the input doc. """
24 if output and inplace:
25 raise typer.BadParameter("You cannot use both an output path and --inplace/-i at the same time.")
26 if not output and not inplace:
27 raise typer.BadParameter("You must provide either an output path or use --inplace/-i.")
29 if inplace:
30 output = doc
32 return output
35@app.command()
36def classify(
37 doc:Path=typer.Argument(..., help="The path to the TEI XML document to classify."),
38 output:Path=typer.Argument(None, help="The path to the output TEI XML file."),
39 inplace: bool = typer.Option(False, "--inplace", "-i", help="Overwrite the input file."),
40 verbose:bool=typer.Option(False, help="Print verbose output."),
41 api_key:str=typer.Option("", help="API key for the LLM."),
42 llm:str=typer.Option(DEFAULT_MODEL_ID, help="ID of the language model to use."),
43 temperature:float=typer.Option(0.1, help="Temperature for sampling from the language model."),
44 prompt_only:bool=typer.Option(False, help="Only print the prompt and not classify."),
45 examples:int=typer.Option(10, help="Number of examples to include in the prompt."),
46):
47 """
48 Classifies relations in TEI documents.
49 """
50 doc = Doc(doc)
51 output = get_output_path(doc, output, inplace)
53 return classify_fn(
54 doc=doc,
55 output=output,
56 verbose=verbose,
57 api_key=api_key,
58 llm=llm,
59 temperature=temperature,
60 prompt_only=prompt_only,
61 examples=examples,
62 console=console,
63 )
66@app.command()
67def classified_pairs(
68 doc:Path=typer.Argument(..., help="The path to the TEI XML document with the classifications."),
69):
70 """ Print classified pairs in a document. """
71 doc = Doc(doc)
72 doc.print_classified_pairs(console)
75@app.command()
76def html(
77 doc:Path=typer.Argument(..., help="The path to the TEI XML document to render as HTML."),
78 output:Path=typer.Argument(..., help="The path to the output HTML file."),
79 all_apps:bool=typer.Option(False, help="Whether or not to use all variation unit `app` elements. By default it shows only non-redundant pairs of readings."),
80):
81 """ Renders the variation units of a TEI document as HTML. """
82 doc = Doc(doc)
83 doc.render_html(output, all_apps=all_apps)
86@app.command()
87def gui(
88 doc:Path=typer.Argument(..., help="The path to the TEI XML document to classify."),
89 output:Path=typer.Argument(None, help="The path to the output TEI XML file."),
90 inplace: bool = typer.Option(False, "--inplace", "-i", help="Overwrite the input file."),
91 debug:bool=True,
92 use_reloader:bool=False,
93 all_apps:bool=typer.Option(False, help="Whether or not to use all variation unit `app` elements. By default it shows only non-redundant pairs of readings."),
94):
95 """ Starts a Flask app to view and classify a TEI document. """
96 doc = Doc(doc)
97 output = get_output_path(doc, output, inplace)
98 flask_app = doc.flask_app(output, all_apps=all_apps)
99 flask_app.run(debug=debug, use_reloader=use_reloader)
102@app.command()
103def evaluate(
104 predicted:Path=typer.Argument(..., help="The path to the TEI XML document with predictions from Rdgai to evaluate."),
105 ground_truth:Path=typer.Argument(..., help="The path to the input TEI XML document to use as the ground truth for evaluation."),
106 confusion_matrix:Path=typer.Option(None, help="Path to write the confusion matrix plot as a CSV file."),
107 confusion_matrix_plot:Path=typer.Option(None, help="Path to write the confusion matrix plot as an HTML file."),
108 report:Path=typer.Option(None, help="Path to write the report."),
109):
110 """ Evaluates the classifications in a predicted document against a ground truth document. """
111 predicted = Doc(predicted)
112 ground_truth = Doc(ground_truth)
114 evaluate_docs(predicted, ground_truth, confusion_matrix=confusion_matrix, confusion_matrix_plot=confusion_matrix_plot, report=report)
117@app.command()
118def validate(
119 ground_truth:Path=typer.Argument(..., help="The path to the input TEI XML document to use as the ground truth for evaluation."),
120 output:Path=typer.Argument(..., help="The path to the output TEI XML file."),
121 proportion:float=typer.Option(0.5, help="Proportion of classified pairs to use for validation."),
122 api_key:str=typer.Option("", help="API key for the LLM."),
123 llm:str=typer.Option(DEFAULT_MODEL_ID, help="ID of the language model to use."),
124 temperature:float=typer.Option(0.1, help="Temperature for sampling from the language model."),
125 examples:int=typer.Option(10, help="Number of examples to include in the prompt."),
126 confusion_matrix:Path=typer.Option(None, help="Path to write the confusion matrix plot as a CSV file."),
127 confusion_matrix_plot:Path=typer.Option(None, help="Path to write the confusion matrix plot as an HTML file."),
128 seed:int=typer.Option(42, help="Seed for random sampling of validation pairs."),
129 report:Path=typer.Option(None, help="Path to write the report."),
130):
131 """ Takes a ground truth document, chooses a proportion of classified pairs to validate against and outputs a report. """
132 ground_truth = Doc(ground_truth)
134 validate_fn(
135 ground_truth,
136 output,
137 llm=llm,
138 api_key=api_key,
139 examples=examples,
140 seed=seed,
141 temperature=temperature,
142 proportion=proportion,
143 confusion_matrix=confusion_matrix,
144 confusion_matrix_plot=confusion_matrix_plot,
145 report=report,
146 )
149@app.command()
150def clean(
151 doc:Path=typer.Argument(..., help="The path to the TEI XML document to clean."),
152 output:Path=typer.Argument(None, help="The path to the output TEI XML file."),
153 inplace: bool = typer.Option(False, "--inplace", "-i", help="Overwrite the input file."),
154):
155 """ Cleans a TEI XML file for common errors. """
156 doc = Doc(doc)
157 output = get_output_path(doc, output, inplace)
158 doc.clean(output=output)
161@app.command()
162def export(
163 doc:Path=typer.Argument(..., help="The path to the TEI XML document to export."),
164 output:Path=typer.Argument(..., help="The path to the output Excel file."),
165):
166 """ Exports pairs of readings with classifications from a TEI document to an Excel spreadsheet. """
167 doc = Doc(doc)
168 export_variants_to_excel(doc, output)
171@app.command()
172def import_classifications(
173 doc:Path=typer.Argument(..., help="The path to the base TEI XML document to use for importing the classifications from Excel."),
174 spreadsheet:Path=typer.Argument(..., help="The path to the Excel file to import."),
175 output:Path=typer.Argument(None, help="The path to the output TEI XML file."),
176 inplace: bool = typer.Option(False, "--inplace", "-i", help="Overwrite the input file."),
177 responsible:str=typer.Option("", help="The responsible party for the classifications. By default it is the name of the spreadsheet."),
178):
179 """ Imports classifications from a spreadsheet into a TEI document. """
180 doc = Doc(doc)
181 output = get_output_path(doc, output, inplace)
183 if spreadsheet.suffix == ".xlsx":
184 variants_df = pd.read_excel(spreadsheet, sheet_name="Variants", keep_default_na=False)
185 elif spreadsheet.suffix == ".csv":
186 variants_df = pd.read_csv(spreadsheet, keep_default_na=False)
188 # TODO add responsible to TEI header
189 responsible = responsible or spreadsheet.stem
190 responsible = responsible.replace(" ", "_")
191 if not responsible.startswith("#"):
192 responsible = "#" + responsible
194 import_classifications_from_dataframe(doc, variants_df, output, responsible=responsible)
197@app.command()
198def prompt_preamble(
199 doc:Path=typer.Argument(..., help="The path to the TEI XML document to classify."),
200 examples:int=typer.Option(10, help="Number of examples to include in the prompt."),
201):
202 """ Prints the prompt preamble for a TEI document for a given number of examples. """
203 doc = Doc(doc)
204 template = build_preamble(doc, examples)
205 print(template)