import argparse
import datetime as dt
import logging
import os
import re
from concurrent.futures import as_completed, ProcessPoolExecutor
import cartopy.crs as ccrs
import cartopy.feature as cfeature
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import xarray as xr
from matplotlib.animation import FuncAnimation
from mpl_toolkits.axes_grid1 import make_axes_locatable
from icenet.process.predict import get_refcube
from icenet.utils import setup_logging
from icenet.plotting.utils import get_plot_axes, set_plot_geoaxes, set_plot_geoextent, get_custom_cmap
# TODO: This can be a plotting or analysis util function elsewhere
[docs]
def get_dataarray_from_files(files: object, numpy: bool = False) -> object:
"""
:param files:
:param numpy:
:return:
"""
if not numpy:
ds = xr.open_mfdataset(files)
# TODO: We're relying on single variable files from downloaders
# so maybe allow a specifier for this for multi var files?
da = ds.to_array(dim=list(ds.data_vars)[0])[0]
else:
first_file = np.load(files[0])
arr = np.zeros((len(files), *first_file.shape))
dates = []
assert len(first_file.shape) == 2, \
"Wrong number of dims for use in videos {}".\
format(len(first_file.shape))
for np_idx in range(0, len(files)):
arr[np_idx] = np.load(files[np_idx])
nom = os.path.basename(files[np_idx])
# TODO: error handling
date_match = re.search(r"(\d{4})_(\d{1,2})_(\d{1,2})", nom)
dates.append(
pd.to_datetime(dt.date(*[int(s) for s in date_match.groups()])))
# FIXME: naive implementations abound
path_comps = os.path.dirname(files[0]).split(os.sep)
ref_cube = get_refcube("north" in path_comps, "south" in path_comps)
var_name = path_comps[-2]
da = xr.DataArray(
data=arr,
dims=("time", "yc", "xc"),
coords=dict(
time=[pd.Timestamp(d) for d in dates],
xc=ref_cube.coord("projection_x_coordinate").points,
yc=ref_cube.coord("projection_y_coordinate").points,
),
name=var_name,
)
return da
[docs]
def xarray_to_video(
da: object,
fps: int,
video_path: object = None,
reproject: bool = False,
north: bool = True,
south: bool = False,
extent: tuple = None,
region_definition: str = "pixel",
coastlines: str = "default",
gridlines: bool = False,
target_crs: object = None,
transform_crs: object = None,
mask: object = None,
mask_type: str = 'contour',
clim: object = None,
crop: object = None,
data_type: str = 'abs',
video_dates: object = None,
cmap: object = plt.get_cmap("viridis"),
figsize: tuple = (10, 8),
dpi: int = 150,
imshow_kwargs: dict = None,
ax_init: object = None,
ax_extra: callable = None,
colorbar_label: str = '',
) -> object:
"""
Generate video of an xarray.DataArray. Optionally input a list of
`video_dates` to show, otherwise the full set of time coordiantes
of the dataset is used.
:param da: Dataset to create video of.
:param video_path: Path to save the video to.
:param fps: Frames per second of the video.
:param mask: Boolean mask with True over masked elements to overlay
as a contour or filled contour. Defaults to None (no mask plotting).
:param mask_type: 'contour' or 'contourf' dictating whether the mask is
overlaid as a contour line or a filled contour.
:param data_type: 'abs' or 'anom' describing whether the data is in absolute
or anomaly format. If anomaly, the colorbar is centred on 0.
:param video_dates: List of Pandas Timestamps or datetime.datetime objects
to plot video from the dataset.
:param crop: [(a, b), (c, d)] to crop the video from a:b and c:d
:param clim: Colormap limits. Default is None, in which case the min and
max values of the array are used.
:param cmap: Matplotlib colormap object.
:param figsize: Figure size in inches.
:param dpi: Figure DPI.
:param imshow_kwargs: Extra arguments for displaying array
:param ax_init: pre-initialised axes object for display
:param ax_extra: Extra method called with axes for additional plotting
"""
assert north ^ south, "Only one hemisphere must be selected"
pole = 1 if north else -1
target_crs = ccrs.LambertAzimuthalEqualArea(central_latitude=pole*90, central_longitude=0) if target_crs is None else target_crs
transform_crs = ccrs.PlateCarree() if transform_crs is None else transform_crs
# Hack since cartopy needs transparency for nan regions to wraparound
# correctly with pcolormesh, set nan areas as under range.
if reproject:
da = da.where(~np.isnan(da), -9999, drop=False)
def update(date):
logging.debug("Plotting {}".format(date.strftime("%D")))
data = da.sel(time=date)
image.set_array(data)
image_title.set_text("{:04d}/{:02d}/{:02d}".format(
date.year, date.month, date.day))
return image, image_title
logging.info("Inspecting data")
if clim is not None:
n_min = clim[0]
n_max = clim[1]
else:
n_max = da.max().values
n_min = da.min().values
if data_type == 'anom':
if np.abs(n_max) > np.abs(n_min):
n_min = -n_max
elif np.abs(n_min) > np.abs(n_max):
n_max = -n_min
if video_dates is None:
video_dates = [
pd.Timestamp(date).to_pydatetime() for date in da.time.values
]
if crop is not None:
a = crop[0][0]
b = crop[0][1]
c = crop[1][0]
d = crop[1][1]
da = da.isel(xc=np.arange(a, b), yc=np.arange(c, d))
if mask is not None:
mask = mask[a:b, c:d]
logging.info("Initialising plot")
if ax_init is None:
fig, ax = get_plot_axes(
geoaxes=True,
north=north,
south=south,
target_crs=target_crs,
figsize=figsize,
dpi=dpi,
)
ax = set_plot_geoaxes(ax,
region_definition=region_definition,
extent=extent,
coastlines=coastlines,
gridlines=gridlines,
north=north,
south=south,
)
else:
ax = ax_init
fig = ax.get_figure()
ax.axes.xaxis.set_visible(False)
ax.axes.yaxis.set_visible(False)
if ax_extra is not None:
ax_extra(ax)
#if extent and region_definition == "geographic":
# # ax.set_extent(extent, crs=transform_crs)
# set_plot_geoextent(ax, extent)
date = pd.Timestamp(da.time.values[0]).to_pydatetime()
data = da.sel(time=date)
if mask is not None:
if mask_type == 'contour':
image = ax.contour(data.xc.data, data.yc.data, mask,
levels=[.5, 1],
colors='k',
transform=target_crs,
zorder=3,
)
elif mask_type == 'contourf':
image = ax.contourf(data.xc.data, data.yc.data, mask,
levels=[.5, 1],
colors='k',
transform=target_crs,
zorder=3,
)
# TODO: Tidy up, and cover all argument options
# Hack since cartopy needs transparency for nan regions to wraparound
# correctly with pcolormesh.
custom_cmap = get_custom_cmap(cmap)
if "lon" in data.coords and "lat" in data.coords:
image = data.plot.pcolormesh("lon",
"lat",
ax=ax,
transform=transform_crs,
animated=True,
zorder=1,
add_colorbar=False,
cmap=custom_cmap,
vmin=n_min,
vmax=n_max,
**imshow_kwargs if imshow_kwargs is not None else {}
)
else:
image = data.plot.pcolormesh("xc",
"yc",
ax=ax,
animated=True,
zorder=1,
add_colorbar=False,
cmap=custom_cmap,
vmin=n_min,
vmax=n_max,
**imshow_kwargs if imshow_kwargs is not None else {}
)
image_title = ax.set_title("{:04d}/{:02d}/{:02d}".format(
date.year, date.month, date.day),
fontsize="large",
zorder=2)
try:
divider = make_axes_locatable(ax)
cax = divider.append_axes("right", size="5%", pad=0.05, zorder=2, axes_class=plt.Axes)
cbar = plt.colorbar(image, ax=ax, cax=cax)
if colorbar_label:
cbar.set_label(colorbar_label)
plt.subplots_adjust(right=0.9)
except KeyError as ex:
logging.warning("Could not configure locatable colorbar: {}".format(ex))
logging.info("Animating")
# Investigated blitting, but it causes a few problems with masks/titles.
animation = FuncAnimation(fig,
func=update,
frames=video_dates,
interval=1000 / fps,
repeat=False,
blit=True,
)
plt.close()
if not video_path:
logging.info("Not saving plot, will return animation")
else:
logging.info("Saving plot to {}".format(video_path))
animation.save(video_path, fps=fps, extra_args=['-vcodec', 'libx264'])
return animation
[docs]
def recurse_data_folders(base_path: object,
lookups: object,
children: object,
filetype: str = "nc") -> object:
"""
:param base_path:
:param lookups:
:param children:
:param filetype:
:return:
"""
logging.info("Looking at {}".format(base_path))
files = []
if children is None and lookups is None:
# TODO: should ideally use scandir for performance
# TODO: naive hardcoded filtering of files
logging.debug("CHILDREN: {} or LOOKUPS: {}".format(children, lookups))
files = sorted([
os.path.join(base_path, f)
for f in os.listdir(base_path)
if os.path.splitext(f)[1] == ".{}".format(filetype) and
(re.match(r'^\d{4}\.nc$', f) or
re.search(r'(abs|anom|linear_trend)\.nc$', f))
])
logging.debug("Files found: {}".format(", ".join(files)))
if not len(files):
return None
else:
for subdir in os.listdir(base_path):
logging.debug("SUBDIR: {}".format(subdir))
new_path = os.path.join(base_path, subdir)
if not os.path.isdir(new_path):
continue
if not len(lookups) or \
(len(lookups) and subdir in [str(s) for s in lookups]):
subdir_files = recurse_data_folders(
new_path, children[0] if children is not None and
len(children) > 0 else None, children[1:]
if children is not None and len(children) > 1 else None,
filetype)
if subdir_files:
files.append(subdir_files)
return files
[docs]
def video_process(files: object, numpy: object, output_dir: object,
fps: int) -> object:
"""
:param files:
:param numpy:
:param output_dir:
:param fps:
:return:
"""
north = True if '/north/' in files[0] else False
south = not north
path_comps = os.path.dirname(files[0]).split(os.sep)
os.makedirs(output_dir, exist_ok=True)
output_name = os.path.join(output_dir,
"{}.mp4".format("_".join(path_comps)))
if not os.path.exists(output_name):
logging.debug("Supplied: {} files for processing".format(len(files)))
da = get_dataarray_from_files(files, numpy)
logging.info("Saving to {}".format(output_name))
xarray_to_video(da, fps, video_path=output_name, north=north,
south=south, mask=None, coastlines=None)
else:
logging.warning("Not overwriting existing: {}".format(output_name))
return None
return output_name
[docs]
@setup_logging
def cli_args():
"""
:return:
"""
args = argparse.ArgumentParser()
args.add_argument("-f", "--fps", default=15, type=int)
args.add_argument("-n", "--numpy", action="store_true", default=False)
args.add_argument("-o",
"--output-dir",
dest="output_dir",
type=str,
default="plot")
args.add_argument("-p", "--path", default="data", type=str)
args.add_argument("-w", "--workers", default=8, type=int)
args.add_argument("-v", "--verbose", action="store_true", default=False)
args.add_argument("data", type=lambda s: s.split(","))
args.add_argument("hemisphere",
default=[],
choices=["north", "south"],
nargs="?")
args.add_argument("--vars", default=[], type=lambda s: s.split(","))
args.add_argument("--years", default=[], type=lambda s: s.split(","))
return args.parse_args()
[docs]
def data_cli():
"""
"""
args = cli_args()
hemis = [args.hemisphere] if len(args.hemisphere) else ["north", "south"]
logging.info("Looking into {}".format(args.path))
path_children = [hemis, args.vars]
video_batches = recurse_data_folders(
args.path,
args.data,
path_children,
filetype="nc" if not args.numpy else "npy")
logging.debug("Batches: {}".format(video_batches))
video_batches = [
v_el for h_list in video_batches for v_list in h_list for v_el in v_list
]
if len(args.years) > 0:
new_batches = []
for batch in video_batches:
batch = [
el for el in batch if os.path.basename(el)[0:4] in args.years
]
if len(batch):
new_batches.append(batch)
video_batches = new_batches
logging.debug("Batches {}".format(video_batches))
with ProcessPoolExecutor(
max_workers=min(len(video_batches), args.workers)) as executor:
futures = []
for batch in video_batches:
futures.append(
executor.submit(video_process, batch, args.numpy,
args.output_dir, args.fps))
for future in as_completed(futures):
try:
res = future.result()
if res:
logging.info("Produced {}".format(res))
except Exception as e:
logging.error(e)