Coverage for rdgai/apparatus.py: 100.00%

463 statements  

« prev     ^ index     » next       coverage.py v7.6.4, created at 2025-01-03 01:37 +0000

1from typing import Optional 

2from pathlib import Path 

3from dataclasses import dataclass, field 

4from lxml.etree import _Element as Element 

5from lxml.etree import _ElementTree as ElementTree 

6from lxml import etree as ET 

7from rich.console import Console 

8import functools 

9import Levenshtein 

10import numpy as np 

11 

12# from .relations import Relation, get_reading_identifier 

13from .tei import read_tei, find_elements, extract_text, find_parent, find_element, write_tei, make_nc_name, get_language, get_reading_identifier 

14from .mapper import Mapper 

15 

16@dataclass 

17class Reading(): 

18 element: Element 

19 app:"App" 

20 n: str = field(default=None) 

21 text: str = field(default=None) 

22 witnesses: list[str] = field(default_factory=list) 

23 

24 def __post_init__(self): 

25 self.n = get_reading_identifier(self.element) 

26 self.text = extract_text(self.element).strip() 

27 self.witnesses = self.element.attrib.get("wit", "").split() 

28 

29 def __str__(self): 

30 return self.text or 'OMIT' 

31 

32 def witnesses_str(self) -> str: 

33 return " ".join(self.witnesses) 

34 

35 def __hash__(self): 

36 return hash(self.element) 

37 

38 

39@dataclass 

40class RelationType(): 

41 element: Element 

42 name: str 

43 description: str 

44 inverse: Optional['RelationType'] = None 

45 pairs: set['Pair'] = field(default_factory=set) 

46 

47 def __str__(self): 

48 return self.name 

49 

50 def __repr__(self) -> str: 

51 return str(self) 

52 

53 def __eq__(self, other): 

54 if isinstance(other, RelationType): 

55 return (self.name, self.element, self.description) == (other.name, other.element, other.description) 

56 return False 

57 

58 def __hash__(self): 

59 return hash((self.name, self.element, self.description)) 

60 

61 def str_with_description(self) -> str: 

62 result = self.name 

63 if self.description: 

64 result += f": {self.description}" 

65 return result 

66 

67 def pairs_sorted(self, exclude_rdgai:bool = False) -> list['Pair']: 

68 pairs = sorted(self.pairs, key=lambda pair: (str(pair.active.app), pair.active.n, pair.passive.n)) 

69 if exclude_rdgai: 

70 pairs = [pair for pair in pairs if not pair.rdgai_responsible()] 

71 return pairs 

72 

73 def get_inverse(self) -> 'RelationType': 

74 return self.inverse if self.inverse else self 

75 

76 @functools.lru_cache(maxsize=None) 

77 def representative_examples(self, k:int, random_state:int=42) -> list['Pair']: 

78 

79 def find_representative_examples(pairs_list:list[Pair], k:int, random_state:int=42): 

80 import kmedoids 

81 if len(pairs_list) <= k: 

82 return pairs_list 

83 distance_matrix = np.zeros((len(pairs_list), len(pairs_list))) 

84 for index1, pair in enumerate(pairs_list): 

85 for index2 in range(index1+1, len(pairs_list)): 

86 other_pair = pairs_list[index2] 

87 active_text_distance = Levenshtein.distance(pair.active.text, other_pair.active.text) 

88 passive_text_distance = Levenshtein.distance(pair.passive.text, other_pair.passive.text) 

89 distance = active_text_distance + passive_text_distance 

90 distance_matrix[index1, index2] = distance 

91 distance_matrix[index2, index1] = distance 

92 

93 result = kmedoids.fasterpam(distance_matrix, k, random_state=random_state, init="build") 

94 

95 return [pairs_list[index] for index in result.medoids] 

96 

97 pairs_list = self.pairs_sorted(exclude_rdgai=True) 

98 pairs_with_descriptions = [pair for pair in pairs_list if pair.has_description()] 

99 representative_pairs = [] 

100 if pairs_with_descriptions: 

101 representative_pairs = find_representative_examples(pairs_with_descriptions, k, random_state=random_state) 

102 

103 if len(representative_pairs) < k: 

104 pairs_without_descriptions = [pair for pair in pairs_list if not pair.has_description()] 

105 additional_pairs = find_representative_examples(pairs_without_descriptions, k-len(representative_pairs), random_state=random_state) 

106 representative_pairs.extend(additional_pairs) 

107 

108 return representative_pairs 

109 

110 

111@dataclass 

112class Pair(): 

113 active: Reading 

114 passive: Reading 

115 types: set[RelationType] = field(default_factory=set) 

116 

117 def __post_init__(self): 

118 for relation_type in self.types: 

119 relation_type.pairs.add(self) 

120 

121 def __str__(self): 

122 return f"{self.active} ➞ {self.passive}" 

123 

124 def print(self, console): 

125 console.print(f"[bold red]{self.app}[/bold red]: [green]{self.active}[/green] [red]➞[/red] [green]{self.passive}[/green]") 

126 

127 def reading_transition_str(self) -> str: 

128 return f"{self.active or 'OMISSION'} → {self.passive or 'OMISSION'}" 

129 

130 def __repr__(self) -> str: 

131 return str(self) 

132 

133 @property 

134 def app(self) -> "App": 

135 # assert self.active.app == self.passive.app 

136 return self.active.app 

137 

138 def __hash__(self): 

139 return hash((self.active, self.passive)) 

140 

141 def app_element(self) -> Element: 

142 return find_parent(self.active.element, "app") 

143 

144 def relation_elements(self) -> list[Element]: 

145 list_relation = find_element(self.app_element(), ".//listRelation[@type='transcriptional']") 

146 if list_relation is None: 

147 return [] 

148 

149 return find_elements(list_relation, f".//relation[@active='{self.active.n}'][@passive='{self.passive.n}']") 

150 

151 def element_for_type(self, type:RelationType) -> Element|None: 

152 for relation in self.relation_elements(): 

153 if f"#{type.name}" in relation.attrib.get("ana").split(): 

154 return relation 

155 return None 

156 

157 def get_inverse(self) -> "Pair": 

158 found_pair = None 

159 for pair in self.app.pairs: 

160 if pair.active == self.passive and pair.passive == self.active: 

161 found_pair = pair 

162 break 

163 assert found_pair is not None, f"No inverse pair found for {self}" 

164 return found_pair 

165 

166 def add_type_with_inverse(self, type:RelationType, responsible:str|None=None, description:str="", inverse_description:str="") -> Element: 

167 relation = self.add_type(type, responsible=responsible, description=description) 

168 inverse = self.get_inverse() 

169 inverse.add_type(type.get_inverse(), responsible=responsible, description=inverse_description) 

170 return relation 

171 

172 def add_type(self, type:RelationType, responsible:str|None=None, description:str="") -> Element: 

173 self.types.add(type) 

174 type.pairs.add(self) 

175 

176 # Check if the relation already exists 

177 relation = self.element_for_type(type) 

178 if relation is not None: 

179 return relation 

180 

181 list_relation = find_element(self.app_element(), ".//listRelation[@type='transcriptional']") 

182 if list_relation is None: 

183 list_relation = ET.SubElement(self.app_element(), "listRelation", attrib={"type":"transcriptional"}) 

184 

185 relation = find_element(list_relation, f".//relation[@active='{self.active.n}'][@passive='{self.passive.n}']") 

186 if relation is not None: 

187 if type.name not in relation.attrib.get("ana").split(): 

188 relation.attrib["ana"] += f" #{type.name}" 

189 else: 

190 relation = ET.SubElement(list_relation, "relation", attrib={"active":self.active.n, "passive":self.passive.n, "ana":f"#{type.name}"}) 

191 

192 if responsible is not None: 

193 relation.set("resp", responsible) 

194 

195 self.add_description(description, relation) 

196 

197 return relation 

198 

199 def remove_description(self): 

200 for relation in self.relation_elements(): 

201 for desc in find_elements(relation, ".//desc"): 

202 relation.remove(desc) 

203 

204 def add_description(self, description:str, relation:Element|None=None): 

205 if relation is None: 

206 relation_elements = self.relation_elements() 

207 

208 if len(relation_elements) == 0: 

209 list_relation = find_element(self.app_element(), ".//listRelation[@type='transcriptional']") 

210 if list_relation is None: 

211 list_relation = ET.SubElement(self.app_element(), "listRelation", attrib={"type":"transcriptional"}) 

212 relation = ET.SubElement(list_relation, "relation", attrib={"active":self.active.n, "passive":self.passive.n}) 

213 else: 

214 relation = relation_elements[0] 

215 

216 description = description.strip() 

217 if description: 

218 description_element = find_element(relation, ".//desc") 

219 if description_element is None: 

220 description_element = ET.SubElement(relation, "desc") 

221 

222 description_element.text = description 

223 

224 def remove_type(self, relation_type:RelationType): 

225 if relation_type in self.types: 

226 self.types.remove(relation_type) 

227 

228 if self in relation_type.pairs: 

229 relation_type.pairs.remove(self) 

230 

231 list_relation = find_element(self.app_element(), ".//listRelation[@type='transcriptional']") 

232 for relation in find_elements(list_relation, f".//relation[@active='{self.active.n}'][@passive='{self.passive.n}']"): 

233 if f"#{relation_type.name}" in relation.attrib.get("ana").split(): 

234 relation.attrib['ana'] = " ".join([ana for ana in relation.attrib.get("ana").split() if ana != f"#{relation_type.name}"]) 

235 if not relation.attrib.get("ana"): 

236 relation.getparent().remove(relation) 

237 

238 def remove_type_with_inverse(self, relation_type:RelationType): 

239 self.remove_type(relation_type) 

240 inverse = self.get_inverse() 

241 inverse.remove_type(relation_type.get_inverse()) 

242 

243 def remove_all_types(self): 

244 for relation_type in set(self.types): 

245 self.remove_type_with_inverse(relation_type) 

246 

247 def rdgai_responsible(self) -> bool: 

248 for element in self.relation_elements(): 

249 if element.attrib.get('resp', '') == '#rdgai': 

250 return True 

251 return False 

252 

253 def relation_type_names(self) -> set[str]: 

254 return set(type.name for type in self.types) 

255 

256 def has_description(self) -> bool: 

257 for relation in self.relation_elements(): 

258 if find_element(relation, ".//desc") is not None: 

259 return True 

260 return False 

261 

262 def get_description(self) -> str: 

263 description = "" 

264 for relation_element in self.relation_elements(): 

265 for desc in find_elements(relation_element, ".//desc"): 

266 description += "\n" + extract_text(desc) 

267 return description.strip() 

268 

269 

270@dataclass 

271class App(): 

272 element: Element 

273 doc: "Doc" 

274 readings: list[Reading] = field(default_factory=list) 

275 pairs: list[Pair] = field(default_factory=list) 

276 non_redundant_pairs: list[Pair] = field(default_factory=list) 

277 

278 def __post_init__(self): 

279 for reading in find_elements(self.element, ".//rdg"): 

280 self.readings.append(Reading(reading, app=self)) 

281 

282 # Build list of relation elements 

283 relation_elements = [] 

284 for list_relation in find_elements(self.element, ".//listRelation[@type='transcriptional']"): 

285 for relation_element in find_elements(list_relation, ".//relation"): 

286 relation_elements.append(relation_element) 

287 

288 # Build list of relation pairs 

289 active_visited = set() 

290 for active in self.readings: 

291 active_visited.add(active) 

292 for passive in self.readings: 

293 if active == passive: 

294 continue 

295 

296 types = set() 

297 for relation_element in relation_elements: 

298 if relation_element.attrib.get("active") == active.n and relation_element.attrib.get("passive") == passive.n: 

299 for ana in relation_element.attrib.get("ana", "").split(): 

300 if ana.startswith("#"): 

301 ana = ana[1:] 

302 if ana: 

303 types.add(ana) 

304 

305 pair_relation_types = set() 

306 for type_name in types: 

307 relation_type = self.doc.relation_types[type_name] if type_name in self.doc.relation_types else self.doc.add_relation_type(type_name) 

308 pair_relation_types.add(relation_type) 

309 

310 pair = Pair(active=active, passive=passive, types=pair_relation_types) 

311 self.pairs.append(pair) 

312 if passive not in active_visited: 

313 self.non_redundant_pairs.append(pair) 

314 

315 for relation_type in pair_relation_types: 

316 relation_type.pairs.add(pair) 

317 

318 assert len(self.pairs) == len(self.non_redundant_pairs) * 2 

319 

320 def get_classified_pairs(self, redundant:bool=True) -> list[Pair]: 

321 pairs = self.pairs if redundant else self.non_redundant_pairs 

322 return [pair for pair in pairs if len(pair.types) > 0] 

323 

324 def get_unclassified_pairs(self, redundant:bool=True) -> list[Pair]: 

325 pairs = self.pairs if redundant else self.non_redundant_pairs 

326 return [pair for pair in pairs if len(pair.types) == 0] 

327 

328 def __hash__(self): 

329 return hash(self.element) 

330 

331 def __str__(self): 

332 name = self.element.attrib.get('{http://www.w3.org/XML/1998/namespace}id', '') 

333 if not name: 

334 name = self.element.attrib.get('n', '') 

335 

336 if not name: 

337 ab = self.ab() 

338 for index, app in enumerate(find_elements(ab, ".//app")): 

339 if app == self.element: 

340 name = make_nc_name(f"{self.ab_name()}-{index+1}") 

341 self.element.attrib['{http://www.w3.org/XML/1998/namespace}id'] = name 

342 break 

343 return str(name).replace(" ", "_").replace(":", "_") 

344 

345 def ab(self) -> Element|None: 

346 return find_parent(self.element, "ab") 

347 

348 def ab_name(self) -> str: 

349 ab = self.ab() 

350 return ab.attrib.get("n", "") 

351 

352 def text_before(self) -> str: 

353 ab = self.ab() 

354 if ab is None: 

355 return "" 

356 

357 items = [] 

358 for child in ab: 

359 if child == self.element: 

360 break 

361 child_text = extract_text(child) 

362 if child_text: 

363 items.append(child_text) 

364 

365 text = " ".join(items) 

366 return text.strip() 

367 

368 def text_in_context(self, text="") -> str: 

369 return f"{self.text_before()} {self.text_with_signs(text)} {self.text_after()}".strip() 

370 

371 def text(self) -> str: 

372 return extract_text(self.element) 

373 

374 def text_with_signs(self, text="") -> str: 

375 text = text or self.text() 

376 if not text: 

377 return "⸆" 

378 return f"⸂{text}⸃" 

379 

380 def text_after(self) -> str: 

381 ab = self.ab() 

382 if ab is None: 

383 return "" 

384 

385 items = [] 

386 reached_element = False 

387 for child in ab: 

388 if reached_element: 

389 child_text = extract_text(child) 

390 if child_text: 

391 items.append(child_text) 

392 if child == self.element: 

393 reached_element = True 

394 

395 text = " ".join(items) 

396 return text.strip() 

397 

398 

399@dataclass 

400class Doc(): 

401 path: Path 

402 tree: ElementTree = field(default=None) 

403 apps: list[App] = field(default_factory=list) 

404 relation_types: dict[str,RelationType] = field(default_factory=dict) 

405 

406 def __post_init__(self): 

407 self.tree = read_tei(self.path) 

408 self.relation_types = self.get_relation_types() 

409 

410 for app_element in find_elements(self.tree, ".//app"): 

411 app = App(app_element, doc=self) 

412 self.apps.append(app) 

413 

414 def get_interpgrp(self) -> Element: 

415 text = find_element(self.tree, "text") 

416 interp_group = find_element(text, ".//interpGrp[@type='transcriptional']") 

417 if interp_group is None: 

418 interp_group = ET.Element("interpGrp", attrib={"type":"transcriptional"}) 

419 text.insert(0, interp_group) 

420 

421 return interp_group 

422 

423 def add_relation_type(self, name:str, description:str="") -> RelationType: 

424 if name in self.relation_types: 

425 assert self.relation_types[name].description == description, f"RelationType {name} already exists with a different description." 

426 return self.relation_types[name] 

427 

428 interp_group = self.get_interpgrp() 

429 interp = find_element(interp_group, f".//interp[@xml:id='{name}']") 

430 if interp is None: 

431 interp = ET.Element("interp", attrib={"{http://www.w3.org/XML/1998/namespace}id":name}) 

432 interp_group.append(interp) 

433 

434 relation_type = RelationType(name=name, element=interp, description="") 

435 self.relation_types[name] = relation_type 

436 return relation_type 

437 

438 def __str__(self): 

439 return str(self.path) 

440 

441 def __repr__(self) -> str: 

442 return str(self) 

443 

444 def write(self, output:str|Path): 

445 write_tei(self.tree, output) 

446 

447 @property 

448 def language(self): 

449 return get_language(self.tree) 

450 

451 def get_relation_types(self, categories_to_ignore:list[str]|None=None) -> list[RelationType]: 

452 interp_group = self.get_interpgrp() 

453 categories_to_ignore = categories_to_ignore or [] 

454 

455 relation_types = dict() 

456 assert interp_group is not None, "No interpGrp of type='transcriptional' found in TEI file." 

457 

458 for interp in find_elements(interp_group, "./interp"): 

459 name = interp.attrib.get("{http://www.w3.org/XML/1998/namespace}id", "") 

460 if name in categories_to_ignore: continue 

461 

462 description = extract_text(interp).strip() 

463 relation_types[name] = RelationType(name=name, element=interp, description=description) 

464 

465 # get corresponding relations 

466 for category in relation_types.values(): 

467 inverse_name = category.element.attrib.get("corresp", "") 

468 if inverse_name.startswith("#"): 

469 inverse_name = inverse_name[1:] 

470 

471 if inverse_name in relation_types: 

472 inverse = relation_types[inverse_name] 

473 category.inverse = inverse 

474 if inverse.inverse is None: 

475 inverse.inverse = category 

476 else: 

477 assert inverse.inverse == category, f"Inverse category {inverse} already has an inverse {inverse.inverse}." 

478 

479 return relation_types 

480 

481 def get_classified_pairs(self, redundant:bool=True) -> list[Pair]: 

482 pairs = [] 

483 for app in self.apps: 

484 pairs.extend(app.get_classified_pairs(redundant=redundant)) 

485 

486 return pairs 

487 

488 def get_unclassified_pairs(self, redundant:bool=True) -> list[Pair]: 

489 pairs = [] 

490 for app in self.apps: 

491 pairs.extend(app.get_unclassified_pairs(redundant=redundant)) 

492 

493 return pairs 

494 

495 def print_classified_pairs(self, console:Console|None=None) -> None: 

496 console = console or Console() 

497 for relation_type in self.relation_types.values(): 

498 console.rule(str(relation_type)) 

499 console.print(relation_type.description, style="grey46") 

500 for pair in relation_type.pairs_sorted(): 

501 pair.print(console) 

502 

503 console.print("") 

504 

505 def render_html(self, output:Path|None=None, all_apps:bool=False) -> str: 

506 from flask import Flask, request, render_template 

507 

508 mapper = Mapper() 

509 app = Flask(__name__) 

510 

511 with app.app_context(): 

512 html = render_template('server.html', doc=self, mapper=mapper, all_apps=all_apps) 

513 

514 if output: 

515 output.parent.mkdir(parents=True, exist_ok=True) 

516 output.write_text(html) 

517 

518 return html 

519 

520 def flask_app(self, output:Path, all_apps:bool=False): 

521 from flask import Flask, request, render_template 

522 

523 self.write(output) 

524 mapper = Mapper() 

525 

526 app = Flask(__name__) 

527 

528 @app.route("/") 

529 def root(): 

530 return render_template('server.html', doc=self, mapper=mapper, all_apps=all_apps) 

531 

532 @app.route("/api/relation-type", methods=['POST']) 

533 def api_relation_type(): 

534 data = request.get_json() 

535 

536 relation_type = mapper.obj(data['relation_type']) 

537 assert isinstance(relation_type, RelationType), f"Expected RelationType, got {type(relation_type)}" 

538 

539 pair = mapper.obj(data['pair']) 

540 assert isinstance(pair, Pair), f"Expected Pair, got {type(pair)}" 

541 

542 try: 

543 if data['operation'] == 'remove': 

544 print('remove', relation_type) 

545 pair.remove_type_with_inverse(relation_type) 

546 elif data['operation'] == 'add': 

547 print('add', relation_type) 

548 pair.add_type_with_inverse(relation_type) 

549 else: 

550 raise ValueError(f"Unknown operation {data['operation']}") 

551 

552 print('write', output) 

553 self.write(output) 

554 return "Success", 200 

555 except Exception as e: 

556 print(str(e)) 

557 return str(e), 400 

558 

559 return "Failed", 400 

560 

561 @app.route("/api/desc", methods=['POST']) 

562 def desc(): 

563 data = request.get_json() 

564 

565 pair = mapper.obj(data['pair']) 

566 assert isinstance(pair, Pair), f"Expected Pair, got {type(pair)}" 

567 

568 try: 

569 if data['operation'] == 'remove': 

570 pair.remove_description() 

571 elif data['operation'] == 'add': 

572 pair.add_description(data['description']) 

573 else: 

574 raise ValueError(f"Unknown operation {data['operation']}") 

575 

576 print('write', output) 

577 self.write(output) 

578 return "Success", 200 

579 except Exception as e: 

580 print(str(e)) 

581 return str(e), 400 

582 

583 return "Failed", 400 

584 

585 return app 

586 # app.run(debug=True, use_reloader=True) 

587 

588 def clean(self, output:Path|None=None): 

589 """ Cleans a TEI XML file for common errors. """ 

590 

591 # find all listRelation elements 

592 list_relations = find_elements(self.tree, ".//listRelation") 

593 for list_relation in list_relations: 

594 relations_so_far = set() 

595 for relation in find_elements(list_relation, ".//relation"): 

596 # make sure that relation elements have a # at the start of the ana attribute 

597 if not relation.attrib['ana'].startswith("#"): 

598 relation.attrib['ana'] = f"#{relation.attrib['ana']}" 

599 

600 relations_so_far.add( (relation.attrib['active'], relation.attrib['passive']) ) 

601 

602 # consolidate duplicate relations 

603 for active, passive in relations_so_far: 

604 relations = find_elements(list_relation, f".//relation[@active='{active}'][@passive='{passive}']") 

605 if len(relations) > 1: 

606 analytic_set = set() 

607 for relation in relations: 

608 analytic_set.update(relation.attrib['ana'].split()) 

609 

610 for relation in relations[1:]: 

611 list_relation.remove(relation) 

612 

613 relations[0].attrib['ana'] = " ".join(sorted(analytic_set)) 

614 

615 if output: 

616 output = Path(output) 

617 output.parent.mkdir(parents=True, exist_ok=True) 

618 print("Writing to", output) 

619 self.write(output)