diff --git a/pyart/core/grid.py b/pyart/core/grid.py index 7d36dbc6226..bb8102e4d6f 100644 --- a/pyart/core/grid.py +++ b/pyart/core/grid.py @@ -326,47 +326,129 @@ def to_xarray(self): y = self.y["data"] x = self.x["data"] - time = np.array([num2date(self.time["data"][0], self.time["units"])]) + time = np.array( + [num2date(self.time["data"][0], units=self.time["units"])], + ) ds = xarray.Dataset() - for field in list(self.fields.keys()): - field_data = self.fields[field]["data"] + for field, field_info in self.fields.items(): + field_data = field_info["data"] data = xarray.DataArray( np.ma.expand_dims(field_data, 0), dims=("time", "z", "y", "x"), coords={ - "time": (["time"], time), - "z": (["z"], z), + "time": time, + "z": z, "lat": (["y", "x"], lat), "lon": (["y", "x"], lon), - "y": (["y"], y), - "x": (["x"], x), + "y": y, + "x": x, }, ) - for meta in list(self.fields[field].keys()): + + for meta, value in field_info.items(): if meta != "data": - data.attrs.update({meta: self.fields[field][meta]}) + data.attrs.update({meta: value}) ds[field] = data - ds.lon.attrs = [ - ("long_name", "longitude of grid cell center"), - ("units", "degree_E"), - ("standard_name", "Longitude"), - ] - ds.lat.attrs = [ - ("long_name", "latitude of grid cell center"), - ("units", "degree_N"), - ("standard_name", "Latitude"), - ] - - ds.z.attrs = get_metadata("z") - ds.y.attrs = get_metadata("y") - ds.x.attrs = get_metadata("x") - - ds.z.encoding["_FillValue"] = None - ds.lat.encoding["_FillValue"] = None - ds.lon.encoding["_FillValue"] = None - ds.close() + + ds.lon.attrs = [ + ("long_name", "longitude of grid cell center"), + ("units", "degree_E"), + ("standard_name", "Longitude"), + ] + ds.lat.attrs = [ + ("long_name", "latitude of grid cell center"), + ("units", "degree_N"), + ("standard_name", "Latitude"), + ] + + ds.z.attrs = get_metadata("z") + ds.y.attrs = get_metadata("y") + ds.x.attrs = get_metadata("x") + + for attr in [ds.z, ds.lat, ds.lon]: + attr.encoding["_FillValue"] = None + + # Delayed import + from ..io.grid_io import _make_coordinatesystem_dict + + ds.coords["ProjectionCoordinateSystem"] = xarray.DataArray( + data=np.array(1, dtype="int32"), + attrs=_make_coordinatesystem_dict(self), + ) + + # write the projection dictionary as a scalar + projection = self.projection.copy() + # NetCDF does not support boolean attribute, covert to string + if "_include_lon_0_lat_0" in projection: + include = projection["_include_lon_0_lat_0"] + projection["_include_lon_0_lat_0"] = ["false", "true"][include] + ds.coords["projection"] = xarray.DataArray( + data=np.array(1, dtype="int32"), + dims=None, + attrs=projection, + ) + + for attr_name in [ + "origin_latitude", + "origin_longitude", + "origin_altitude", + "radar_altitude", + "radar_latitude", + "radar_longitude", + "radar_time", + ]: + if hasattr(self, attr_name): + attr_data = getattr(self, attr_name) + if attr_data is not None: + if attr_name in [ + "origin_latitude", + "origin_longitude", + "origin_altitude", + ]: + # Adjusting the dims to 'time' for the origin attributes + attr_value = np.ma.expand_dims(attr_data["data"][0], 0) + dims = ("time",) + else: + if "radar_time" not in attr_name: + attr_value = np.ma.expand_dims(attr_data["data"][0], 0) + else: + attr_value = [ + np.array( + num2date( + attr_data["data"][0], + units=attr_data["units"], + ), + dtype="datetime64[ns]", + ) + ] + dims = ("nradar",) + + ds.coords[attr_name] = xarray.DataArray( + attr_value, dims=dims, attrs=get_metadata(attr_name) + ) + + if "radar_time" in ds.variables: + ds.radar_time.attrs.pop("calendar") + + if self.radar_name is not None: + radar_name = self.radar_name["data"] + ds["radar_name"] = xarray.DataArray( + np.array([b"".join(radar_name)]), + dims=("nradar"), + attrs=get_metadata("radar_name"), + ) + + ds.attrs = self.metadata + for key in ds.attrs: + try: + ds.attrs[key] = ds.attrs[key].decode("utf-8") + except AttributeError: + # If the attribute is not a byte string, just pass + pass + + ds.close() return ds def add_field(self, field_name, field_dict, replace_existing=False): @@ -389,7 +471,7 @@ def add_field(self, field_name, field_dict, replace_existing=False): if "data" not in field_dict: raise KeyError('Field dictionary must contain a "data" key') if field_name in self.fields and replace_existing is False: - raise ValueError(f"A field named {field_name} already exists") + raise ValueError("A field named %s already exists" % (field_name)) if field_dict["data"].shape != (self.nz, self.ny, self.nx): raise ValueError("Field has invalid shape")