Coverage for rdgai/classification.py: 100.00%

38 statements  

« prev     ^ index     » next       coverage.py v7.6.4, created at 2025-01-03 01:37 +0000

1from pathlib import Path 

2from langchain.schema.output_parser import StrOutputParser 

3from langchain_core.language_models.llms import LLM 

4import llmloader 

5from rich.console import Console 

6from rich.progress import track 

7 

8from .prompts import build_template 

9from .parsers import CategoryParser 

10from .apparatus import Doc, Pair 

11 

12 

13DEFAULT_MODEL_ID = "gpt-4o" 

14 

15 

16def classify_pair( 

17 doc:Doc, 

18 pair:Pair, 

19 llm:LLM, 

20 output:Path, 

21 verbose:bool=False, 

22 prompt_only:bool=False, 

23 examples:int=10, 

24 console:Console|None=None, 

25): 

26 """ 

27 Classifies relations for a pair of readings. 

28 """ 

29 assert isinstance(doc, Doc), f"Expected Doc, got {type(doc)}" 

30 

31 console = console or Console() 

32 

33 template = build_template(pair, examples=examples) 

34 if verbose or prompt_only: 

35 template.pretty_print() 

36 if prompt_only: 

37 return 

38 

39 chain = template | llm.bind(stop=["----"]) | StrOutputParser() | CategoryParser(doc.relation_types.keys()) 

40 

41 doc.write(output) 

42 

43 category, description = chain.invoke({}) 

44 

45 console.print() 

46 pair.print(console) 

47 console.print(category, style="green bold") 

48 console.print(description, style="grey46") 

49 

50 relation_type = doc.relation_types.get(category, None) 

51 if relation_type is None: 

52 return 

53 

54 inverse_description = f"c.f. {pair.active} ➞ {pair.passive}" 

55 pair.add_type_with_inverse(relation_type, responsible="#rdgai", description=description, inverse_description=inverse_description) 

56 

57 doc.write(output) 

58 

59 

60def classify( 

61 doc:Doc, 

62 output:Path, 

63 pairs:list[Pair]|None=None, 

64 verbose:bool=False, 

65 api_key:str="", 

66 llm:str=DEFAULT_MODEL_ID, 

67 temperature:float=0.1, 

68 prompt_only:bool=False, 

69 examples:int=10, 

70 console:Console|None=None, 

71): 

72 """ 

73 Classifies relations in TEI documents. 

74 """ 

75 assert isinstance(doc, Doc), f"Expected Doc, got {type(doc)}" 

76 

77 console = console or Console() 

78 llm = llmloader.load(model=llm, api_key=api_key, temperature=temperature) 

79 

80 pairs = pairs or doc.get_unclassified_pairs(redundant=False) 

81 for pair in track(pairs): 

82 classify_pair(doc, pair, llm, output, verbose=verbose, prompt_only=prompt_only, examples=examples, console=console) 

83 

84