|
1 | 1 | #!/usr/bin/env python3 |
2 | 2 |
|
| 3 | +import torch |
3 | 4 | import requests |
4 | 5 | import tqdm |
5 | 6 |
|
@@ -44,3 +45,98 @@ def save_response_content(response, destination): |
44 | 45 | for chunk in response.iter_content(CHUNK_SIZE): |
45 | 46 | if chunk: # filter out keep-alive new chunks |
46 | 47 | f.write(chunk) |
| 48 | + |
| 49 | + |
| 50 | +class InfiniteIterator(object): |
| 51 | + |
| 52 | + def __init__(self, dataloader): |
| 53 | + self.dataloader = dataloader |
| 54 | + self.iterator = iter(self.dataloader) |
| 55 | + |
| 56 | + def __iter__(self): |
| 57 | + return self |
| 58 | + |
| 59 | + def __next__(self): |
| 60 | + while True: |
| 61 | + try: |
| 62 | + return next(self.iterator) |
| 63 | + except StopIteration: |
| 64 | + self.iterator = iter(self.dataloader) |
| 65 | + |
| 66 | + |
| 67 | +def partition_task(data, labels, shots=1, ways=None): |
| 68 | + assert data.size(0) == labels.size(0) |
| 69 | + unique_labels = labels.unique() |
| 70 | + if ways is None: |
| 71 | + ways = unique_labels.numel() |
| 72 | + data_shape = data.shape[1:] |
| 73 | + num_support = ways * shots |
| 74 | + num_query = data.size(0) - num_support |
| 75 | + assert num_query % ways == 0, 'Only query_shot == support_shot supported.' |
| 76 | + query_shots = num_query // ways |
| 77 | + support_data = torch.empty( |
| 78 | + (num_support,) + data_shape, |
| 79 | + device=data.device, |
| 80 | + dtype=data.dtype, |
| 81 | + ) |
| 82 | + support_labels = torch.empty( |
| 83 | + num_support, |
| 84 | + device=labels.device, |
| 85 | + dtype=labels.dtype, |
| 86 | + ) |
| 87 | + query_data = torch.empty( |
| 88 | + (num_query, ) + data_shape, |
| 89 | + device=data.device, |
| 90 | + dtype=data.dtype, |
| 91 | + ) |
| 92 | + query_labels = torch.empty( |
| 93 | + num_query, |
| 94 | + device=labels.device, |
| 95 | + dtype=labels.dtype, |
| 96 | + ) |
| 97 | + for i, label in enumerate(unique_labels): |
| 98 | + support_start = i * shots |
| 99 | + support_end = support_start + shots |
| 100 | + query_start = i * query_shots |
| 101 | + query_end = query_start + query_shots |
| 102 | + |
| 103 | + # filter data |
| 104 | + label_data = data[labels == label] # TODO: fancy index makes a copy. |
| 105 | + num_label_data = label_data.size(0) |
| 106 | + assert num_label_data == shots + query_shots, \ |
| 107 | + 'Only same number of query per label supported.' |
| 108 | + |
| 109 | + # set value of labels |
| 110 | + support_labels[support_start:support_end].fill_(label) |
| 111 | + query_labels[query_start:query_end].fill_(label) |
| 112 | + |
| 113 | + # set value of data |
| 114 | + support_data[support_start:support_end].copy_(label_data[:shots]) |
| 115 | + query_data[query_start:query_end].copy_(label_data[shots:]) |
| 116 | + |
| 117 | + return (support_data, support_labels), (query_data, query_labels) |
| 118 | + |
| 119 | + |
| 120 | +class OnDeviceDataset(torch.utils.data.TensorDataset): |
| 121 | + |
| 122 | + def __init__(self, dataset, device=None, transform=None): |
| 123 | + data = [] |
| 124 | + labels = [] |
| 125 | + for x, y in dataset: |
| 126 | + data.append(x.unsqueeze(0)) |
| 127 | + labels.append(y) |
| 128 | + data = torch.cat(data, dim=0) |
| 129 | + labels = torch.tensor(labels) |
| 130 | + if device is not None: |
| 131 | + data = data.to(device) |
| 132 | + labels = labels.to(device) |
| 133 | + super(OnDeviceDataset, self).__init__(data, labels) |
| 134 | + self.transform = transform |
| 135 | + if hasattr(dataset, '_bookkeeping_path'): |
| 136 | + self._bookkeeping_path = dataset._bookkeeping_path |
| 137 | + |
| 138 | + def __getitem__(self, index): |
| 139 | + x, y = super(OnDeviceDataset, self).__getitem__(index) |
| 140 | + if self.transform is not None: |
| 141 | + x = self.transform(x) |
| 142 | + return x, y |
0 commit comments