detecto: Can't do model.fit(dataset)
Actually I got the following errors when I was trying this code on Google Colab.
from detecto import core, utils, visualize
dataset = core.Dataset('images/')
model = core.Model(['dog', 'cat'])
model.fit(dataset)
This is the error I got in the Google Colab terminal. Please help me, anyone
/usr/local/lib/python3.6/dist-packages/torch/nn/functional.py:2854: UserWarning: The default behavior for interpolate/upsample with float scale_factor will change in 1.6.0 to align with other frameworks/libraries, and use scale_factor directly, instead of relying on the computed output size. If you wish to keep the old behavior, please set recompute_scale_factor=True. See the documentation of nn.Upsample for details.
warnings.warn("The default behavior for interpolate/upsample with float scale_factor will change "
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
<ipython-input-9-11e9a74e8844> in <module>()
4 model = core.Model(['dog', 'cat'])
5
----> 6 model.fit(dataset)
/usr/local/lib/python3.6/dist-packages/detecto/core.py in fit(self, dataset, val_dataset, epochs, learning_rate, momentum, weight_decay, gamma, lr_step_size, verbose)
467 # Training step
468 self._model.train()
--> 469 for images, targets in dataset:
470 self._convert_to_int_labels(targets)
471 images, targets = self._to_device(images, targets)
/usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py in __next__(self)
343
344 def __next__(self):
--> 345 data = self._next_data()
346 self._num_yielded += 1
347 if self._dataset_kind == _DatasetKind.Iterable and \
/usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py in _next_data(self)
383 def _next_data(self):
384 index = self._next_index() # may raise StopIteration
--> 385 data = self._dataset_fetcher.fetch(index) # may raise StopIteration
386 if self._pin_memory:
387 data = _utils.pin_memory.pin_memory(data)
/usr/local/lib/python3.6/dist-packages/torch/utils/data/_utils/fetch.py in fetch(self, possibly_batched_index)
42 def fetch(self, possibly_batched_index):
43 if self.auto_collation:
---> 44 data = [self.dataset[idx] for idx in possibly_batched_index]
45 else:
46 data = self.dataset[possibly_batched_index]
/usr/local/lib/python3.6/dist-packages/torch/utils/data/_utils/fetch.py in <listcomp>(.0)
42 def fetch(self, possibly_batched_index):
43 if self.auto_collation:
---> 44 data = [self.dataset[idx] for idx in possibly_batched_index]
45 else:
46 data = self.dataset[possibly_batched_index]
/usr/local/lib/python3.6/dist-packages/detecto/core.py in __getitem__(self, idx)
196 box[0, 0], box[0, 2] = box[0, (2, 0)]
197 else:
--> 198 image = t(image)
199
200 # Scale down box if necessary
/usr/local/lib/python3.6/dist-packages/torchvision/transforms/transforms.py in __call__(self, tensor)
164 Tensor: Normalized Tensor image.
165 """
--> 166 return F.normalize(tensor, self.mean, self.std, self.inplace)
167
168 def __repr__(self):
/usr/local/lib/python3.6/dist-packages/torchvision/transforms/functional.py in normalize(tensor, mean, std, inplace)
206 if std.ndim == 1:
207 std = std[:, None, None]
--> 208 tensor.sub_(mean).div_(std)
209 return tensor
210
RuntimeError: The size of tensor a (4) must match the size of tensor b (3) at non-singleton dimension 0
About this issue
- Original URL
- State: closed
- Created 4 years ago
- Comments: 16 (7 by maintainers)
The error seems to be that the PNG images have 4 channels instead of 3, i.e. RGBA vs. RGB. You’ll want to convert your images from RGBA to RGB, which you can do by applying a custom transform when creating your Dataset. For example:
Here, the
transforms.Lambda(lambda x: x[:,:,:3])
takes in an image and returns only the first three channels (RGB), which should allow you to train your model as normal:I can’t seem to open your attachments, but here are the answers to your questions:
dataset[0]
then your dataset should be valid.Basically, just try to run the following code to see if any similar errors come up when reading in and plotting your images:
If that doesn’t have any errors, you can also try seeing if the issue is with the dataset: