import xarray as xr
import numpy as np
from sklearn.neighbors import KDTree
from sklearn.metrics import r2_score, mean_squared_error
xr.set_options(display_max_rows=50)
xr.set_options(display_width=800)
high = xr.open_dataset(r'F:\pr\pr_2000y_12geyue.nc')
high_values = high.pr.values
extent = high.isel(time=0).pr.values
high_lon = high.lon.values
high_lat = high.lat.values
high_ilon, high_ilat = np.meshgrid(high_lon, high_lat)
valid_index = np.where(np.isnan(extent)==False)
high_multiyearmonthlymean = high.groupby(high.time.dt.month).mean()
high_grids_lonlat = np.vstack((high_ilon[valid_index].ravel(), high_ilat[valid_index].ravel())).T
low = xr.open_dataset(r'F:\pr\cru_ts4.07.1901.2022.pre.dat.nc')
low_values = low.pre.values
low_lon = low.lon.values
low_lat = low.lat.values
low_ilon, low_ilat = np.meshgrid(low_lon, low_lat)
low_grids_lonlat = np.vstack((low_ilon.ravel(), low_ilat.ravel())).T
tree = KDTree(low_grids_lonlat, leaf_size=2)
low_id = []
for i in range(high_grids_lonlat.shape[0]):
dist, ind = tree.query(high_grids_lonlat[i, :].reshape(-1, 2), k=1)
low_id.append(ind)
resample = np.zeros((low_values.shape[0], high_values.shape[1], high_values.shape[2]),
dtype=np.float32) * np.nan
for i in range(low_values.shape[0]):
low_value = low_values[i, :, :].ravel()
resample[i, valid_index[0], valid_index[1]] = low_value[low_id].ravel()
resample = xr.Dataset({"pre": (('time', 'lat', 'lon'), resample)},
coords={"time": low.time, "lat": high.lat, "lon": high.lon})
# resample.to_netcdf(r'G:\111kkk\resample.nc')
# 使用低分辨率和高分辨率重合时段的数据计算delta
resample = xr.open_dataset(r'F:\pr\resample.nc')
resample_section = resample.sel(time=resample.time.dt.year.isin(range(1961, 2000)))
resample_multiyearmonthlymean = resample_section.groupby(resample_section.time.dt.month).mean()
delta = high_multiyearmonthlymean.pre / resample_multiyearmonthlymean.pre
# delta.to_netcdf(r'G:\111kkk\delta.nc')
result = []
for i in range(1, 13):
data = resample.sel(time=resample.time.dt.month==i) * delta.isel(month=i-1)
data = xr.Dataset({"pre": (('time', 'lat', 'lon'), data.pre.values.astype("float32"))},
coords={"time": data.time.values, "lat": data.lat.values, "lon": data.lon.values})
result.append(data)
bc_down_scale = xr.merge(result)
bc_down_scale = xr.Dataset({"pre": (('time', 'lat', 'lon'), bc_down_scale.pre.values.astype("float32"))},
coords={"time": low.time, "lat": high.lat, "lon": high.lon})
bc_down_scale.to_netcdf(r'F:\pr\delta_down_scale.nc')