Coverage for hierarchicalsoftmax/treedict.py: 100.00%
180 statements
« prev ^ index » next coverage.py v7.8.0, created at 2025-07-02 01:49 +0000
« prev ^ index » next coverage.py v7.8.0, created at 2025-07-02 01:49 +0000
1import copy
2from pathlib import Path
3from attrs import define, field
4from collections import UserDict
5import pickle
6from collections import Counter
7from rich.progress import track
8import typer
10from .nodes import SoftmaxNode
13@define
14class NodeDetail:
15 """
16 Stores metadata for a key in the TreeDict.
18 Attributes:
19 partition (int): The partition ID this key belongs to.
20 node (SoftmaxNode): The node in the classification tree associated with the key.
21 node_id (int): The index of the node in the tree (used during pickling).
22 """
23 partition:int
24 node:SoftmaxNode = field(default=None, eq=False)
25 node_id:int = None
27 def __getstate__(self):
28 return (self.partition, self.node_id)
30 def __setstate__(self, state):
31 self.partition, self.node_id = state
32 self.node = None
35class AlreadyExists(Exception):
36 pass
39class TreeDict(UserDict):
40 def __init__(self, classification_tree:SoftmaxNode|None=None):
41 """
42 Initialize a TreeDict.
44 Args:
45 classification_tree (SoftmaxNode, optional): The root of the classification tree.
46 If not provided, a new root node named "root" will be created.
47 """
48 super().__init__()
49 self.classification_tree = classification_tree or SoftmaxNode("root")
51 def add(self, key:str, node:SoftmaxNode, partition:int) -> NodeDetail:
52 """
53 Associate a key with a node and a partition.
55 Args:
56 key (str): The unique identifier for the item.
57 node (SoftmaxNode): The node in the classification tree to associate with the key.
58 partition (int): The partition index for the key.
60 Raises:
61 AlreadyExists: If the key already exists with a different node.
63 Returns:
64 NodeDetail: The metadata object for the added key.
65 """
66 assert node.root == self.classification_tree
67 if key in self:
68 old_node = self.node(key)
69 if not node == old_node:
70 raise AlreadyExists(f"Accession {key} already exists in TreeDict at node {self.node(key)}. Cannot change to {node}")
72 detail = NodeDetail(
73 partition=partition,
74 node=node,
75 )
76 self[key] = detail
77 return detail
79 def set_indexes(self):
80 """
81 Ensure the tree has assigned node indexes, and record the node_id for each key.
82 """
83 self.classification_tree.set_indexes_if_unset()
84 for detail in self.values():
85 if detail.node:
86 detail.node_id = self.classification_tree.node_to_id[detail.node]
88 def save(self, path:Path):
89 """
90 Save the TreeDict to a pickle file.
92 Args:
93 path (Path): The file path to save the TreeDict.
94 """
95 path = Path(path)
96 path.parent.mkdir(exist_ok=True, parents=True)
98 self.set_indexes()
99 with open(path, 'wb') as handle:
100 pickle.dump(self, handle, protocol=pickle.HIGHEST_PROTOCOL)
102 @classmethod
103 def load(self, path:Path):
104 """
105 Load a TreeDict from a pickle file.
107 Args:
108 path (Path): The path to the serialized TreeDict.
110 Returns:
111 TreeDict: The loaded TreeDict instance.
112 """
113 with open(path, 'rb') as handle:
114 return pickle.load(handle)
116 def node(self, key:str):
117 """
118 Retrieve the node associated with a key.
120 Args:
121 key (str): The key for which to retrieve the node.
123 Returns:
124 SoftmaxNode: The node corresponding to the key.
125 """
126 detail = self[key]
127 if detail.node is not None:
128 return detail.node
129 return self.classification_tree.node_list[detail.node_id]
131 def keys_in_partition(self, partition:int):
132 """
133 Yield all keys that belong to a given partition.
135 Args:
136 partition (int): The partition to filter by.
138 Yields:
139 str: Keys in the specified partition.
140 """
141 for key, detail in self.items():
142 if detail.partition == partition:
143 yield key
145 def keys(self, partition:int|None = None):
146 """
147 Return keys in the TreeDict, optionally filtering by partition.
149 Args:
150 partition (int | None): The partition to filter keys by. If None, return all keys.
152 Returns:
153 Iterator[str]: An iterator over the keys.
154 """
155 return super().keys() if partition is None else self.keys_in_partition(partition)
157 def truncate(self, max_depth:int) -> "TreeDict":
158 """
159 Truncate the classification tree to a specified maximum depth and return a new TreeDict.
161 Keys deeper than the depth limit will be reassigned to the ancestor node at that depth.
163 Args:
164 max_depth (int): The maximum number of ancestor levels to keep.
166 Returns:
167 TreeDict: A new truncated TreeDict.
168 """
169 self.classification_tree.set_indexes_if_unset()
170 classification_tree = copy.deepcopy(self.classification_tree)
171 new_treedict = TreeDict(classification_tree)
172 for key in track(self.keys()):
173 original_node = self.node(key)
174 node_id = self.classification_tree.node_to_id[original_node]
175 node = classification_tree.node_list[node_id]
177 ancestors = node.ancestors
178 if len(ancestors) >= max_depth:
179 node = ancestors[max_depth-1]
180 new_treedict.add(key, node, self[key].partition)
182 # Remove any nodes that beyond the max depth
183 for node in new_treedict.classification_tree.pre_order_iter():
184 node.readonly = False
185 node.softmax_start_index = None
186 if len(node.ancestors) >= max_depth:
187 node.parent = None
189 new_treedict.set_indexes()
191 return new_treedict
193 def add_counts(self):
194 """
195 Count the number of keys assigned to each node, and store the count in each node.
196 """
197 for node in self.classification_tree.post_order_iter():
198 node.count = 0
200 for key in self.keys():
201 node = self.node(key)
202 node.count += 1
204 def add_partition_counts(self):
205 """
206 Count the number of keys in each partition per node and store it in the node.
207 """
208 for node in self.classification_tree.post_order_iter():
209 node.partition_counts = Counter()
211 for key, detail in self.items():
212 node = self.node(key)
213 partition = detail.partition
214 node.partition_counts[partition] += 1
216 def render(self, count:bool=False, partition_counts:bool=False, **kwargs):
217 """
218 Render the tree as text, optionally showing key counts or partition counts.
220 Args:
221 count (bool): If True, show the number of keys at each node.
222 partition_counts (bool): If True, show partition-wise key counts at each node.
223 **kwargs: Additional arguments passed to the underlying tree render method.
225 Returns:
226 anytree.RenderTree or str: The rendered tree.
227 """
228 if partition_counts:
229 self.add_partition_counts()
230 for node in self.classification_tree.post_order_iter():
231 partition_counts_str = "; ".join([f"{k}->{node.partition_counts[k]}" for k in sorted(node.partition_counts.keys())])
232 node.render_str = f"{node.name} {partition_counts_str}"
233 kwargs['attr'] = "render_str"
234 elif count:
235 self.add_counts()
236 for node in self.classification_tree.post_order_iter():
237 node.render_str = f"{node.name} ({node.count})" if getattr(node, "count", 0) else node.name
239 kwargs['attr'] = "render_str"
241 return self.classification_tree.render(**kwargs)
243 def sunburst(self, **kwargs) -> "go.Figure":
244 """
245 Generate a Plotly sunburst plot based on the TreeDict.
247 Node values are based on the number of keys mapped to each node.
249 Args:
250 **kwargs: Additional keyword arguments passed to Plotly layout.
252 Returns:
253 plotly.graph_objects.Figure: A sunburst plot.
254 """
255 import plotly.graph_objects as go
257 self.add_counts()
258 labels = []
259 parents = []
260 values = []
262 for node in self.classification_tree.pre_order_iter():
263 labels.append(node.name)
264 parents.append(node.parent.name if node.parent else "")
265 values.append(node.count)
267 fig = go.Figure(go.Sunburst(
268 labels=labels,
269 parents=parents,
270 values=values,
271 branchvalues="remainder",
272 ))
274 fig.update_layout(margin=dict(t=10, l=10, r=10, b=10), **kwargs)
275 return fig
277 def keys_to_file(self, file:Path) -> None:
278 """
279 Write all keys to a text file, one per line.
281 Args:
282 file (Path): Path to the output text file.
283 """
284 with open(file, "w") as f:
285 for key in self.keys():
286 print(key, file=f)
288 def csv(self, file:Path) -> None:
289 """
290 Write all keys, node names and partitions to a CSV file.
292 Args:
293 file (Path): Path to the output text file.
294 """
295 with open(file, "w") as f:
296 print("key,node,partition", file=f)
297 for key in self.keys():
298 detail = self[key]
299 node = self.node(key)
300 print(f"{key},{node.name.strip()},{detail.partition}", file=f)
302 def pickle_tree(self, output:Path):
303 """
304 Save only the classification tree (not the key-to-node mapping) to a pickle file.
306 Args:
307 output (Path): Path to the output file.
308 """
309 with open(output, 'wb') as pickle_file:
310 pickle.dump(self.classification_tree, pickle_file)
313app = typer.Typer()
315@app.command()
316def keys(
317 treedict:Path = typer.Argument(...,help="The path to the TreeDict."),
318 partition:int|None = typer.Option(None,help="The index of the partition to list."),
319):
320 """
321 Prints a list of keys in a TreeDict.
323 If a partition is given, then only the keys for that partition are given.
324 """
325 treedict = TreeDict.load(treedict)
326 for key in treedict.keys(partition=partition):
327 print(key)
330@app.command()
331def csv(
332 treedict:Path = typer.Argument(...,help="The path to the TreeDict."),
333 csv:Path = typer.Argument(...,help="The path to the output CSV file."),
334):
335 """
336 Writes a CSV file with the key, node name and partition.
337 """
338 treedict = TreeDict.load(treedict)
339 treedict.csv(csv)
342@app.command()
343def render(
344 treedict:Path = typer.Argument(...,help="The path to the TreeDict."),
345 output:Path|None = typer.Option(None, help="The path to save the rendered tree."),
346 print_tree:bool = typer.Option(True, help="Whether or not to print the tree to the screen."),
347 count:bool = typer.Option(False, help="Whether or not to print the count of keys at each node."),
348 partition_counts:bool = typer.Option(False, help="Whether or not to print the count of each partition at each node."),
349):
350 """
351 Render the tree as text, optionally showing key counts or partition counts.
352 """
353 treedict = TreeDict.load(treedict)
354 treedict.render(filepath=output, print=print_tree, count=count, partition_counts=partition_counts)
357@app.command()
358def count(
359 treedict:Path = typer.Argument(...,help="The path to the TreeDict."),
360):
361 """
362 Prints the number of keys in the TreeDict.
363 """
364 treedict = TreeDict.load(treedict)
365 print(len(treedict))
368@app.command()
369def sunburst(
370 treedict:Path = typer.Argument(...,help="The path to the TreeDict."),
371 show:bool = typer.Option(False, help="Whether or not to show the plot."),
372 output:Path = typer.Option(None, help="The path to save the rendered tree."),
373 width:int = typer.Option(1000, help="The width of the plot."),
374 height:int = typer.Option(0, help="The height of the plot. If 0 then it will be calculated based on the width."),
375):
376 """
377 Renders the TreeDict as a sunburst plot.
378 """
379 treedict = TreeDict.load(treedict)
380 height = height or width
382 fig = treedict.sunburst(width=width, height=height)
383 if show:
384 fig.show()
386 if output:
387 output = Path(output)
388 output.parent.mkdir(exist_ok=True, parents=True)
390 # if kaleido is installed, turn off mathjax
391 # https://github.com/plotly/plotly.py/issues/3469
392 try:
393 import plotly.io as pio
394 pio.kaleido.scope.mathjax = None
395 except Exception as e:
396 pass
398 output_func = fig.write_html if output.suffix.lower() == ".html" else fig.write_image
399 output_func(output)
402@app.command()
403def truncate(
404 treedict:Path = typer.Argument(...,help="The path to the TreeDict."),
405 max_depth:int = typer.Argument(...,help="The maximum depth to truncate the tree."),
406 output:Path = typer.Argument(...,help="The path to the output file."),
407):
408 """
409 Truncates the tree to a maximum depth.
410 """
411 treedict = TreeDict.load(treedict)
412 new_tree = treedict.truncate(max_depth)
413 new_tree.save(output)
416@app.command()
417def layer_size(
418 treedict:Path = typer.Argument(...,help="The path to the TreeDict."),
419):
420 """
421 Prints the size of the neural network layer to predict the classification tree.
422 """
423 treedict = TreeDict.load(treedict)
424 print(treedict.classification_tree.layer_size)
427@app.command()
428def pickle_tree(
429 treedict:Path = typer.Argument(...,help="The path to the TreeDict."),
430 output:Path = typer.Argument(...,help="The path to the output pickle file."),
431):
432 """
433 Pickles the classification tree to a file.
434 """
435 treedict = TreeDict.load(treedict)
436 treedict.pickle_tree(output)