Coverage for hierarchicalsoftmax/treedict.py: 100.00%

180 statements  

« prev     ^ index     » next       coverage.py v7.8.0, created at 2025-07-02 01:49 +0000

1import copy 

2from pathlib import Path 

3from attrs import define, field 

4from collections import UserDict 

5import pickle 

6from collections import Counter 

7from rich.progress import track 

8import typer 

9 

10from .nodes import SoftmaxNode 

11 

12 

13@define 

14class NodeDetail: 

15 """ 

16 Stores metadata for a key in the TreeDict. 

17 

18 Attributes: 

19 partition (int): The partition ID this key belongs to. 

20 node (SoftmaxNode): The node in the classification tree associated with the key. 

21 node_id (int): The index of the node in the tree (used during pickling). 

22 """ 

23 partition:int 

24 node:SoftmaxNode = field(default=None, eq=False) 

25 node_id:int = None 

26 

27 def __getstate__(self): 

28 return (self.partition, self.node_id) 

29 

30 def __setstate__(self, state): 

31 self.partition, self.node_id = state 

32 self.node = None 

33 

34 

35class AlreadyExists(Exception): 

36 pass 

37 

38 

39class TreeDict(UserDict): 

40 def __init__(self, classification_tree:SoftmaxNode|None=None): 

41 """ 

42 Initialize a TreeDict. 

43 

44 Args: 

45 classification_tree (SoftmaxNode, optional): The root of the classification tree. 

46 If not provided, a new root node named "root" will be created. 

47 """ 

48 super().__init__() 

49 self.classification_tree = classification_tree or SoftmaxNode("root") 

50 

51 def add(self, key:str, node:SoftmaxNode, partition:int) -> NodeDetail: 

52 """ 

53 Associate a key with a node and a partition. 

54 

55 Args: 

56 key (str): The unique identifier for the item. 

57 node (SoftmaxNode): The node in the classification tree to associate with the key. 

58 partition (int): The partition index for the key. 

59 

60 Raises: 

61 AlreadyExists: If the key already exists with a different node. 

62 

63 Returns: 

64 NodeDetail: The metadata object for the added key. 

65 """ 

66 assert node.root == self.classification_tree 

67 if key in self: 

68 old_node = self.node(key) 

69 if not node == old_node: 

70 raise AlreadyExists(f"Accession {key} already exists in TreeDict at node {self.node(key)}. Cannot change to {node}") 

71 

72 detail = NodeDetail( 

73 partition=partition, 

74 node=node, 

75 ) 

76 self[key] = detail 

77 return detail 

78 

79 def set_indexes(self): 

80 """ 

81 Ensure the tree has assigned node indexes, and record the node_id for each key. 

82 """ 

83 self.classification_tree.set_indexes_if_unset() 

84 for detail in self.values(): 

85 if detail.node: 

86 detail.node_id = self.classification_tree.node_to_id[detail.node] 

87 

88 def save(self, path:Path): 

89 """ 

90 Save the TreeDict to a pickle file. 

91 

92 Args: 

93 path (Path): The file path to save the TreeDict. 

94 """ 

95 path = Path(path) 

96 path.parent.mkdir(exist_ok=True, parents=True) 

97 

98 self.set_indexes() 

99 with open(path, 'wb') as handle: 

100 pickle.dump(self, handle, protocol=pickle.HIGHEST_PROTOCOL) 

101 

102 @classmethod 

103 def load(self, path:Path): 

104 """ 

105 Load a TreeDict from a pickle file. 

106 

107 Args: 

108 path (Path): The path to the serialized TreeDict. 

109 

110 Returns: 

111 TreeDict: The loaded TreeDict instance. 

112 """ 

113 with open(path, 'rb') as handle: 

114 return pickle.load(handle) 

115 

116 def node(self, key:str): 

117 """ 

118 Retrieve the node associated with a key. 

119 

120 Args: 

121 key (str): The key for which to retrieve the node. 

122 

123 Returns: 

124 SoftmaxNode: The node corresponding to the key. 

125 """ 

126 detail = self[key] 

127 if detail.node is not None: 

128 return detail.node 

129 return self.classification_tree.node_list[detail.node_id] 

130 

131 def keys_in_partition(self, partition:int): 

132 """ 

133 Yield all keys that belong to a given partition. 

134 

135 Args: 

136 partition (int): The partition to filter by. 

137 

138 Yields: 

139 str: Keys in the specified partition. 

140 """ 

141 for key, detail in self.items(): 

142 if detail.partition == partition: 

143 yield key 

144 

145 def keys(self, partition:int|None = None): 

146 """ 

147 Return keys in the TreeDict, optionally filtering by partition. 

148 

149 Args: 

150 partition (int | None): The partition to filter keys by. If None, return all keys. 

151 

152 Returns: 

153 Iterator[str]: An iterator over the keys. 

154 """ 

155 return super().keys() if partition is None else self.keys_in_partition(partition) 

156 

157 def truncate(self, max_depth:int) -> "TreeDict": 

158 """ 

159 Truncate the classification tree to a specified maximum depth and return a new TreeDict. 

160 

161 Keys deeper than the depth limit will be reassigned to the ancestor node at that depth. 

162 

163 Args: 

164 max_depth (int): The maximum number of ancestor levels to keep. 

165 

166 Returns: 

167 TreeDict: A new truncated TreeDict. 

168 """ 

169 self.classification_tree.set_indexes_if_unset() 

170 classification_tree = copy.deepcopy(self.classification_tree) 

171 new_treedict = TreeDict(classification_tree) 

172 for key in track(self.keys()): 

173 original_node = self.node(key) 

174 node_id = self.classification_tree.node_to_id[original_node] 

175 node = classification_tree.node_list[node_id] 

176 

177 ancestors = node.ancestors 

178 if len(ancestors) >= max_depth: 

179 node = ancestors[max_depth-1] 

180 new_treedict.add(key, node, self[key].partition) 

181 

182 # Remove any nodes that beyond the max depth 

183 for node in new_treedict.classification_tree.pre_order_iter(): 

184 node.readonly = False 

185 node.softmax_start_index = None 

186 if len(node.ancestors) >= max_depth: 

187 node.parent = None 

188 

189 new_treedict.set_indexes() 

190 

191 return new_treedict 

192 

193 def add_counts(self): 

194 """ 

195 Count the number of keys assigned to each node, and store the count in each node. 

196 """ 

197 for node in self.classification_tree.post_order_iter(): 

198 node.count = 0 

199 

200 for key in self.keys(): 

201 node = self.node(key) 

202 node.count += 1 

203 

204 def add_partition_counts(self): 

205 """ 

206 Count the number of keys in each partition per node and store it in the node. 

207 """ 

208 for node in self.classification_tree.post_order_iter(): 

209 node.partition_counts = Counter() 

210 

211 for key, detail in self.items(): 

212 node = self.node(key) 

213 partition = detail.partition 

214 node.partition_counts[partition] += 1 

215 

216 def render(self, count:bool=False, partition_counts:bool=False, **kwargs): 

217 """ 

218 Render the tree as text, optionally showing key counts or partition counts. 

219 

220 Args: 

221 count (bool): If True, show the number of keys at each node. 

222 partition_counts (bool): If True, show partition-wise key counts at each node. 

223 **kwargs: Additional arguments passed to the underlying tree render method. 

224 

225 Returns: 

226 anytree.RenderTree or str: The rendered tree. 

227 """ 

228 if partition_counts: 

229 self.add_partition_counts() 

230 for node in self.classification_tree.post_order_iter(): 

231 partition_counts_str = "; ".join([f"{k}->{node.partition_counts[k]}" for k in sorted(node.partition_counts.keys())]) 

232 node.render_str = f"{node.name} {partition_counts_str}" 

233 kwargs['attr'] = "render_str" 

234 elif count: 

235 self.add_counts() 

236 for node in self.classification_tree.post_order_iter(): 

237 node.render_str = f"{node.name} ({node.count})" if getattr(node, "count", 0) else node.name 

238 

239 kwargs['attr'] = "render_str" 

240 

241 return self.classification_tree.render(**kwargs) 

242 

243 def sunburst(self, **kwargs) -> "go.Figure": 

244 """ 

245 Generate a Plotly sunburst plot based on the TreeDict. 

246 

247 Node values are based on the number of keys mapped to each node. 

248 

249 Args: 

250 **kwargs: Additional keyword arguments passed to Plotly layout. 

251 

252 Returns: 

253 plotly.graph_objects.Figure: A sunburst plot. 

254 """ 

255 import plotly.graph_objects as go 

256 

257 self.add_counts() 

258 labels = [] 

259 parents = [] 

260 values = [] 

261 

262 for node in self.classification_tree.pre_order_iter(): 

263 labels.append(node.name) 

264 parents.append(node.parent.name if node.parent else "") 

265 values.append(node.count) 

266 

267 fig = go.Figure(go.Sunburst( 

268 labels=labels, 

269 parents=parents, 

270 values=values, 

271 branchvalues="remainder", 

272 )) 

273 

274 fig.update_layout(margin=dict(t=10, l=10, r=10, b=10), **kwargs) 

275 return fig 

276 

277 def keys_to_file(self, file:Path) -> None: 

278 """ 

279 Write all keys to a text file, one per line. 

280 

281 Args: 

282 file (Path): Path to the output text file. 

283 """ 

284 with open(file, "w") as f: 

285 for key in self.keys(): 

286 print(key, file=f) 

287 

288 def csv(self, file:Path) -> None: 

289 """ 

290 Write all keys, node names and partitions to a CSV file. 

291 

292 Args: 

293 file (Path): Path to the output text file. 

294 """ 

295 with open(file, "w") as f: 

296 print("key,node,partition", file=f) 

297 for key in self.keys(): 

298 detail = self[key] 

299 node = self.node(key) 

300 print(f"{key},{node.name.strip()},{detail.partition}", file=f) 

301 

302 def pickle_tree(self, output:Path): 

303 """ 

304 Save only the classification tree (not the key-to-node mapping) to a pickle file. 

305 

306 Args: 

307 output (Path): Path to the output file. 

308 """ 

309 with open(output, 'wb') as pickle_file: 

310 pickle.dump(self.classification_tree, pickle_file) 

311 

312 

313app = typer.Typer() 

314 

315@app.command() 

316def keys( 

317 treedict:Path = typer.Argument(...,help="The path to the TreeDict."), 

318 partition:int|None = typer.Option(None,help="The index of the partition to list."), 

319): 

320 """  

321 Prints a list of keys in a TreeDict.  

322  

323 If a partition is given, then only the keys for that partition are given. 

324 """ 

325 treedict = TreeDict.load(treedict) 

326 for key in treedict.keys(partition=partition): 

327 print(key) 

328 

329 

330@app.command() 

331def csv( 

332 treedict:Path = typer.Argument(...,help="The path to the TreeDict."), 

333 csv:Path = typer.Argument(...,help="The path to the output CSV file."), 

334): 

335 """  

336 Writes a CSV file with the key, node name and partition.  

337 """ 

338 treedict = TreeDict.load(treedict) 

339 treedict.csv(csv) 

340 

341 

342@app.command() 

343def render( 

344 treedict:Path = typer.Argument(...,help="The path to the TreeDict."), 

345 output:Path|None = typer.Option(None, help="The path to save the rendered tree."), 

346 print_tree:bool = typer.Option(True, help="Whether or not to print the tree to the screen."), 

347 count:bool = typer.Option(False, help="Whether or not to print the count of keys at each node."), 

348 partition_counts:bool = typer.Option(False, help="Whether or not to print the count of each partition at each node."), 

349): 

350 """ 

351 Render the tree as text, optionally showing key counts or partition counts. 

352 """ 

353 treedict = TreeDict.load(treedict) 

354 treedict.render(filepath=output, print=print_tree, count=count, partition_counts=partition_counts) 

355 

356 

357@app.command() 

358def count( 

359 treedict:Path = typer.Argument(...,help="The path to the TreeDict."), 

360): 

361 """ 

362 Prints the number of keys in the TreeDict. 

363 """ 

364 treedict = TreeDict.load(treedict) 

365 print(len(treedict)) 

366 

367 

368@app.command() 

369def sunburst( 

370 treedict:Path = typer.Argument(...,help="The path to the TreeDict."), 

371 show:bool = typer.Option(False, help="Whether or not to show the plot."), 

372 output:Path = typer.Option(None, help="The path to save the rendered tree."), 

373 width:int = typer.Option(1000, help="The width of the plot."), 

374 height:int = typer.Option(0, help="The height of the plot. If 0 then it will be calculated based on the width."), 

375): 

376 """ 

377 Renders the TreeDict as a sunburst plot. 

378 """ 

379 treedict = TreeDict.load(treedict) 

380 height = height or width 

381 

382 fig = treedict.sunburst(width=width, height=height) 

383 if show: 

384 fig.show() 

385 

386 if output: 

387 output = Path(output) 

388 output.parent.mkdir(exist_ok=True, parents=True) 

389 

390 # if kaleido is installed, turn off mathjax 

391 # https://github.com/plotly/plotly.py/issues/3469 

392 try: 

393 import plotly.io as pio 

394 pio.kaleido.scope.mathjax = None 

395 except Exception as e: 

396 pass 

397 

398 output_func = fig.write_html if output.suffix.lower() == ".html" else fig.write_image 

399 output_func(output) 

400 

401 

402@app.command() 

403def truncate( 

404 treedict:Path = typer.Argument(...,help="The path to the TreeDict."), 

405 max_depth:int = typer.Argument(...,help="The maximum depth to truncate the tree."), 

406 output:Path = typer.Argument(...,help="The path to the output file."), 

407): 

408 """ 

409 Truncates the tree to a maximum depth. 

410 """ 

411 treedict = TreeDict.load(treedict) 

412 new_tree = treedict.truncate(max_depth) 

413 new_tree.save(output) 

414 

415 

416@app.command() 

417def layer_size( 

418 treedict:Path = typer.Argument(...,help="The path to the TreeDict."), 

419): 

420 """ 

421 Prints the size of the neural network layer to predict the classification tree. 

422 """ 

423 treedict = TreeDict.load(treedict) 

424 print(treedict.classification_tree.layer_size) 

425 

426 

427@app.command() 

428def pickle_tree( 

429 treedict:Path = typer.Argument(...,help="The path to the TreeDict."), 

430 output:Path = typer.Argument(...,help="The path to the output pickle file."), 

431): 

432 """ 

433 Pickles the classification tree to a file. 

434 """ 

435 treedict = TreeDict.load(treedict) 

436 treedict.pickle_tree(output)