Coverage for rdgai/classification.py: 100.00%
38 statements
« prev ^ index » next coverage.py v7.6.4, created at 2025-01-03 01:37 +0000
« prev ^ index » next coverage.py v7.6.4, created at 2025-01-03 01:37 +0000
1from pathlib import Path
2from langchain.schema.output_parser import StrOutputParser
3from langchain_core.language_models.llms import LLM
4import llmloader
5from rich.console import Console
6from rich.progress import track
8from .prompts import build_template
9from .parsers import CategoryParser
10from .apparatus import Doc, Pair
13DEFAULT_MODEL_ID = "gpt-4o"
16def classify_pair(
17 doc:Doc,
18 pair:Pair,
19 llm:LLM,
20 output:Path,
21 verbose:bool=False,
22 prompt_only:bool=False,
23 examples:int=10,
24 console:Console|None=None,
25):
26 """
27 Classifies relations for a pair of readings.
28 """
29 assert isinstance(doc, Doc), f"Expected Doc, got {type(doc)}"
31 console = console or Console()
33 template = build_template(pair, examples=examples)
34 if verbose or prompt_only:
35 template.pretty_print()
36 if prompt_only:
37 return
39 chain = template | llm.bind(stop=["----"]) | StrOutputParser() | CategoryParser(doc.relation_types.keys())
41 doc.write(output)
43 category, description = chain.invoke({})
45 console.print()
46 pair.print(console)
47 console.print(category, style="green bold")
48 console.print(description, style="grey46")
50 relation_type = doc.relation_types.get(category, None)
51 if relation_type is None:
52 return
54 inverse_description = f"c.f. {pair.active} ➞ {pair.passive}"
55 pair.add_type_with_inverse(relation_type, responsible="#rdgai", description=description, inverse_description=inverse_description)
57 doc.write(output)
60def classify(
61 doc:Doc,
62 output:Path,
63 pairs:list[Pair]|None=None,
64 verbose:bool=False,
65 api_key:str="",
66 llm:str=DEFAULT_MODEL_ID,
67 temperature:float=0.1,
68 prompt_only:bool=False,
69 examples:int=10,
70 console:Console|None=None,
71):
72 """
73 Classifies relations in TEI documents.
74 """
75 assert isinstance(doc, Doc), f"Expected Doc, got {type(doc)}"
77 console = console or Console()
78 llm = llmloader.load(model=llm, api_key=api_key, temperature=temperature)
80 pairs = pairs or doc.get_unclassified_pairs(redundant=False)
81 for pair in track(pairs):
82 classify_pair(doc, pair, llm, output, verbose=verbose, prompt_only=prompt_only, examples=examples, console=console)