try:
# framework is running
from .startup_choice import *
except ImportError as _excp:
# class is imported by itself
if (
'attempted relative import with no known parent package' in str(_excp)
or 'No module named \'omfit_classes\'' in str(_excp)
or "No module named '__main__.startup_choice'" in str(_excp)
):
from startup_choice import *
else:
raise
from omfit_classes.utils_fusion import is_device, tokamak
import numpy as np
try:
import pyodbc
except Exception as excp:
warnings.warn('No RDB support: ' + repr(excp))
__all__ = ['OMFITrdb', 'available_efits_from_rdb', 'translate_RDBserver', 'set_rdb_password', 'get_rdb_password']
_RDBserverDict = {
# dictionary to translate mnemonic names to RDB servers stored in MainSettings['SERVER']
'd3drdb': 'd3drdb',
'atlas.gat.com': 'd3drdb',
'DIIID': 'd3drdb',
'DIII-D': 'd3drdb',
'D3D': 'd3drdb',
'gat': 'd3drdb',
'd3dpub': 'huez',
'huez.gat.com': 'huez',
'CMOD': 'alcdb2',
'C-Mod': 'alcdb2',
'EAST': 'east_database',
'east': 'east_database',
'nstx': 'nstxrdb',
'ignition': 'loki',
}
def printq(*args):
"""Shortcut so that the topic for the file can be set consistently and easily"""
printd(*args, topic='omfit_rdb')
[docs]def translate_RDBserver(server, servers=_RDBserverDict):
"""This function maps mnemonic names to real RDB servers"""
server = tokamak(server)
if server.split(':')[0] not in list(OMFIT['MainSettings']['SERVER'].keys()):
score_table = []
for item in list(servers.keys()):
m = difflib.SequenceMatcher(None, item.lower(), server.lower())
score_table.append([m.ratio(), item, servers[item]])
# return best match
scores = list(map(eval, np.array(score_table)[:, 0]))
if max(scores) > 0.80:
server = score_table[np.array(scores).argmax()][2]
else:
raise OMFITexception(
"Server '"
+ server
+ "' was not recognized.\nYou need to add '"
+ server
+ "' to the list of servers in OMFIT['MainSettings']['SERVER']"
)
return server
def _get_rdb_credential(server):
"""
Provides the name of the encrypted credentials file
:param server: string
SQL database server. Each server can have its own file.
:return: string
Credential name
"""
return 'rdb_{}:0'.format(translate_RDBserver(server))
def _read_rdb_password_from_file(server=None, filename=None):
"""
Attempts to read a typical login file and get the username and password
:param server: string
Sets the default filenames to search
:param filename: string
Override the filename to search
:return: (string, string)
Username and password if successful or ('', '') if unsuccessful
"""
username = ''
password = ''
files = {
'huez': ['.pgdb', '.pgpass'],
'pppl_postgres': ['.pgdb', '.pgpass'],
'd3drdb': ['D3DRDB.sybase_login', 'd3drdb.sybase_login'],
'd3dsqlsrvr': ['D3DRDB.sybase_login', 'd3drdb.sybase_login'],
'd3sqlsrvr': ['D3SQLSRVR.sybase_login', 'd3sqlsrvr.sybase_login'],
'east_database': ['east_database.sybase_login'],
'energy': ['ENERGY.sybase_login'],
'nstxrdb': ['nstxlogs.sybase_login'],
'loki': ['loki.sybase_login'],
}.get(translate_RDBserver(server), [''])
files = [filename] if filename is not None else files
def parse_login1(ff):
return map(lambda x: x.strip(), ff.read().split(':')[-2:])
def parse_login2(ff):
return map(lambda x: x.strip(), ff.readlines()[-2:])
parsers = [parse_login1, parse_login2]
printq('Searching for {} login info in files {}'.format(server, files))
for login_file in files:
login_file = os.environ['HOME'] + os.sep + login_file
if os.path.exists(login_file):
for i, parser in enumerate(parsers):
try:
printq(' Attempt {}: read {} login from {}'.format(i, server, login_file))
with open(login_file, 'r') as _f:
username, password = parser(_f)
printq(' Success: username {} from file on attempt {}.'.format(username, i))
break
except Exception:
etype, value, tb = sys.exc_info()
exc_info = traceback.format_exception(etype, value, tb)
printq(' Fail attempt {} at reading {}; {} attempts left'.format(i, login_file, len(parsers) - 1 - i), exc_info)
else:
printq(" File {} doesn't exist; can't get login for {} from it.".format(login_file, server))
return username, password
def get_public_rdb_login(server):
"""
Returns the public or guest username and password for server
:param server: string
:return: (string, string)
"""
default_user_pass = ('guest', 'guest_pwd') # Applies to d3drdb, d3dsqlsrvr, d3sqlsrvr, east_database, energy
return {'huez': ('d3dpub', 'd3dpub'), 'pppl_postgres': ('nstxpub', 'nstxpub')}.get(translate_RDBserver(server), default_user_pass)
[docs]def set_rdb_password(server, username=None, password=None, guest=False):
"""
Sets up an encrypted password for OMFIT to use with SQL databases on a specific server
:param server: string
The server this credential applies to
:param username: string
Username on server. If a password is specified, this defaults to os.environ['USER'].
If neither username nor password is specified, OMFIT will try to read both from a login file.
:param password: string
The password to be encrypted. Set to '' to erase the exting password for username@server, if there is one.
If None, OMFIT will attempt to read it from a default login file, like .pgpass. This may or may not be the right
password for server.
:param guest: bool
Use guest login and save it. Each server has its own, with the default being guest // guest_pwd .
"""
cred = _get_rdb_credential(server)
if password == '':
printq('Reset credentials because password was ""')
return reset_credential(credential=cred)
elif guest:
username, password = get_public_rdb_login(server)
elif password is None:
username, password = _read_rdb_password_from_file(server)
elif username is None:
username = os.environ['USER']
if username is None or password is None:
raise ValueError(
'Failed to figure out a username and password combination to use for server {}. '
'Try again with guest=True to load public/guest login.'
)
userpass = '{} {}'.format(username, password)
encrypt_credential(credential=cred, password='', otp=userpass)
printi('Encrypted credentials for {}@{}'.format(username, (translate_RDBserver(server))))
return
[docs]def get_rdb_password(server, username=None, password=None, guest=False, guest_fallback_allowed=True):
"""
Returns the RDB username and password for server
:param server: string
Servers can have different passwords, so you have to tell us which one you're after
:param username: string
Defaults to os.environ['USER']
:param password: string [optional]
Override decrypted credential and just return this input instead
:param guest: bool
Return guest login for server. Each server has its own, with the default being guest // guest_pwd .
:param guest_fallback_allowed: bool
Allowed to return guest credentials if look up fails. Otherwise, raise ValueError
:return: (string, string)
The username and password for server
"""
if guest:
return get_public_rdb_login(server)
if username is None or password is None:
userpass = decrypt_credential(_get_rdb_credential(server))[1].split(' ')
username, password = userpass[0], ' '.join(userpass[1:])
if not len(password):
try:
set_rdb_password(server)
except Exception:
pass
else:
userpass = decrypt_credential(_get_rdb_credential(server))[1].split(' ')
username, password = userpass[0], ' '.join(userpass[1:])
if not len(password):
if guest_fallback_allowed:
printq(f'Password lookup failed for {server}. Using guest/public credentials...')
return get_public_rdb_login(server)
raise ValueError(
'''Use set_rdb_password('{server:}') to encrypt your password for {server:} (translates to {t_server:}).
If your password for {server:} is in .pgpass in your home directory, it will be read from there.
Otherwise, you must provide it to set_rdb_password().
Your password for huez would be in .pgpass or .pgdb
Your password for d3drdb would be in D3DRDB.sybase_login or d3drdb.sybase_login
Your password for d3dsqlsrvr would be in D3SQLSRVR.sybase_login or d3sqlsrvr.sybase_login
'''.format(
server=server, t_server=translate_RDBserver(server)
)
)
else:
return username, password
[docs]class OMFITrdb(SortedDict):
"""
Class used to connect to relational databases
"""
def __init__(self, query=None, db='d3drdb', server='d3drdb', by_column=False):
"""
:param query: string
SQL SELECT query
:param db: string
database to connect to
:param server: string
SQL server to connect to (e.g. `d3drdb` as listed under OMFIT['MainSettings']['SERVER'])
:param by_column: bool
False: return results by rows (can be slow for large number of records)
Result of .select() is a SortedDict with keys numbered 0, 1, ... and each holding another SortedDict,
within which keys correspond to columns
True: return results by columns
Result of .select() is a SortedDict with keys corresponding to columns,
each holding an array with length set by the number of rows selected.
"""
SortedDict.__init__(self)
self.dynaLoad = True
self.query = query
self.db = db
self.server = server
self.cnxn = None
self.cursor = None
self.by_column = by_column
@property
def postgresql(self):
return self.db == 'd3d' or self.server == 'huez' or self.server == 'pppl_postgres'
def _connect(self, redact_password=True):
"""Connects to the database"""
connection = self._setup_connection()
try:
if self.postgresql:
import psycopg2
try:
self.cnxn = psycopg2.connect(**connection)
except psycopg2.OperationalError:
connection = self._setup_connection(force_tunnel=True)
self.cnxn = psycopg2.connect(**connection)
else:
self.cnxn = pyodbc.connect(**connection)
self.cursor = self.cnxn.cursor()
except Exception as _excp1:
if "Login failed for user " in repr(_excp1) and self.server == 'd3drdb':
raise ValueError(
"""
Connection to the D3DRDB relational database requires the file D3DRDB.sybase_login to be present in the $HOME directory of the user on the workstation where OMFIT is running.
You should be able to copy this file from the $HOME folder on the GA iris cluster.
If this file is not available to you contact the GA computer group: iris-support@fusion.gat.com
"""
)
connection2 = {}
try:
if self.postgresql:
connection2 = connection
raise
dsn = 'OMFIT_tmp'
self._buildUserConfig(connection, dsn)
connection2['DSN'] = dsn
connection2['UID'] = connection['UID']
connection2['PWD'] = connection['PWD']
self.cnxn = pyodbc.connect(**connection2)
self.cursor = self.cnxn.cursor()
except Exception as _excp2:
msg1 = '\nFailed connection details: ' + repr(_excp1)
msg2 = '\nFailed connection details: ' + repr(_excp2)
if redact_password:
msg1 = msg1.replace(connection['PWD'], '<PASSWORD REDACTED>')
msg2 = msg2.replace(connection['PWD'], '<PASSWORD REDACTED>')
printe(msg1)
printe(connection)
printe(msg2)
printe(connection2)
printe('\n')
raise
@staticmethod
def _buildUserConfig(connection, dsn):
"""
Sets up OMFIT-related controls in the .odbc.ini file in the user's home directory
:param connection: dict
Contains connection information. Typically formed by self._setup_connection()
:param dsn: string
This is a tag or header name for the relevant section of the .odbc.ini file
"""
filename = os.environ['HOME'] + os.sep + '.odbc.ini'
# Create if it did not exist and give OMFIT full control over it
header = '#--below this line OMFIT is in control of this file--\n'
if not os.path.exists(filename):
with open(filename, 'w') as f:
f.write(header)
# Retain what is above the header line
with open(filename, 'r') as _f:
top = _f.read().split(header)[0]
# Text to add to .odbc.ini
txt = []
c = ''
if '[' + dsn + ']' in top:
c = '# '
txt.append('')
txt.append(c + '[' + dsn + ']')
for item in connection:
txt.append(c + item + ' = ' + str(connection[item]))
txt.append('')
# Write updated .odbc.ini
with open(filename, 'w') as f:
f.write(top + header + '\n'.join(txt))
return
[docs] def select(self, query=None, by_column=None):
"""
Pass a query to the database, presumably with a SELECT statement in it
:param query: string
A string such as "SELECT * from profile_runs where shot=126006";
if None then use the query used to instantiate the object
:param by_column: bool [optional]
If True or False, override the self.by_column set at instantiation
:return: SortedDict
by_column = False:
Keys are numbered 0, 1, 2, ...
The value behind each key is a SortedDict
The keys of the child SortedDict are the columns selected
by_column = True:
Keys are column names
Values are arrays with length set by the number of rows selected
"""
if query is None and self.query is None:
return SortedDict()
if self.cnxn is None or self.cursor is None:
self._connect()
if not self.postgresql:
self.cursor.execute("use " + self.db)
if query is None:
query = self.query
if by_column is None:
by_column = self.by_column
self.cursor.execute(query)
rows = self.cursor.fetchall()
if self.postgresql:
colnames = [desc[0] for desc in self.cursor.description]
else:
colnames = None
res = SortedDict()
# Not by column can be very computationally expensive if there are many records
if not by_column:
for r, row in enumerate(rows):
res[r] = SortedDict()
for c, col in enumerate(row):
if not self.postgresql:
name = row.cursor_description[c][0]
else:
name = colnames[c]
res[r][name] = row[c]
# Faster, by column organization
else:
for r, row in enumerate(rows):
names = []
for c, col in enumerate(row):
if not self.postgresql:
name = row.cursor_description[c][0]
else:
name = colnames[c]
if name in names:
continue
names.append(name)
if r == 0:
res[name] = []
res[name].append(row[c])
for name in list(res.keys()):
res[name] = np.array(res[name])
return res
[docs] def custom_procedure(self, procedure, commit=True, **arguments):
"""
Pass an arbitrary custom procedure to the sql database
:param procedure: string
A string that represents the custom procedure to be called
:param commit (bool): If set to False it will not commit the data to the coderunrdb. This should be done when running a
jenkins test, otherwise it may attempt to write data to the same shot/runid twice and throw an error.
:return: Dict
Output list of pyodbc rows returned by the custom query
"""
arguments = ', '.join(
[
"@%s='%s'" % (arg, arguments[arg]) if isinstance(arguments[arg], str) else "@%s=%s" % (arg, arguments[arg])
for arg in arguments
]
)
command = "exec %s " % (procedure) + arguments
if self.cnxn is None or self.cursor is None:
self._connect()
self.cursor.execute("use " + self.db)
try:
self.cursor.execute(command)
except pyodbc.ProgrammingError as e:
printi("Failed to execute the following command\n")
printi(command)
raise e
try:
rows = self.cursor.fetchall()
except Exception as err:
rows = err
if commit:
self.cnxn.commit()
return rows
[docs] def alter_add(self, table, column_name, column_type, commit=True, verbose=True):
"""
Alter table in SQL database by adding a column
:param table: string
table name
:param column_name: string
column name
:param column_type: string
column type, has to be an SQL DataType, e.g. BOOLEAN, FLOAT etc..
:param commit: bool
commit alter command to SQL
:param verbose: bool
print SQL command being used
"""
if self.postgresql:
raise NotImplementedError('File issue on Github if this is needed')
if self.cnxn is None or self.cursor is None:
self._connect()
self.cursor.execute("use " + self.db)
sql_cmd = "ALTER TABLE %s ADD %s %s" % (table, str(column_name), str(column_type))
if verbose:
printi(sql_cmd)
self.cursor.execute(sql_cmd)
if commit:
self.commit()
return sql_cmd
[docs] def alter_drop(self, table, column_name, commit=True, verbose=True):
"""
Alter table in SQL database by dropping a column
:param table: string
table name
:param column_name: string
column name
:param commit: bool
commit alter command to SQL
:param verbose: bool
print SQL command being used
:return: string
SQL command
"""
if self.postgresql:
raise NotImplementedError('File issue on Github if this is needed')
if self.cnxn is None or self.cursor is None:
self._connect()
self.cursor.execute("use " + self.db)
sql_cmd = "ALTER TABLE %s DROP COLUMN %s" % (table, str(column_name))
if verbose:
printi(sql_cmd)
self.cursor.execute(sql_cmd)
if commit:
self.commit()
return sql_cmd
[docs] def delete(self, table, where, commit=True, verbose=True):
"""
Delete row(s) in SQL database
:param table: string
table where to update
:param where: string or dict
Which record or records should be deleted
NOTE that all records that satisfy this condition will be deleted!
A dict will be converted into a string of the form "key1=value1 and key2=value2 ..."
:param commit: bool
commit delete to SQL
:param verbose: bool
print SQL command being used
"""
if self.cnxn is None or self.cursor is None:
self._connect()
self.cursor.execute("use " + self.db)
if isinstance(where, dict):
where = ' and '.join([str(x[0]) + '=' + repr(x[1]) for x in zip(list(where.keys()), list(where.values()))])
sql_cmd = "DELETE FROM %s WHERE %s" % (table, where)
if verbose:
printi(sql_cmd)
self.cursor.execute(sql_cmd)
if commit:
self.commit()
return sql_cmd
[docs] def commit(self):
"""
Commit commands in SQL database
"""
printq('committing')
self.cnxn.commit()
[docs] def update(self, table, data, where, commit=True, overwrite=1, verbose=True):
"""
Update row(s) in SQL database
:param table: string
Table to update.
:param data: dict
Keys are columns to update and values are values to put into those columns.
:param where: dict or string
Which record or records should be updated.
NOTE that all records that satisfy this condition will be updated!
If it's a dictionary, the columns/data condition will be concatenated with " AND ", so that
{'my_column': 5.2, 'another_col': 7} becomes "my_column=5.2 AND another_col=7".
A string will be used directly.
:param commit: bool
Commit update to SQL. Set to false for testing without editing the database.
:param overwrite: bool or int
0/False: If any of the keys in data already have entries in the table, do not update anything
1/True: Update everything. Don't even check.
2: If any of the keys in dat already have entries in the table, don't update those,
but DO write missing entries.
:param verbose: bool
Print SQL command being used.
:return: string
The SQL command that would be used.
If the update is aborted due to overwrite avoidance, the SQL command will be prefixed by "ABORT:"
"""
if self.cnxn is None or self.cursor is None:
self._connect()
if not self.postgresql:
self.cursor.execute("use " + self.db)
def mk_kv_list(a):
"""
Turns a dictionary of columns and values into a list of strings of the form 'k=v'
to be later concatenated by ", " or " AND "
:param a: dict-like
:return: list of strings
"""
return [str(x[0]) + '=' + x[1] for x in zip(list(a.keys()), tolist(self._sqlRepr(list(a.values()))))]
if isinstance(where, dict):
where = ' AND '.join(mk_kv_list(where))
if len(where):
where = ' WHERE ' + where
columns_values = ', '.join(mk_kv_list(data))
sql_cmd = "UPDATE %s SET %s%s" % (table, columns_values, where)
if verbose:
printi(sql_cmd)
if overwrite not in [1, True, 1.0]:
# Need to check for pre-existing data
columns = ', '.join(list(data.keys()))
query = 'SELECT {columns:} FROM {table:}{where:}'.format(**locals())
results = self.select(query=query, by_column=False)
results_with_content = []
for result in results.values():
cleaned_result = {k: v for k, v in result.items() if v is not None}
if len(cleaned_result):
results_with_content += [cleaned_result]
if len(results_with_content) and overwrite == 0:
printw('Requested update aborted to avoid overwriting some data.')
return 'ABORT:' + sql_cmd
elif len(results_with_content) and overwrite > 1:
skipped_columns = []
for cleaned_result in results_with_content:
for key in cleaned_result:
data.pop(key, None)
skipped_columns += [key]
if len(data) == 0:
printw('To avoid overwriting, all data were removed from the update list & there is nothing to do.')
return 'ABORT:' + sql_cmd
else:
columns_values = ', '.join(mk_kv_list(data))
sql_cmd = "UPDATE %s SET %s%s" % (table, columns_values, where)
printq('New update list: {}'.format(columns_values))
printq('Data removed: {}'.format(list(set(skipped_columns))))
if verbose:
print('Update command after avoiding overwriting anything:')
printi(sql_cmd)
self.cursor.execute(sql_cmd)
if commit:
self.commit()
return sql_cmd
[docs] def insert(self, table, data, duplicate_key_update=False, commit=True, verbose=True):
"""
Insert row in SQL database
:param table: string
table where data will be inserted
:param data: dict, list, or tuple
dict: keys are column names
list or tuple:
dictionary (columns & values) or list/tuple of values
:param duplicate_key_update: bool
append ' ON DUPLICATE KEY UPDATE' to INSERT command
:param commit: bool
commit insert to SQL
:param verbose: bool
print SQL command being used
:return: The SQL command that was used or would be used
"""
if self.cnxn is None or self.cursor is None:
self._connect()
if not self.postgresql:
self.cursor.execute("use " + self.db)
if isinstance(data, (list, tuple)):
values = (','.join(["%s"] * len(data))) % tuple(self._sqlRepr(data))
sql_cmd = "INSERT INTO %s VALUES (%s)" % (table, values)
else:
columns = (','.join(["%s"] * len(data))) % tuple(data.keys())
values = (','.join(["%s"] * len(data))) % tuple(self._sqlRepr(list(data.values())))
sql_cmd = "INSERT INTO %s (%s) VALUES (%s)" % (table, columns, values)
if duplicate_key_update:
sql_cmd += ' ON DUPLICATE KEY UPDATE'
if verbose:
printi(sql_cmd)
self.cursor.execute(sql_cmd)
if commit:
self.commit()
return sql_cmd
[docs] def copy_row(self, table, **kw):
r"""
Copy one row of a table to a new row of a table
:param table: string
table in the database for which to copy a row
:param \**kw: The keywords passed must be the primary keys of the table.
The values of the keywords must be two-element containers: (copy_from, copy_to)
:return: string
SQL command
"""
if not self.postgresql:
raise NotImplemented('File issue on GitHub if needed')
if self.cnxn is None or self.cursor is None:
self._connect()
if not self.postgresql:
self.cursor.execute("use " + self.db)
primary = self.primary_keys(table)
for k in kw:
if k not in primary:
raise ValueError('%s is not a primary key; only primary keys allowed for copying a row' % k)
where = []
where2 = []
for k, v in list(kw.items()):
where.append('%s=%s' % (k, self._sqlRepr(v[0])))
where2.append('%s=%s' % (k, self._sqlRepr(v[1])))
query = 'SELECT * from %s where %s' % (table, ' and '.join(where))
copy_from = self.select(query)[0]
if 'userid' in copy_from:
copy_from['userid'] = self.username
for k, v in list(kw.items()):
copy_from[k] = v[1]
insert_row = 'INSERT INTO {0}({1}) VALUES ({2});'.format(
table, ','.join(list(copy_from.keys())), ','.join(['%s'] * len(list(copy_from.values())))
)
self.cursor.execute(insert_row, list(copy_from.values()))
self.commit()
return insert_row
[docs] def primary_keys(self, table):
"""
Return the keys that are the primary keys of the table
:param table: string
table for which to evaluate the primary keys
:return: list of strings
"""
if not self.postgresql:
raise NotImplemented('File issue on GitHub if needed')
if self.cnxn is None or self.cursor is None:
self._connect()
if not self.postgresql:
self.cursor.execute("use " + self.db)
self.cursor.execute("SELECT indexdef from pg_indexes where tablename = '%s'" % table)
return list(map(lambda x: str(x).strip(), self.cursor.fetchall()[0][0].split('(')[1].split(')')[0].split(',')))
def _sqlRepr(self, data):
"""
Forms the SQL representation of data
:param data: dict
:return: string
"""
if self.postgresql:
# Adapted from Osborne tools setup_run_table._value_to_dbstring:
# CONVERT PYTHON VALUE TO EXPRESSION FOR ENTRY INTO DB TABLE
sl = []
for v in np.atleast_1d(data):
if isinstance(v, str):
s = "'" + v + "'"
elif v is None:
s = 'NULL'
elif np.iterable(v):
s = '{'
for x in v:
if isinstance(x, str):
s = s[:] + '"' + x + '",'
else:
s = s[:] + str(x) + ','
if len(v):
s = s[:-1] + '}'
else:
s = s[:] + '}'
s = "'" + s + "'"
else:
s = str(v)
sl.append(s)
if len(sl) == 1:
return sl[0]
return sl
else:
data_out = []
for k, d in enumerate(np.atleast_1d(data)):
if d is None or (is_numeric(d) and np.isnan(d)):
data_out.append('NULL')
else:
data_out.append(repr(d))
return data_out
def __getstate__(self):
return {
'query': self.query,
'db': self.db,
'server': self.server,
'cnxn': None,
'cursor': None,
'dynaLoad': self.dynaLoad,
'by_column': self.by_column,
}
def __setstate__(self, new_state):
"""
:param new_state: dict or tuple containing dict in [0]
"""
if isinstance(new_state, tuple):
new_state = new_state[0]
for k in list(new_state.keys()):
setattr(self, k, new_state[k])
if 'by_column' not in new_state:
self.by_column = False
return
[docs] @dynaLoad
def load(self):
"""
Connect to the database and retrieve its content
"""
if self.query is not None:
self.clear()
res = self.select()
SortedDict.update(self, res)
def _setup_connection(self, force_tunnel=False):
# Get server/tunnel from OMFIT['MainSettings']['SERVER'] list
server0 = translate_RDBserver(self.server, _RDBserverDict)
MS_server0 = OMFIT['MainSettings']['SERVER'][server0]
server = MS_server0['RDB_server']
tunnel = MS_server0.get('tunnel', '')
# Try directly connecting to the RDB server
if server not in OMFITaux['RDBserverReachable']:
OMFITaux['RDBserverReachable'][server] = test_connection(None, *parse_server(server)[2:], timeout=1, ntries=1)
rdb_server = list(parse_server(server))
username, password = get_rdb_password(server0)
if username != rdb_server[0]:
printw(
"`get_rdb_password('{}')` username is `{}` is different from username `{}` indicated in MainSettings\n"
"Use `set_rdb_password('{}', username='{}', password=...)` to setup username/password associated with this server.".format(
server0, username, rdb_server[0], server0, rdb_server[0]
)
)
self.username = username
# If direct connection fails, use tunneling
if not OMFITaux['RDBserverReachable'][server] or force_tunnel:
# Handle server tunneling (this function does already buffering)
rdb_server[2:] = list(map(str, setup_ssh_tunnel(server, tunnel, True, ssh_path=SERVER['localhost'].get('ssh_path', None))))[1:]
connection = {}
if not self.postgresql:
connection['UID'] = username
connection['PWD'] = password
connection['SERVER'] = rdb_server[2]
connection['PORT'] = rdb_server[3]
if 'instance' in MS_server0:
connection['SERVER'] = rf"{connection['SERVER']},{connection['PORT']}"
del connection['PORT']
driver = 'tdsodbc'
if 'driver' in MS_server0:
connection['DRIVER'] = driver = evalExpr(MS_server0['driver'])
if isinstance(driver, str) and driver and os.sep not in driver and driver not in pyodbc.drivers():
connection['DRIVER'] = find_library(driver)
if not isinstance(connection['DRIVER'], str) or not connection['DRIVER']:
raise OMFITexception(
'The library to drive your `' + str(driver) + '` SQL connection is not setup properly.\n' + repr(connection)
)
if 'tdsodbc' in driver:
connection['TDS_Version'] = '8.0'
for k in MS_server0:
if k not in list(connection.keys()) + ['driver', 'server', 'tunnel', 'RDB_server', 'instance']:
connection[k] = MS_server0[k]
else:
connection['user'] = username
connection['password'] = password
connection['host'] = rdb_server[2]
connection['port'] = rdb_server[3]
connection['database'] = self.db
printq('')
printq(server)
for k in list(connection.keys()):
printq(k + '\t\t:\t ' + str(connection[k]))
return connection
[docs] def get_databases(self):
'''
Return a list of databases on the server
'''
return sorted(list(self.select('select name from sys.databases', by_column=True)['name']))
[docs] def get_tables(self):
'''
Return a list of tables on the given database
'''
return sorted(list(self.select('select table_name from information_schema.tables', by_column=True)['table_name']))
[docs] def get_columns(self, table):
'''
Return a list of columns in the given table
'''
if table not in self.get_tables():
raise RuntimeError(f'Table {table} not in {self.get_tables()}')
return sorted(
list(
self.select(f"select column_name from information_schema.columns where table_name='{table}'", by_column=True)['column_name']
)
)
[docs] def get_db_structure(self):
'''
Return a nested list of tables and columns of those tables within the current database
'''
result = {}
for table in self.get_tables():
result[table] = []
for column in self.get_columns(table):
result[table].append(column)
return result
[docs]def available_efits_from_rdb(scratch, device, shot, default_snap_list=None, format='{tree}', mdsplus_treename=None, **kw):
"""
Retrieves EFIT runids for a given shot from the rdb.
:param scratch: dict
dictionary where the information is cached (to avoid querying SQL database multiple times)
:param device: string
device for which to query list of available EFITs
:param shot: int
shot for which to query list of available EFITs
:param default_snap_list: dict
dictionary to which list of available efits will be passed
:param format: string
format in which to write list of available efits (tree, by, com, drun, runid) are possible options
:param **kw: quietly accepts and ignores other keywords for compatibility with other similar functions
:return: (dict, str)
dictionary with list of available options formatted as {text:format}
information about the discovered EFITs
"""
return available_code_output_from_rdb(
scratch, device, shot, default_snap_list=default_snap_list, format=format, code_name='EFIT', mdsplus_treename=mdsplus_treename, **kw
)
def available_code_output_from_rdb(
scratch, device, shot, default_snap_list=None, format='{tree}', code_name='EFIT', mdsplus_treename=None, **kw
):
"""
Retrieves runids for a given code name and shot from the rdb.
:param scratch: dict
dictionary where the information is cached (to avoid querying SQL database multiple times)
:param device: string
device for which to query list of available EFITs
:param shot: int
shot for which to query list of available EFITs
:param default_snap_list: dict
dictionary to which list of available efits will be passed
:param format: string
format in which to write list of available efits (tree, by, com, drun, runid) are possible options
:param code_name: String, either OMFITprofiles or EFIT is tested. Other may work as well.
:param **kw: quietly accepts and ignores other keywords for compatibility with other similar functions
:return: (dict, str)
dictionary with list of available options formatted as {text:format}
information about the discovered EFITs
"""
if mdsplus_treename is None:
mdsplus_treename = code_name
efit_sql = f'EFIT_SQL_{device}_{shot}'
if default_snap_list is None:
default_snap_list = {'': ''}
snap_list = copy.deepcopy(default_snap_list)
help = ''
scratch.setdefault('searched_in', [])
scratch['searched_in'] += ['rdb']
if is_device(device, ['DIII-D', 'NSTX', 'NSTXU']):
try:
query = f"SELECT * FROM plasmas WHERE shot={shot} AND code_name='{code_name}' AND deleted!=1 ORDER BY tree"
try:
if scratch[efit_sql].query != query:
scratch[efit_sql] = OMFITrdb(query, 'code_rundb', server=device)
except Exception:
scratch[efit_sql] = OMFITrdb(query, 'code_rundb', server=device)
runids = scratch[efit_sql].across("['*']['shot']")
try:
runids = scratch[efit_sql].across("['*']['run_id']")
except Exception:
pass
for tree, by, com, drun, tag, runid in zip(
*(
scratch[efit_sql].across("['*']['tree']"),
scratch[efit_sql].across("['*']['run_by']"),
scratch[efit_sql].across("['*']['run_comment']"),
scratch[efit_sql].across("['*']['date_run']"),
scratch[efit_sql].across("['*']['runtag']"),
runids,
)
):
if tree == mdsplus_treename:
item = by + ' ' + tag
tree = int(runid)
elif by in ['autoload', 'mdsadmin']:
item = f'{tree} {tag}'
else:
item = f'{tree} by {by}'
help += '[' + item + '] : "' + com.strip() + '"\n'
if drun is not None:
snap_list['[' + item + '] ' + drun.strftime(' (%d %b %Y @ %H:%M)')] = format.format(
tree=tree, runid=runid, by=by, com=com, drun=drun, tag=tag
)
else:
snap_list['[' + item + '] ' + ' (???)'] = format.format(tree=tree, runid=runid, by=by, com=com, drun=drun, tag=tag)
scratch['available_efits_from_rdb_success'] = True
except Exception as _excp:
scratch['available_efits_from_rdb_success'] = False
raise
printe('Could not retrieve list of available EFIT runs: ' + repr(_excp))
snap_list = {'EFIT01': 'EFIT01', 'EFIT02': 'EFIT02'}
help = 'Enter EFIT tree by hand\n\nCould not retrieve list of available EFIT runs:\n' + repr(_excp)
else:
scratch['available_efits_from_rdb_success'] = False
help = help.strip()
return snap_list, help