Utils
File: utils.py
We also provide some experiment workflow utilities which one can use and import from this module.
BaseModelInterface
Wrapper class which provides a model interface for torch.nn models. Mainly, this class provides the forward pass pipeline function, 'predict' which sends an
input through this pipeline
preprocess --> model --> postprocess.
Users must provide a torch.nn model and can optionally specify preprocess and postprocess functions.
Source code in torchplate/utils.py
16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 |
|
__init__(model)
Provide torch.nn module.
Source code in torchplate/utils.py
28 29 30 31 32 |
|
XYDataset
Bases: Dataset
PyTorch Dataset class for datasets of the form [(x,y), ..., (x,y)].
Source code in torchplate/utils.py
59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 |
|
__init__(data_set)
- distribution (sequence): sequence of the form [(x,y), ..., (x,y)] representing the dataset.
Source code in torchplate/utils.py
64 65 66 67 68 69 70 71 |
|
get_loaders(torch_sets)
Given a sequence of torch.utils.data.Dataset objects, this function wraps them all in torch.utils.data.Dataloader objects and returns a sequence in the same order. Note that this function doesn't support custom arguments to the torch.utils.data.DataLoader call. If one desires to use custom arguments (e.g., batch_size), they should call torch.utils.data.DataLoader themselves.
- torch_sets (sequence): a sequence consisting of a torch.utils.data.Dataset objects.
- loaders (sequence): the datasets wrapped in a torch.utils.data.Dataloader objects (returned in the same order.)
Source code in torchplate/utils.py
120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 |
|
get_xy_dataset(distribution)
Given a dataset of the form [(x,y), ..., (x,y)], returns a PyTorch Dataset object.
- distribution (sequence): sequence of the form [(x,y), ..., (x,y)] representing the dataset.
- a torch.utils.data.Dataset object
Source code in torchplate/utils.py
84 85 86 87 88 89 90 91 92 93 94 95 96 |
|
get_xy_loaders(distribution)
end-to-end function which returns train and test loaders given a sequence of the form [(x,y), ..., (x,y)]. If more customization is needed, please call the other utility functions individually.
- distribution (sequence): dataset of the form [(x,y), ..., (x,y)].
- loaders (sequence): the datasets wrapped in a torch.utils.data.Dataloader objects (returned in the same order).
Source code in torchplate/utils.py
146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 |
|
split_dataset(torch_set, ratio=0.9)
Given a torch.utils.data.Dataset object, this function splits it into train and test a torch.utils.data.Dataset objects. The split is random is the size is based on the input ratio.
- torch_set: a torch.utils.data.Dataset object containing the entire dataset
- ratio: train/test ratio split. Default is 0.9.
Tuple consisting of: - trainset: a torch.utils.data.Dataset object to be used for training - testset: a torch.utils.data.Dataset object to be used for testing
Source code in torchplate/utils.py
99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 |
|