xbatcher: Generating the batches seems slow
I’ve just come across xbatcher, and I think it could be just what I need for using CNNs on data stored in dask-backed xarrays. I’ve got a number of questions about how it works, and some issues I’m having. If this isn’t the appropriate place for these questions then please let me know, and I’ll direct them elsewhere. I decided not to create issues for each question, as I expect a number of them aren’t actually problems with xbatcher, they’re problems with my understanding instead - that would clog up the issues board - but if some of these questions need extracting to a separate issue then I’m happy to do that.
Firstly, thanks for putting this together - it has already solved a lot of problems for me.
To give some context, I’m trying to use xbatcher to run batches of image data through a pytorch CNN on Microsoft Planetary Computer. I’m not doing training here, I’m just doing inference - so I just need to push the raw array data through the ML model and get the results out.
Now on to the questions:
1. Generating the batches seems slow
I’m trying to create batches from a DataArray of size (8172, 8082), which is a single band of a satellite image. I’m using the following call to create a BatchGenerator:
patch_size = 64
batch_size = 10
bgen = b1.batch.generator(input_dims=dict(x=patch_size, y=patch_size),
batch_dims=dict(x=batch_size*patch_size, y=batch_size*patch_size),
concat_input_dims=True, preload_batch=False)
That should create DataArrays that are 64 x 64 (in x and y), with 100 of those entries in the batch.
I’m then running a loop over the batch generator, doing something with the batches. We’ll come to what I’m doing later - but for the moment lets just append the result to a list:
results = []
for batch in tqdm.tqdm(bgen):
results.append(batch)
This takes around 1s per batch, and creates a very small Dask task that goes away and generates the batch (I’ve already run b1.persist() to ensure all the data is on the Dask cluster). I have a few questions about this:
a) Is this sort of speed expected? From some rough calculations at 1s per batch, for a 64 x 64 batch, it’ll take hours to batch up my ~8000x8000 array) b) With preload_batch=False I’d expect these to be generated lazily - and it does seem that the underlying data in the DataArray is a dask array - however it still seems to take around a second per batch. c) Should I be approaching this in a different way to get a better speed?
2. How do you put batches back together after processing?
My machine learning model is producing a single value as an output, so for a batch of 100 64x64 patches, I get an output of a 100-element array. What’s the best way of putting this back into a DataArray that has the same format/co-ordinates as the original input array? I’d be happy with either an array with dimensions of original_size / 64 in both the x and y dimension, or an array of the same size as the input with the single output value repeated for each of the input pixels in that batch.
I’ve tried to put some of this together myself, but it seems that the x co-ordinate value in the batch DataArray is the same for each batch. I’d have thought this would represent the x co-ordinates that had been extracted from the original DataArray, but it doesn’t seem to. For example, if I run:
batches = []
for i, batch in enumerate(bgen):
batches.append(batch)
if i == 1:
break
to get the first two batches, I can then compare their x co-ordinate values:
np.all(batches[0].to_array().squeeze().x == batches[1].to_array().squeeze().x)
and it shows that they’re all equal.
Do you have any ideas as to what I could do to be able to put the batches back together?
3. Documentation and tutorial notebook It took me quite a while to find the example notebook that is sitting in #31 - but the notebook was really helpful (actually a lot more helpful than the documentation on ReadTheDocs). Could this be merged soon, and a prominent link put to it in the docs/README? I think this would significantly help any other users trying to get to grips with xbatcher.
4. Overlap seems to cause hang Trying to batch the array with an overlap seems to take ages to do anything - I’m not sure whether it has hung or is just taking a long time to do the computations. If I run:
patch_size = 64
batch_size = 10
bgen = b1.batch.generator(input_dims=dict(x=patch_size, y=patch_size),
batch_dims=dict(x=batch_size*patch_size, y=batch_size*patch_size),
input_overlap=dict(x=patch_size-1, y=patch_size-1),
concat_input_dims=True,
preload_batch=False)
and then try and get the first two batches:
batches = []
for i, batch in enumerate(bgen):
batches.append(batch)
if i == 1:
break
I leave it running for a few minutes, and nothing seems to happen. When I interrupt it, it seems to be deep inside xarray/pandas index handling code.
Any idea what’s going on?
About this issue
- Original URL
- State: open
- Created 3 years ago
- Reactions: 1
- Comments: 15 (7 by maintainers)
Not exactly…just pointing out that “use a for-loop” might be a better path for folks like OP until the library matures. I’m also trying to lightly nudge you in a certain direction that I think many would find compelling. Apologies if I rubbed you the wrong way. I can be blunt, but I think a library like this has a definite path forward. Not sure why you don’t see opportunities for improving the performance…maybe you just meant single threaded. I think
xbatcherwould be awesome if it was bundled with parallel execution frameworks likerechunkeris.I am having trouble keeping up with this issue because we are discussing at least three (maybe four) separate things at once in a single thread. I think we need to take some time to split this up into multiple distinct issues.
@leifdenby - thanks for #40 - it’s a good idea and someone will try to review it asap and give you feedback.
@nbren12 - It is probably impossible for xbatcher to ever beat the baseline you defined above without bypassing Xarray completely. I think you know this. If that’s a dealbreaker for you, no one is going to force you to use xbatcher. You seem to be implying that we should abandon the project entirely. Is that the correct interpretation of your comment?
@robintw - it would be fantastic if you could edit your original issue and split items 2, 3, 4, into distinct issues.
I do see @rabernat’s point about establishing APIs and then optimizing, but I’m not sure many will use this project without convincing performance benchmarks.
The competition for
xbatcheris simple for-loop like this:I suspect this code is immediately clear to most xarray users…and any bugs can be quickly fixed without interacting with upstream and learning a new code base. I personally would only replace this code with an external dependency if it were much faster. Clean abstractions should be weighed against the substantial maintenance burden added by taking on a dependency like xbatcher.
@RichardScottOZ Oh wow, I can’t believe I was that silly. Yes, the y co-ordinates are different and the x co-ordinates are the same, which makes perfect sense.
I still can’t seem to work out what exactly I need to do to stitch them back together, but at least now I know that I have the information required.
Sorry for wasting people’s time with my mistake - but I hope the other parts of my questions are still valid.