import collections
from copy import copy
from pprint import pformat as prettyformat
from functools import partial
from itertools import chain
from pathlib import Path
import warnings
import gc
import xarray as xr
import animatplot as amp
from matplotlib import pyplot as plt
from matplotlib.animation import PillowWriter
from mpl_toolkits.axes_grid1 import make_axes_locatable
import numpy as np
from dask.diagnostics import ProgressBar
from .geometries import apply_geometry
from .plotting.animate import (
animate_poloidal,
animate_pcolormesh,
animate_line,
_add_controls,
_normalise_time_coord,
_parse_coord_option,
)
from .region import _from_region
from .utils import _get_bounding_surfaces, _split_into_restarts
[docs]@xr.register_dataset_accessor("bout")
class BoutDatasetAccessor:
"""
Contains BOUT-specific methods to use on BOUT++ datasets opened using
`open_boutdataset()`.
These BOUT-specific methods and attributes are accessed via the bout
accessor, e.g. `ds.bout.options` returns a `BoutOptionsFile` instance.
"""
def __init__(self, ds):
self.data = ds
self.metadata = ds.attrs.get("metadata") # None if just grid file
self.options = ds.attrs.get("options") # None if no inp file
def __str__(self):
"""
String representation of the BoutDataset.
Accessed by print(ds.bout)
"""
styled = partial(prettyformat, indent=4, compact=True)
text = (
"<xbout.BoutDataset>\n"
+ "Contains:\n{}\n".format(str(self.data))
+ "Metadata:\n{}\n".format(styled(self.metadata))
)
if self.options:
text += "Options:\n{}".format(self.options)
return text
# def __repr__(self):
# return 'boutdata.BoutDataset(', {}, ',', {}, ')'.format(self.datapath,
# self.prefix)
[docs] def get_field_aligned(self, name, caching=True):
"""
Get a field-aligned version of a variable, calculating (and caching in the
Dataset) if necessary
Parameters
----------
name : str
Name of the variable to get field-aligned version of
caching : bool, optional
Save the field-aligned variable in the Dataset (default: True)
"""
aligned_name = name + "_aligned"
try:
result = self.data[aligned_name]
if result.direction_y != "Aligned":
raise ValueError(
aligned_name + " exists, but is not field-aligned, it "
"has direction_y=" + result.direction_y
)
return result
except KeyError:
if caching:
self.data[aligned_name] = self.data[name].bout.to_field_aligned()
return self.data[aligned_name]
[docs] def to_field_aligned(self):
"""
Create a new Dataset with all 3d variables transformed to field-aligned
coordinates, which are shifted with respect to the base coordinates by an angle
zShift
"""
result = self.data.copy()
for v in chain(result, result.coords):
da = result[v]
# Need to transform any z-dependent variables or coordinates
if (result.metadata["bout_zdim"] in da.dims) and (
da.attrs.get("direction_y", None) == "Standard"
):
result[v] = da.bout.to_field_aligned()
return result
[docs] def from_field_aligned(self):
"""
Create a new Dataset with all 3d variables transformed to non-field-aligned
coordinates
"""
result = self.data.copy()
for v in chain(result, result.coords):
da = result[v]
# Need to transform any 3d variables or coordinates
if (result.metadata["bout_zdim"] in da.dims) and (
da.attrs.get("direction_y", None) == "Aligned"
):
result[v] = da.bout.from_field_aligned()
return result
[docs] def from_region(self, name, with_guards=None):
"""
Get a logically-rectangular section of data from a certain region.
Includes guard cells from neighbouring regions.
Parameters
----------
name : str
Region to get data for
with_guards : int or dict of int, optional
Number of guard cells to include, by default use MXG and MYG from BOUT++.
Pass a dict to set different numbers for different coordinates.
"""
return _from_region(self.data, name, with_guards)
@property
def _regions(self):
if "regions" not in self.data.attrs:
raise ValueError(
"Called a method requiring regions, but these have not been created. "
"Please set the 'geometry' option when calling open_boutdataset() to "
"create regions."
)
return self.data.attrs["regions"]
@property
def fine_interpolation_factor(self):
"""
The default factor to increase resolution when doing parallel interpolation
"""
return self.data.metadata["fine_interpolation_factor"]
@fine_interpolation_factor.setter
def fine_interpolation_factor(self, n):
"""
Set the default factor to increase resolution when doing parallel interpolation.
Parameters
-----------
n : int
Factor to increase parallel resolution by
"""
ds = self.data
ds.metadata["fine_interpolation_factor"] = n
for da in ds.data_vars.values():
da.metadata["fine_interpolation_factor"] = n
[docs] def interpolate_parallel(self, variables, **kwargs):
"""
Interpolate in the parallel direction to get a higher resolution version of a
subset of variables.
Note that the high-resolution variables are all loaded into memory, so most
likely it is necessary to select only a small number. The toroidal_points
argument can also be used to reduce the memory demand.
Parameters
----------
variables : str or sequence of str or ...
The names of the variables to interpolate. If 'variables=...' is passed
explicitly, then interpolate all variables in the Dataset.
n : int, optional
The factor to increase the resolution by. Defaults to the value set by
BoutDataset.setupParallelInterp(), or 10 if that has not been called.
toroidal_points : int or sequence of int, optional
If int, number of toroidal points to output, applies a stride to toroidal
direction to save memory usage. If sequence of int, the indexes of toroidal
points for the output.
method : str, optional
The interpolation method to use. Options from xarray.DataArray.interp(),
currently: linear, nearest, zero, slinear, quadratic, cubic. Default is
'cubic'.
Returns
-------
A new Dataset containing a high-resolution versions of the variables. The new
Dataset is a valid BoutDataset, although containing only the specified variables.
"""
if variables is ...:
variables = [v for v in self.data]
if isinstance(variables, str):
variables = [variables]
if isinstance(variables, tuple):
variables = list(variables)
# Need to start with a Dataset with attrs as merge() drops the attrs of the
# passed-in argument.
# Make sure the first variable has all dimensions so we don't lose any
# coordinates
def find_with_dims(first_var, dims):
if first_var is None:
dims = set(dims)
for v in variables:
if set(self.data[v].dims) == dims:
first_var = v
break
return first_var
tcoord = self.data.metadata.get("bout_tdim", "t")
zcoord = self.data.metadata.get("bout_zdim", "z")
first_var = find_with_dims(None, self.data.dims)
first_var = find_with_dims(first_var, set(self.data.dims) - set(tcoord))
first_var = find_with_dims(first_var, set(self.data.dims) - set(zcoord))
first_var = find_with_dims(
first_var, set(self.data.dims) - set([tcoord, zcoord])
)
if first_var is None:
raise ValueError(
f"Could not find variable to interpolate with both "
f"{ds.metadata.get('bout_xdim', 'x')} and "
f"{ds.metadata.get('bout_ydim', 'y')} dimensions"
)
variables.remove(first_var)
ds = self.data[first_var].bout.interpolate_parallel(
return_dataset=True, **kwargs
)
xcoord = ds.metadata.get("bout_xdim", "x")
ycoord = ds.metadata.get("bout_ydim", "y")
for var in variables:
da = self.data[var]
if xcoord in da.dims and ycoord in da.dims:
ds = ds.merge(
da.bout.interpolate_parallel(return_dataset=True, **kwargs)
)
elif ycoord not in da.dims:
ds[var] = da
# Can't interpolate a variable that depends on y but not x, so just skip
# Apply geometry
ds = apply_geometry(ds, ds.geometry)
return ds
[docs] def integrate_midpoints(self, variable, *, dims=None, cumulative_t=False):
"""
Integrate using the midpoint rule for spatial dimensions, and trapezium rule for
time.
The quantity being integrated is assumed to be a scalar variable.
When doing a 1d integral in the 'y' dimension, the integral is calculated as a
poloidal integral if the variable is on the standard grid ('direction_y'
attribute is "Standard"), or as a parallel-to-B integral if the variable is on
the field-aligned grid ('direction_y' attribute is "Aligned").
When doing a 2d integral over 'x' and 'y' dimensions, the integral will be over
poloidal cross-sections if the variable is not field-aligned (direction_y ==
"Standard") and over field-aligned surfaces if the variable is field-aligned
(direction_ == "Aligned"). The latter seems unlikely to be useful as the
surfaces depend on the arbitrary origin used for zShift.
Is a method of BoutDataset accessor rather than of BoutDataArray so we can use
other variables like `J`, `g11`, `g_22` for the integration.
Note the xarray.DataArray.integrate() method uses the trapezium rule, which is
not consistent with the way BOUT++ defines grid spacings as cell widths. Also,
this way for example::
inner = da.isel(x=slice(i)).bout.integrate_midpoints()
outer = da.isel(x=slice(i, None).bout.integrate_midpoints()
total = da.bout.integrate_midpoints()
inner + outer == total
while with the trapezium rule you would have to select ``radial=slice(i+1)`` for
inner to get a similar relation to be true.
Parameters
----------
variable : str or DataArray
Name of the variable to integrate, or the variable itself as a DataArray.
dims : str, list of str or ...
Dimensions to integrate over. Can be any combination of of the dimensions of
the Dataset. Defaults to integration over all spatial dimensions. If `...`
is passed, integrate over all dimensions including time.
cumulative_t : bool, default False
If integrating in time, return the cumulative integral (integral from the
beginning up to each point in the time dimension) instead of the definite
integral.
"""
ds = self.data
if isinstance(variable, str):
variable = ds[variable]
location = variable.cell_location
suffix = "" if location == "CELL_CENTRE" else f"_{location}"
tcoord = ds.metadata["bout_tdim"]
xcoord = ds.metadata["bout_xdim"]
ycoord = ds.metadata["bout_ydim"]
zcoord = ds.metadata["bout_zdim"]
if dims is None:
dims = []
if xcoord in ds.dims:
dims.append(xcoord)
if ycoord in ds.dims:
dims.append(ycoord)
if zcoord in ds.dims:
dims.append(zcoord)
elif dims is ...:
dims = []
if tcoord in ds.dims:
dims.append(tcoord)
if xcoord in ds.dims:
dims.append(xcoord)
if ycoord in ds.dims:
dims.append(ycoord)
if zcoord in ds.dims:
dims.append(zcoord)
elif isinstance(dims, str):
dims = [dims]
dx = ds[f"dx{suffix}"]
dy = ds[f"dy{suffix}"]
if ds.metadata["BOUT_VERSION"] >= 5.0:
dz = ds[f"dz{suffix}"]
else:
dz = ds["dz"]
# Work out the spatial volume element
if xcoord in dims and ycoord in dims and zcoord in dims:
# Volume integral, use the 3d Jacobian "J"
spatial_volume_element = ds[f"J{suffix}"] * dx * dy * dz
elif xcoord in dims and ycoord in dims:
# 2d integral on poloidal planes
if variable.direction_y == "Standard":
# Need to use a metric constructed from basis vectors within the
# poloidal plane, so use 'reciprocal basis vectors' Grad(x^i)
# J = 1/sqrt(det(g_2d))
# det(g_2d) = g11*g22 - g12**2
g = ds[f"g11{suffix}"] * ds[f"g22{suffix}"] - ds[f"g12{suffix}"] ** 2
J = 1.0 / np.sqrt(g)
elif variable.direction_y == "Aligned":
# Need to work out area element from metric coefficients. See book by
# D'haeseleer, Hitchon, Callen and Shohet eq. (2.5.51).
# Need to use a metric constructed from basis vectors within the
# field-aligned x-y plane, so use 'tangent basis vectors' e_i
# J = sqrt(g_11*g_22 - g_12**2)
J = np.sqrt(
ds[f"g_11{suffix}"] * ds[f"g_22{suffix}"] - ds[f"g_12{suffix}"] ** 2
)
spatial_volume_element = J * dx * dy
elif xcoord in dims and zcoord in dims:
# 2d integral on toroidal planes
# Need to work out area element from metric coefficients. See book by
# D'haeseleer, Hitchon, Callen and Shohet eq. (2.5.51)
# J = sqrt(g_11*g_33 - g_13**2)
J = np.sqrt(
ds[f"g_11{suffix}"] * ds[f"g_33{suffix}"] - ds[f"g_13{suffix}"] ** 2
)
spatial_volume_element = J * dx * dz
elif ycoord in dims and zcoord in dims:
# 2d integral on flux-surfaces
# Need to work out area element from metric coefficients. See book by
# D'haeseleer, Hitchon, Callen and Shohet eq. (2.5.51)
# J = sqrt(g_22*g_33 - g_23**2)
J = np.sqrt(
ds[f"g_22{suffix}"] * ds[f"g_33{suffix}"] - ds[f"g_23{suffix}"] ** 2
)
spatial_volume_element = J * dy * dz
elif xcoord in dims:
if variable.direction_y == "Aligned":
raise ValueError(
"Variable is field-aligned, but radial integral along coordinate "
"line in globally field-aligned coordinates not supported"
)
# 1d radial integral, line element is sqrt(g_11)*dx
spatial_volume_element = np.sqrt(ds[f"g_11{suffix}"]) * dx
elif ycoord in dims:
if variable.direction_y == "Standard":
# Poloidal integral, line element is e_y projected onto a unit vector in
# the poloidal direction. e_z is in the toroidal direction and Grad(x)
# is orthogonal to flux surfaces, so their cross product is in the
# poloidal direction (within flux surfaces). e_z and Grad(x) are also
# always orthogonal, so the magnitude of their cross product is the
# product of their magnitudes. Therefore
# e_y.hat{e}_pol = e_y.(e_z x Grad(x))/|Grad(x)||e_z|
# e_y.hat{e}_pol = e_y.(e_z x Grad(x))/sqrt(g11*g_33)
# and using eqs. (2.3.12) and (2.5.22a) from D'haeseleer
# e_y.hat{e}_pol = e_y.(e_z x (e_y x e_z / J))/sqrt(g11*g_33)
# e_y.hat{e}_pol = e_y.(e_z x (e_y x e_z))/ (J*sqrt(g11*g_33))
# The double cross product identity is A x (B x C) = (A.C)B - (A.B)C.
# e_y.hat{e}_pol = e_y.((e_z.e_z)*e_y - (e_z.e_y)*e_z)/(J*sqrt(g11*g_33))
# e_y.hat{e}_pol = e_y.(g_33*e_y - g_23*e_z)/(J*sqrt(g11*g_33))
# e_y.hat{e}_pol = (g_33*g_22 - g_23*g_23)/(J*sqrt(g11*g_33))
# For 'orthogonal' coordinates (radial and poloidal directions are
# orthogonal) this is equal to 1/sqrt(g22)
spatial_volume_element = (
(
ds[f"g_22{suffix}"] * ds[f"g_33{suffix}"]
- ds[f"g_23{suffix}"] ** 2
)
/ (
ds[f"J{suffix}"]
* np.sqrt(ds[f"g11{suffix}"] * ds[f"g_33{suffix}"])
)
* dy
)
elif variable.direction_y == "Aligned":
# Parallel integral, line element is sqrt(g_22)*dy
spatial_volume_element = np.sqrt(ds[f"g_22{suffix}"]) * dy
elif zcoord in dims:
# Toroidal integral, line element is sqrt(g_33)*dz
spatial_volume_element = np.sqrt(ds[f"g_33{suffix}"]) * dz
else:
# No spatial integral
spatial_volume_element = 1.0
spatial_dims = set(dims) - set([tcoord])
# Need to check if the variable being integrated is a Field2D, which does not
# have a z-dimension to sum over. Other variables are OK because metric
# coefficients, dx and dy all have both x- and y-dimensions so variable would be
# broadcast to include them if necessary
missing_z_sum = zcoord in dims and zcoord not in variable.dims
integrand = variable * spatial_volume_element
integral = integrand.sum(dim=spatial_dims)
# If integrand is a Field2D, need to multiply by nz if integrating over z
if missing_z_sum:
integral = integral * ds.sizes[zcoord]
if tcoord in dims:
if cumulative_t:
integral = integral.cumulative_integrate(coord=tcoord)
else:
integral = integral.integrate(coord=tcoord)
return integral
[docs] def interpolate_from_unstructured(
self,
variables,
*,
fill_value=np.nan,
structured_output=True,
unstructured_dim_name="unstructured_dim",
**kwargs,
):
"""Interpolate Dataset onto new grids of some existing coordinates
Parameters
----------
variables : str or sequence of str or ...
The names of the variables to interpolate. If 'variables=...' is passed
explicitly, then interpolate all variables in the Dataset.
**kwargs : (str, array)
Each keyword is the name of a coordinate in the DataArray, the argument is a
1d array giving the values of that coordinate on the output grid
fill_value : float
fill_value passed through to scipy.interpolation.griddata
structured_output : bool, default True
If True, treat output coordinates values as a structured grid.
If False, output coordinate values must all have the same length and are not
broadcast together.
unstructured_dim_name : str, default "unstructured_dim"
Name used for the dimension in the output that replaces the dimensions of
the interpolated coordinates. Only used if structured_output=False.
Returns
-------
Dataset
Dataset interpolated onto a new, structured grid
"""
if variables is ...:
variables = [v for v in self.data]
explicit_variables_arg = False
else:
explicit_variables_arg = True
if isinstance(variables, str):
variables = [variables]
if isinstance(variables, tuple):
variables = list(variables)
coords_to_interpolate = []
for coord in self.data.coords:
if coord not in variables and coord not in kwargs:
coords_to_interpolate.append(coord)
ds = xr.Dataset()
for v in variables + coords_to_interpolate:
if np.all([c in self.data[v].coords for c in kwargs]):
ds = ds.merge(
self.data[v]
.bout.interpolate_from_unstructured(
fill_value=fill_value,
structured_output=structured_output,
unstructured_dim_name=unstructured_dim_name,
**kwargs,
)
.to_dataset()
)
elif explicit_variables_arg and v in variables:
# User explicitly requested v to be interpolated
raise ValueError(
f"Could not interpolate {v} because it does not depend on all "
f"coordinates {[c for c in kwargs]}"
)
elif v in coords_to_interpolate:
coords_to_interpolate.remove(v)
ds = ds.set_coords(coords_to_interpolate)
return ds
[docs] def remove_yboundaries(self, **kwargs):
"""
Remove y-boundary points, if present, from the Dataset
"""
variables = []
xcoord = self.data.metadata["bout_xdim"]
ycoord = self.data.metadata["bout_ydim"]
new_metadata = None
for v in self.data:
if xcoord in self.data[v].dims and ycoord in self.data[v].dims:
variables.append(
self.data[v].bout.remove_yboundaries(return_dataset=True, **kwargs)
)
new_metadata = variables[-1].metadata
elif ycoord in self.data[v].dims:
raise ValueError(
f"{v} only has a {ycoord}-dimension so cannot split "
f"into regions."
)
else:
variable = self.data[v]
if "keep_yboundaries" in variable.metadata:
variable.attrs["metadata"] = copy(variable.metadata)
variable.metadata["keep_yboundaries"] = 0
variables.append(variable.bout.to_dataset())
if new_metadata is None:
# were no 2d or 3d variables so do not have updated jyseps*, ny_inner but
# does not matter because missing metadata is only useful for 2d or 3d
# variables
new_metadata = variables[0].metadata
result = xr.merge(variables)
result.attrs = copy(self.data.attrs)
# Copy metadata to get possibly modified jyseps*, ny_inner, ny
result.attrs["metadata"] = new_metadata
if "regions" in result.attrs:
# regions are not correct for modified BoutDataset
del result.attrs["regions"]
# call to re-create regions
result = apply_geometry(result, self.data.geometry)
return result
[docs] def get_bounding_surfaces(self, coords=("R", "Z")):
"""
Get bounding surfaces.
Surfaces are returned as arrays of points describing a polygon, assuming the
third spatial dimension is a symmetry direction.
Parameters
----------
coords : (str, str), default ("R", "Z")
Pair of names of coordinates whose values are used to give the positions of
the points in the result
Returns
-------
result : list of DataArrays
Each DataArray in the list contains points on a boundary, with size
(<number of points in the bounding polygon>, 2). Points wind clockwise around
the outside domain, and anti-clockwise around the inside (if there is an
inner boundary).
"""
return _get_bounding_surfaces(self.data, coords)
[docs] def save(
self,
savepath="./boutdata.nc",
filetype="NETCDF4",
variables=None,
save_dtype=None,
separate_vars=False,
pre_load=False,
):
"""
Save data variables to a netCDF file.
Parameters
----------
savepath : str, optional
filetype : str, optional
variables : list of str, optional
Variables from the dataset to save. Default is to save all of them.
separate_vars: bool, optional
If this is true then every variable which depends on time (but not
solely on time) will be saved into a different output file.
The files are labelled by the name of the variable. Variables which
don't meet this criterion will be present in every output file.
pre_load : bool, optional
When saving separate variables, will load each variable into memory
before saving to file, which can be considerably faster.
Examples
--------
If `separate_vars=True`, then multiple files will be created. These can
all be opened and merged in one go using a call of the form:
ds = xr.open_mfdataset('boutdata_*.nc', combine='nested', concat_dim=None)
"""
if variables is None:
# Save all variables
to_save = self.data
else:
to_save = self.data[variables]
if savepath == "./boutdata.nc":
print(
"Will save data into the current working directory, named as"
" boutdata_[var].nc"
)
if savepath is None:
raise ValueError("Must provide a path to which to save the data.")
# make shallow copy of Dataset, so we do not modify the attributes of the data
# when we change things to save
to_save = to_save.copy()
options = to_save.attrs.pop("options")
if options:
# TODO Convert Ben's options class to a (flattened) nested
# dictionary then store it in ds.attrs?
warnings.warn(
"Haven't decided how to write options file back out yet - deleting "
"options for now. To re-load this Dataset, pass the same inputfilepath "
"to open_boutdataset when re-loading."
)
# Delete placeholders for options on each variable and coordinate
for var in chain(to_save.data_vars, to_save.coords):
try:
del to_save[var].attrs["options"]
except KeyError:
pass
# Store the metadata as individual attributes instead because
# netCDF can't handle storing arbitrary objects in attrs
def dict_to_attrs(obj, section):
for key, value in obj.attrs.pop(section).items():
obj.attrs[section + ":" + key] = value
dict_to_attrs(to_save, "metadata")
# Must do this for all variables and coordinates in dataset too
for varname, da in chain(to_save.data_vars.items(), to_save.coords.items()):
try:
dict_to_attrs(da, "metadata")
except KeyError:
pass
if "regions" in to_save.attrs:
# Do not need to save regions as these can be reconstructed from the metadata
try:
del to_save.attrs["regions"]
except KeyError:
pass
for var in chain(to_save.data_vars, to_save.coords):
try:
del to_save[var].attrs["regions"]
except KeyError:
pass
if save_dtype is not None:
encoding = {v: {"dtype": save_dtype} for v in to_save}
else:
encoding = None
if separate_vars:
# Save each major variable to a different netCDF file
# Determine which variables are "major"
# Defined as time-dependent, but not solely time-dependent
major_vars, minor_vars = _find_major_vars(to_save)
print("Will save the variables {} separately".format(str(major_vars)))
# Save each one to separate file
# TODO perform the save in parallel with save_mfdataset?
for major_var in major_vars:
# Group variables so that there is only one time-dependent
# variable saved in each file
minor_data = [to_save[minor_var] for minor_var in minor_vars]
single_var_ds = xr.merge([to_save[major_var], *minor_data])
# Add the attrs back on
single_var_ds.attrs = to_save.attrs
if pre_load:
single_var_ds.load()
# Include the name of the variable in the name of the saved
# file
path = Path(savepath)
var_savepath = (
str(path.parent / path.stem) + "_" + str(major_var) + path.suffix
)
if encoding is not None:
var_encoding = {major_var: encoding[major_var]}
else:
var_encoding = None
print("Saving " + major_var + " data...")
with ProgressBar():
single_var_ds.to_netcdf(
path=str(var_savepath),
format=filetype,
compute=True,
encoding=var_encoding,
)
# Force memory deallocation to limit RAM usage
single_var_ds.close()
del single_var_ds
gc.collect()
else:
# Save data to a single file
print("Saving data...")
with ProgressBar():
to_save.to_netcdf(
path=savepath, format=filetype, compute=True, encoding=encoding
)
return
[docs] def to_restart(
self,
variables=None,
*,
savepath=".",
nxpe=None,
nype=None,
tind=-1,
prefix="BOUT.restart",
overwrite=False,
):
"""
Write out a timestep as a set of netCDF BOUT.restart files.
If processor decomposition is not specified then data will be saved
using the decomposition it had when loaded.
Parameters
----------
variables : str or sequence of str, optional
The evolving variables needed in the restart files. If not given explicitly,
all time-evolving variables in the Dataset will be used, which may result in
larger restart files than necessary.
savepath : str, default '.'
Directory to save the created restart files under
nxpe : int, optional
Number of processors in the x-direction. If not given, keep the number used
for the original simulation
nype : int, optional
Number of processors in the y-direction. If not given, keep the number used
for the original simulation
tind : int, default -1
Time-index of the slice to write to the restart files
prefix : str, default "BOUT.restart"
Prefix to use for names of restart files
overwrite : bool, default False
By default, raises if restart file already exists. Set to True to overwrite
existing files
"""
if isinstance(variables, str):
variables = [variables]
# Set processor decomposition if not given
if nxpe is None:
nxpe = self.metadata["NXPE"]
if nype is None:
nype = self.metadata["NYPE"]
# Is this even possible without saving the guard cells?
# Can they be recreated?
restart_datasets, paths = _split_into_restarts(
self.data,
variables,
savepath,
nxpe,
nype,
tind,
prefix,
overwrite,
)
with ProgressBar():
xr.save_mfdataset(restart_datasets, paths, compute=True)
[docs] def animate_list(
self,
variables,
animate_over=None,
save_as=None,
show=False,
fps=10,
nrows=None,
ncols=None,
poloidal_plot=False,
axis_coords=None,
subplots_adjust=None,
vmin=None,
vmax=None,
logscale=None,
titles=None,
aspect=None,
extend=None,
controls="both",
tight_layout=True,
**kwargs,
):
"""
Parameters
----------
variables : list of str or BoutDataArray
The variables to plot. For any string passed, the corresponding
variable in this DataSet is used - then the calling DataSet must
have only 3 dimensions. It is possible to pass BoutDataArrays to
allow more flexible plots, e.g. with different variables being
plotted against different axes.
animate_over : str, optional
Dimension over which to animate, defaults to the time dimension
save_as : str, optional
If passed, a gif is created with this filename
show : bool, optional
Call pyplot.show() to display the animation
fps : float, optional
Indicates the number of frames per second to play
nrows : int, optional
Specify the number of rows of plots
ncols : int, optional
Specify the number of columns of plots
poloidal_plot : bool or sequence of bool, optional
If set to True, make all 2D animations in the poloidal plane instead of using
grid coordinates, per variable if sequence is given
axis_coords : None, str, dict or list of None, str or dict
Coordinates to use for axis labelling.
- None: Use the dimension coordinate for each axis, if it exists.
- "index": Use the integer index values.
- dict: keys are dimension names, values set axis_coords for each axis
separately. Values can be: None, "index", the name of a 1d variable or
coordinate (which must have the dimension given by 'key'), or a 1d
numpy array, dask array or DataArray whose length matches the length of
the dimension given by 'key'.
Only affects time coordinate for plots with poloidal_plot=True.
If a list is passed, it must have the same length as 'variables' and gives
the axis_coords setting for each plot individually.
The setting to use for the 'animate_over' coordinate can be passed in one or
more dict values, but must be the same in all dicts if given more than once.
subplots_adjust : dict, optional
Arguments passed to fig.subplots_adjust()()
vmin : float or sequence of floats
Minimum value for color scale, per variable if a sequence is given
vmax : float or sequence of floats
Maximum value for color scale, per variable if a sequence is given
logscale : bool or float, sequence of bool or float, optional
If True, default to a logarithmic color scale instead of a linear one.
If a non-bool type is passed it is treated as a float used to set the linear
threshold of a symmetric logarithmic scale as
linthresh=min(abs(vmin),abs(vmax))*logscale, defaults to 1e-5 if True is
passed.
Per variable if sequence is given.
titles : sequence of str or None, optional
Custom titles for each plot. Pass None in the sequence to use the default for
a certain variable
aspect : str or None, or sequence of str or None, optional
Argument to set_aspect() for each plot. Defaults to "equal" for poloidal
plots and "auto" for others.
extend : str or None, optional
Passed to fig.colorbar()
controls : string or None, default "both"
By default, add both the timeline and play/pause toggle to the animation. If
"timeline" is passed add only the timeline, if "toggle" is passed add only
the play/pause toggle. If None or an empty string is passed, add neither.
tight_layout : bool or dict, optional
If set to False, don't call tight_layout() on the figure.
If a dict is passed, the dict entries are passed as arguments to
tight_layout()
**kwargs : dict, optional
Additional keyword arguments are passed on to each animation function, per
variable if a sequence is given.
Returns
-------
animation
An animatplot.Animation object.
"""
if animate_over is None:
animate_over = self.metadata.get("bout_tdim", "t")
nvars = len(variables)
if nrows is None and ncols is None:
ncols = int(np.ceil(np.sqrt(nvars)))
nrows = int(np.ceil(nvars / ncols))
elif nrows is None:
nrows = int(np.ceil(nvars / ncols))
elif ncols is None:
ncols = int(np.ceil(nvars / nrows))
else:
if nrows * ncols < nvars:
raise ValueError("Not enough rows*columns to fit all variables")
fig, axes = plt.subplots(nrows, ncols, squeeze=False)
axes = axes.flatten()
ncells = nrows * ncols
if nvars < ncells:
for index in range(ncells - nvars):
fig.delaxes(axes[ncells - index - 1])
if subplots_adjust is not None:
fig.subplots_adjust(**subplots_adjust)
def _expand_list_arg(arg, arg_name):
if isinstance(arg, collections.abc.Sequence) and not isinstance(arg, str):
if len(arg) != len(variables):
raise ValueError(
"if %s is a sequence, it must have the same "
'number of elements as "variables"' % arg_name
)
else:
arg = [arg] * len(variables)
return arg
poloidal_plot = _expand_list_arg(poloidal_plot, "poloidal_plot")
vmin = _expand_list_arg(vmin, "vmin")
vmax = _expand_list_arg(vmax, "vmax")
logscale = _expand_list_arg(logscale, "logscale")
titles = _expand_list_arg(titles, "titles")
aspect = _expand_list_arg(aspect, "aspect")
extend = _expand_list_arg(extend, "extend")
axis_coords = _expand_list_arg(axis_coords, "axis_coords")
for k in kwargs:
kwargs[k] = _expand_list_arg(kwargs[k], k)
blocks = []
def is_list(variable):
return (
isinstance(variable, list)
or isinstance(variable, tuple)
or isinstance(variable, set)
)
for i, subplot_args in enumerate(
zip(
variables,
axes,
poloidal_plot,
axis_coords,
vmin,
vmax,
logscale,
titles,
aspect,
extend,
)
):
(
v,
ax,
this_poloidal_plot,
this_axis_coords,
this_vmin,
this_vmax,
this_logscale,
this_title,
this_aspect,
this_extend,
) = subplot_args
this_kwargs = {k: v[i] for k, v in kwargs.items()}
divider = make_axes_locatable(ax)
cax = divider.append_axes("right", size="5%", pad=0.1)
if is_list(v):
for i in range(len(v)):
if isinstance(v[i], str):
v[i] = self.data[v[i]]
# list of variables for one subplot only supported for line plots with 1
# dimension plus time
ndims = 2
dims = v[0].dims
if len(dims) != 2:
raise ValueError(
"Variables in sublist must be 2d - can only overlay line plots"
)
for w in v:
if not w.dims == dims:
raise ValueError(
f"All variables in sub-list must have same dimensions."
f"{v[0].name} had {v[0].dims} but {w.name} had {w.dims}."
)
else:
if isinstance(v, str):
v = self.data[v]
data = v.bout.data
ndims = len(data.dims)
ax.set_title(data.name)
if ndims == 2:
if not is_list(v):
blocks.append(
animate_line(
data=data,
ax=ax,
animate_over=animate_over,
animate=False,
axis_coords=this_axis_coords,
aspect=this_aspect,
**this_kwargs,
)
)
else:
for w in v:
blocks.append(
animate_line(
data=w,
ax=ax,
animate_over=animate_over,
animate=False,
axis_coords=this_axis_coords,
aspect=this_aspect,
label=w.name,
**this_kwargs,
)
)
legend = ax.legend()
legend.set_draggable(True)
# set 'v' to use for the timeline below
v = v[0]
elif ndims == 3:
if this_poloidal_plot:
var_blocks = animate_poloidal(
data,
ax=ax,
cax=cax,
animate_over=animate_over,
animate=False,
axis_coords=this_axis_coords,
vmin=this_vmin,
vmax=this_vmax,
logscale=this_logscale,
aspect=this_aspect,
extend=this_extend,
**this_kwargs,
)
for block in var_blocks:
blocks.append(block)
else:
blocks.append(
animate_pcolormesh(
data=data,
ax=ax,
cax=cax,
animate_over=animate_over,
animate=False,
axis_coords=this_axis_coords,
vmin=this_vmin,
vmax=this_vmax,
logscale=this_logscale,
aspect=this_aspect,
extend=this_extend,
**this_kwargs,
)
)
else:
raise ValueError(
"Unsupported number of dimensions "
+ str(ndims)
+ ". Dims are "
+ str(v.dims)
)
if this_title is not None:
# Replace default title with user-specified one
ax.set_title(this_title)
if np.all([a == "index" for a in axis_coords]):
time_opt = "index"
elif np.any([isinstance(a, dict) and animate_over in a for a in axis_coords]):
given_values = [
a[animate_over]
for a in axis_coords
if isinstance(a, dict) and animate_over in a
]
time_opt = given_values[0]
if len(given_values) > 1 and not np.all(
[v == time_opt for v in given_values[1:]]
):
raise ValueError(
f"Inconsistent axis_coords values given for animate_over "
f"coordinate ({animate_over}). Got {given_values}."
)
else:
time_opt = None
time_values, time_label = _parse_coord_option(animate_over, time_opt, self.data)
time_values, time_suffix = _normalise_time_coord(time_values)
timeline = amp.Timeline(time_values, fps=fps, units=time_suffix)
anim = amp.Animation(blocks, timeline)
if tight_layout:
if subplots_adjust is not None:
warnings.warn(
"tight_layout argument to animate_list() is True, but "
"subplots_adjust argument is not None. subplots_adjust "
"is being ignored."
)
if not isinstance(tight_layout, dict):
tight_layout = {}
fig.tight_layout(**tight_layout)
_add_controls(anim, controls, time_label)
if save_as is not None:
anim.save(save_as + ".gif", writer=PillowWriter(fps=fps))
if show:
plt.show()
return anim
def _find_major_vars(data):
"""
Splits data into those variables likely to require a lot of storage space
(defined as those which depend on time and at least one other dimension).
These are normally the variables of physical interest.
"""
# TODO Use an Ordered Set instead to preserve order of variables in files?
tcoord = data.attrs.get("metadata:bout_tdim", "t")
major_vars = set(
var
for var in data.data_vars
if (tcoord in data[var].dims) and data[var].dims != (tcoord,)
)
minor_vars = set(data.data_vars) - set(major_vars)
return list(major_vars), list(minor_vars)