xarray: multidim groupby on dask arrays: dask.array.reshape error

If I try to run a groupby operation using a multidimensional group, I get an error from dask about “dask.array.reshape requires that reshaped dimensions after the first contain at most one chunk”.

This error is arises with dask 0.11.0 but NOT dask 0.8.0.

Consider the following test example:

import dask.array as da
import xarray as xr

nz, ny, nx = (10,20,30)
data = da.ones((nz,ny,nx), chunks=(5,ny,nx))
coord_2d = da.random.random((ny,nx), chunks=(ny,nx))>0.5
ds = xr.Dataset({'thedata': (('z','y','x'), data)},
                coords={'thegroup': (('y','x'), coord_2d)})
# this works fine
ds.thedata.groupby('thegroup')

Now I rechunk one of the later dimensions and group again:

ds.chunk({'x': 5}).thedata.groupby('thegroup')

This raises the following error and stack trace

ValueError                                Traceback (most recent call last)
<ipython-input-16-1b0095ee24a0> in <module>()
----> 1 ds.chunk({'x': 5}).thedata.groupby('thegroup')

/Users/rpa/RND/open_source/xray/xarray/core/common.pyc in groupby(self, group, squeeze)
    343         if isinstance(group, basestring):
    344             group = self[group]
--> 345         return self.groupby_cls(self, group, squeeze=squeeze)
    346 
    347     def groupby_bins(self, group, bins, right=True, labels=None, precision=3,

/Users/rpa/RND/open_source/xray/xarray/core/groupby.pyc in __init__(self, obj, group, squeeze, grouper, bins, cut_kwargs)
    170             # the copy is necessary here, otherwise read only array raises error
    171             # in pandas: https://github.com/pydata/pandas/issues/12813>
--> 172             group = group.stack(**{stacked_dim_name: orig_dims}).copy()
    173             obj = obj.stack(**{stacked_dim_name: orig_dims})
    174             self._stacked_dim = stacked_dim_name

/Users/rpa/RND/open_source/xray/xarray/core/dataarray.pyc in stack(self, **dimensions)
    857         DataArray.unstack
    858         """
--> 859         ds = self._to_temp_dataset().stack(**dimensions)
    860         return self._from_temp_dataset(ds)
    861 

/Users/rpa/RND/open_source/xray/xarray/core/dataset.pyc in stack(self, **dimensions)
   1359         result = self
   1360         for new_dim, dims in dimensions.items():
-> 1361             result = result._stack_once(dims, new_dim)
   1362         return result
   1363 

/Users/rpa/RND/open_source/xray/xarray/core/dataset.pyc in _stack_once(self, dims, new_dim)
   1322                     shape = [self.dims[d] for d in vdims]
   1323                     exp_var = var.expand_dims(vdims, shape)
-> 1324                     stacked_var = exp_var.stack(**{new_dim: dims})
   1325                     variables[name] = stacked_var
   1326                 else:

/Users/rpa/RND/open_source/xray/xarray/core/variable.pyc in stack(self, **dimensions)
    801         result = self
    802         for new_dim, dims in dimensions.items():
--> 803             result = result._stack_once(dims, new_dim)
    804         return result
    805 

/Users/rpa/RND/open_source/xray/xarray/core/variable.pyc in _stack_once(self, dims, new_dim)
    771 
    772         new_shape = reordered.shape[:len(other_dims)] + (-1,)
--> 773         new_data = reordered.data.reshape(new_shape)
    774         new_dims = reordered.dims[:len(other_dims)] + (new_dim,)
    775 

/Users/rpa/anaconda/lib/python2.7/site-packages/dask/array/core.pyc in reshape(self, *shape)
   1101         if len(shape) == 1 and not isinstance(shape[0], Number):
   1102             shape = shape[0]
-> 1103         return reshape(self, shape)
   1104 
   1105     @wraps(topk)

/Users/rpa/anaconda/lib/python2.7/site-packages/dask/array/core.pyc in reshape(array, shape)
   2585 
   2586     if any(len(c) != 1 for c in array.chunks[ndim_same+1:]):
-> 2587         raise ValueError('dask.array.reshape requires that reshaped '
   2588                          'dimensions after the first contain at most one chunk')
   2589 

ValueError: dask.array.reshape requires that reshaped dimensions after the first contain at most one chunk

I am using the latest xarray master and dask version 0.11.0. Note that the example works fine if I use an earlier version of dask (e.g. 0.8.0, the only other one I tested.) This suggests an upstream issue with dask, but I wanted to bring it up here first.

About this issue

  • Original URL
  • State: closed
  • Created 8 years ago
  • Comments: 17 (11 by maintainers)

Most upvoted comments

This is what I was looking for:

Frozen(SortedKeysDict({'allpoints': (1, 1, 1, 1, 1......(allpoints)....., 1, 1), 'T': (11L,)}))

So in this case (where the chunk size is already 1), dask.array.reshape could actually work fine and the error is unnecessary (we don’t have the exploding task issue). So this could potentially be fixed upstream in dask.

For now, the best work-around (because you don’t have any memory concerns) is to “rechunk” into a single block along the last axis before reshaping, e.g., .chunk(allpoints=259200) or .chunk(allpoints=1e9) (or something arbitrarily large).