detecto: Can't do

Actually I got the following errors when I was trying this code on Google Colab.

from detecto import core, utils, visualize

dataset = core.Dataset('images/')
model = core.Model(['dog', 'cat'])

This is the error I got in the Google Colab terminal. Please help me, anyone

/usr/local/lib/python3.6/dist-packages/torch/nn/ UserWarning: The default behavior for interpolate/upsample with float scale_factor will change in 1.6.0 to align with other frameworks/libraries, and use scale_factor directly, instead of relying on the computed output size. If you wish to keep the old behavior, please set recompute_scale_factor=True. See the documentation of nn.Upsample for details. 
  warnings.warn("The default behavior for interpolate/upsample with float scale_factor will change "
RuntimeError                              Traceback (most recent call last)
<ipython-input-9-11e9a74e8844> in <module>()
      4 model = core.Model(['dog', 'cat'])
----> 6

/usr/local/lib/python3.6/dist-packages/detecto/ in fit(self, dataset, val_dataset, epochs, learning_rate, momentum, weight_decay, gamma, lr_step_size, verbose)
    467             # Training step
    468             self._model.train()
--> 469             for images, targets in dataset:
    470                 self._convert_to_int_labels(targets)
    471                 images, targets = self._to_device(images, targets)

/usr/local/lib/python3.6/dist-packages/torch/utils/data/ in __next__(self)
    344     def __next__(self):
--> 345         data = self._next_data()
    346         self._num_yielded += 1
    347         if self._dataset_kind == _DatasetKind.Iterable and \

/usr/local/lib/python3.6/dist-packages/torch/utils/data/ in _next_data(self)
    383     def _next_data(self):
    384         index = self._next_index()  # may raise StopIteration
--> 385         data = self._dataset_fetcher.fetch(index)  # may raise StopIteration
    386         if self._pin_memory:
    387             data = _utils.pin_memory.pin_memory(data)

/usr/local/lib/python3.6/dist-packages/torch/utils/data/_utils/ in fetch(self, possibly_batched_index)
     42     def fetch(self, possibly_batched_index):
     43         if self.auto_collation:
---> 44             data = [self.dataset[idx] for idx in possibly_batched_index]
     45         else:
     46             data = self.dataset[possibly_batched_index]

/usr/local/lib/python3.6/dist-packages/torch/utils/data/_utils/ in <listcomp>(.0)
     42     def fetch(self, possibly_batched_index):
     43         if self.auto_collation:
---> 44             data = [self.dataset[idx] for idx in possibly_batched_index]
     45         else:
     46             data = self.dataset[possibly_batched_index]

/usr/local/lib/python3.6/dist-packages/detecto/ in __getitem__(self, idx)
    196                         box[0, 0], box[0, 2] = box[0, (2, 0)]
    197                 else:
--> 198                     image = t(image)
    200             # Scale down box if necessary

/usr/local/lib/python3.6/dist-packages/torchvision/transforms/ in __call__(self, tensor)
    164             Tensor: Normalized Tensor image.
    165         """
--> 166         return F.normalize(tensor, self.mean, self.std, self.inplace)
    168     def __repr__(self):

/usr/local/lib/python3.6/dist-packages/torchvision/transforms/ in normalize(tensor, mean, std, inplace)
    206     if std.ndim == 1:
    207         std = std[:, None, None]
--> 208     tensor.sub_(mean).div_(std)
    209     return tensor

RuntimeError: The size of tensor a (4) must match the size of tensor b (3) at non-singleton dimension 0

About this issue

  • Original URL
  • State: closed
  • Created 4 years ago
  • Comments: 16 (7 by maintainers)

Most upvoted comments

The error seems to be that the PNG images have 4 channels instead of 3, i.e. RGBA vs. RGB. You’ll want to convert your images from RGBA to RGB, which you can do by applying a custom transform when creating your Dataset. For example:

from detecto import core, utils
from torchvision import transforms

augmentations = transforms.Compose([
    transforms.Lambda(lambda x: x[:,:,:3]),

dataset = core.Dataset('images/', transform=augmentations)

Here, the transforms.Lambda(lambda x: x[:,:,:3]) takes in an image and returns only the first three channels (RGB), which should allow you to train your model as normal:

model = core.Model(['dog', 'cat'])

I can’t seem to open your attachments, but here are the answers to your questions:

  1. If you run the code I shared above and it works without any errors (and you can successfully print dataset[0] then your dataset should be valid.
  2. Yes, that shouldn’t be an issue.
  3. Yes, there is the visualize.detect_live function you can use. However, note that it doesn’t work on VMs like Google Colaboratory.

Basically, just try to run the following code to see if any similar errors come up when reading in and plotting your images:

from detecto import core, utils
import matplotlib.pyplot as plt

image = utils.read_image('images/cat0.jpg')

If that doesn’t have any errors, you can also try seeing if the issue is with the dataset:

from torchvision import transforms

augmentations = transforms.Compose([
    transforms.Lambda(lambda x: x[:,:,:3]),

dataset = core.Dataset('images/', transform=augmentations)
