SCRIPTS RDB d3d_rdb_explorerΒΆ

# -*-Python-*-
# Created by prattq at 16 Mar 2026  13:53

"""
This script uses the DIII-D RDB to find interesting shots.

The script creates an interactive matplotlib plot with infoScatter and the
following key-press-events,
    't' - print the text logs for the selected shots.
    'e' - erase/de-select the selected shots.
    'i' - isolate/de-isolate shots from the same 'run'.
    'g' - compare gEQDSK files for the selected shots (at geqdsk_time)
    'p' - compare profiles (zipfits) for the selected shots (at profiles_time)

The default plot is 'DENSITY_AVG' vs. 'pbeam'

See,
    https://nomos.gat.com/DIII-D/comp/database/d3drdb/query_str.php
See also,
    https://nomos.gat.com/DIII-D/comp/database/d3drdb/columnlist_process.php?database=D3DRDB&table=SUMMARIES

defaultVars parameters
----------------------
:param use_keywords: (bool) use a list of 'keywords' and 'exclude_keywords' to filter RDB.
:param use_runids: (bool) use a list of D3D run ids in the 'runids' kw to filter RDB.
:param shots_to_mark: (list) these shots will be circled in the plot.
:param shots_to_include: (list) these shots will be forcibly included in the RDB results.
"""
import pandas as pd

# NOTE: defaultVars are captured and passed to the D3D_RDB_Explorer class below.
dfv = defaultVars(
    shot_lim=[202000, 204264],  # min, max shot to search,
    ne_lim=[0, 100e13],  # [1/cm^3]
    pnbi_lim=[1e6, 20e6],  # [W] or None
    pech_lim=None,  # [W] or None
    btor_lim=[1.0, 3.0],  # [T]
    ip_lim=[0.6e6, 3.0e6],  # [A]
    pulse_length_lim=[4, 10],  # [sec]
    use_keywords=True,
    keywords=["LH"],
    exclude_keywords=["detach", "IR", "XPR"],
    cmap_str="viridis",
    use_runids=False,
    runids=None,
    plot_x="density",
    xlim=[0, 20],
    plot_y="pnbi",
    ylim=[0, 12],
    shots_to_mark=[200000],  # set [] for none.
    shots_to_include=[200000],  # Force including these shots, set None for none.
    color_by="bt_ip",  # options are ['pech', 'bt_ip']
    debug=False,
    geqdsk_time=3000,  # [ms]
    profiles_time=3000,  # [ms]
    make_plot=True,
    save_to_tree=True,
)


class D3D_RDB_Explorer:
    def __init__(self, **kwargs):
        """D3D_RDB_Explorer class to manage constructing a complicated RDB query and
        plotting the results interactivley.
        """
        self.xid = "D3D_RDB_Explorer"
        for name, value in kwargs.items():
            setattr(self, name, value)

        # Creates the SQL query
        self.make_query()
        # Executes the SQL query
        self.do_query()
        # Interactive plot,
        if self.make_plot:
            self.do_plot()
        if self.save_to_tree:
            printi(f"INFO: ({self.xid}) Saving RDB results as pandas dataframe in root['OUTPUTS']...")
            root['OUTPUTS']['D3D_RDB_Explorer_output'] = self.result

    def make_query(
        self,
    ):
        """Method to construct the multi-line SQL query to be sent to the RDB."""
        self.query = f"""SELECT
    entries.shot,
    entries.text,
    entries.username,
    entries.run,
    summaries.DENSITY_AVG,
    summaries.pbeam,
    summaries.pech,
    summaries.topology,
    summaries.btor,
    summaries.ipsign,
    summaries.TIME_OF_SHOT
FROM summaries INNER JOIN entries ON (summaries.shot = entries.shot)
WHERE"""

        # mapping between the name for the summaries entry and the kwargs passed to this script.
        name_map = dict(shot="shot", pulse_length="pulse_length", DENSITY_AVG="ne", pbeam="pnbi", pech="pech", btormax="btor", ip_flat="ip")

        i = 0
        for sname, name in name_map.items():
            limits = getattr(self, f"{name}_lim", None)
            if limits is None:
                continue
            else:
                if len(limits) != 2:
                    printe(f"ERROR: ({self.xid}) length of {name} limits must be 2 - ending.")
                    OMFITx.End()
                # if we made it to here we're good...
                if i == 0:
                    prefix = " "
                else:
                    prefix = "    AND "

                if name in ["btor", "ip"]:
                    cmd = f"ABS(summaries.{sname}) BETWEEN {limits[0]} AND {limits[1]} \n"
                else:
                    cmd = f"summaries.{sname} BETWEEN {limits[0]} AND {limits[1]} \n"

                self.query += prefix + cmd
                i += 1

        # Basic query with min/maxes complete...

        if self.use_keywords and self.use_runids:
            printe(f"ERROR: ({self.xid}) Cannot use both keywords and runids at the same time - ending.")
            OMFITx.End()

        if self.use_keywords:
            if (self.keywords is None) or (len(keywords) == 0):
                printe(f"ERROR: ({self.xid}) cannot have use_keywords=True but no keywords - ending.")
                OMFITx.End()
            kw_str = ""
            logicals = ["OR"] * (len(self.keywords) - 1)
            if len(self.keywords) > 0:
                kw_str = f"    AND entries.text LIKE '%{self.keywords[0]}%'\n"

            for i, k in enumerate(keywords[1:]):
                kw_str += f"    {logicals[i]} entries.text LIKE '%{k}%'\n"

            if self.exclude_keywords is None:
                self.exclude_keywords = []
            if len(self.exclude_keywords) > 0:
                for k in self.exclude_keywords:
                    kw_str += f"    AND entries.text NOT LIKE '%{k}%'\n"

            self.query += kw_str

        if use_runids:
            if (self.runids is None) or (len(runids) == 0):
                printe(f"ERROR: ({self.xid}) cannot have use_runids=True but no runids - ending.")
                OMFITx.End()
            runid_str = ""
            if len(self.runids) == 1:
                runid_str = f"AND entries.run = '{runids[0]}'\n"
            elif len(self.runids) > 1:
                runs = [f"'{r}'" for r in self.runids]
                runs = ",".join(runs)
                runid_str = f"AND entries.run IN ({runs})\n"

            self.query += runid_str

        # Lastly, add a simple "OR" to include specific shots in the DB,
        if self.shots_to_include is not None:
            self.shots_to_include = tuple(self.shots_to_include)
            if len(self.shots_to_include) == 1:
                self.query += f"OR summaries.shot = {self.shots_to_include[0]}\n"
            else:
                self.query += f"OR summaries.shot IN {self.shots_to_include}\n"

        if self.debug:
            print(f"DEBUG: ({self.xid}) query,")
            print(self.query)

    def do_query(self, verbose=True):
        """ """
        # Main call to OMFITrdb - returns a list of dictionaries,
        t0 = time.time()
        q = OMFITrdb(query=self.query, db='d3drdb', server='d3drdb', by_column=True)
        t1 = time.time()

        # Parse the outputs into numpy arrays and sanitize,
        N = len(q)
        if verbose:
            printi(f"INFO: ({self.xid}) Time for OMFITrdb = {(t1-t0)*1e3:.3f} sec --> {N} entries")

        shots = q["shot"]
        ne = q["DENSITY_AVG"] / 1e13
        btor = q["btor"]
        ipsign = q["ipsign"]
        dtime = q["TIME_OF_SHOT"]
        pnbi = q["pbeam"]
        pnbi[pnbi == None] = 0.0
        pnbi /= 1e6
        pech = q["pech"]
        pech[pech == None] = 0.0
        pech /= 1e6
        topo = [t.strip() if isinstance(t, str) else "?" for t in q["topology"]]
        runs = [r if r is not None else '' for r in q["run"]]
        users = q["username"]
        #
        texts = []
        for t in q["text"]:
            _t = t.strip().splitlines(True)
            texts.append("\n".join([tt.rstrip() for tt in _t]))

        if use_runids:
            check = [r in runs for r in runids]
            for r, c in zip(runids, check):
                if not c:
                    printw(f"WARN ({self.xid}): runid={r} was not returned.")

        # Format everything into a pandas.Dataframe.
        self.result = pd.DataFrame(
            {
                'shot': shots,
                'datetime': np.array(dtime),
                'user': np.array(users),
                'btsign': np.sign(btor),
                'topology': np.array(topo),
                'run': np.array(runs),
                'density': ne,
                'btor': btor,
                'ipsign': ipsign,
                'pnbi': pnbi,
                'pech': pech,
                'logs': texts,
            }
        )

        if self.debug:
            print(f"DEBUG: ({self.xid}) Dataframe result from do_query,")
            print(self.result)

    def do_plot(self):
        """ """
        shots = self.result['shot'].values
        # Create lookuptable for topologies and IP/BT configs,
        topo = self.result['topology'].values
        btsign = self.result['btsign'].values
        ipsign = self.result['ipsign'].values
        N = len(topo)

        # DIII-D Bt/Ip look-up-table,
        bt_ip_dict = {
            "std": (btsign < 0) & (ipsign > 0),  # 0
            "rev_ip": ipsign < 0,  # -1
            "rev_bt": btsign > 0,  # 1
            "rev_ipbt": (btsign > 0) & (ipsign < 0),  # -2
        }
        # map to these numbers for Boundary-norm,
        bt_ip_config_dict = dict(std=0, rev_ip=-1, rev_bt=1, rev_ipbt=-2)

        # Generate figure env,
        fig, ax = plt.subplots(num="D3D_RDB_Explorer_main", figsize=(11, 5.5))

        # Handle special "color by" cases,
        if self.color_by == "bt_ip":
            # Discrete colorbar,
            bt_ip_config = np.zeros(N)
            for k, v in bt_ip_dict.items():
                bt_ip_config[v] = bt_ip_config_dict[k]
            # must add 'name' for this to work in OMFIT,
            self.cmap = matplotlib.colors.ListedColormap(['pink', 'r', 'b', 'c'], name="jet")
            self.cmap.with_extremes(under='yellow', over='magenta')
            bounds = [-2.5, -1.5, -0.5, 0.5, 1.5]
            self.norm = mpl.colors.BoundaryNorm(bounds, self.cmap.N)
            self.c_array = bt_ip_config
        else:
            #  Continuous colorbar based on a given quantity,
            if self.color_by in self.result:
                # continuous colorbar,
                self.cmap = plt.get_cmap(self.cmap_str)
                self.cmap.set_under("b")
                self.cmap.set_over("r")
                # Mark outliers from PECH
                self.c_array = self.result[self.color_by].values
                max_c = np.nanmax(c_array)
                min_c = np.nanmin(c_array)
                self.norm = matplotlib.colors.Normalize(min_c, max_c)
            else:
                printe(
                    "ERROR: ({self.xid}) 'color_by' not found in results. Must be 'bt_ip' or a numeric value returned by OMFITrdb - ending."
                )
                OMFITx.End()

        # Scalar mappable for consistent colorbar,
        scm = matplotlib.cm.ScalarMappable(cmap=self.cmap, norm=self.norm)

        # Generate initial scatterplot,
        self.ax = ax
        self._scatter_by_topo()

        # Deal with special shots to label,
        for s in self.shots_to_mark:
            if s not in shots:
                printw(f"WARN ({self.xid}): shot={s} was not returned by OMFITrdb, consider including it in 'shots_to_include'.")
                continue
            # highlight it in the plot,
            i = shots.tolist().index(s)
            xy = (self.result[plot_x][i], self.result[plot_y][i])
            p = matplotlib.patches.Circle(xy, radius=0.5, lw=2, fc="none", ec="y")
            ax.add_artist(p)

        if color_by == "pech":
            fig.colorbar(scm, ax=ax, label="PECH [MW]", extend="both")
        elif color_by == "bt_ip":
            cbar = fig.colorbar(
                scm,
                ax=ax,
                label="BT/Ip Config.",
                ticks=[-2, -1, 0, 1],
            )
            cbar.ax.set_yticklabels(["Rev. Ip \n+ Rev. Bt", "Rev. Ip", "Std.", "Rev. Bt"])

        # Make pretty,
        for v in pnbi_lim:
            ax.axvline(v / 1e6, ls='--')
        for v in ne_lim:
            ax.axhline(v / 1e13, ls='--')

        # ----------------
        # Adds click-able labels for the shot number,
        txt = [f"{s},{i}" for s, i in zip(shots, range(N))]
        self.ifs = infoScatter(self.result[plot_x], self.result[plot_y], txt, axis=ax)
        # ----------------

        ax.set_xlabel("(max) Pinj [MW]")
        ax.set_ylabel("(avg) density [E19 1/m^3]")
        ax.legend()
        title = f"""N={N} RDB results
Annotations are: (shot, index-in-db)
Press 'h' to print help."""
        ax.set_title(title, fontsize=10)

        # ----------------
        # Must capture the mpl_connect obj. to keep events active.
        self.cid = ax.figure.canvas.mpl_connect('key_press_event', self._key_press_callback)
        # ----------------

        ax.set_xlim(self.xlim)
        ax.set_ylim(self.ylim)
        # ax.set_aspect("equal")

        self.fig = fig
        self.ax = ax

    def _scatter_by_topo(self, mask=True, alpha=0.6):
        """This function generates the scatterplot from the RDB."""
        topo = self.result["topology"].values
        topo_lbl = np.unique(topo)
        topo_dict = dict(SNT="^", SNB="v", DN="D", IN="o")
        total = 0
        for i, l in enumerate(topo_lbl):
            inds = (topo == l) & mask  # this topology AND any additional mask.
            # print(f"plotting {sum(inds)} with topo={l}")
            total += sum(inds)
            marker = topo_dict.get(l, "s")
            x = self.result[self.plot_x].values[inds]
            y = self.result[self.plot_y].values[inds]
            c = self.c_array[inds]

            self.ax.scatter(x, y, 100, alpha=alpha, marker=marker, c=c, cmap=self.cmap, norm=self.norm, edgecolor="k")

        if self.debug:
            print(f"INFO: ({self.xid}) Plotting a total of {total} points")

        for l in topo_lbl:
            # Dummy plot for the topo label,
            marker = topo_dict.get(l, "s")
            self.ax.plot([], [], ls="none", color='grey', marker=marker, ms=8, label=l)

    def _key_press_callback(self, event):
        """Method defining the key-press callback for the main database plot."""
        sys.stdout.flush()
        if event.key == 'h':
            # print the help,
            help_str = """Click shots in the plot to see their (shot, index-in-database).
Click shots again to de-select.
Use the following keys for more,
    't' - print the logs (entries.text) for the selected shots.
    'e' - erase/clear the current selections.
    'i' - isolate runs for the selected shots.
    'h' - print this help.
    'g' - compare gEQDSK files at time=geqdsk_time
    'p' - compare ZipFit profiles at time=profiles_time.
"""
            printi(help_str)
        if event.key == 't':
            # Print information,
            drawn_annotes = self.ifs.drawnAnnotations
            for v in drawn_annotes.values():
                txt_obj, patch_obj = v
                if not txt_obj.get_visible():
                    continue
                t = txt_obj.get_text()
                shot, ind = [int(s) for s in t.split(",")]
                header = f"---- {shot},{ind} ----"
                print(header)
                print(f"Run={self.result['run'][ind]}")
                print(f"User={self.result['user'][ind]}")
                print("Entries.text=")
                print(self.result['logs'][ind])
                print("-" * len(header))
        if event.key == "e":
            # Erase the current anootations,
            print("* erasing annotations")
            drawn_annotes = self.ifs.drawnAnnotations
            for v in drawn_annotes.values():
                txt_obj, patch_obj = v
                txt_obj.set_visible(False)
                patch_obj.set_visible(False)
            self.ifs.drawnAnnotations = dict()  # reset.
            # self._scatter_by_topo() # re-draw.

        if event.key == "i":
            # Isolate shots from the current run(s),
            drawn_annotes = self.ifs.drawnAnnotations
            runs_to_isolate = []
            for v in drawn_annotes.values():
                txt_obj, patch_obj = v
                if not txt_obj.get_visible():
                    continue
                t = txt_obj.get_text()
                shot, ind = [int(s) for s in t.split(",")]
                runs_to_isolate += [self.result["run"][ind]]
            # Loop over runs to isolate and modify the alpha,
            if len(runs_to_isolate) == 0:
                print("* resetting isolation")
                self._scatter_by_topo()  # reset.
                return
            else:
                print(f"* isolating runs: {runs_to_isolate}")
                # Need to remove all scatterplot items,
                for obj in self.ax.get_children():
                    # Scatterplots return PathCollection objects,
                    if isinstance(obj, matplotlib.collections.PathCollection):
                        obj.remove()
            # plot using the run-mask in full-opacity,
            for r in runs_to_isolate:
                mask = self.result['run'].values == r
                self._scatter_by_topo(mask=mask, alpha=1.0)
            # Plot the rest,
            mask = np.array([ri not in runs_to_isolate for ri in self.result['run'].values])
            self._scatter_by_topo(mask=mask, alpha=0.05)

        if event.key == "g":
            drawn_annotes = self.ifs.drawnAnnotations
            shots = []
            for v in drawn_annotes.values():
                txt_obj, patch_obj = v
                if not txt_obj.get_visible():
                    continue
                t = txt_obj.get_text()
                shot, ind = [int(s) for s in t.split(",")]
                shots += [shot]
            # Then show the gEQDSK file...
            figg, axg = plt.subplots(1, 1, num=f"{self.xid}_geqdsk")
            for s in shots:
                geqdsk = from_mds_plus("DIII-D", shot=s, times=[geqdsk_time], get_afile=False)['gEQDSK'][geqdsk_time]
                geqdsk.plot(only2D=True, ax=axg, levels=[1], label=s)  # sep only.
            for k in axg.spines:
                axg.spines[k].set_visible(False)
            axg.legend(frameon=False)
            axg.set_title(f"shape comparion at time={geqdsk_time} ms")

        if event.key == "p":
            drawn_annotes = self.ifs.drawnAnnotations
            shots = []
            for v in drawn_annotes.values():
                txt_obj, patch_obj = v
                if not txt_obj.get_visible():
                    continue
                t = txt_obj.get_text()
                shot, ind = [int(s) for s in t.split(",")]
                shots += [shot]

            figp, axp = plt.subplots(1, 1, num=f"{self.xid}_profiles")
            for s in shots:
                _tmp = (
                    OMFITmdsValue(
                        server='DIII-D',
                        treename='ELECTRONS',
                        shot=s,
                        TDI='\\ELECTRONS::TOP.PROFILE_FITS.ZIPFIT.EDENSFIT',
                    )
                    .xarray()
                    .sel(dim_1=profiles_time, method="nearest")
                )
                (l,) = axp.plot(_tmp["dim_0"].data, nominal_values(_tmp.data), label=s)
                _tmp = (
                    OMFITmdsValue(
                        server='DIII-D',
                        treename='ELECTRONS',
                        shot=s,
                        TDI='\\ELECTRONS::TOP.PROFILE_FITS.ZIPFIT.ETEMPFIT',
                    )
                    .xarray()
                    .sel(dim_1=profiles_time, method="nearest")
                )
                axp.plot(_tmp["dim_0"].data, nominal_values(_tmp.data), '--', color=l.get_color(), label=s)

            axp.legend(frameon=False)
            axp.set_title(f"ne(-) and Te(--) at time={profiles_time} ms")
            axp.set_ylabel("n_e [E19 /m3]")
            axp.set_xlabel("rho")


if __name__ == "omfit_classes.omfit_python":

    # Must capture the class instance to keep matplotlib events active.
    d3drdbexp = D3D_RDB_Explorer(**dfv)