Coverage for rdgai/evaluation.py: 100.00%
120 statements
« prev ^ index » next coverage.py v7.6.4, created at 2025-01-03 01:37 +0000
« prev ^ index » next coverage.py v7.6.4, created at 2025-01-03 01:37 +0000
1import numpy as np
2from pathlib import Path
3from dataclasses import dataclass
4from langchain_core.language_models.llms import LLM
5from langchain.schema.output_parser import StrOutputParser
7from .tei import find_elements
8from .apparatus import App, Doc, Pair
9from .prompts import build_preamble, build_review_prompt
11@dataclass
12class EvalItem:
13 app_id:str
14 app:App
15 active:str
16 passive:str
17 text_in_context:str
18 reading_transition_str:str
19 ground_truth:str
20 predicted:str
21 description:str = ""
22 ground_truth_description:str = ""
25def llm_review_results(
26 doc:Doc,
27 correct_items:list[EvalItem],
28 incorrect_items:list[EvalItem],
29 llm:LLM|None=None,
30 examples:int=10,
31):
32 template = build_review_prompt(doc, correct_items, incorrect_items, examples=examples)
33 result = ""
34 if llm:
35 chain = template | llm | StrOutputParser()
36 result = chain.invoke({})
38 return template, result
41def evaluate_docs(
42 doc:Doc,
43 ground_truth:Doc,
44 pairs:list[Pair]|None=None,
45 confusion_matrix:Path|None=None,
46 confusion_matrix_plot:Path|None=None,
47 report:Path|None=None,
48 llm:LLM|None=None,
49 examples:int=10,
50):
51 # get dictionary of ground truth apps
52 ground_truth_apps = {str(app):app for app in ground_truth.apps}
54 # find all classified relations in the ground truth that correspond to the classified relations in the doc
55 predicted = []
56 gold = []
57 correct_items = []
58 incorrect_items = []
60 # Create a dictionary of ground truth relation elements to their corresponding pairs
61 ground_truth_pair_dict = dict()
62 for pair in ground_truth.get_classified_pairs():
63 for relation_element in pair.relation_elements():
64 ground_truth_pair_dict[relation_element] = pair
67 # find all classified relations in the doc that have been classified with rdgai
68 pairs = pairs or [pair for pair in doc.get_classified_pairs() if pair.rdgai_responsible()]
70 if len(pairs) == 0:
71 print("No rdgai relations found in predicted document.")
72 return
74 for pair in pairs:
75 # find app
76 app = pair.app
77 app_id = str(app)
79 active = pair.active.n
80 passive = pair.passive.n
82 ground_truth_app = ground_truth_apps.get(app_id, None)
83 if ground_truth_app is None:
84 continue
86 ground_truth_relations = [element for element in find_elements(ground_truth_app.element, f".//relation[@active='{active}']") if element.attrib['passive'] == passive]
87 if not ground_truth_relations:
88 continue
89 ground_truth_relation = ground_truth_relations[0]
90 ground_truth_pair = ground_truth_pair_dict[ground_truth_relation]
92 # exclude any classified with rdgai
93 if ground_truth_pair.rdgai_responsible():
94 continue
96 ground_truth_types = ground_truth_pair.relation_type_names()
97 predicted_types = pair.relation_type_names()
99 description = pair.get_description()
101 eval_item = EvalItem(
102 app_id=app_id,
103 app=app,
104 text_in_context=app.text_in_context(),
105 active=ground_truth_pair.active,
106 passive=ground_truth_pair.passive,
107 reading_transition_str=ground_truth_pair.reading_transition_str(),
108 ground_truth=ground_truth_types,
109 predicted=predicted_types,
110 description=description,
111 ground_truth_description=ground_truth_pair.get_description(),
112 )
113 if ground_truth_types == predicted_types:
114 correct_items.append(eval_item)
115 else:
116 incorrect_items.append(eval_item)
118 predicted.append(" ".join(sorted(predicted_types)))
119 gold.append(" ".join(sorted(ground_truth_types)))
121 print(len(predicted), len(gold))
122 assert len(predicted) == len(gold), f"Predicted and gold lengths do not match: {len(predicted)} != {len(gold)}"
124 if len(gold) == 0:
125 print("No relations found in ground truth.")
126 return
128 from sklearn.metrics import precision_score, recall_score, f1_score, classification_report, accuracy_score
129 print(classification_report(gold, predicted))
131 precision = precision_score(gold, predicted, average='macro')*100.0
132 print("precision", precision)
133 recall = recall_score(gold, predicted, average='macro')*100.0
134 print("recall", recall)
135 f1 = f1_score(gold, predicted, average='macro')*100.0
136 print("f1", f1)
137 accuracy = accuracy_score(gold, predicted)*100.0
138 print("accuracy", accuracy)
140 # create confusion matrix
141 if confusion_matrix or confusion_matrix_plot or report:
142 from sklearn.metrics import confusion_matrix as sk_confusion_matrix
143 import pandas as pd
145 labels = list(ground_truth.relation_types.keys())
146 cm = sk_confusion_matrix(gold, predicted, labels=labels)
147 confusion_df = pd.DataFrame(cm, index=labels, columns=labels)
148 if confusion_matrix:
149 confusion_matrix = Path(confusion_matrix)
150 confusion_matrix.parent.mkdir(parents=True, exist_ok=True)
151 confusion_df.to_csv(confusion_matrix)
153 if confusion_matrix_plot or report:
154 import plotly.graph_objects as go
156 sums = cm.sum(axis=1, keepdims=True)
157 cm_normalized = cm / np.maximum(sums, 1)
159 text_annotations = [[str(cm[i][j]) for j in range(len(labels))] for i in range(len(labels))]
161 # Plot the normalized confusion matrix
162 fig = go.Figure(data=go.Heatmap(
163 z=cm_normalized,
164 x=labels,
165 y=labels,
166 colorscale='Viridis',
167 text=text_annotations, # Add only the raw counts to each cell
168 colorbar=dict(title="Proportion of True Values") # Updated legend title
169 ))
170 annotations = []
171 for i in range(len(labels)):
172 for j in range(len(labels)):
173 count = cm[i][j] # Raw count
174 proportion = cm_normalized[i][j] # Normalized proportion
175 annotations.append(
176 go.layout.Annotation(
177 x=j, y=i,
178 text=f"{count}", # Showing both raw count and normalized proportion
179 showarrow=False,
180 font=dict(size=10, color="white" if proportion < 0.5 else "black")
181 )
182 )
183 fig.update_layout(annotations=annotations)
186 fig.update_layout(
187 xaxis_title='Predicted',
188 yaxis_title='Actual',
189 xaxis=dict(tickmode='array', tickvals=list(range(len(labels))), ticktext=labels, side="top"),
190 yaxis=dict(tickmode='array', tickvals=list(range(len(labels))), ticktext=labels, autorange="reversed"),
191 )
193 # Save the plot as HTML
194 if confusion_matrix_plot:
195 confusion_matrix_plot = Path(confusion_matrix_plot)
196 confusion_matrix_plot.parent.mkdir(parents=True, exist_ok=True)
197 fig.write_html(confusion_matrix_plot)
199 if report:
200 from flask import Flask, render_template
201 report = Path(report)
202 report.parent.mkdir(parents=True, exist_ok=True)
203 app = Flask(__name__)
205 review_template, review_result = llm_review_results(
206 doc,
207 correct_items=correct_items,
208 incorrect_items=incorrect_items,
209 examples=examples,
210 llm=llm,
211 )
213 import plotly.io as pio
214 confusion_matrix_html = pio.to_html(fig, full_html=True, include_plotlyjs='inline')
216 with app.app_context():
217 text = render_template(
218 'report.html',
219 correct_items=correct_items,
220 incorrect_items=incorrect_items,
221 confusion_matrix=confusion_matrix_html,
222 accuracy=accuracy,
223 precision=precision,
224 recall=recall,
225 f1=f1,
226 correct_count=len(correct_items),
227 incorrect_count=len(incorrect_items),
228 prompt=build_preamble(doc, examples=examples),
229 review_template=review_template,
230 review_result=review_result,
231 )
233 print(f"Writing HTML report to {report}")
234 report.write_text(text)