Coverage for rdgai/evaluation.py: 100.00%

120 statements  

« prev     ^ index     » next       coverage.py v7.6.4, created at 2025-01-03 01:37 +0000

1import numpy as np 

2from pathlib import Path 

3from dataclasses import dataclass 

4from langchain_core.language_models.llms import LLM 

5from langchain.schema.output_parser import StrOutputParser 

6 

7from .tei import find_elements 

8from .apparatus import App, Doc, Pair 

9from .prompts import build_preamble, build_review_prompt 

10 

11@dataclass 

12class EvalItem: 

13 app_id:str 

14 app:App 

15 active:str 

16 passive:str 

17 text_in_context:str 

18 reading_transition_str:str 

19 ground_truth:str 

20 predicted:str 

21 description:str = "" 

22 ground_truth_description:str = "" 

23 

24 

25def llm_review_results( 

26 doc:Doc, 

27 correct_items:list[EvalItem], 

28 incorrect_items:list[EvalItem], 

29 llm:LLM|None=None, 

30 examples:int=10, 

31): 

32 template = build_review_prompt(doc, correct_items, incorrect_items, examples=examples) 

33 result = "" 

34 if llm: 

35 chain = template | llm | StrOutputParser() 

36 result = chain.invoke({}) 

37 

38 return template, result 

39 

40 

41def evaluate_docs( 

42 doc:Doc, 

43 ground_truth:Doc, 

44 pairs:list[Pair]|None=None, 

45 confusion_matrix:Path|None=None, 

46 confusion_matrix_plot:Path|None=None, 

47 report:Path|None=None, 

48 llm:LLM|None=None, 

49 examples:int=10, 

50): 

51 # get dictionary of ground truth apps 

52 ground_truth_apps = {str(app):app for app in ground_truth.apps} 

53 

54 # find all classified relations in the ground truth that correspond to the classified relations in the doc 

55 predicted = [] 

56 gold = [] 

57 correct_items = [] 

58 incorrect_items = [] 

59 

60 # Create a dictionary of ground truth relation elements to their corresponding pairs 

61 ground_truth_pair_dict = dict() 

62 for pair in ground_truth.get_classified_pairs(): 

63 for relation_element in pair.relation_elements(): 

64 ground_truth_pair_dict[relation_element] = pair 

65 

66 

67 # find all classified relations in the doc that have been classified with rdgai 

68 pairs = pairs or [pair for pair in doc.get_classified_pairs() if pair.rdgai_responsible()] 

69 

70 if len(pairs) == 0: 

71 print("No rdgai relations found in predicted document.") 

72 return 

73 

74 for pair in pairs: 

75 # find app 

76 app = pair.app 

77 app_id = str(app) 

78 

79 active = pair.active.n 

80 passive = pair.passive.n 

81 

82 ground_truth_app = ground_truth_apps.get(app_id, None) 

83 if ground_truth_app is None: 

84 continue 

85 

86 ground_truth_relations = [element for element in find_elements(ground_truth_app.element, f".//relation[@active='{active}']") if element.attrib['passive'] == passive] 

87 if not ground_truth_relations: 

88 continue 

89 ground_truth_relation = ground_truth_relations[0] 

90 ground_truth_pair = ground_truth_pair_dict[ground_truth_relation] 

91 

92 # exclude any classified with rdgai 

93 if ground_truth_pair.rdgai_responsible(): 

94 continue 

95 

96 ground_truth_types = ground_truth_pair.relation_type_names() 

97 predicted_types = pair.relation_type_names() 

98 

99 description = pair.get_description() 

100 

101 eval_item = EvalItem( 

102 app_id=app_id, 

103 app=app, 

104 text_in_context=app.text_in_context(), 

105 active=ground_truth_pair.active, 

106 passive=ground_truth_pair.passive, 

107 reading_transition_str=ground_truth_pair.reading_transition_str(), 

108 ground_truth=ground_truth_types, 

109 predicted=predicted_types, 

110 description=description, 

111 ground_truth_description=ground_truth_pair.get_description(), 

112 ) 

113 if ground_truth_types == predicted_types: 

114 correct_items.append(eval_item) 

115 else: 

116 incorrect_items.append(eval_item) 

117 

118 predicted.append(" ".join(sorted(predicted_types))) 

119 gold.append(" ".join(sorted(ground_truth_types))) 

120 

121 print(len(predicted), len(gold)) 

122 assert len(predicted) == len(gold), f"Predicted and gold lengths do not match: {len(predicted)} != {len(gold)}" 

123 

124 if len(gold) == 0: 

125 print("No relations found in ground truth.") 

126 return 

127 

128 from sklearn.metrics import precision_score, recall_score, f1_score, classification_report, accuracy_score 

129 print(classification_report(gold, predicted)) 

130 

131 precision = precision_score(gold, predicted, average='macro')*100.0 

132 print("precision", precision) 

133 recall = recall_score(gold, predicted, average='macro')*100.0 

134 print("recall", recall) 

135 f1 = f1_score(gold, predicted, average='macro')*100.0 

136 print("f1", f1) 

137 accuracy = accuracy_score(gold, predicted)*100.0 

138 print("accuracy", accuracy) 

139 

140 # create confusion matrix 

141 if confusion_matrix or confusion_matrix_plot or report: 

142 from sklearn.metrics import confusion_matrix as sk_confusion_matrix 

143 import pandas as pd 

144 

145 labels = list(ground_truth.relation_types.keys()) 

146 cm = sk_confusion_matrix(gold, predicted, labels=labels) 

147 confusion_df = pd.DataFrame(cm, index=labels, columns=labels) 

148 if confusion_matrix: 

149 confusion_matrix = Path(confusion_matrix) 

150 confusion_matrix.parent.mkdir(parents=True, exist_ok=True) 

151 confusion_df.to_csv(confusion_matrix) 

152 

153 if confusion_matrix_plot or report: 

154 import plotly.graph_objects as go 

155 

156 sums = cm.sum(axis=1, keepdims=True) 

157 cm_normalized = cm / np.maximum(sums, 1) 

158 

159 text_annotations = [[str(cm[i][j]) for j in range(len(labels))] for i in range(len(labels))] 

160 

161 # Plot the normalized confusion matrix 

162 fig = go.Figure(data=go.Heatmap( 

163 z=cm_normalized, 

164 x=labels, 

165 y=labels, 

166 colorscale='Viridis', 

167 text=text_annotations, # Add only the raw counts to each cell 

168 colorbar=dict(title="Proportion of True Values") # Updated legend title 

169 )) 

170 annotations = [] 

171 for i in range(len(labels)): 

172 for j in range(len(labels)): 

173 count = cm[i][j] # Raw count 

174 proportion = cm_normalized[i][j] # Normalized proportion 

175 annotations.append( 

176 go.layout.Annotation( 

177 x=j, y=i, 

178 text=f"{count}", # Showing both raw count and normalized proportion 

179 showarrow=False, 

180 font=dict(size=10, color="white" if proportion < 0.5 else "black") 

181 ) 

182 ) 

183 fig.update_layout(annotations=annotations) 

184 

185 

186 fig.update_layout( 

187 xaxis_title='Predicted', 

188 yaxis_title='Actual', 

189 xaxis=dict(tickmode='array', tickvals=list(range(len(labels))), ticktext=labels, side="top"), 

190 yaxis=dict(tickmode='array', tickvals=list(range(len(labels))), ticktext=labels, autorange="reversed"), 

191 ) 

192 

193 # Save the plot as HTML 

194 if confusion_matrix_plot: 

195 confusion_matrix_plot = Path(confusion_matrix_plot) 

196 confusion_matrix_plot.parent.mkdir(parents=True, exist_ok=True) 

197 fig.write_html(confusion_matrix_plot) 

198 

199 if report: 

200 from flask import Flask, render_template 

201 report = Path(report) 

202 report.parent.mkdir(parents=True, exist_ok=True) 

203 app = Flask(__name__) 

204 

205 review_template, review_result = llm_review_results( 

206 doc, 

207 correct_items=correct_items, 

208 incorrect_items=incorrect_items, 

209 examples=examples, 

210 llm=llm, 

211 ) 

212 

213 import plotly.io as pio 

214 confusion_matrix_html = pio.to_html(fig, full_html=True, include_plotlyjs='inline') 

215 

216 with app.app_context(): 

217 text = render_template( 

218 'report.html', 

219 correct_items=correct_items, 

220 incorrect_items=incorrect_items, 

221 confusion_matrix=confusion_matrix_html, 

222 accuracy=accuracy, 

223 precision=precision, 

224 recall=recall, 

225 f1=f1, 

226 correct_count=len(correct_items), 

227 incorrect_count=len(incorrect_items), 

228 prompt=build_preamble(doc, examples=examples), 

229 review_template=review_template, 

230 review_result=review_result, 

231 ) 

232 

233 print(f"Writing HTML report to {report}") 

234 report.write_text(text) 

235 

236