Coverage for hierarchicalsoftmax/dotexporter.py: 100.00%
50 statements
« prev ^ index » next coverage.py v7.8.0, created at 2025-07-02 01:49 +0000
« prev ^ index » next coverage.py v7.8.0, created at 2025-07-02 01:49 +0000
1from anytree.exporter import DotExporter
2from anytree import PreOrderIter
4class ThresholdDotExporter(DotExporter):
5 def __init__(
6 self,
7 node,
8 probabilities,
9 greedy_nodes,
10 graph="digraph",
11 name="tree",
12 options=None,
13 indent=4,
14 nodenamefunc=None,
15 nodeattrfunc=None,
16 edgeattrfunc=None,
17 edgetypefunc=None,
18 prediction_color="red",
19 non_prediction_color="gray",
20 horizontal:bool=True,
21 threshold:float=0.005,
22 ):
23 options = options or []
24 if horizontal:
25 options.append('rankdir="LR";')
27 super().__init__(
28 node,
29 graph=graph,
30 name=name,
31 options=options,
32 indent=indent,
33 nodenamefunc=nodenamefunc,
34 nodeattrfunc=nodeattrfunc,
35 edgeattrfunc=edgeattrfunc,
36 edgetypefunc=edgetypefunc
37 )
38 self.greedy_nodes = greedy_nodes
39 self.probabilities = probabilities
40 self.prediction_color = prediction_color
41 self.non_prediction_color = non_prediction_color
42 self.threshold = threshold
43 self.excluded_nodes = set()
45 def _default_nodeattrfunc(self, node):
46 return f"color={self.prediction_color}" if node in self.greedy_nodes else ""
48 def _default_edgeattrfunc(
49 self,
50 parent,
51 child,
52 ):
53 color = self.prediction_color if child in self.greedy_nodes else self.non_prediction_color
54 label = f"{self.probabilities[child.index_in_softmax_layer]:.2f}" if child.index_in_softmax_layer is not None else "x"
55 return f"label={label},color={color}"
57 def exclude_node(self, node) -> bool:
58 if node in self.excluded_nodes:
59 return True
61 if node.index_in_softmax_layer is None:
62 exclude_node = node.parent in self.excluded_nodes
63 else:
64 include_node = node.is_root or node in self.greedy_nodes or self.probabilities[node.index_in_softmax_layer] >= self.threshold
65 exclude_node = not include_node
67 if exclude_node:
68 self.excluded_nodes.add(node)
69 return exclude_node
71 def _DotExporter__iter_nodes(self, indent, nodenamefunc, nodeattrfunc, *args, **kwargs):
72 for node in PreOrderIter(self.node, maxlevel=self.maxlevel):
73 if self.exclude_node(node):
74 continue
75 nodename = nodenamefunc(node)
76 nodeattr = nodeattrfunc(node)
77 nodeattr = " [%s]" % nodeattr if nodeattr is not None else ""
78 yield '%s"%s"%s;' % (indent, DotExporter.esc(nodename), nodeattr)
80 def _DotExporter__iter_edges(self, indent, nodenamefunc, edgeattrfunc, edgetypefunc, *args, **kwargs):
81 maxlevel = self.maxlevel - 1 if self.maxlevel else None
82 for node in PreOrderIter(self.node, maxlevel=maxlevel):
83 nodename = nodenamefunc(node)
84 for child in node.children:
85 if self.exclude_node(child):
86 continue
88 childname = nodenamefunc(child)
89 edgeattr = edgeattrfunc(node, child)
90 edgetype = edgetypefunc(node, child)
91 edgeattr = " [%s]" % edgeattr if edgeattr is not None else ""
92 yield '%s"%s" %s "%s"%s;' % (indent, DotExporter.esc(nodename), edgetype,
93 DotExporter.esc(childname), edgeattr)