Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1from typing import List, Tuple 

2from torch import Tensor 

3 

4 

5def total_size(data_types) -> int: 

6 """ 

7 Calculates the total number of features required to predict a list of output types. 

8 

9 Args: 

10 data_types (List[PolyData]): The data types to predict. 

11 

12 Returns: 

13 int: The number of features required to predict the given data types. 

14 """ 

15 return sum(data_type.size() for data_type in data_types) 

16 

17 

18def split_tensor(tensor:Tensor, data_types, feature_axis:int=-1) -> Tuple[Tensor, ...]: 

19 """ 

20 Splits a tensor into a tuple of tensors, one for each data type. 

21 

22 Args: 

23 tensor (Tensor): The predictions tensor. 

24 data_types (List[PolyData]): The data types to predict. 

25 feature_axis (int, optional): The axis which has the features to predict. Defaults to last axis. 

26 

27 Returns: 

28 Tuple[Tensor, ...]: A tuple of tensors, one for each data type. 

29 """ 

30 current_index = 0 

31 split_tensors = [] 

32 slice_indices = [slice(0, None)] * len(tensor.shape) 

33 for data_type in data_types: 

34 size = data_type.size() 

35 slice_indices[feature_axis] = slice(current_index,current_index+size) 

36 split_tensors.append(tensor[slice_indices]) 

37 current_index += size 

38 

39 assert current_index == tensor.shape[feature_axis] 

40 

41 return tuple(split_tensors) 

42 

43 

44def permute_feature_axis(tensor:Tensor, old_axis:int, new_axis:int) -> Tensor: 

45 """ 

46 Changes the shape of a tensor so that the feature axis is in a new axis. 

47 

48 Args: 

49 tensor (torch.Tensor): The tensor to permute. 

50 new_axis (int): The desired index of the feature axis. 

51 

52 Returns: 

53 torch.Tensor: The predictions tensor with the feature axis at the specified index. 

54 """ 

55 axes_count = len(tensor.shape) 

56 if old_axis % axes_count != new_axis % axes_count: 

57 axes = list(range(axes_count)) 

58 axes.insert(new_axis, axes.pop(old_axis)) 

59 return tensor.permute(*axes) 

60 return tensor 

61 

62 

63def squeeze_prediction(prediction:Tensor, target:Tensor, feature_axis:int): 

64 """ 

65 Squeeze feature axis if necessary 

66 """ 

67 feature_axis = feature_axis % len(prediction.shape) 

68 if ( 

69 len(prediction.shape) == len(target.shape) + 1 and 

70 prediction.shape[:feature_axis] + prediction.shape[feature_axis+1:] == target.shape 

71 ): 

72 prediction = prediction.squeeze(feature_axis) 

73 return prediction