Coverage for rdgai/parsers.py: 100.00%

22 statements  

« prev     ^ index     » next       coverage.py v7.6.4, created at 2025-01-03 01:37 +0000

1from dataclasses import dataclass 

2from langchain_core.runnables import Runnable 

3 

4@dataclass 

5class CategoryParser(Runnable): 

6 relation_type_names: list[str] 

7 

8 def invoke(self, llm_output:str, *args, **kwargs) -> tuple[str, str]: 

9 """ 

10 Parses the output of a language model to category and justification. 

11 """ 

12 llm_output = llm_output.strip() 

13 if "-----" in llm_output: 

14 llm_output = llm_output[:llm_output.find("-----")] 

15 

16 category = "" 

17 justification = "" 

18 minimim_index = len(llm_output) 

19 for relation_type in self.relation_type_names: 

20 if relation_type in llm_output: 

21 index = llm_output.find(relation_type) 

22 if index < minimim_index: 

23 minimim_index = index 

24 category = relation_type 

25 justification = llm_output[index + 1:] 

26 justification_index = justification.find("\n") 

27 justification = justification[justification_index + 1:].strip() 

28 

29 return category, justification 

30