# -*-Python-*-
# Created by prattq at 16 Mar 2026 13:53
"""
This script uses the DIII-D RDB to find interesting shots.
The script creates an interactive matplotlib plot with infoScatter and the
following key-press-events,
't' - print the text logs for the selected shots.
'e' - erase/de-select the selected shots.
'i' - isolate/de-isolate shots from the same 'run'.
'g' - compare gEQDSK files for the selected shots (at geqdsk_time)
'p' - compare profiles (zipfits) for the selected shots (at profiles_time)
The default plot is 'DENSITY_AVG' vs. 'pbeam'
See,
https://nomos.gat.com/DIII-D/comp/database/d3drdb/query_str.php
See also,
https://nomos.gat.com/DIII-D/comp/database/d3drdb/columnlist_process.php?database=D3DRDB&table=SUMMARIES
defaultVars parameters
----------------------
:param use_keywords: (bool) use a list of 'keywords' and 'exclude_keywords' to filter RDB.
:param use_runids: (bool) use a list of D3D run ids in the 'runids' kw to filter RDB.
:param shots_to_mark: (list) these shots will be circled in the plot.
:param shots_to_include: (list) these shots will be forcibly included in the RDB results.
"""
import pandas as pd
# NOTE: defaultVars are captured and passed to the D3D_RDB_Explorer class below.
dfv = defaultVars(
shot_lim=[202000, 204264], # min, max shot to search,
ne_lim=[0, 100e13], # [1/cm^3]
pnbi_lim=[1e6, 20e6], # [W] or None
pech_lim=None, # [W] or None
btor_lim=[1.0, 3.0], # [T]
ip_lim=[0.6e6, 3.0e6], # [A]
pulse_length_lim=[4, 10], # [sec]
use_keywords=True,
keywords=["LH"],
exclude_keywords=["detach", "IR", "XPR"],
cmap_str="viridis",
use_runids=False,
runids=None,
plot_x="density",
xlim=[0, 20],
plot_y="pnbi",
ylim=[0, 12],
shots_to_mark=[200000], # set [] for none.
shots_to_include=[200000], # Force including these shots, set None for none.
color_by="bt_ip", # options are ['pech', 'bt_ip']
debug=False,
geqdsk_time=3000, # [ms]
profiles_time=3000, # [ms]
make_plot=True,
save_to_tree=True,
)
class D3D_RDB_Explorer:
def __init__(self, **kwargs):
"""D3D_RDB_Explorer class to manage constructing a complicated RDB query and
plotting the results interactivley.
"""
self.xid = "D3D_RDB_Explorer"
for name, value in kwargs.items():
setattr(self, name, value)
# Creates the SQL query
self.make_query()
# Executes the SQL query
self.do_query()
# Interactive plot,
if self.make_plot:
self.do_plot()
if self.save_to_tree:
printi(f"INFO: ({self.xid}) Saving RDB results as pandas dataframe in root['OUTPUTS']...")
root['OUTPUTS']['D3D_RDB_Explorer_output'] = self.result
def make_query(
self,
):
"""Method to construct the multi-line SQL query to be sent to the RDB."""
self.query = f"""SELECT
entries.shot,
entries.text,
entries.username,
entries.run,
summaries.DENSITY_AVG,
summaries.pbeam,
summaries.pech,
summaries.topology,
summaries.btor,
summaries.ipsign,
summaries.TIME_OF_SHOT
FROM summaries INNER JOIN entries ON (summaries.shot = entries.shot)
WHERE"""
# mapping between the name for the summaries entry and the kwargs passed to this script.
name_map = dict(shot="shot", pulse_length="pulse_length", DENSITY_AVG="ne", pbeam="pnbi", pech="pech", btormax="btor", ip_flat="ip")
i = 0
for sname, name in name_map.items():
limits = getattr(self, f"{name}_lim", None)
if limits is None:
continue
else:
if len(limits) != 2:
printe(f"ERROR: ({self.xid}) length of {name} limits must be 2 - ending.")
OMFITx.End()
# if we made it to here we're good...
if i == 0:
prefix = " "
else:
prefix = " AND "
if name in ["btor", "ip"]:
cmd = f"ABS(summaries.{sname}) BETWEEN {limits[0]} AND {limits[1]} \n"
else:
cmd = f"summaries.{sname} BETWEEN {limits[0]} AND {limits[1]} \n"
self.query += prefix + cmd
i += 1
# Basic query with min/maxes complete...
if self.use_keywords and self.use_runids:
printe(f"ERROR: ({self.xid}) Cannot use both keywords and runids at the same time - ending.")
OMFITx.End()
if self.use_keywords:
if (self.keywords is None) or (len(keywords) == 0):
printe(f"ERROR: ({self.xid}) cannot have use_keywords=True but no keywords - ending.")
OMFITx.End()
kw_str = ""
logicals = ["OR"] * (len(self.keywords) - 1)
if len(self.keywords) > 0:
kw_str = f" AND entries.text LIKE '%{self.keywords[0]}%'\n"
for i, k in enumerate(keywords[1:]):
kw_str += f" {logicals[i]} entries.text LIKE '%{k}%'\n"
if self.exclude_keywords is None:
self.exclude_keywords = []
if len(self.exclude_keywords) > 0:
for k in self.exclude_keywords:
kw_str += f" AND entries.text NOT LIKE '%{k}%'\n"
self.query += kw_str
if use_runids:
if (self.runids is None) or (len(runids) == 0):
printe(f"ERROR: ({self.xid}) cannot have use_runids=True but no runids - ending.")
OMFITx.End()
runid_str = ""
if len(self.runids) == 1:
runid_str = f"AND entries.run = '{runids[0]}'\n"
elif len(self.runids) > 1:
runs = [f"'{r}'" for r in self.runids]
runs = ",".join(runs)
runid_str = f"AND entries.run IN ({runs})\n"
self.query += runid_str
# Lastly, add a simple "OR" to include specific shots in the DB,
if self.shots_to_include is not None:
self.shots_to_include = tuple(self.shots_to_include)
if len(self.shots_to_include) == 1:
self.query += f"OR summaries.shot = {self.shots_to_include[0]}\n"
else:
self.query += f"OR summaries.shot IN {self.shots_to_include}\n"
if self.debug:
print(f"DEBUG: ({self.xid}) query,")
print(self.query)
def do_query(self, verbose=True):
""" """
# Main call to OMFITrdb - returns a list of dictionaries,
t0 = time.time()
q = OMFITrdb(query=self.query, db='d3drdb', server='d3drdb', by_column=True)
t1 = time.time()
# Parse the outputs into numpy arrays and sanitize,
N = len(q)
if verbose:
printi(f"INFO: ({self.xid}) Time for OMFITrdb = {(t1-t0)*1e3:.3f} sec --> {N} entries")
shots = q["shot"]
ne = q["DENSITY_AVG"] / 1e13
btor = q["btor"]
ipsign = q["ipsign"]
dtime = q["TIME_OF_SHOT"]
pnbi = q["pbeam"]
pnbi[pnbi == None] = 0.0
pnbi /= 1e6
pech = q["pech"]
pech[pech == None] = 0.0
pech /= 1e6
topo = [t.strip() if isinstance(t, str) else "?" for t in q["topology"]]
runs = [r if r is not None else '' for r in q["run"]]
users = q["username"]
#
texts = []
for t in q["text"]:
_t = t.strip().splitlines(True)
texts.append("\n".join([tt.rstrip() for tt in _t]))
if use_runids:
check = [r in runs for r in runids]
for r, c in zip(runids, check):
if not c:
printw(f"WARN ({self.xid}): runid={r} was not returned.")
# Format everything into a pandas.Dataframe.
self.result = pd.DataFrame(
{
'shot': shots,
'datetime': np.array(dtime),
'user': np.array(users),
'btsign': np.sign(btor),
'topology': np.array(topo),
'run': np.array(runs),
'density': ne,
'btor': btor,
'ipsign': ipsign,
'pnbi': pnbi,
'pech': pech,
'logs': texts,
}
)
if self.debug:
print(f"DEBUG: ({self.xid}) Dataframe result from do_query,")
print(self.result)
def do_plot(self):
""" """
shots = self.result['shot'].values
# Create lookuptable for topologies and IP/BT configs,
topo = self.result['topology'].values
btsign = self.result['btsign'].values
ipsign = self.result['ipsign'].values
N = len(topo)
# DIII-D Bt/Ip look-up-table,
bt_ip_dict = {
"std": (btsign < 0) & (ipsign > 0), # 0
"rev_ip": ipsign < 0, # -1
"rev_bt": btsign > 0, # 1
"rev_ipbt": (btsign > 0) & (ipsign < 0), # -2
}
# map to these numbers for Boundary-norm,
bt_ip_config_dict = dict(std=0, rev_ip=-1, rev_bt=1, rev_ipbt=-2)
# Generate figure env,
fig, ax = plt.subplots(num="D3D_RDB_Explorer_main", figsize=(11, 5.5))
# Handle special "color by" cases,
if self.color_by == "bt_ip":
# Discrete colorbar,
bt_ip_config = np.zeros(N)
for k, v in bt_ip_dict.items():
bt_ip_config[v] = bt_ip_config_dict[k]
# must add 'name' for this to work in OMFIT,
self.cmap = matplotlib.colors.ListedColormap(['pink', 'r', 'b', 'c'], name="jet")
self.cmap.with_extremes(under='yellow', over='magenta')
bounds = [-2.5, -1.5, -0.5, 0.5, 1.5]
self.norm = mpl.colors.BoundaryNorm(bounds, self.cmap.N)
self.c_array = bt_ip_config
else:
# Continuous colorbar based on a given quantity,
if self.color_by in self.result:
# continuous colorbar,
self.cmap = plt.get_cmap(self.cmap_str)
self.cmap.set_under("b")
self.cmap.set_over("r")
# Mark outliers from PECH
self.c_array = self.result[self.color_by].values
max_c = np.nanmax(c_array)
min_c = np.nanmin(c_array)
self.norm = matplotlib.colors.Normalize(min_c, max_c)
else:
printe(
"ERROR: ({self.xid}) 'color_by' not found in results. Must be 'bt_ip' or a numeric value returned by OMFITrdb - ending."
)
OMFITx.End()
# Scalar mappable for consistent colorbar,
scm = matplotlib.cm.ScalarMappable(cmap=self.cmap, norm=self.norm)
# Generate initial scatterplot,
self.ax = ax
self._scatter_by_topo()
# Deal with special shots to label,
for s in self.shots_to_mark:
if s not in shots:
printw(f"WARN ({self.xid}): shot={s} was not returned by OMFITrdb, consider including it in 'shots_to_include'.")
continue
# highlight it in the plot,
i = shots.tolist().index(s)
xy = (self.result[plot_x][i], self.result[plot_y][i])
p = matplotlib.patches.Circle(xy, radius=0.5, lw=2, fc="none", ec="y")
ax.add_artist(p)
if color_by == "pech":
fig.colorbar(scm, ax=ax, label="PECH [MW]", extend="both")
elif color_by == "bt_ip":
cbar = fig.colorbar(
scm,
ax=ax,
label="BT/Ip Config.",
ticks=[-2, -1, 0, 1],
)
cbar.ax.set_yticklabels(["Rev. Ip \n+ Rev. Bt", "Rev. Ip", "Std.", "Rev. Bt"])
# Make pretty,
for v in pnbi_lim:
ax.axvline(v / 1e6, ls='--')
for v in ne_lim:
ax.axhline(v / 1e13, ls='--')
# ----------------
# Adds click-able labels for the shot number,
txt = [f"{s},{i}" for s, i in zip(shots, range(N))]
self.ifs = infoScatter(self.result[plot_x], self.result[plot_y], txt, axis=ax)
# ----------------
ax.set_xlabel("(max) Pinj [MW]")
ax.set_ylabel("(avg) density [E19 1/m^3]")
ax.legend()
title = f"""N={N} RDB results
Annotations are: (shot, index-in-db)
Press 'h' to print help."""
ax.set_title(title, fontsize=10)
# ----------------
# Must capture the mpl_connect obj. to keep events active.
self.cid = ax.figure.canvas.mpl_connect('key_press_event', self._key_press_callback)
# ----------------
ax.set_xlim(self.xlim)
ax.set_ylim(self.ylim)
# ax.set_aspect("equal")
self.fig = fig
self.ax = ax
def _scatter_by_topo(self, mask=True, alpha=0.6):
"""This function generates the scatterplot from the RDB."""
topo = self.result["topology"].values
topo_lbl = np.unique(topo)
topo_dict = dict(SNT="^", SNB="v", DN="D", IN="o")
total = 0
for i, l in enumerate(topo_lbl):
inds = (topo == l) & mask # this topology AND any additional mask.
# print(f"plotting {sum(inds)} with topo={l}")
total += sum(inds)
marker = topo_dict.get(l, "s")
x = self.result[self.plot_x].values[inds]
y = self.result[self.plot_y].values[inds]
c = self.c_array[inds]
self.ax.scatter(x, y, 100, alpha=alpha, marker=marker, c=c, cmap=self.cmap, norm=self.norm, edgecolor="k")
if self.debug:
print(f"INFO: ({self.xid}) Plotting a total of {total} points")
for l in topo_lbl:
# Dummy plot for the topo label,
marker = topo_dict.get(l, "s")
self.ax.plot([], [], ls="none", color='grey', marker=marker, ms=8, label=l)
def _key_press_callback(self, event):
"""Method defining the key-press callback for the main database plot."""
sys.stdout.flush()
if event.key == 'h':
# print the help,
help_str = """Click shots in the plot to see their (shot, index-in-database).
Click shots again to de-select.
Use the following keys for more,
't' - print the logs (entries.text) for the selected shots.
'e' - erase/clear the current selections.
'i' - isolate runs for the selected shots.
'h' - print this help.
'g' - compare gEQDSK files at time=geqdsk_time
'p' - compare ZipFit profiles at time=profiles_time.
"""
printi(help_str)
if event.key == 't':
# Print information,
drawn_annotes = self.ifs.drawnAnnotations
for v in drawn_annotes.values():
txt_obj, patch_obj = v
if not txt_obj.get_visible():
continue
t = txt_obj.get_text()
shot, ind = [int(s) for s in t.split(",")]
header = f"---- {shot},{ind} ----"
print(header)
print(f"Run={self.result['run'][ind]}")
print(f"User={self.result['user'][ind]}")
print("Entries.text=")
print(self.result['logs'][ind])
print("-" * len(header))
if event.key == "e":
# Erase the current anootations,
print("* erasing annotations")
drawn_annotes = self.ifs.drawnAnnotations
for v in drawn_annotes.values():
txt_obj, patch_obj = v
txt_obj.set_visible(False)
patch_obj.set_visible(False)
self.ifs.drawnAnnotations = dict() # reset.
# self._scatter_by_topo() # re-draw.
if event.key == "i":
# Isolate shots from the current run(s),
drawn_annotes = self.ifs.drawnAnnotations
runs_to_isolate = []
for v in drawn_annotes.values():
txt_obj, patch_obj = v
if not txt_obj.get_visible():
continue
t = txt_obj.get_text()
shot, ind = [int(s) for s in t.split(",")]
runs_to_isolate += [self.result["run"][ind]]
# Loop over runs to isolate and modify the alpha,
if len(runs_to_isolate) == 0:
print("* resetting isolation")
self._scatter_by_topo() # reset.
return
else:
print(f"* isolating runs: {runs_to_isolate}")
# Need to remove all scatterplot items,
for obj in self.ax.get_children():
# Scatterplots return PathCollection objects,
if isinstance(obj, matplotlib.collections.PathCollection):
obj.remove()
# plot using the run-mask in full-opacity,
for r in runs_to_isolate:
mask = self.result['run'].values == r
self._scatter_by_topo(mask=mask, alpha=1.0)
# Plot the rest,
mask = np.array([ri not in runs_to_isolate for ri in self.result['run'].values])
self._scatter_by_topo(mask=mask, alpha=0.05)
if event.key == "g":
drawn_annotes = self.ifs.drawnAnnotations
shots = []
for v in drawn_annotes.values():
txt_obj, patch_obj = v
if not txt_obj.get_visible():
continue
t = txt_obj.get_text()
shot, ind = [int(s) for s in t.split(",")]
shots += [shot]
# Then show the gEQDSK file...
figg, axg = plt.subplots(1, 1, num=f"{self.xid}_geqdsk")
for s in shots:
geqdsk = from_mds_plus("DIII-D", shot=s, times=[geqdsk_time], get_afile=False)['gEQDSK'][geqdsk_time]
geqdsk.plot(only2D=True, ax=axg, levels=[1], label=s) # sep only.
for k in axg.spines:
axg.spines[k].set_visible(False)
axg.legend(frameon=False)
axg.set_title(f"shape comparion at time={geqdsk_time} ms")
if event.key == "p":
drawn_annotes = self.ifs.drawnAnnotations
shots = []
for v in drawn_annotes.values():
txt_obj, patch_obj = v
if not txt_obj.get_visible():
continue
t = txt_obj.get_text()
shot, ind = [int(s) for s in t.split(",")]
shots += [shot]
figp, axp = plt.subplots(1, 1, num=f"{self.xid}_profiles")
for s in shots:
_tmp = (
OMFITmdsValue(
server='DIII-D',
treename='ELECTRONS',
shot=s,
TDI='\\ELECTRONS::TOP.PROFILE_FITS.ZIPFIT.EDENSFIT',
)
.xarray()
.sel(dim_1=profiles_time, method="nearest")
)
(l,) = axp.plot(_tmp["dim_0"].data, nominal_values(_tmp.data), label=s)
_tmp = (
OMFITmdsValue(
server='DIII-D',
treename='ELECTRONS',
shot=s,
TDI='\\ELECTRONS::TOP.PROFILE_FITS.ZIPFIT.ETEMPFIT',
)
.xarray()
.sel(dim_1=profiles_time, method="nearest")
)
axp.plot(_tmp["dim_0"].data, nominal_values(_tmp.data), '--', color=l.get_color(), label=s)
axp.legend(frameon=False)
axp.set_title(f"ne(-) and Te(--) at time={profiles_time} ms")
axp.set_ylabel("n_e [E19 /m3]")
axp.set_xlabel("rho")
if __name__ == "omfit_classes.omfit_python":
# Must capture the class instance to keep matplotlib events active.
d3drdbexp = D3D_RDB_Explorer(**dfv)