SCRIPTS XARRAY basicΒΆ

# -*-Python-*-
# Created by bgrierson at 13 Dec 2016  20:50

# xarray (xarray.pydata.org) is a convenient container for data
# that provides more functionality than a simple array or dictionary
# The basic object is an xarray DataArray, which is an array with dimensions,
# attributes, etc...
# An xarray Dataset contains a DataArray or multiple DataArrays with
# at least one common dimension.

# First create the location to store the Dataset in OMFIT
root['OUTPUTS'].setdefault('XARRAY', OMFITtree())
root['OUTPUTS']['XARRAY'].setdefault('BASIC', OMFITtree())

# A 1d array, such as a time history, with uncertainty
t = linspace(0, 1, 101)
w = 4.0 * np.pi
vals = np.sin(w * t)
err = 0.1 * np.random.rand(len(t))

# Make an uncertainty array of the data and uncertainty
uar = uarray(vals, err)
root['OUTPUTS']['XARRAY']['BASIC']['uar'] = uar

# Create a DataArray.  There are a few equivalent ways
da = DataArray(uar, coords=[('time', t)])
# or
da = DataArray(uar, dims='time', coords=[t])
# or create with one true dimension and other coordinates
da = DataArray(uar, dims=['time'], coords={'time': t, 'R0': 1.0, 'R1': 2.0})

# Add some attributes
da.name = 'Sinlike'
da.attrs['Calib. date'] = 'Nov. 8 2016'

# Store the DataArray
root['OUTPUTS']['XARRAY']['BASIC']['da'] = da

# Now we can plot it and display some information about it.
fig, ax = plt.subplots()
da.plot(ax=ax)

print('Our 1d da coordinates')
print(da.coords)
print('Our 1d da dimensions')
print(da.dims)
print('Some elements of our data')
print(da.values[0:5])
print('Our data at time 0.1')
print(da.sel(time=0.1).values)
print('Our data element 0')
print(da.isel(time=0).values)

# Make another DataArray
w = 4.0 * np.pi
uar2 = uarray(np.sin(4.5 * np.pi * t), 0.2 * np.random.rand(len(t)))
da2 = DataArray(uar2, coords=[('time', t)])

# Store the DataArray
root['OUTPUTS']['XARRAY']['BASIC']['da2'] = da2

# Now create a Dataset from these two DataArrays
ds = Dataset({'da1': da, 'da2': da2})

# Store the Dataset
root['OUTPUTS']['XARRAY']['BASIC']['ds'] = ds

# Get the coordinates and all the DataArrays in this Dataset
print('DataArrays in our Dataset')
print(list(ds.keys()))
print('Dimensions in our Dataset')
print(list(ds.dims.keys()))
print('Coordinates in our Dataset')
print(list(ds.coords.keys()))

# Plot all the data in our Dataset that are not coordinates (i.e. time, R0, R1)
toplot = list(ds.variables.keys())
for k in ds.coords:
    toplot.remove(k)

fig, ax = plt.subplots()
for tpl in toplot:
    ds[tpl].plot(ls='None', ax=ax, label=tpl)
ax.legend()

# Manipulations
# You may want to snip an invalid time point from this dataset at time = 0.77
# Get the index for the valid times
valid = where(ds['time'].values != 0.77)
# Make a new Dataset containing only valid times
ds_snipped = ds.isel(time=valid[0])

# Store the Dataset
root['OUTPUTS']['XARRAY']['BASIC']['ds_snipped'] = ds_snipped