Coverage for polytorch/util.py : 100.00%
Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1from typing import List, Tuple
2from torch import Tensor
5def total_size(data_types) -> int:
6 """
7 Calculates the total number of features required to predict a list of output types.
9 Args:
10 data_types (List[PolyData]): The data types to predict.
12 Returns:
13 int: The number of features required to predict the given data types.
14 """
15 return sum(data_type.size() for data_type in data_types)
18def split_tensor(tensor:Tensor, data_types, feature_axis:int=-1) -> Tuple[Tensor, ...]:
19 """
20 Splits a tensor into a tuple of tensors, one for each data type.
22 Args:
23 tensor (Tensor): The predictions tensor.
24 data_types (List[PolyData]): The data types to predict.
25 feature_axis (int, optional): The axis which has the features to predict. Defaults to last axis.
27 Returns:
28 Tuple[Tensor, ...]: A tuple of tensors, one for each data type.
29 """
30 current_index = 0
31 split_tensors = []
32 slice_indices = [slice(0, None)] * len(tensor.shape)
33 for data_type in data_types:
34 size = data_type.size()
35 slice_indices[feature_axis] = slice(current_index,current_index+size)
36 split_tensors.append(tensor[slice_indices])
37 current_index += size
39 assert current_index == tensor.shape[feature_axis]
41 return tuple(split_tensors)
44def permute_feature_axis(tensor:Tensor, old_axis:int, new_axis:int) -> Tensor:
45 """
46 Changes the shape of a tensor so that the feature axis is in a new axis.
48 Args:
49 tensor (torch.Tensor): The tensor to permute.
50 new_axis (int): The desired index of the feature axis.
52 Returns:
53 torch.Tensor: The predictions tensor with the feature axis at the specified index.
54 """
55 axes_count = len(tensor.shape)
56 if old_axis % axes_count != new_axis % axes_count:
57 axes = list(range(axes_count))
58 axes.insert(new_axis, axes.pop(old_axis))
59 return tensor.permute(*axes)
60 return tensor
63def squeeze_prediction(prediction:Tensor, target:Tensor, feature_axis:int):
64 """
65 Squeeze feature axis if necessary
66 """
67 feature_axis = feature_axis % len(prediction.shape)
68 if (
69 len(prediction.shape) == len(target.shape) + 1 and
70 prediction.shape[:feature_axis] + prediction.shape[feature_axis+1:] == target.shape
71 ):
72 prediction = prediction.squeeze(feature_axis)
73 return prediction