Coverage for rdgai/parsers.py: 100.00%
22 statements
« prev ^ index » next coverage.py v7.6.4, created at 2025-01-03 01:37 +0000
« prev ^ index » next coverage.py v7.6.4, created at 2025-01-03 01:37 +0000
1from dataclasses import dataclass
2from langchain_core.runnables import Runnable
4@dataclass
5class CategoryParser(Runnable):
6 relation_type_names: list[str]
8 def invoke(self, llm_output:str, *args, **kwargs) -> tuple[str, str]:
9 """
10 Parses the output of a language model to category and justification.
11 """
12 llm_output = llm_output.strip()
13 if "-----" in llm_output:
14 llm_output = llm_output[:llm_output.find("-----")]
16 category = ""
17 justification = ""
18 minimim_index = len(llm_output)
19 for relation_type in self.relation_type_names:
20 if relation_type in llm_output:
21 index = llm_output.find(relation_type)
22 if index < minimim_index:
23 minimim_index = index
24 category = relation_type
25 justification = llm_output[index + 1:]
26 justification_index = justification.find("\n")
27 justification = justification[justification_index + 1:].strip()
29 return category, justification