xarray: multidim groupby on dask arrays: dask.array.reshape error
If I try to run a groupby operation using a multidimensional group, I get an error from dask about “dask.array.reshape requires that reshaped dimensions after the first contain at most one chunk”.
This error is arises with dask 0.11.0 but NOT dask 0.8.0.
Consider the following test example:
import dask.array as da
import xarray as xr
nz, ny, nx = (10,20,30)
data = da.ones((nz,ny,nx), chunks=(5,ny,nx))
coord_2d = da.random.random((ny,nx), chunks=(ny,nx))>0.5
ds = xr.Dataset({'thedata': (('z','y','x'), data)},
coords={'thegroup': (('y','x'), coord_2d)})
# this works fine
ds.thedata.groupby('thegroup')
Now I rechunk one of the later dimensions and group again:
ds.chunk({'x': 5}).thedata.groupby('thegroup')
This raises the following error and stack trace
ValueError Traceback (most recent call last)
<ipython-input-16-1b0095ee24a0> in <module>()
----> 1 ds.chunk({'x': 5}).thedata.groupby('thegroup')
/Users/rpa/RND/open_source/xray/xarray/core/common.pyc in groupby(self, group, squeeze)
343 if isinstance(group, basestring):
344 group = self[group]
--> 345 return self.groupby_cls(self, group, squeeze=squeeze)
346
347 def groupby_bins(self, group, bins, right=True, labels=None, precision=3,
/Users/rpa/RND/open_source/xray/xarray/core/groupby.pyc in __init__(self, obj, group, squeeze, grouper, bins, cut_kwargs)
170 # the copy is necessary here, otherwise read only array raises error
171 # in pandas: https://github.com/pydata/pandas/issues/12813>
--> 172 group = group.stack(**{stacked_dim_name: orig_dims}).copy()
173 obj = obj.stack(**{stacked_dim_name: orig_dims})
174 self._stacked_dim = stacked_dim_name
/Users/rpa/RND/open_source/xray/xarray/core/dataarray.pyc in stack(self, **dimensions)
857 DataArray.unstack
858 """
--> 859 ds = self._to_temp_dataset().stack(**dimensions)
860 return self._from_temp_dataset(ds)
861
/Users/rpa/RND/open_source/xray/xarray/core/dataset.pyc in stack(self, **dimensions)
1359 result = self
1360 for new_dim, dims in dimensions.items():
-> 1361 result = result._stack_once(dims, new_dim)
1362 return result
1363
/Users/rpa/RND/open_source/xray/xarray/core/dataset.pyc in _stack_once(self, dims, new_dim)
1322 shape = [self.dims[d] for d in vdims]
1323 exp_var = var.expand_dims(vdims, shape)
-> 1324 stacked_var = exp_var.stack(**{new_dim: dims})
1325 variables[name] = stacked_var
1326 else:
/Users/rpa/RND/open_source/xray/xarray/core/variable.pyc in stack(self, **dimensions)
801 result = self
802 for new_dim, dims in dimensions.items():
--> 803 result = result._stack_once(dims, new_dim)
804 return result
805
/Users/rpa/RND/open_source/xray/xarray/core/variable.pyc in _stack_once(self, dims, new_dim)
771
772 new_shape = reordered.shape[:len(other_dims)] + (-1,)
--> 773 new_data = reordered.data.reshape(new_shape)
774 new_dims = reordered.dims[:len(other_dims)] + (new_dim,)
775
/Users/rpa/anaconda/lib/python2.7/site-packages/dask/array/core.pyc in reshape(self, *shape)
1101 if len(shape) == 1 and not isinstance(shape[0], Number):
1102 shape = shape[0]
-> 1103 return reshape(self, shape)
1104
1105 @wraps(topk)
/Users/rpa/anaconda/lib/python2.7/site-packages/dask/array/core.pyc in reshape(array, shape)
2585
2586 if any(len(c) != 1 for c in array.chunks[ndim_same+1:]):
-> 2587 raise ValueError('dask.array.reshape requires that reshaped '
2588 'dimensions after the first contain at most one chunk')
2589
ValueError: dask.array.reshape requires that reshaped dimensions after the first contain at most one chunk
I am using the latest xarray master and dask version 0.11.0. Note that the example works fine if I use an earlier version of dask (e.g. 0.8.0, the only other one I tested.) This suggests an upstream issue with dask, but I wanted to bring it up here first.
About this issue
- Original URL
- State: closed
- Created 8 years ago
- Comments: 17 (11 by maintainers)
This is what I was looking for:
So in this case (where the chunk size is already 1), dask.array.reshape could actually work fine and the error is unnecessary (we don’t have the exploding task issue). So this could potentially be fixed upstream in dask.
For now, the best work-around (because you don’t have any memory concerns) is to “rechunk” into a single block along the last axis before reshaping, e.g.,
.chunk(allpoints=259200)
or.chunk(allpoints=1e9)
(or something arbitrarily large).