Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1from anytree.exporter import DotExporter 

2from anytree import PreOrderIter 

3 

4class ThresholdDotExporter(DotExporter): 

5 def __init__( 

6 self, 

7 node, 

8 probabilities, 

9 greedy_nodes, 

10 graph="digraph", 

11 name="tree", 

12 options=None, 

13 indent=4, 

14 nodenamefunc=None, 

15 nodeattrfunc=None, 

16 edgeattrfunc=None, 

17 edgetypefunc=None, 

18 prediction_color="red", 

19 non_prediction_color="gray", 

20 horizontal:bool=True, 

21 threshold:float=0.005, 

22 ): 

23 options = options or [] 

24 if horizontal: 

25 options.append('rankdir="LR";') 

26 

27 super().__init__( 

28 node, 

29 graph=graph, 

30 name=name, 

31 options=options, 

32 indent=indent, 

33 nodenamefunc=nodenamefunc, 

34 nodeattrfunc=nodeattrfunc, 

35 edgeattrfunc=edgeattrfunc, 

36 edgetypefunc=edgetypefunc 

37 ) 

38 self.greedy_nodes = greedy_nodes 

39 self.probabilities = probabilities 

40 self.prediction_color = prediction_color 

41 self.non_prediction_color = non_prediction_color 

42 self.threshold = threshold 

43 

44 def _default_nodeattrfunc(self, node): 

45 return f"color={self.prediction_color}" if node in self.greedy_nodes else "" 

46 

47 def _default_edgeattrfunc( 

48 self, 

49 parent, 

50 child, 

51 ): 

52 color = self.prediction_color if child in self.greedy_nodes else self.non_prediction_color 

53 return f"label={self.probabilities[child.index_in_softmax_layer]:.2f},color={color}" 

54 

55 def exclude_node(self, node): 

56 return not node.is_root and node not in self.greedy_nodes and self.probabilities[node.index_in_softmax_layer] < self.threshold 

57 

58 def _DotExporter__iter_nodes(self, indent, nodenamefunc, nodeattrfunc, *args, **kwargs): 

59 for node in PreOrderIter(self.node, maxlevel=self.maxlevel): 

60 if self.exclude_node(node): 

61 continue 

62 nodename = nodenamefunc(node) 

63 nodeattr = nodeattrfunc(node) 

64 nodeattr = " [%s]" % nodeattr if nodeattr is not None else "" 

65 yield '%s"%s"%s;' % (indent, DotExporter.esc(nodename), nodeattr) 

66 

67 def _DotExporter__iter_edges(self, indent, nodenamefunc, edgeattrfunc, edgetypefunc, *args, **kwargs): 

68 maxlevel = self.maxlevel - 1 if self.maxlevel else None 

69 for node in PreOrderIter(self.node, maxlevel=maxlevel): 

70 nodename = nodenamefunc(node) 

71 for child in node.children: 

72 if self.exclude_node(child): 

73 continue 

74 

75 childname = nodenamefunc(child) 

76 edgeattr = edgeattrfunc(node, child) 

77 edgetype = edgetypefunc(node, child) 

78 edgeattr = " [%s]" % edgeattr if edgeattr is not None else "" 

79 yield '%s"%s" %s "%s"%s;' % (indent, DotExporter.esc(nodename), edgetype, 

80 DotExporter.esc(childname), edgeattr)