Coverage for rdgai/apparatus.py: 100.00%
463 statements
« prev ^ index » next coverage.py v7.6.4, created at 2025-01-03 01:37 +0000
« prev ^ index » next coverage.py v7.6.4, created at 2025-01-03 01:37 +0000
1from typing import Optional
2from pathlib import Path
3from dataclasses import dataclass, field
4from lxml.etree import _Element as Element
5from lxml.etree import _ElementTree as ElementTree
6from lxml import etree as ET
7from rich.console import Console
8import functools
9import Levenshtein
10import numpy as np
12# from .relations import Relation, get_reading_identifier
13from .tei import read_tei, find_elements, extract_text, find_parent, find_element, write_tei, make_nc_name, get_language, get_reading_identifier
14from .mapper import Mapper
16@dataclass
17class Reading():
18 element: Element
19 app:"App"
20 n: str = field(default=None)
21 text: str = field(default=None)
22 witnesses: list[str] = field(default_factory=list)
24 def __post_init__(self):
25 self.n = get_reading_identifier(self.element)
26 self.text = extract_text(self.element).strip()
27 self.witnesses = self.element.attrib.get("wit", "").split()
29 def __str__(self):
30 return self.text or 'OMIT'
32 def witnesses_str(self) -> str:
33 return " ".join(self.witnesses)
35 def __hash__(self):
36 return hash(self.element)
39@dataclass
40class RelationType():
41 element: Element
42 name: str
43 description: str
44 inverse: Optional['RelationType'] = None
45 pairs: set['Pair'] = field(default_factory=set)
47 def __str__(self):
48 return self.name
50 def __repr__(self) -> str:
51 return str(self)
53 def __eq__(self, other):
54 if isinstance(other, RelationType):
55 return (self.name, self.element, self.description) == (other.name, other.element, other.description)
56 return False
58 def __hash__(self):
59 return hash((self.name, self.element, self.description))
61 def str_with_description(self) -> str:
62 result = self.name
63 if self.description:
64 result += f": {self.description}"
65 return result
67 def pairs_sorted(self, exclude_rdgai:bool = False) -> list['Pair']:
68 pairs = sorted(self.pairs, key=lambda pair: (str(pair.active.app), pair.active.n, pair.passive.n))
69 if exclude_rdgai:
70 pairs = [pair for pair in pairs if not pair.rdgai_responsible()]
71 return pairs
73 def get_inverse(self) -> 'RelationType':
74 return self.inverse if self.inverse else self
76 @functools.lru_cache(maxsize=None)
77 def representative_examples(self, k:int, random_state:int=42) -> list['Pair']:
79 def find_representative_examples(pairs_list:list[Pair], k:int, random_state:int=42):
80 import kmedoids
81 if len(pairs_list) <= k:
82 return pairs_list
83 distance_matrix = np.zeros((len(pairs_list), len(pairs_list)))
84 for index1, pair in enumerate(pairs_list):
85 for index2 in range(index1+1, len(pairs_list)):
86 other_pair = pairs_list[index2]
87 active_text_distance = Levenshtein.distance(pair.active.text, other_pair.active.text)
88 passive_text_distance = Levenshtein.distance(pair.passive.text, other_pair.passive.text)
89 distance = active_text_distance + passive_text_distance
90 distance_matrix[index1, index2] = distance
91 distance_matrix[index2, index1] = distance
93 result = kmedoids.fasterpam(distance_matrix, k, random_state=random_state, init="build")
95 return [pairs_list[index] for index in result.medoids]
97 pairs_list = self.pairs_sorted(exclude_rdgai=True)
98 pairs_with_descriptions = [pair for pair in pairs_list if pair.has_description()]
99 representative_pairs = []
100 if pairs_with_descriptions:
101 representative_pairs = find_representative_examples(pairs_with_descriptions, k, random_state=random_state)
103 if len(representative_pairs) < k:
104 pairs_without_descriptions = [pair for pair in pairs_list if not pair.has_description()]
105 additional_pairs = find_representative_examples(pairs_without_descriptions, k-len(representative_pairs), random_state=random_state)
106 representative_pairs.extend(additional_pairs)
108 return representative_pairs
111@dataclass
112class Pair():
113 active: Reading
114 passive: Reading
115 types: set[RelationType] = field(default_factory=set)
117 def __post_init__(self):
118 for relation_type in self.types:
119 relation_type.pairs.add(self)
121 def __str__(self):
122 return f"{self.active} ➞ {self.passive}"
124 def print(self, console):
125 console.print(f"[bold red]{self.app}[/bold red]: [green]{self.active}[/green] [red]➞[/red] [green]{self.passive}[/green]")
127 def reading_transition_str(self) -> str:
128 return f"{self.active or 'OMISSION'} → {self.passive or 'OMISSION'}"
130 def __repr__(self) -> str:
131 return str(self)
133 @property
134 def app(self) -> "App":
135 # assert self.active.app == self.passive.app
136 return self.active.app
138 def __hash__(self):
139 return hash((self.active, self.passive))
141 def app_element(self) -> Element:
142 return find_parent(self.active.element, "app")
144 def relation_elements(self) -> list[Element]:
145 list_relation = find_element(self.app_element(), ".//listRelation[@type='transcriptional']")
146 if list_relation is None:
147 return []
149 return find_elements(list_relation, f".//relation[@active='{self.active.n}'][@passive='{self.passive.n}']")
151 def element_for_type(self, type:RelationType) -> Element|None:
152 for relation in self.relation_elements():
153 if f"#{type.name}" in relation.attrib.get("ana").split():
154 return relation
155 return None
157 def get_inverse(self) -> "Pair":
158 found_pair = None
159 for pair in self.app.pairs:
160 if pair.active == self.passive and pair.passive == self.active:
161 found_pair = pair
162 break
163 assert found_pair is not None, f"No inverse pair found for {self}"
164 return found_pair
166 def add_type_with_inverse(self, type:RelationType, responsible:str|None=None, description:str="", inverse_description:str="") -> Element:
167 relation = self.add_type(type, responsible=responsible, description=description)
168 inverse = self.get_inverse()
169 inverse.add_type(type.get_inverse(), responsible=responsible, description=inverse_description)
170 return relation
172 def add_type(self, type:RelationType, responsible:str|None=None, description:str="") -> Element:
173 self.types.add(type)
174 type.pairs.add(self)
176 # Check if the relation already exists
177 relation = self.element_for_type(type)
178 if relation is not None:
179 return relation
181 list_relation = find_element(self.app_element(), ".//listRelation[@type='transcriptional']")
182 if list_relation is None:
183 list_relation = ET.SubElement(self.app_element(), "listRelation", attrib={"type":"transcriptional"})
185 relation = find_element(list_relation, f".//relation[@active='{self.active.n}'][@passive='{self.passive.n}']")
186 if relation is not None:
187 if type.name not in relation.attrib.get("ana").split():
188 relation.attrib["ana"] += f" #{type.name}"
189 else:
190 relation = ET.SubElement(list_relation, "relation", attrib={"active":self.active.n, "passive":self.passive.n, "ana":f"#{type.name}"})
192 if responsible is not None:
193 relation.set("resp", responsible)
195 self.add_description(description, relation)
197 return relation
199 def remove_description(self):
200 for relation in self.relation_elements():
201 for desc in find_elements(relation, ".//desc"):
202 relation.remove(desc)
204 def add_description(self, description:str, relation:Element|None=None):
205 if relation is None:
206 relation_elements = self.relation_elements()
208 if len(relation_elements) == 0:
209 list_relation = find_element(self.app_element(), ".//listRelation[@type='transcriptional']")
210 if list_relation is None:
211 list_relation = ET.SubElement(self.app_element(), "listRelation", attrib={"type":"transcriptional"})
212 relation = ET.SubElement(list_relation, "relation", attrib={"active":self.active.n, "passive":self.passive.n})
213 else:
214 relation = relation_elements[0]
216 description = description.strip()
217 if description:
218 description_element = find_element(relation, ".//desc")
219 if description_element is None:
220 description_element = ET.SubElement(relation, "desc")
222 description_element.text = description
224 def remove_type(self, relation_type:RelationType):
225 if relation_type in self.types:
226 self.types.remove(relation_type)
228 if self in relation_type.pairs:
229 relation_type.pairs.remove(self)
231 list_relation = find_element(self.app_element(), ".//listRelation[@type='transcriptional']")
232 for relation in find_elements(list_relation, f".//relation[@active='{self.active.n}'][@passive='{self.passive.n}']"):
233 if f"#{relation_type.name}" in relation.attrib.get("ana").split():
234 relation.attrib['ana'] = " ".join([ana for ana in relation.attrib.get("ana").split() if ana != f"#{relation_type.name}"])
235 if not relation.attrib.get("ana"):
236 relation.getparent().remove(relation)
238 def remove_type_with_inverse(self, relation_type:RelationType):
239 self.remove_type(relation_type)
240 inverse = self.get_inverse()
241 inverse.remove_type(relation_type.get_inverse())
243 def remove_all_types(self):
244 for relation_type in set(self.types):
245 self.remove_type_with_inverse(relation_type)
247 def rdgai_responsible(self) -> bool:
248 for element in self.relation_elements():
249 if element.attrib.get('resp', '') == '#rdgai':
250 return True
251 return False
253 def relation_type_names(self) -> set[str]:
254 return set(type.name for type in self.types)
256 def has_description(self) -> bool:
257 for relation in self.relation_elements():
258 if find_element(relation, ".//desc") is not None:
259 return True
260 return False
262 def get_description(self) -> str:
263 description = ""
264 for relation_element in self.relation_elements():
265 for desc in find_elements(relation_element, ".//desc"):
266 description += "\n" + extract_text(desc)
267 return description.strip()
270@dataclass
271class App():
272 element: Element
273 doc: "Doc"
274 readings: list[Reading] = field(default_factory=list)
275 pairs: list[Pair] = field(default_factory=list)
276 non_redundant_pairs: list[Pair] = field(default_factory=list)
278 def __post_init__(self):
279 for reading in find_elements(self.element, ".//rdg"):
280 self.readings.append(Reading(reading, app=self))
282 # Build list of relation elements
283 relation_elements = []
284 for list_relation in find_elements(self.element, ".//listRelation[@type='transcriptional']"):
285 for relation_element in find_elements(list_relation, ".//relation"):
286 relation_elements.append(relation_element)
288 # Build list of relation pairs
289 active_visited = set()
290 for active in self.readings:
291 active_visited.add(active)
292 for passive in self.readings:
293 if active == passive:
294 continue
296 types = set()
297 for relation_element in relation_elements:
298 if relation_element.attrib.get("active") == active.n and relation_element.attrib.get("passive") == passive.n:
299 for ana in relation_element.attrib.get("ana", "").split():
300 if ana.startswith("#"):
301 ana = ana[1:]
302 if ana:
303 types.add(ana)
305 pair_relation_types = set()
306 for type_name in types:
307 relation_type = self.doc.relation_types[type_name] if type_name in self.doc.relation_types else self.doc.add_relation_type(type_name)
308 pair_relation_types.add(relation_type)
310 pair = Pair(active=active, passive=passive, types=pair_relation_types)
311 self.pairs.append(pair)
312 if passive not in active_visited:
313 self.non_redundant_pairs.append(pair)
315 for relation_type in pair_relation_types:
316 relation_type.pairs.add(pair)
318 assert len(self.pairs) == len(self.non_redundant_pairs) * 2
320 def get_classified_pairs(self, redundant:bool=True) -> list[Pair]:
321 pairs = self.pairs if redundant else self.non_redundant_pairs
322 return [pair for pair in pairs if len(pair.types) > 0]
324 def get_unclassified_pairs(self, redundant:bool=True) -> list[Pair]:
325 pairs = self.pairs if redundant else self.non_redundant_pairs
326 return [pair for pair in pairs if len(pair.types) == 0]
328 def __hash__(self):
329 return hash(self.element)
331 def __str__(self):
332 name = self.element.attrib.get('{http://www.w3.org/XML/1998/namespace}id', '')
333 if not name:
334 name = self.element.attrib.get('n', '')
336 if not name:
337 ab = self.ab()
338 for index, app in enumerate(find_elements(ab, ".//app")):
339 if app == self.element:
340 name = make_nc_name(f"{self.ab_name()}-{index+1}")
341 self.element.attrib['{http://www.w3.org/XML/1998/namespace}id'] = name
342 break
343 return str(name).replace(" ", "_").replace(":", "_")
345 def ab(self) -> Element|None:
346 return find_parent(self.element, "ab")
348 def ab_name(self) -> str:
349 ab = self.ab()
350 return ab.attrib.get("n", "")
352 def text_before(self) -> str:
353 ab = self.ab()
354 if ab is None:
355 return ""
357 items = []
358 for child in ab:
359 if child == self.element:
360 break
361 child_text = extract_text(child)
362 if child_text:
363 items.append(child_text)
365 text = " ".join(items)
366 return text.strip()
368 def text_in_context(self, text="") -> str:
369 return f"{self.text_before()} {self.text_with_signs(text)} {self.text_after()}".strip()
371 def text(self) -> str:
372 return extract_text(self.element)
374 def text_with_signs(self, text="") -> str:
375 text = text or self.text()
376 if not text:
377 return "⸆"
378 return f"⸂{text}⸃"
380 def text_after(self) -> str:
381 ab = self.ab()
382 if ab is None:
383 return ""
385 items = []
386 reached_element = False
387 for child in ab:
388 if reached_element:
389 child_text = extract_text(child)
390 if child_text:
391 items.append(child_text)
392 if child == self.element:
393 reached_element = True
395 text = " ".join(items)
396 return text.strip()
399@dataclass
400class Doc():
401 path: Path
402 tree: ElementTree = field(default=None)
403 apps: list[App] = field(default_factory=list)
404 relation_types: dict[str,RelationType] = field(default_factory=dict)
406 def __post_init__(self):
407 self.tree = read_tei(self.path)
408 self.relation_types = self.get_relation_types()
410 for app_element in find_elements(self.tree, ".//app"):
411 app = App(app_element, doc=self)
412 self.apps.append(app)
414 def get_interpgrp(self) -> Element:
415 text = find_element(self.tree, "text")
416 interp_group = find_element(text, ".//interpGrp[@type='transcriptional']")
417 if interp_group is None:
418 interp_group = ET.Element("interpGrp", attrib={"type":"transcriptional"})
419 text.insert(0, interp_group)
421 return interp_group
423 def add_relation_type(self, name:str, description:str="") -> RelationType:
424 if name in self.relation_types:
425 assert self.relation_types[name].description == description, f"RelationType {name} already exists with a different description."
426 return self.relation_types[name]
428 interp_group = self.get_interpgrp()
429 interp = find_element(interp_group, f".//interp[@xml:id='{name}']")
430 if interp is None:
431 interp = ET.Element("interp", attrib={"{http://www.w3.org/XML/1998/namespace}id":name})
432 interp_group.append(interp)
434 relation_type = RelationType(name=name, element=interp, description="")
435 self.relation_types[name] = relation_type
436 return relation_type
438 def __str__(self):
439 return str(self.path)
441 def __repr__(self) -> str:
442 return str(self)
444 def write(self, output:str|Path):
445 write_tei(self.tree, output)
447 @property
448 def language(self):
449 return get_language(self.tree)
451 def get_relation_types(self, categories_to_ignore:list[str]|None=None) -> list[RelationType]:
452 interp_group = self.get_interpgrp()
453 categories_to_ignore = categories_to_ignore or []
455 relation_types = dict()
456 assert interp_group is not None, "No interpGrp of type='transcriptional' found in TEI file."
458 for interp in find_elements(interp_group, "./interp"):
459 name = interp.attrib.get("{http://www.w3.org/XML/1998/namespace}id", "")
460 if name in categories_to_ignore: continue
462 description = extract_text(interp).strip()
463 relation_types[name] = RelationType(name=name, element=interp, description=description)
465 # get corresponding relations
466 for category in relation_types.values():
467 inverse_name = category.element.attrib.get("corresp", "")
468 if inverse_name.startswith("#"):
469 inverse_name = inverse_name[1:]
471 if inverse_name in relation_types:
472 inverse = relation_types[inverse_name]
473 category.inverse = inverse
474 if inverse.inverse is None:
475 inverse.inverse = category
476 else:
477 assert inverse.inverse == category, f"Inverse category {inverse} already has an inverse {inverse.inverse}."
479 return relation_types
481 def get_classified_pairs(self, redundant:bool=True) -> list[Pair]:
482 pairs = []
483 for app in self.apps:
484 pairs.extend(app.get_classified_pairs(redundant=redundant))
486 return pairs
488 def get_unclassified_pairs(self, redundant:bool=True) -> list[Pair]:
489 pairs = []
490 for app in self.apps:
491 pairs.extend(app.get_unclassified_pairs(redundant=redundant))
493 return pairs
495 def print_classified_pairs(self, console:Console|None=None) -> None:
496 console = console or Console()
497 for relation_type in self.relation_types.values():
498 console.rule(str(relation_type))
499 console.print(relation_type.description, style="grey46")
500 for pair in relation_type.pairs_sorted():
501 pair.print(console)
503 console.print("")
505 def render_html(self, output:Path|None=None, all_apps:bool=False) -> str:
506 from flask import Flask, request, render_template
508 mapper = Mapper()
509 app = Flask(__name__)
511 with app.app_context():
512 html = render_template('server.html', doc=self, mapper=mapper, all_apps=all_apps)
514 if output:
515 output.parent.mkdir(parents=True, exist_ok=True)
516 output.write_text(html)
518 return html
520 def flask_app(self, output:Path, all_apps:bool=False):
521 from flask import Flask, request, render_template
523 self.write(output)
524 mapper = Mapper()
526 app = Flask(__name__)
528 @app.route("/")
529 def root():
530 return render_template('server.html', doc=self, mapper=mapper, all_apps=all_apps)
532 @app.route("/api/relation-type", methods=['POST'])
533 def api_relation_type():
534 data = request.get_json()
536 relation_type = mapper.obj(data['relation_type'])
537 assert isinstance(relation_type, RelationType), f"Expected RelationType, got {type(relation_type)}"
539 pair = mapper.obj(data['pair'])
540 assert isinstance(pair, Pair), f"Expected Pair, got {type(pair)}"
542 try:
543 if data['operation'] == 'remove':
544 print('remove', relation_type)
545 pair.remove_type_with_inverse(relation_type)
546 elif data['operation'] == 'add':
547 print('add', relation_type)
548 pair.add_type_with_inverse(relation_type)
549 else:
550 raise ValueError(f"Unknown operation {data['operation']}")
552 print('write', output)
553 self.write(output)
554 return "Success", 200
555 except Exception as e:
556 print(str(e))
557 return str(e), 400
559 return "Failed", 400
561 @app.route("/api/desc", methods=['POST'])
562 def desc():
563 data = request.get_json()
565 pair = mapper.obj(data['pair'])
566 assert isinstance(pair, Pair), f"Expected Pair, got {type(pair)}"
568 try:
569 if data['operation'] == 'remove':
570 pair.remove_description()
571 elif data['operation'] == 'add':
572 pair.add_description(data['description'])
573 else:
574 raise ValueError(f"Unknown operation {data['operation']}")
576 print('write', output)
577 self.write(output)
578 return "Success", 200
579 except Exception as e:
580 print(str(e))
581 return str(e), 400
583 return "Failed", 400
585 return app
586 # app.run(debug=True, use_reloader=True)
588 def clean(self, output:Path|None=None):
589 """ Cleans a TEI XML file for common errors. """
591 # find all listRelation elements
592 list_relations = find_elements(self.tree, ".//listRelation")
593 for list_relation in list_relations:
594 relations_so_far = set()
595 for relation in find_elements(list_relation, ".//relation"):
596 # make sure that relation elements have a # at the start of the ana attribute
597 if not relation.attrib['ana'].startswith("#"):
598 relation.attrib['ana'] = f"#{relation.attrib['ana']}"
600 relations_so_far.add( (relation.attrib['active'], relation.attrib['passive']) )
602 # consolidate duplicate relations
603 for active, passive in relations_so_far:
604 relations = find_elements(list_relation, f".//relation[@active='{active}'][@passive='{passive}']")
605 if len(relations) > 1:
606 analytic_set = set()
607 for relation in relations:
608 analytic_set.update(relation.attrib['ana'].split())
610 for relation in relations[1:]:
611 list_relation.remove(relation)
613 relations[0].attrib['ana'] = " ".join(sorted(analytic_set))
615 if output:
616 output = Path(output)
617 output.parent.mkdir(parents=True, exist_ok=True)
618 print("Writing to", output)
619 self.write(output)