Coverage for hierarchicalsoftmax/dotexporter.py : 100.00%

Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1from anytree.exporter import DotExporter
2from anytree import PreOrderIter
4class ThresholdDotExporter(DotExporter):
5 def __init__(
6 self,
7 node,
8 probabilities,
9 greedy_nodes,
10 graph="digraph",
11 name="tree",
12 options=None,
13 indent=4,
14 nodenamefunc=None,
15 nodeattrfunc=None,
16 edgeattrfunc=None,
17 edgetypefunc=None,
18 prediction_color="red",
19 non_prediction_color="gray",
20 horizontal:bool=True,
21 threshold:float=0.005,
22 ):
23 options = options or []
24 if horizontal:
25 options.append('rankdir="LR";')
27 super().__init__(
28 node,
29 graph=graph,
30 name=name,
31 options=options,
32 indent=indent,
33 nodenamefunc=nodenamefunc,
34 nodeattrfunc=nodeattrfunc,
35 edgeattrfunc=edgeattrfunc,
36 edgetypefunc=edgetypefunc
37 )
38 self.greedy_nodes = greedy_nodes
39 self.probabilities = probabilities
40 self.prediction_color = prediction_color
41 self.non_prediction_color = non_prediction_color
42 self.threshold = threshold
44 def _default_nodeattrfunc(self, node):
45 return f"color={self.prediction_color}" if node in self.greedy_nodes else ""
47 def _default_edgeattrfunc(
48 self,
49 parent,
50 child,
51 ):
52 color = self.prediction_color if child in self.greedy_nodes else self.non_prediction_color
53 return f"label={self.probabilities[child.index_in_softmax_layer]:.2f},color={color}"
55 def exclude_node(self, node):
56 return not node.is_root and node not in self.greedy_nodes and self.probabilities[node.index_in_softmax_layer] < self.threshold
58 def _DotExporter__iter_nodes(self, indent, nodenamefunc, nodeattrfunc, *args, **kwargs):
59 for node in PreOrderIter(self.node, maxlevel=self.maxlevel):
60 if self.exclude_node(node):
61 continue
62 nodename = nodenamefunc(node)
63 nodeattr = nodeattrfunc(node)
64 nodeattr = " [%s]" % nodeattr if nodeattr is not None else ""
65 yield '%s"%s"%s;' % (indent, DotExporter.esc(nodename), nodeattr)
67 def _DotExporter__iter_edges(self, indent, nodenamefunc, edgeattrfunc, edgetypefunc, *args, **kwargs):
68 maxlevel = self.maxlevel - 1 if self.maxlevel else None
69 for node in PreOrderIter(self.node, maxlevel=maxlevel):
70 nodename = nodenamefunc(node)
71 for child in node.children:
72 if self.exclude_node(child):
73 continue
75 childname = nodenamefunc(child)
76 edgeattr = edgeattrfunc(node, child)
77 edgetype = edgetypefunc(node, child)
78 edgeattr = " [%s]" % edgeattr if edgeattr is not None else ""
79 yield '%s"%s" %s "%s"%s;' % (indent, DotExporter.esc(nodename), edgetype,
80 DotExporter.esc(childname), edgeattr)