scvi-tools: GPU memory overflow (leak?) for large datasets
We love scVI and use it all the time for batch-effect correction / dataset harmonization. As we scale to larger and larger datasets, we seem to have hit a wall with scVI’s scalability. Datasets with more than ~800k cells (using 2000 highly variable genes) never seem to complete training. I seem to be getting out-of-memory errors from the GPU memory late in training. This is strange, since I would not expect GPU memory usage to grow during training itself.
An error only occurs if you have a dataset with maybe 800k cells or more. Here is a sketch of the script that I run (not the whole thing) to do batch-effect correction:
import pandas as pd
import numpy as np
import scanpy as sc
import scvi
from scvi.dataset import Dataset10X
from scvi.dataset.anndataset import AnnDatasetFromAnnData
from scvi.models import VAE
from scvi.inference import UnsupervisedTrainer
# load h5ad
adata = sc.read_h5ad(filename=args.file)
# batch info
batch = np.array(adata.obs['scvi_batch'].values, dtype=np.double).flatten()
# ensure the batch numbers are in order starting at zero
lookup = dict(zip(np.unique(batch), np.arange(0, np.unique(batch).size)))
batch = np.array([lookup[b] for b in batch], dtype=np.double)
# subset adata_hvg to highly variable genes
adata_hvg = adata.copy()
adata_hvg._inplace_subset_var(adata.var['highly_variable'].values)
# create scvi Dataset object
scvi.set_verbosity('WARNING')
scvi_dataset = AnnDatasetFromAnnData(adata_hvg)
# training params
n_epochs = args.max_epochs
lr = 1e-3
use_cuda = True
# train the model and output model likelihood every epoch
vae = VAE(scvi_dataset.nb_genes,
n_batch=scvi_dataset.n_batches,
n_latent=args.latent_dim,
dispersion=dispersion)
trainer = UnsupervisedTrainer(vae,
scvi_dataset,
train_size=0.9,
use_cuda=use_cuda,
frequency=1,
early_stopping_kwargs={'early_stopping_metric': 'elbo',
'save_best_state_metric': 'elbo',
'patience': 20,
'threshold': 0.5})
print('Running scVI...')
trainer.train(n_epochs=n_epochs, lr=lr)
print('scVI complete.')
# obtain latents from scvi
posterior = trainer.create_posterior(trainer.model, scvi_dataset, indices=np.arange(len(scvi_dataset))).sequential()
latent, _, _ = posterior.get_latent() # https://github.com/YosefLab/scVI/blob/master/tests/notebooks/scanpy_pbmc3k.ipynb
# store scvi latents as part of the AnnData object
adata.obsm['X_scvi'] = latent
Below is an example of the output from a log file. I run scVI with a maximum of 300 epochs, but use early stopping, so it can terminate before that. In this example, we get to epoch 142 and the process is then killed. Because this actually happens during training, and I do not get the scVI complete.
message from my script, I know that the memory error was GPU memory and not CPU memory after training had completed.
Linux error: Killed
Starting workflow.
Working on file.h5ad
scvi_script.py --file file.h5ad --latent_dim 50 --umap --max_epochs 300
Dataset for scvi:
GeneExpressionDataset object with n_cells x nb_genes = 729211 x 2364
gene_attribute_names: 'highly_variable', 'gene_names'
cell_attribute_names: 'local_means', 'batch_indices', 'scvi_batch', 'local_vars', 'labels'
cell_categorical_attribute_names: 'batch_indices', 'labels'
Running scVI...
training: 0%| | 0/300 [00:00<?, ?it/s] training: 0%| | 1/300 [02:18<11:30:33, 138.58s/it] training: 1%| | 2/300 [04:37<11:29:01, 138.73s/it] training: 1%| | 3/300 [06:54<11:23:30, 138.08s/it] training: 1%|▏ | 4/300 [09:13<11:22:59, 138.45s/it] training: 2%|▏ | 5/300 [11:32<11:21:48, 138.67s/it] training: 2%|▏ | 6/300 [13:49<11:16:16, 138.02s/it] training: 2%|▏ | 7/300 [16:03<11:08:58, 136.99s/it] training: 3%|▎ | 8/300 [18:21<11:07:15, 137.11s/it] training: 3%|▎ | 9/300 [20:36<11:02:53, 136.68s/it] training: 3%|▎ | 10/300 [22:55<11:03:29, 137.28s/it] training: 4%|▎ | 11/300 [25:10<10:57:34, 136.52s/it] training: 4%|▍ | 12/300 [27:28<10:57:01, 136.88s/it] training: 4%|▍ | 13/300 [29:45<10:56:17, 137.20s/it] training: 5%|▍ | 14/300 [32:04<10:55:22, 137.49s/it] training: 5%|▌ | 15/300 [34:19<10:50:05, 136.86s/it] training: 5%|▌ | 16/300 [36:37<10:49:39, 137.25s/it] training: 6%|▌ | 17/300 [38:52<10:44:27, 136.63s/it] training: 6%|▌ | 18/300 [41:12<10:46:21, 137.52s/it] training: 6%|▋ | 19/300 [43:34<10:49:57, 138.78s/it] training: 7%|▋ | 20/300 [45:52<10:47:08, 138.67s/it] training: 7%|▋ | 21/300 [48:10<10:43:18, 138.35s/it] training: 7%|▋ | 22/300 [50:30<10:43:23, 138.86s/it] training: 8%|▊ | 23/300 [52:51<10:45:00, 139.71s/it] training: 8%|▊ | 24/300 [55:08<10:38:25, 138.79s/it] training: 8%|▊ | 25/300 [57:21<10:28:01, 137.02s/it] training: 9%|▊ | 26/300 [59:36<10:23:10, 136.46s/it] training: 9%|▉ | 27/300 [1:02:01<10:32:14, 138.95s/it] training: 9%|▉ | 28/300 [1:04:25<10:36:37, 140.43s/it] training: 10%|▉ | 29/300 [1:06:47<10:36:28, 140.92s/it] training: 10%|█ | 30/300 [1:09:08<10:34:11, 140.93s/it] training: 10%|█ | 31/300 [1:11:26<10:28:16, 140.14s/it] training: 11%|█ | 32/300 [1:13:42<10:20:53, 139.00s/it] training: 11%|█ | 33/300 [1:16:00<10:16:34, 138.56s/it] training: 11%|█▏ | 34/300 [1:18:21<10:17:45, 139.35s/it] training: 12%|█▏ | 35/300 [1:20:40<10:15:08, 139.28s/it] training: 12%|█▏ | 36/300 [1:22:54<10:05:51, 137.70s/it] training: 12%|█▏ | 37/300 [1:25:10<10:00:29, 136.99s/it] training: 13%|█▎ | 38/300 [1:27:29<10:00:43, 137.57s/it] training: 13%|█▎ | 39/300 [1:29:48<10:00:21, 138.01s/it] training: 13%|█▎ | 40/300 [1:32:02<9:52:49, 136.81s/it] training: 14%|█▎ | 41/300 [1:34:23<9:56:40, 138.22s/it] training: 14%|█▍ | 42/300 [1:36:43<9:56:13, 138.66s/it] training: 14%|█▍ | 43/300 [1:39:05<9:58:24, 139.71s/it] training: 15%|█▍ | 44/300 [1:41:22<9:53:13, 139.04s/it] training: 15%|█▌ | 45/300 [1:43:43<9:53:13, 139.58s/it] training: 15%|█▌ | 46/300 [1:45:58<9:44:24, 138.05s/it] training: 16%|█▌ | 47/300 [1:48:17<9:43:02, 138.27s/it] training: 16%|█▌ | 48/300 [1:50:36<9:42:05, 138.59s/it] training: 16%|█▋ | 49/300 [1:52:49<9:33:29, 137.09s/it] training: 17%|█▋ | 50/300 [1:55:05<9:29:46, 136.74s/it] training: 17%|█▋ | 51/300 [1:57:17<9:21:20, 135.26s/it] training: 17%|█▋ | 52/300 [1:59:31<9:16:58, 134.75s/it] training: 18%|█▊ | 53/300 [2:01:48<9:18:04, 135.57s/it] training: 18%|█▊ | 54/300 [2:04:03<9:14:46, 135.31s/it] training: 18%|█▊ | 55/300 [2:06:16<9:10:12, 134.74s/it] training: 19%|█▊ | 56/300 [2:08:32<9:08:30, 134.88s/it] training: 19%|█▉ | 57/300 [2:10:44<9:03:29, 134.20s/it] training: 19%|█▉ | 58/300 [2:12:59<9:02:09, 134.42s/it] training: 20%|█▉ | 59/300 [2:15:19<9:07:06, 136.21s/it] training: 20%|██ | 60/300 [2:17:36<9:05:39, 136.41s/it] training: 20%|██ | 61/300 [2:19:58<9:09:28, 137.94s/it] training: 21%|██ | 62/300 [2:22:12<9:02:45, 136.83s/it] training: 21%|██ | 63/300 [2:24:28<8:59:49, 136.66s/it] training: 21%|██▏ | 64/300 [2:26:45<8:57:32, 136.66s/it] training: 22%|██▏ | 65/300 [2:29:04<8:57:49, 137.32s/it] training: 22%|██▏ | 66/300 [2:31:20<8:53:34, 136.81s/it] training: 22%|██▏ | 67/300 [2:33:38<8:52:46, 137.20s/it] training: 23%|██▎ | 68/300 [2:35:55<8:50:23, 137.17s/it] training: 23%|██▎ | 69/300 [2:38:14<8:50:03, 137.68s/it] training: 23%|██▎ | 70/300 [2:40:29<8:45:01, 136.96s/it] training: 24%|██▎ | 71/300 [2:42:48<8:45:03, 137.57s/it] training: 24%|██▍ | 72/300 [2:45:06<8:42:49, 137.59s/it] training: 24%|██▍ | 73/300 [2:47:21<8:38:05, 136.94s/it] training: 25%|██▍ | 74/300 [2:49:45<8:43:46, 139.05s/it] training: 25%|██▌ | 75/300 [2:52:06<8:43:25, 139.58s/it] training: 25%|██▌ | 76/300 [2:54:19<8:33:57, 137.67s/it] training: 26%|██▌ | 77/300 [2:56:38<8:33:33, 138.18s/it] training: 26%|██▌ | 78/300 [2:59:00<8:35:36, 139.35s/it] training: 26%|██▋ | 79/300 [3:01:24<8:37:40, 140.54s/it] training: 27%|██▋ | 80/300 [3:03:37<8:26:52, 138.24s/it] training: 27%|██▋ | 81/300 [3:05:55<8:24:47, 138.30s/it] training: 27%|██▋ | 82/300 [3:08:12<8:21:31, 138.04s/it] training: 28%|██▊ | 83/300 [3:10:21<8:09:27, 135.33s/it] training: 28%|██▊ | 84/300 [3:12:40<8:10:50, 136.34s/it] training: 28%|██▊ | 85/300 [3:14:56<8:08:02, 136.20s/it] training: 29%|██▊ | 86/300 [3:17:12<8:05:12, 136.04s/it] training: 29%|██▉ | 87/300 [3:19:33<8:08:40, 137.66s/it] training: 29%|██▉ | 88/300 [3:21:55<8:10:41, 138.88s/it] training: 30%|██▉ | 89/300 [3:24:09<8:03:30, 137.49s/it] training: 30%|███ | 90/300 [3:26:28<8:03:08, 138.04s/it] training: 30%|███ | 91/300 [3:28:47<8:01:35, 138.26s/it] training: 31%|███ | 92/300 [3:31:03<7:56:34, 137.47s/it] training: 31%|███ | 93/300 [3:33:27<8:00:43, 139.34s/it] training: 31%|███▏ | 94/300 [3:35:41<7:53:45, 137.99s/it] training: 32%|███▏ | 95/300 [3:38:01<7:53:02, 138.45s/it] training: 32%|███▏ | 96/300 [3:40:14<7:45:39, 136.96s/it] training: 32%|███▏ | 97/300 [3:42:31<7:42:33, 136.72s/it] training: 33%|███▎ | 98/300 [3:44:47<7:39:48, 136.57s/it] training: 33%|███▎ | 99/300 [3:47:02<7:35:54, 136.09s/it] training: 33%|███▎ | 100/300 [3:49:16<7:31:22, 135.41s/it] training: 34%|███▎ | 101/300 [3:51:30<7:28:23, 135.19s/it] training: 34%|███▍ | 102/300 [3:53:49<7:29:44, 136.29s/it] training: 34%|███▍ | 103/300 [3:56:03<7:24:55, 135.51s/it] training: 35%|███▍ | 104/300 [3:58:21<7:25:24, 136.35s/it] training: 35%|███▌ | 105/300 [4:00:40<7:25:43, 137.15s/it] training: 35%|███▌ | 106/300 [4:02:59<7:24:53, 137.60s/it] training: 36%|███▌ | 107/300 [4:05:17<7:22:52, 137.68s/it] training: 36%|███▌ | 108/300 [4:07:41<7:26:49, 139.63s/it] training: 36%|███▋ | 109/300 [4:10:06<7:29:58, 141.35s/it] training: 37%|███▋ | 110/300 [4:12:28<7:28:18, 141.57s/it] training: 37%|███▋ | 111/300 [4:14:51<7:27:17, 142.00s/it] training: 37%|███▋ | 112/300 [4:17:08<7:19:51, 140.38s/it] training: 38%|███▊ | 113/300 [4:19:27<7:16:48, 140.15s/it] training: 38%|███▊ | 114/300 [4:21:46<7:13:19, 139.78s/it] training: 38%|███▊ | 115/300 [4:24:03<7:08:02, 138.83s/it] training: 39%|███▊ | 116/300 [4:26:18<7:01:55, 137.58s/it] training: 39%|███▉ | 117/300 [4:28:32<6:56:48, 136.66s/it] training: 39%|███▉ | 118/300 [4:30:53<6:58:04, 137.82s/it] training: 40%|███▉ | 119/300 [4:33:06<6:51:25, 136.38s/it] training: 40%|████ | 120/300 [4:35:24<6:50:51, 136.95s/it] training: 40%|████ | 121/300 [4:37:43<6:50:23, 137.56s/it] training: 41%|████ | 122/300 [4:39:59<6:46:48, 137.13s/it] training: 41%|████ | 123/300 [4:42:19<6:46:54, 137.94s/it] training: 41%|████▏ | 124/300 [4:44:38<6:45:21, 138.19s/it] training: 42%|████▏ | 125/300 [4:46:56<6:42:52, 138.13s/it] training: 42%|████▏ | 126/300 [4:49:09<6:36:17, 136.65s/it] training: 42%|████▏ | 127/300 [4:51:26<6:34:28, 136.81s/it] training: 43%|████▎ | 128/300 [4:53:42<6:31:11, 136.46s/it] training: 43%|████▎ | 129/300 [4:55:59<6:29:47, 136.77s/it] training: 43%|████▎ | 130/300 [4:58:17<6:28:32, 137.13s/it] training: 44%|████▎ | 131/300 [5:00:36<6:27:52, 137.71s/it] training: 44%|████▍ | 132/300 [5:02:52<6:24:07, 137.19s/it] training: 44%|████▍ | 133/300 [5:05:15<6:26:16, 138.78s/it] training: 45%|████▍ | 134/300 [5:07:38<6:27:32, 140.08s/it] training: 45%|████▌ | 135/300 [5:09:54<6:22:20, 139.04s/it] training: 45%|████▌ | 136/300 [5:12:16<6:21:55, 139.73s/it] training: 46%|████▌ | 137/300 [5:14:32<6:16:53, 138.74s/it] training: 46%|████▌ | 138/300 [5:16:54<6:17:22, 139.77s/it] training: 46%|████▋ | 139/300 [5:19:12<6:13:37, 139.24s/it] training: 47%|████▋ | 140/300 [5:21:25<6:05:56, 137.23s/it] training: 47%|████▋ | 141/300 [5:23:41<6:02:57, 136.97s/it] training: 47%|████▋ | 142/300 [5:25:56<5:58:37, 136.18s/it]
16 Killed python scvi_script.py --file file.h5ad --latent_dim 50 --umap --max_epochs 300
When the number of cells is 500k or 600k, this works just fine!
Versions:
Ubuntu 16.04 scVI 0.6.5 scanpy 1.5.1
Hardware:
Tested on Tesla K80 GPU with 12GB memory, and Tesla P100 GPU with 16GB memory. We see this same failure in both cases.
But the point I wanted to emphasize is that this is not something for which the answer is “more GPU memory”… I think there is some problem happening during training. Why should the memory usage increase during the course of training? Training in mini-batches, what does it matter if the dataset size is 400k or 900k?
About this issue
- Original URL
- State: closed
- Created 4 years ago
- Comments: 18 (15 by maintainers)
Got it. FWIW, this should be one of the most memory-efficient platforms for single-cell integration.
@watiss can you take a look into this? You can use this callback to the trainer:
https://pytorch-lightning.readthedocs.io/en/stable/api/pytorch_lightning.callbacks.gpu_stats_monitor.html