Coverage for hierarchicalsoftmax/dotexporter.py: 100.00%

50 statements  

« prev     ^ index     » next       coverage.py v7.8.0, created at 2025-07-02 01:49 +0000

1from anytree.exporter import DotExporter 

2from anytree import PreOrderIter 

3 

4class ThresholdDotExporter(DotExporter): 

5 def __init__( 

6 self, 

7 node, 

8 probabilities, 

9 greedy_nodes, 

10 graph="digraph", 

11 name="tree", 

12 options=None, 

13 indent=4, 

14 nodenamefunc=None, 

15 nodeattrfunc=None, 

16 edgeattrfunc=None, 

17 edgetypefunc=None, 

18 prediction_color="red", 

19 non_prediction_color="gray", 

20 horizontal:bool=True, 

21 threshold:float=0.005, 

22 ): 

23 options = options or [] 

24 if horizontal: 

25 options.append('rankdir="LR";') 

26 

27 super().__init__( 

28 node, 

29 graph=graph, 

30 name=name, 

31 options=options, 

32 indent=indent, 

33 nodenamefunc=nodenamefunc, 

34 nodeattrfunc=nodeattrfunc, 

35 edgeattrfunc=edgeattrfunc, 

36 edgetypefunc=edgetypefunc 

37 ) 

38 self.greedy_nodes = greedy_nodes 

39 self.probabilities = probabilities 

40 self.prediction_color = prediction_color 

41 self.non_prediction_color = non_prediction_color 

42 self.threshold = threshold 

43 self.excluded_nodes = set() 

44 

45 def _default_nodeattrfunc(self, node): 

46 return f"color={self.prediction_color}" if node in self.greedy_nodes else "" 

47 

48 def _default_edgeattrfunc( 

49 self, 

50 parent, 

51 child, 

52 ): 

53 color = self.prediction_color if child in self.greedy_nodes else self.non_prediction_color 

54 label = f"{self.probabilities[child.index_in_softmax_layer]:.2f}" if child.index_in_softmax_layer is not None else "x" 

55 return f"label={label},color={color}" 

56 

57 def exclude_node(self, node) -> bool: 

58 if node in self.excluded_nodes: 

59 return True 

60 

61 if node.index_in_softmax_layer is None: 

62 exclude_node = node.parent in self.excluded_nodes 

63 else: 

64 include_node = node.is_root or node in self.greedy_nodes or self.probabilities[node.index_in_softmax_layer] >= self.threshold 

65 exclude_node = not include_node 

66 

67 if exclude_node: 

68 self.excluded_nodes.add(node) 

69 return exclude_node 

70 

71 def _DotExporter__iter_nodes(self, indent, nodenamefunc, nodeattrfunc, *args, **kwargs): 

72 for node in PreOrderIter(self.node, maxlevel=self.maxlevel): 

73 if self.exclude_node(node): 

74 continue 

75 nodename = nodenamefunc(node) 

76 nodeattr = nodeattrfunc(node) 

77 nodeattr = " [%s]" % nodeattr if nodeattr is not None else "" 

78 yield '%s"%s"%s;' % (indent, DotExporter.esc(nodename), nodeattr) 

79 

80 def _DotExporter__iter_edges(self, indent, nodenamefunc, edgeattrfunc, edgetypefunc, *args, **kwargs): 

81 maxlevel = self.maxlevel - 1 if self.maxlevel else None 

82 for node in PreOrderIter(self.node, maxlevel=maxlevel): 

83 nodename = nodenamefunc(node) 

84 for child in node.children: 

85 if self.exclude_node(child): 

86 continue 

87 

88 childname = nodenamefunc(child) 

89 edgeattr = edgeattrfunc(node, child) 

90 edgetype = edgetypefunc(node, child) 

91 edgeattr = " [%s]" % edgeattr if edgeattr is not None else "" 

92 yield '%s"%s" %s "%s"%s;' % (indent, DotExporter.esc(nodename), edgetype, 

93 DotExporter.esc(childname), edgeattr)