Source code for icenet.plotting.utils

import datetime as dt
import glob
import logging
import os
import re

import cartopy.crs as ccrs
import cartopy.feature as cfeature
import dask.array as da
import imageio_ffmpeg as ffmpeg
import matplotlib as mpl
import matplotlib.patches as patches
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import rioxarray
import xarray as xr

from cartopy.feature import ShapelyFeature, NaturalEarthFeature
from cartopy.feature import AdaptiveScaler
from functools import cache
from ibicus.debias import LinearScaling
from matplotlib.path import Path
from pyproj import CRS, Transformer
from rasterio.enums import Resampling
from shapely.geometry import Polygon

from icenet.data.sic.mask import Masks

[docs] def broadcast_forecast(start_date: object, end_date: object, datafiles: object = None, dataset: object = None, target: object = None) -> object: """ :param start_date: :param end_date: :param datafiles: :param dataset: :param target: :return: """ assert (datafiles is None) ^ (dataset is None), \ "Only one of datafiles and dataset can be set" if datafiles: logging.info("Using {} to generate forecast through {} to {}".format( ", ".join(datafiles), start_date, end_date)) dataset = xr.open_mfdataset(datafiles, engine="netcdf4") dates = pd.date_range(start_date, end_date) i = 0 logging.debug("Dataset summary: \n{}".format(dataset)) if len(dataset.time.values) > 1: while dataset.time.values[i + 1] < dates[0]: i += 1 logging.info("Starting index will be {} for {} - {}".format( i, dates[0], dates[-1])) dt_arr = [] for d in dates: logging.debug("Looking for date {}".format(d)) arr = None while arr is None: if d >= dataset.time.values[i]: d_lead = (d - dataset.time.values[i]).days if i + 1 < len(dataset.time.values): if pd.to_datetime(dataset.time.values[i]) + \ dt.timedelta(days=d_lead) >= \ pd.to_datetime(dataset.time.values[i + 1]) + \ dt.timedelta(days=1): i += 1 continue logging.debug("Selecting date {} and lead {}".format( pd.to_datetime(dataset.time.values[i]).strftime("%D"), d_lead)) arr = dataset.sel(time=dataset.time.values[i], leadtime=d_lead).\ copy().\ drop("time").\ assign_coords(dict(time=d)).\ drop("leadtime") else: i += 1 dt_arr.append(arr) target_ds = xr.concat(dt_arr, dim="time") if target: logging.info("Saving dataset to {}".format(target)) target_ds.to_netcdf(target) return target_ds
[docs] def get_seas_forecast_init_dates( hemisphere: str, source_path: object = os.path.join(".", "data", "mars.seas") ) -> object: """ Obtains list of dates for which we have SEAS forecasts we have. :param hemisphere: string, typically either 'north' or 'south' :param source_path: path where north and south SEAS forecasts are stored :return: list of dates """ # list the files in the path where SEAS forecasts are stored filenames = os.listdir(os.path.join(source_path, hemisphere, "siconca")) # obtain the dates from files with YYYYMMDD.nc format return pd.to_datetime( [x.split('.')[0] for x in filenames if re.search(r'^\d{8}\.nc$', x)])
[docs] def get_seas_forecast_da( hemisphere: str, date: str, bias_correct: bool = True, source_path: object = os.path.join(".", "data", "mars.seas"), ) -> tuple: """ Atmospheric model Ensemble 15-day forecast (Set III - ENS) Coordinates: * time (time) datetime64[ns] 2022-04-01 ... 2022-0... * yc (yc) float64 5.388e+06 ... -5.388e+06 * xc (xc) float64 -5.388e+06 ... 5.388e+06 :param hemisphere: string, typically either 'north' or 'south' :param date: :param bias_correct: :param source_path: """ seas_file = os.path.join( source_path, hemisphere, "siconca", "{}.nc".format(date.replace(day=1).strftime("%Y%m%d"))) if os.path.exists(seas_file): seas_da = xr.open_dataset(seas_file).siconc else: logging.warning("No SEAS data available at {}".format(seas_file)) return None if bias_correct: # Let's have some maximum, though it's quite high (start_date, end_date) = (date - dt.timedelta(days=10 * 365), date + dt.timedelta(days=10 * 365)) obs_da = get_obs_da(hemisphere, start_date, end_date) seas_hist_files = dict( sorted({ os.path.abspath(el): dt.datetime.strptime(os.path.basename(el)[0:8], "%Y%m%d") for el in glob.glob( os.path.join(source_path, hemisphere, "siconca", "*.nc")) if re.search(r'^\d{8}\.nc$', os.path.basename(el)) and el != seas_file }.items())) def strip_overlapping_time(ds): data_file = os.path.abspath(ds.encoding["source"]) try: idx = list(seas_hist_files.keys()).index(data_file) except ValueError: logging.exception("\n{} not in \n\n{}".format( data_file, seas_hist_files)) return None if idx < len(seas_hist_files) - 1: max_date = seas_hist_files[ list(seas_hist_files.keys())[idx + 1]] \ - dt.timedelta(days=1) logging.debug("Stripping {} to {}".format(data_file, max_date)) return ds.sel(time=slice(None, max_date)) else: logging.debug("Not stripping {}".format(data_file)) return ds hist_da = xr.open_mfdataset(seas_hist_files, preprocess=strip_overlapping_time).siconc debiaser = LinearScaling(delta_type="additive", variable="siconc", reasonable_physical_range=[0., 1.]) logging.info("Debiaser input ranges: obs {:.2f} - {:.2f}, " "hist {:.2f} - {:.2f}, fut {:.2f} - {:.2f}".format( float(obs_da.min()), float(obs_da.max()), float(hist_da.min()), float(hist_da.max()), float(seas_da.min()), float(seas_da.max()))) seas_array = debiaser.apply(obs_da.values, hist_da.values, seas_da.values) seas_da.values = seas_array logging.info("Debiaser output range: {:.2f} - {:.2f}".format( float(seas_da.min()), float(seas_da.max()))) logging.info("Returning SEAS data from {} from {}".format(seas_file, date)) # This isn't great looking, but we know we're not dealing with huge # indexes in here date_location = list(seas_da.time.values).index(pd.Timestamp(date)) if date_location > 0: logging.warning("SEAS forecast started {} day before the requested " "date {}, make sure you account for this!".format( date_location, date)) seas_da = seas_da.sel(time=slice(date, None)) logging.debug("SEAS data range: {} - {}, {} dates".format( pd.to_datetime(min(seas_da.time.values)).strftime("%Y-%m-%d"), pd.to_datetime(max(seas_da.time.values)).strftime("%Y-%m-%d"), len(seas_da.time))) return seas_da
[docs] def get_forecast_ds(forecast_file: object, forecast_date: str, stddev: bool = False) -> object: """ :param forecast_file: a path to a .nc file :param forecast_date: initialisation date of the forecast :param stddev: :returns tuple(fc_ds, obs_ds, land_mask): """ forecast_date = pd.to_datetime(forecast_date) forecast_ds = xr.open_dataset(forecast_file, decode_coords="all") get_key = "sic_mean" if not stddev else "sic_stddev" forecast_ds = getattr( forecast_ds.sel(time=slice(forecast_date, forecast_date)), get_key) return forecast_ds
[docs] def filter_ds_by_obs(ds: object, obs_da: object, forecast_date: str) -> object: """ :param ds: :param obs_da: :param forecast_date: initialisation date of the forecast :return: """ forecast_date = pd.to_datetime(forecast_date) (start_date, end_date) = (forecast_date + dt.timedelta(days=int(ds.leadtime.min())), forecast_date + dt.timedelta(days=int(ds.leadtime.max()))) if len(obs_da.time) < len(ds.leadtime): if len(obs_da.time) < 1: raise RuntimeError("No observational data available between {} " "and {}".format(start_date.strftime("%D"), end_date.strftime("%D"))) logging.warning("Observational data not available for full range of " "forecast lead times: {}-{} vs {}-{}".format( obs_da.time.to_series()[0].strftime("%D"), obs_da.time.to_series()[-1].strftime("%D"), start_date.strftime("%D"), end_date.strftime("%D"))) (start_date, end_date) = (obs_da.time.to_series()[0], obs_da.time.to_series()[-1]) # We broadcast to get a nicely compatible dataset for plotting return broadcast_forecast(start_date=start_date, end_date=end_date, dataset=ds)
[docs] def get_obs_da( hemisphere: str, start_date: str, end_date: str, obs_source: object = os.path.join(".", "data", "osisaf"), ) -> object: """ :param hemisphere: string, typically either 'north' or 'south' :param start_date: :param end_date: :param obs_source: :return: """ obs_years = pd.Series(pd.date_range(start_date, end_date)).dt.year.unique() obs_dfs = [ el for yr in obs_years for el in glob.glob( os.path.join(obs_source, hemisphere, "siconca", "{}.nc".format(yr))) ] if len(obs_dfs) < len(obs_years): logging.warning( "Cannot find all obs source files for {} - {} in {}".format( start_date, end_date, obs_source)) logging.info("Got files: {}".format(obs_dfs)) obs_ds = xr.open_mfdataset(obs_dfs) obs_ds = obs_ds.sel(time=slice(start_date, end_date)) return obs_ds.ice_conc
[docs] def get_crs(crs_str: str): """Get Coordinate Reference System (CRS) from string input argument Args: crs_str: A CRS given as EPSG code (e.g. `EPSG:3347` for North Canada) or, a pre-defined Cartopy CRS call (e.g. "PlateCarree") """ if crs_str.casefold().startswith("epsg"): crs = ccrs.epsg(int(crs_str.split(":")[1])) elif crs_str == "Mercator.GOOGLE": crs = ccrs.Mercator.GOOGLE else: try: crs = getattr(ccrs, crs_str)() except AttributeError: get_crs_options = [crs_option for crs_option in dir(ccrs) if isinstance(getattr(ccrs, crs_option), type) and issubclass(getattr(ccrs, crs_option), ccrs.CRS) ] + ["Mercator.GOOGLE"] get_crs_options.sort() get_crs_options = ", ".join(get_crs_options) raise AttributeError("Unsupported CRS defined, supported options are:",\ f"{get_crs_options}" ) return crs
[docs] def calculate_extents(x1: int, x2: int, y1: int, y2: int): """ :param x1: :param x2: :param y1: :param y2: :return: """ data_extent_base = 5387500 extents = [ -data_extent_base + (x1 * 25000), data_extent_base - ((432 - x2) * 25000), -data_extent_base + (y1 * 25000), data_extent_base - ((432 - y2) * 25000), ] logging.debug("Data extents: {}".format(extents)) return extents
[docs] def pixel_to_projection(pixel_x_min, pixel_x_max, pixel_y_min, pixel_y_max, x_min_proj: float=-5387500, x_max_proj: float=5387500, y_min_proj: float=-5387500, y_max_proj: float=5387500, image_width: int=432, image_height: int=432, ): """Converts pixel coordinates to CRS projection coordinates""" proj_x_min = (pixel_x_min / image_width ) * (x_max_proj - x_min_proj) + x_min_proj proj_x_max = (pixel_x_max / image_width ) * (x_max_proj - x_min_proj) + x_min_proj proj_y_min = (pixel_y_min / image_height) * (y_max_proj - y_min_proj) + y_min_proj proj_y_max = (pixel_y_max / image_height) * (y_max_proj - y_min_proj) + y_min_proj return proj_x_min, proj_x_max, proj_y_min, proj_y_max
[docs] def get_bounds(proj=None, pole=1): """Get min/max bounds for a given CRS projection""" if proj is None or isinstance(proj, ccrs.LambertAzimuthalEqualArea): proj = ccrs.LambertAzimuthalEqualArea(0, pole * 90) x_min_proj, x_max_proj = [-5387500, 5387500] y_min_proj, y_max_proj = [-5387500, 5387500] else: x_min_proj, x_max_proj = proj.x_limits y_min_proj, y_max_proj = proj.y_limits logging.debug(f"Projection bounds: {proj.x_limits}, {proj.y_limits}") return proj, x_min_proj, x_max_proj, y_min_proj, y_max_proj
[docs] def get_plot_axes(x1: int = 0, x2: int = 432, y1: int = 0, y2: int = 432, north: bool = True, south: bool = False, geoaxes: bool = True, target_crs: object = None, figsize: int = (10, 8), dpi: int = 150, ): """ :param x1: :param x2: :param y1: :param y2: :param geoaxes: :return: """ assert north ^ south, "Only one hemisphere must be selected" fig = plt.figure(figsize=figsize, dpi=dpi, layout="tight") if geoaxes: # pole = 1 if north else -1 # target_crs, x_min_proj, x_max_proj, y_min_proj, y_max_proj = get_bounds(target_crs, pole) pole = 1 if north else -1 proj = ccrs.LambertAzimuthalEqualArea(central_longitude=0, central_latitude=pole*90) if target_crs is None else target_crs ax = fig.add_subplot(1, 1, 1, projection=proj) else: ax = fig.add_subplot(1, 1, 1) return fig, ax
[docs] def set_plot_geoaxes(ax, region_definition: str = None, extent: list = None, coastlines: str = None, gridlines: bool = False, north: bool = True, south: bool = False, ): plt.tight_layout(pad=4.0) # Set colour for areas outside of `process_region()` - i.e., no data here. ax.set_facecolor("dimgrey") pole = 1 if north else -1 proj = ccrs.LambertAzimuthalEqualArea(0, pole * 90) if extent: if region_definition == "pixel": extents = calculate_extents(*extent) ax.set_extent(extents, crs=proj) elif region_definition == "geographic": lon_min, lon_max, lat_min, lat_max = extent # With some projections like Mercator, it doesn't like having exact boundary longitude if lon_min == -180: lon_min = -179.99 elif lon_max == 180: lon_max = 179.99 ax.set_extent([lon_min, lon_max, lat_min, lat_max], crs=ccrs.PlateCarree()) clipping_polygon = Polygon(get_geoextent_polygon(extent)) path = Path(np.array(clipping_polygon.exterior.coords)) if coastlines: auto_scaler = AdaptiveScaler("110m", (("50m", 150), ("10m", 50))) land = NaturalEarthFeature("physical", "land", scale="10m", facecolor="dimgrey") if extent and region_definition == "geographic": clipped_land = ShapelyFeature([clipping_polygon.intersection(geom) for geom in land.geometries()], ccrs.PlateCarree(), facecolor="dimgrey") ax.add_feature(clipped_land) # Draw coastlines explicitly within the clipping region ax.add_geometries([clipping_polygon], ccrs.PlateCarree(), edgecolor="red", facecolor="none", linewidth=0.75, linestyle="dashed", zorder=100) else: ax.add_feature(land) # Add OSMnx GeoDataFrame of coastlines #gdf = ox.features_from_place("Antarctica", tags={"natural": "coastline"}) #gdf.plot(ax=ax, facecolor='none', edgecolor='black', linewidth=0.5) ax.coastlines(resolution=auto_scaler) if gridlines: gl = ax.gridlines(crs=ccrs.PlateCarree(), draw_labels=True) # Prevent generating labels beneath the colourbar gl.top_labels = False gl.right_labels = False return ax
[docs] def get_geoextent_polygon(extent, crs=ccrs.PlateCarree(), n_points=100): """Create a high-resolution polygon for the boundary. Increase the number of points to approximate the curved edges Define the number of interpolation points for the curves """ lon_min, lon_max, lat_min, lat_max = extent # Create arrays for the curved sections lon_values_bottom = np.linspace(lon_min, lon_max, n_points) lat_values_left = np.linspace(lat_min, lat_max, n_points) # Create a polygon by defining more points along the edges polygon = [] # Bottom edge (lat_min) for lon in lon_values_bottom: polygon.append([lon, lat_min]) # Right edge (lon_max) for lat in lat_values_left: polygon.append([lon_max, lat]) # Top edge (lat_max) for lon in lon_values_bottom[::-1]: polygon.append([lon, lat_max]) # Left edge (lon_min) for lat in lat_values_left[::-1]: polygon.append([lon_min, lat]) return polygon
[docs] def set_plot_geoextent(ax, extent, crs=ccrs.PlateCarree(), n_points=100): """Create a high-resolution polygon for the boundary """ ax.set_extent(extent, crs=crs) # Create polygon and convert it to a matplotlib Path polygon = Path(get_geoextent_polygon(extent), crs=crs, n_points=n_points) # Show polygon patch in plot patch = patches.PathPatch(polygon, facecolor='orange', lw=2, transform=ccrs.PlateCarree()) #ax.add_patch(patch) # Sets custom boundary, buggy with small lat/lon bounds # Coastlines, land, and gridlines spill outside of boundary ax.set_boundary(polygon, transform=ccrs.PlateCarree())
[docs] def show_img(ax, arr, x1: int = 0, x2: int = 432, y1: int = 0, y2: int = 432, cmap: object = None, geoaxes: bool = True, vmin: float = 0., vmax: float = 1., north: bool = True, south: bool = False, crs: object = None, extents: list = None ): """ :param ax: :param arr: :param x1: :param x2: :param y1: :param y2: :param cmap: :param geoaxes: :param vmin: :param vmax: :param north: :param south: :return: """ assert north ^ south, "One hemisphere only must be selected" if geoaxes: pole = 1 if north else -1 data_crs = ccrs.LambertAzimuthalEqualArea(0, pole * 90) extents = calculate_extents(x1, x2, y1, y2) im = ax.imshow(arr, vmin=vmin, vmax=vmax, cmap=cmap, transform=data_crs, extent=extents) ax.coastlines() else: im = ax.imshow(arr, cmap=cmap, vmin=vmin, vmax=vmax) return im
[docs] def process_probes(probes, data) -> tuple: """ :param probes: A sequence of locations (pairs) :param data: A sequence of xr.DataArray """ # index into each element of data with a xr.DataArray, for pointwise # selection. Construct the indexing DataArray as follows: probes_da = xr.DataArray(probes, dims=('probe', 'coord')) xcs, ycs = probes_da.sel(coord=0), probes_da.sel(coord=1) for idx, arr in enumerate(data): arr = arr.assign_coords({ "xi": ("xc", np.arange(len(arr.xc))), "yi": ("yc", np.arange(len(arr.yc))), }) if arr is not None: data[idx] = arr.isel(xc=xcs, yc=ycs) return data
[docs] def reproject_array(array, target_crs): return array.rio.reproject(target_crs.proj4_init, # resampling=Resampling.bilinear, nodata=np.nan )
[docs] def process_block(block, target_crs): # dataarray = xr.DataArray(block, dims=["leadtime", "y", "x"]) dataarray = block reprojected = reproject_array(dataarray, target_crs) return reprojected.drop_vars(["time"])
[docs] def reproject_projected_coords(data: object, target_crs: object, pole: int=1, ) -> object: """ Reprojects an xarray Dataset from LambertAzimuthalEqualArea to `target_crs`. The Dataset is expected to have dims of (xc, yc). Args: data: xarray dataset with dims (xc, yc), and also coords of lon and lat. target_crs: Cartopy CRS to project to (e.g. `ccrs.Mercator()`) pole: Whether north (`1`) or south pole (`-1`). Returns: Reprojected data as an xarray dataset. Examples: >>> reprojected_data = reproject_projected_coords(arr, # doctest: +SKIP >>> target_crs=target_crs, >>> pole=pole, >>> ) """ # Eastings/Northings projection data_crs_proj = ccrs.LambertAzimuthalEqualArea(0, pole*90) # geographic projection data_crs_geo = ccrs.PlateCarree() data_reproject = data.copy() data_reproject = data_reproject.assign_coords({"xc": data_reproject.xc.data*1000, "yc": data_reproject.yc.data*1000 }) # Need to use correctly scaled xc and yc to get coastlines working even if not reprojecting. # So, just return scaled DataArray back and not reproject if don't need to. if target_crs == data_crs_proj: return data_reproject data_reproject = data_reproject.drop_vars(["Lambert_Azimuthal_Grid", "lon", "lat"]) # Set xc, yc (eastings and northings) projection details data_reproject = data_reproject.rename({"xc": "x", "yc": "y"}) data_reproject.rio.write_crs(data_crs_proj.proj4_init, inplace=True) data_reproject.rio.write_nodata(np.nan, inplace=True) times = len(data_reproject.time) leadtimes = len(data_reproject.leadtime) # Create a sample image block for use as template for Dask sample_block = data_reproject.isel(time=0, leadtime=0) sample_reprojected = reproject_array(sample_block, target_crs) # Create a template DataArray based on the reprojected sample block template_shape = (data_reproject.sizes['leadtime'], sample_reprojected.sizes['y'], sample_reprojected.sizes['x']) template_data = da.zeros(template_shape, chunks=(1, -1, -1)) template = xr.DataArray(template_data, dims=['leadtime', 'y', 'x'], coords={'leadtime': data_reproject.coords['leadtime'], 'y': sample_reprojected.coords['y'], 'x': sample_reprojected.coords['x'], } ) reprojected_data = [] for time in range(times): leadtime_data = xr.map_blocks(process_block, data_reproject.isel(time=time), template=template, kwargs={"target_crs": target_crs}) reprojected_data.append(leadtime_data) # TODO: Add projection info into DataArray, like the `Lambert_Azimuthal_Grid` dropped above reprojected_data = xr.concat(reprojected_data, dim="time") reprojected_data.coords["time"] = data_reproject.time.data # Set attributes reprojected_data.rio.write_crs(target_crs.proj4_init, inplace=True) reprojected_data.rio.write_nodata(np.nan, inplace=True) # Compute geographic for reprojected image transformer = Transformer.from_crs(target_crs.proj4_init, data_crs_geo.proj4_init) x = reprojected_data.x.values y = reprojected_data.y.values X, Y = np.meshgrid(x, y) lon_grid, lat_grid = transformer.transform(X, Y) reprojected_data["lon"] = (("y", "x"), lon_grid) reprojected_data["lat"] = (("y", "x"), lat_grid) # Rename back to 'xc' and 'yc', although, these are now in metres rather than 1000 metres reprojected_data = reprojected_data.rename({"x": "xc", "y": "yc"}) return reprojected_data
[docs] def projection_to_geographic_coords(data, target_crs): # Compute geographic for reprojected image transform_crs=ccrs.PlateCarree() transformer = Transformer.from_crs(target_crs.proj4_init, transform_crs.proj4_init) x = data.xc.values*1000 y = data.yc.values*1000 X, Y = np.meshgrid(x, y) lon_grid, lat_grid = transformer.transform(X, Y) data["lon"] = (("yc", "xc"), lon_grid) data["lat"] = (("yc", "xc"), lat_grid) return data
[docs] def process_region(region: tuple=None, data: tuple=None, pole: int=1, src_da: object=None, region_definition: str = "pixel", ) -> tuple: """Extract subset of pan-Arctic/Antarctic region based on region bounds. :param region: Either image pixel bounds, or geographic bounds. :param data: Contains list of xarray DataArrays. :param region_definition: Whether providing pixel coordinates or geographic (i.e. lon/lat). :return: """ if region is not None: assert len(region) == 4, "Region needs to be a list of four integers" x1, y1, x2, y2 = region assert x2 > x1 and y2 > y1, "Region is not valid" if region_definition == "geographic": assert x1 >= -180 and x2 <= 180, "Expect longitude range to be `-180<=longitude>=180`" for idx, arr in enumerate(data): if arr is not None and region is not None: logging.debug(f"Clipping data to specified bounds: {region}") # Case when not an array, but an IceNet Masks class if isinstance(arr, Masks): if region_definition.casefold() == "geographic": masks = arr xc, yc = src_da.xc, src_da.yc lon, lat = src_da.lon, src_da.lat # Edge cases, where the time dimension is passed in, # seems to be with "./data/osisaf/north/siconca/2020.nc" # and, possibly newer. if "time" in lon.dims: lon = lon.isel(time=0) if "time" in lat.dims: lat = lat.isel(time=0) masks.set_region_by_lonlat(xc, yc, lon,lat, region) data[idx] = masks elif region_definition.casefold() == "pixel": data[idx] = arr[..., (432 - y2):(432 - y1), x1:x2] else: # If array only contains "xc" and "yc", but not "lon" and "lat". # Reproject using pyproj to get it. if "lon" not in arr.coords and "lat" not in arr.coords: target_crs = ccrs.LambertAzimuthalEqualArea(0, pole*90) arr = projection_to_geographic_coords(arr, target_crs) lon, lat = arr.lon, arr.lat if region_definition.casefold() == "geographic": # Limit to lon/lat region, within a given tolerance tolerance = 0 # Create mask where data is within geographic (lon/lat) region mask = (lon >= x1-tolerance) & (lon <= x2+tolerance) & \ (lat >= y1-tolerance) & (lat <= y2+tolerance) # Extract subset within region using where() data[idx] = arr.where(mask.compute(), drop=True) elif region_definition.casefold() == "pixel": x_max, y_max = arr.xc.shape[0], arr.yc.shape[0] # Clip the data array to specified pixel region data[idx] = arr[..., (y_max - y2):(y_max - y1), x1:x2] else: raise NotImplementedError("Only region_definition='pixel' or 'geographic' bounds are supported") return data
[docs] @cache def geographic_box(lon_bounds: np.array, lat_bounds: np.array, segments: int=1): """Rectangular boundary coordinates in lon/lat coordinates. Args: lon_bounds: (min, max) lon values lat_bounds: (min, max) lat values segments: Number of segments per edge Returns: (lats, lons) for rectangular boundary region """ segments += 1 rectangular_sides = 4 lons = np.empty((segments*rectangular_sides)) lats = np.empty((segments*rectangular_sides)) bounds = [ [0, 0], [-1, 0], [-1, -1], [0, -1], ] for i, (lat_min, lat_max) in enumerate(bounds): lats[i*segments:(i+1)*segments] = np.linspace(lat_bounds[lat_min], lat_bounds[lat_max], num=segments) bounds.reverse() for i, (lon_min, lon_max) in enumerate(bounds): lons[i*segments:(i+1)*segments] = np.linspace(lon_bounds[lon_min], lon_bounds[lon_max], num=segments) return lons, lats
[docs] def get_custom_cmap(cmap): """Creates a new colormap, but with nan set to <0. Hack since cartopy needs transparency for nan regions to wraparound correctly with pcolormesh. """ colors = cmap(np.linspace(0, 1, cmap.N)) custom_cmap = mpl.colors.ListedColormap(colors) custom_cmap.set_bad("dimgrey", alpha=0) custom_cmap.set_under("dimgrey") return custom_cmap
[docs] def set_ffmpeg_path(): """Set Matplotlib's ffmpeg exe path to the one from imageio_ffmpeg""" ffmpeg_path = ffmpeg.get_ffmpeg_exe() plt.rcParams['animation.ffmpeg_path'] = ffmpeg_path