litdata: The tested speed is not as fast as expected.

🐛 Bug

The tested speed is not as fast as expected.

Code sample

import os
import torch
import numpy as np
from tqdm import tqdm
from torchvision.transforms import Compose, Lambda
from litdata import StreamingDataset, StreamingDataLoader

from torchvision.transforms._transforms_video import NormalizeVideo, RandomCropVideo, RandomHorizontalFlipVideo, CenterCropVideo

input_dir = 's3://extract_frames/'
OPENAI_DATASET_MEAN = (0.48145466, 0.4578275, 0.40821073)
OPENAI_DATASET_STD = (0.26862954, 0.26130258, 0.27577711)

class ImagenetStreamingDataset(StreamingDataset):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.transform = Compose(
            [
                Lambda(lambda x: x / 255.0),
                NormalizeVideo(mean=OPENAI_DATASET_MEAN, std=OPENAI_DATASET_STD),
                # ShortSideScale(size=224),
                CenterCropVideo(224),
            ]
        )
    
    def __getitem__(self, index):
        data = super().__getitem__(index)
        video_data = []
        for i in range(8):
            frame = np.array(data["image"][i])
            video_data.append(torch.from_numpy(frame).permute(2, 0, 1))
        video_data = torch.stack(video_data, dim=1)
        video_data = self.transform(video_data)
        return video_data

dataset = ImagenetStreamingDataset(input_dir, shuffle=True)
dataloader = StreamingDataLoader(dataset, batch_size=64, num_workers=8)
for batch in tqdm(dataloader, total=len(dataloader)):
    pass

Expected behavior

There are approximately 200,000 data points, each consisting of 8 frames extracted. Based on the tested speed, it should be very fast, but in reality, it is not.

Screenshot 2024-03-07 at 20 42 20

The tested speed is approximately as follows: Screenshot 2024-03-07 at 20 48 30

Environment

  • PyTorch Version (e.g., 1.0): 2.2.1
  • OS (e.g., Linux): linux
  • How you installed PyTorch (conda, pip, source): pip
  • Python version: 3.9
  • CUDA/cuDNN version:11.6

About this issue

  • Original URL
  • State: open
  • Created 4 months ago
  • Comments: 19

Most upvoted comments

@tikboaHIT I also fixed the chunk_bytes not being correct with the optimize operator.

A more efficient one would be to encode the images in JPEG as follow:

import os
import torch
from tqdm import tqdm
from litdata import optimize, StreamingDataset
from PIL import Image
from io import BytesIO

def generate_images(video_path):
    images = []
    for _ in range(8):
        random_image = torch.randint(0, 255, (320, 568, 3), dtype=torch.uint8).numpy()
        buff = BytesIO()
        Image.fromarray(random_image).save(buff, quality=90, format='JPEG') # You can implement a better resizing logic
        buff.seek(0)
        img = buff.read()
        images.append(Image.open(BytesIO(img)))
    return {
        "name": video_path,
        "image": images,
    }

optimize(
    fn=generate_images,
    inputs=list(range(100)),
    output_dir="/teamspace/datasets/videos5",
    num_workers=1,#,os.cpu_count(),
    chunk_bytes="64MB",
)

When streaming it from the cloud, it takes 1 seconds now.

100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:01<00:00, 50.08it/s]

Additionally, I recommend using torchvision.transforms.v2 which are roughly 40% faster at resizing the images, etc…

But alternatively, we support videos from torchvision video support: https://pytorch.org/audio/stable/build.ffmpeg.html. If you convert your clips into av1 format, they should get super small. You should be able to stream them easily and de-serialize them faster. Worth exploring.

It should appear where you run the command. Maybe reduce the batch size and number of workers.

Could you provide a synthetic example for me to debug it too ? This helps tremendously to optimize those things. Here is another user synthetic script: #62 (comment) as a reference.

Sure:

import torch
from tqdm import tqdm
from litdata import optimize, StreamingDataset

def generate_images(video_path):
    data = {
        "name": video_path,
        "image": torch.rand((3, 8, 320, 568)),
    }
    return data

optimize(
    fn=generate_images,
    inputs=list(range(100)),
    output_dir="/root/data/example_data/chunk_cache",
    num_workers=1,
    chunk_bytes="256MB",
)

input_dir = '/root/data/example_data/chunk_cache'
dataset = StreamingDataset(input_dir, shuffle=True)
for data in tqdm(dataset):
    pass

input_dir = 's3://pzhao/data/example_data/chunk_cache'
dataset = StreamingDataset(input_dir, shuffle=True)
for data in tqdm(dataset):
    pass

The speed is as follows when I load from local. Screenshot 2024-03-08 at 19 48 10

The speed is as follows when I load from s3. Screenshot 2024-03-08 at 19 48 29