I’m just not understanding this behavior in xarray, and I probably just don’t understand the broadcasting xarray does. I’ve made a contrived example illustrating the issue.
import numpy as np
import xarray as xr
# These are coordinates
years = xr.DataArray(np.arange(2018, 2021), dims="year")
ids = xr.DataArray(np.arange(1, 4), dims="id")
# These are data with different coordinates
year_data = xr.DataArray(np.arange(18, 21), dims="year", coords={"year": years})
id_data = xr.DataArray(['a', 'b', 'c'], dims="id", coords={"id": ids})
comb_data = xr.DataArray(np.arange(9).reshape(3, 3), dims=["year", "id"], coords={"year": years, "id": ids})
# Make a dataset
ds = xr.Dataset(data_vars={"comb_data": comb_data, "id_data": id_data, "year_data": year_data})
This makes:
<xarray.Dataset>
Dimensions: (year: 3, id: 3)
Coordinates:
* year (year) int64 2018 2019 2020
* id (id) int64 1 2 3
Data variables:
comb_data (year, id) int64 0 1 2 3 4 5 6 7 8
id_data (id) <U1 'a' 'b' 'c'
year_data (year) int64 18 19 20
This is what I want, with a 2 data variables that refer to different coordinates and 1 data variable that uses both. I need to set some data to 0, so I use where.
ds.where(ds.coords["id"] == 2, 0)
<xarray.Dataset>
Dimensions: (year: 3, id: 3)
Coordinates:
* year (year) int64 2018 2019 2020
* id (id) int64 1 2 3
Data variables:
comb_data (year, id) int64 0 2 0 0 5 0 0 8 0
id_data (id) object 0 'b' 0
year_data (year, id) int64 0 18 0 0 19 0 0 20 0
Now the year_data dimension includes id and has created data with no meaning. I just need to ignore the dimensions that aren’t involved in this. I can delete the extraneous data after the fact but that doesn’t feel right. Is there a better way to do this?
>Solution :
xr.Dataset.where will always broadcast the variables in the dataset against the supplied argument. Since you told it to mask all the data in the dataset except where ds["id"] == 2, all objects in the dataset are automatically broadcast against the id dimension.
If you don’t want this to happen, you have two options. The first is a good practice in general – unless you really want to carry out an operation across all variables in a dataset, do your operations with the specific variables you want to work with:
# this will only modify the comb_data array
ds["comb_data"] = ds["comb_data"].where(ds["id"] == 2)
See the docs on broadcasting and automatic alignment for more info about xarray’s computing rules.
Another option, if "id_data" and "year_data" are more metadata than true data variables, is to set them as non-dimension coordinates. Then, any computation will skip over these:
In [4]: ds = ds.set_coords(["id_data", "year_data"])
In [5]: ds
Out[5]:
<xarray.Dataset>
Dimensions: (year: 3, id: 3)
Coordinates:
* year (year) int64 2018 2019 2020
* id (id) int64 1 2 3
id_data (id) <U1 'a' 'b' 'c'
year_data (year) int64 18 19 20
Data variables:
comb_data (year, id) int64 0 1 2 3 4 5 6 7 8
In [6]: ds.where(ds["id"] == 2)
Out[6]:
<xarray.Dataset>
Dimensions: (year: 3, id: 3)
Coordinates:
* year (year) int64 2018 2019 2020
* id (id) int64 1 2 3
id_data (id) <U1 'a' 'b' 'c'
year_data (year) int64 18 19 20
Data variables:
comb_data (year, id) float64 nan 1.0 nan nan 4.0 nan nan 7.0 nan
See the docs on coordinates for more information about how coordinates are handled in computation.