Source code for omfit_classes.omfit_toksearch

try:
    # framework is running
    from .startup_choice import *
except ImportError as _excp:
    # class is imported by itself
    if (
        'attempted relative import with no known parent package' in str(_excp)
        or 'No module named \'omfit_classes\'' in str(_excp)
        or "No module named '__main__.startup_choice'" in str(_excp)
    ):
        from startup_choice import *
    else:
        raise

from sys import version_info
from socket import gethostname
from psutil import virtual_memory
from omfit_classes.utils_fusion import is_device

__all__ = ['OMFITtoksearch', 'TKS_MdsSignal', 'TKS_PtDataSignal', 'TKS_Aligner', 'TKS_OMFITLoader']

# https://vali.gat.com/toksearch/toksearch

# CONSTANTS
LOGIN_NODE_LIMIT = 5  # max gb fetch on iris login node, set by policy
SERVER_MEM = {'iris': 120, 'saga': 120}  # TO DO: need to use remote_sysinfo to find this
SERVER_NAME = gethostname().split('.')[0]
PYTHON_VERSION = version_info[0]
TOKSEARCH_VERSION = 'toksearch/release'

allowed_partitions = set(['short', 'medium', 'long', 'nodes'])

bash_script = """#!/bin/sh -l
#SBATCH --job-name=toksearch_omfit_remote
#SBATCH --output=toksearch_omfit_remote.out
#SBATCH --partition={partition}
#SBATCH --ntasks={n}
#SBATCH --nodes=1
#SBATCH --mem={mem}G
#SBATCH --time={hours}:{minutes}:00
CUDA_VISIBLE_DEVICES=''
module purge
module load %s
python3 toksearch_python.py
"""

executables = {'iris': 'sbatch %s',
               'saga': 'sbatch %s',
               'iris_batch': bash_script % (TOKSEARCH_VERSION),
               'saga_batch': bash_script % (TOKSEARCH_VERSION)}

login_nodes = set(('irisa', 'irisb', 'irisc', 'irisd'))

toksearchQueryString = """
def toksearchQueryRun(shots, signals, server_info, datasets=None, aligners=None, functions=None, where=None, keep=None,
                      compute_type='ray', return_data='by_shot', warn=True, use_dask=True, **compute_kwargs):
    from sys import path
    path.insert(0, '/fusion/projects/omfit-results/toksearch')
    from toksearchQuery import toksearchQuery
    return toksearchQuery(shots, signals, server_info, datasets, aligners, functions, where, keep,
                          compute_type, return_data, warn, use_dask, **compute_kwargs)
"""


# helper functions
def format_time_limit(hours=1,minutes=0):
    #avoid invalid times
    minutes = int(minutes)
    hours = int(hours)
    if minutes < 0 or minutes > 60:
        raise ValueError('Minutes for time limit must be an integer between [0,60), you entered: %s'%(minutes))
    if hours < 0 or hours > 24:
        raise ValueError('Hours for time limit must be an integer between [0,24], you entered: %s'%(hours))
    hours = '%02d'%(hours)
    minutes = '%02d'%(minutes)
    return hours, minutes

def server_info():
    from resource import getrusage, RUSAGE_SELF
    return {'start_memory_usage': getrusage(RUSAGE_SELF).ru_maxrss,
            'omfit_server': SERVER_NAME,
            'memory_available': virtual_memory().available,
            'python_version': PYTHON_VERSION}


def toksearch_wrapper(basename, *args, **kw):
    kw = {} if kw is None else kw
    args = [] if args is None else args
    return [basename, args, kw]


[docs]def TKS_MdsSignal(*args, **kw): return toksearch_wrapper('MdsSignal', *args, **kw)
[docs]def TKS_PtDataSignal(*args, **kw): return toksearch_wrapper('PtDataSignal', *args, **kw)
[docs]def TKS_OMFITLoader(*args, **kw): return toksearch_wrapper('OMFITLoader', *args, **kw)
[docs]def TKS_Aligner(align_with, **kw): ''' Function that takes in a signal name, and keyword arguments and puts them in correct format that the toksearch query method expects. The signal specified is the one that the dataset is intended to be aligned with. :param align_with: A string respresenting the name of the signal name in 'signals' that the dataset is to be aligned with respect to. ''' return (align_with, kw)
def WARN(size): message = "WARNING: It is estimated that you are requesting approx. %.2f GB of data. Are you sure you want to do this?" % (size / (1024 ** 3)) if OMFITaux.get('GUI', None): from omfit_classes.OMFITx import Dialog return Dialog(message, ['Yes', 'No'], 'warning', "TOKSEARCH REQUESTING LARGE FILE SIZE") else: printw(message) def NOTIFY(size): if SERVER_NAME in login_nodes: message = "ERROR: You are trying to pull back approx. %.2f GB of data. This will put you beyond the allowed limit of %s GB of accumulated data on a login node. \nPlease reduce your query size, or use SLURM que or another machine to run OMFIT." % (size * 1e-9, LOGIN_NODE_LIMIT) else: message = "ERROR: You are trying to pull back approx. %.2f GB of data. Which is beyond the available limit on your machine." % (size / (1024 ** 3)) from omfit_classes.OMFITx import Dialog return Dialog(message, ['Ok'], 'error', "TOKSEARCH MEMORY LIMIT HIT") def predictTokUse(mem_required): return round(mem_required * 4 + 25) class OMFITtksValue(): ''' Wrapper for toksearch class, to be interfaced as an OMFIT MDS plus class. ''' def __init__(self, record, server=None, treename=None, shot=None, TDI=None, quiet=False): self.record = record self.server = server self.treename = treename self.shot = shot self.TDI = TDI if shot is None or TDI is None: self.is_dataset = False else: self.is_dataset = isinstance(self.record[self.shot][self.TDI], xarray.core.dataset.Dataset) def data(self, sig_name=None): try: if self.is_dataset: if sig_name is None: return self.record[self.shot][self.TDI].variables else: return self.record[self.shot][self.TDI][sig_name] else: return self.record[self.shot][self.TDI]['data'] except Exception as ex: print("Error while fetching data: " + str(ex)) def dim_of(self, dim): if self.is_dataset: dim_name = list(self.record[self.shot][self.TDI].dims.keys())[dim] else: dim_name = self.record[self.shot][self.TDI]['dims'][dim] return self.record[self.shot][self.TDI][dim_name] def units(self, sig_name=None): try: if self.is_dataset: if sig_name is None: return self.record[self.shot][self.TDI].attrs['units'] else: return self.record[self.shot][self.TDI][sig_name].attrs['units'] else: return self.record[self.shot][self.TDI]['units']['data'] except Exception as ex: return '?' def units_dim_of(self, dim): try: if self.is_dataset: dim_name = list(self.record[self.shot][self.TDI].dims.keys())[dim] return self.record[self.shot][self.TDI][dim_name].attrs['units'] else: dim_name = self.record[self.shot][self.TDI]['dims'][dim] return self.record[self.shot][self.TDI]['units'][dim_name] except Exception as ex: return '?' def xarray(self): """ :return: DataArray with information from this node """ data = self.data() if self.is_dataset: dims = list(self.record[self.shot][self.TDI].dims.keys()) else: dims = list(self.record[self.shot][self.TDI]['dims']) clist=[] for k, c in enumerate(dims): clist.append(self.dim_of(k)) if data.shape != tuple([len(k) if np.ndim(k) == 1 else k.shape[ik] for ik,k in enumerate(clist)]): dims=dims[::-1] clist=clist[::-1] coords = {} for k, c in enumerate(dims): if np.ndim(clist[k]) == 1: ck=c coords[ck]=([c],clist[k],{'units': self.units_dim_of(k)}) else: ck=c+'_val' coords[ck]=(dims, clist[k],{'units': self.units_dim_of(k)}) xdata = DataArray(data, dims=dims, coords=coords, attrs={'units': self.units()}) return xdata def toksearchQueryRemote(serverPicker, shots, signals, datasets=None, aligners=None, functions=None, where=None, keep=None, compute_type='ray', mem_requested=30, return_data='by_shot', use_dask=True, load_data=True, queue_args={}, warn=True, **compute_kwargs): ''' This function creates a toksearch query on a designated server (server must support toksearch). Takes in a list shot numbers and signal point names. :param serverPicker: (string) A string designating the server to create the toksearch query on. :param shots: A list of shot numbers (ints) to be fetched :param signals: A dict where each key corresponds to the signal name returned by toksearch, and each entry is a list which corresponds to a signal object fetched by toksearch. The first element of the list is the string corresponding to each signal name, i.e. 'PtDataSignal', the 2nd and 3rd entries are the args (list), and keyword args (dict) respectively. Ex) ['PtDataSignal',['ip'],{}] corresponds to a fetch for PtData 'ip' signal. :param datasets: A dict representing xarray datasets to be created from fetched signals. :param aligners: A dict where the keys are name of the dataset to align and the entries are a corresponding list of Aligners :param functions: A list of functions or executable strings to be executed in the toksearch mapping stage :param where: (string) An evaluatable string (executed in namespace of record) which should return a boolean when the record should be returned by toksearch query. This shall be used when trying to filter out shots by certain criteria. i.e. return False when you wish to filter out a shot, when string evaluates to True. :param keep: A list of strings pertaining to which attributes (signal,dataset etc.) of each record to be returned by toksearch query default: returns all attrs in namespace record :param compute_type : (string) Type of method to be used to run the pipeline. Options: 'serial','spark','ray' compute_type='ray' gives better memory usage and parallelization :param return_data: (string) A string pertaining on how the data fetched by toksearch should be structured. Options: 'by_shot','by_signal'. 'by_shot' will return a dictionary with shots as keys, with each record namespace stored in each entry. 'by_signal' will return a dictionary organized with the union of all record attrs as keys, and a dictionary organized by shot numbers under it. default: 'by_shot' NOTE: When fetching 'by_signal' Datasets will concatinated together over all valid shots. :param use_dask: (bool) Boolean of whether to load datasets using dask. Loading with dasks reduces the amount of RAM used by saving the data to disk and only loading into memory by chunks. default: False :param load_data: return data or return list of files for the user to handle :param **compute_kwargs: keyword arguments to be passed into the toksearch compute functions :return: data or return list of files for the user to handle depending on `load_data` switch ''' queued_system = {'clean_after': True} queue_args.setdefault('partition', 'short' if is_server(serverPicker, 'iris') else 'nodes') # type of queue resource queue_args.setdefault('n', 10) # number of cpus queue_args.setdefault('mem', int(mem_requested)) # amount of RAM allocated queue_args.setdefault('hours',1) queue_args.setdefault('minutes',0) queue_args['hours'], queue_args['minutes'] = format_time_limit(queue_args['hours'],queue_args['minutes']) if not queue_args['partition'] in allowed_partitions: raise RuntimeError("ERROR: Partition type %s not supported. Choose: 'short','medium' or 'long'" % (queue_args['partition'])) if is_server(serverPicker, 'saga'): queued_system['script'] = (executables['saga_batch'].format(**queue_args), 'batch.sh') queued_system['queued'] = True queued_system['std_out'] = 'toksearch_omfit_remote.out' else: raise NotImplementedError(str(serverPicker) + " is not a supported server.") printe("Error: " + str(serverPicker) + " is not a supported server.") return namespace = {'shots': shots, 'signals': signals, 'functions': functions, 'where': where, 'keep': keep, 'aligners': aligners, 'datasets': datasets, 'compute_type': compute_type, 'use_dask': use_dask, 'return_data': return_data.lower(), 'warn': warn, 'server_info': server_info()} namespace.update(compute_kwargs) data_files = ['toksearch.toks'] if return_data == 'by_signal': data_files.extend(ds + '.tokds' for ds in namespace['datasets']) from omfit_classes.OMFITx import remote_python from time import time date_time = now("%Y-%m-%d_%H_%M_%S_") + (str(hash(time()))[:6]) run_folder = 'toksearch_run_' + date_time + '.tokrun' workDir = SERVER['localhost']['workDir'] data_files, mem_usage, data_size, ret_code = remote_python(None, python_script=(toksearchQueryString, 'toksearch_python.py'), target_function='toksearchQueryRun', namespace=namespace, executable=executables[serverPicker], server=SERVER[serverPicker]['server'], workdir=workDir + run_folder, remotedir=SERVER[serverPicker]['workDir'] + run_folder, tunnel=SERVER[serverPicker]['tunnel'], outputs=data_files, **queued_system) move_data = OMFITsessionDir not in workDir data_path = OMFITsessionDir + '/' + run_folder if move_data: if not os.path.exists(data_path): os.makedirs(data_path) if load_data: with open('toksearch.toks', 'rb') as f: data = pickle.load(f) if return_data == 'by_signal' and use_dask: if move_data: shutil.move('toksearch.toks', data_path + '/toksearch_' + date_time + '.toks') for ds in namespace['datasets']: try: if move_data: shutil.move(ds + '.tokds', data_path + '/' + ds + '.tokds') data[ds] = xarray.open_dataset(data_path + '/' + ds + '.tokds', chunks={'shot': 10}) except Exception as ex: print("FILE %s MISSING" % (ds + '.tokds')) print(ex) else: if move_data: moved_files = [] for file in data_files: shutil.move(file, data_path + '/toksearch_' + date_time + '.toks') moved_files.append(data_path + '/toksearch_' + date_time + '.toks') data = moved_files else: data = data_files if move_data: os.chdir(data_path) shutil.rmtree(workDir + run_folder) return data, mem_usage, data_size, ret_code
[docs]class OMFITtoksearch(SortedDict): ''' This class is used to query from database through tokesearch API :param serverPicker: (string)A string designating the server to create the toksearch query on. :param shots: A list of shot numbers (ints) to be fetched :param signals: A dict where each key corresponds to the signal name returned by toksearch, and each entry is a list which corresponds to a signal object fetched by toksearch. The first element of the list is the string corresponding to each signal name, i.e. 'PtDataSignal', the 2nd and 3rd entries are the args (list), and keyword args (dict) respectively. Ex) ['PtDataSignal',['ip'],{}] corresponds to a fetch for PtData 'ip' signal. :param datasets: A dict representing xarray datasets to be created from fetched signals. :param aligners: A dict where the keys are name of the dataset to align and the entries are a corresponding list of Aligners :param functions: A list of functions or executable strings to be executed in the toksearch mapping stage :param where: (string) An evaluatable string (executed in namespace of record) which should return a boolean when the record should be returned by toksearch query. This shall be used when trying to filter out shots by certain criteria. i.e. return False when you wish to filter out a shot, when string evaluates to True. :param keep: A list of strings pertaining to which attributes (signal,dataset etc.) of each record to be returned by toksearch query default: returns all attrs in namespace record :param compute_type: (string) Type of method to be used to run the pipeline. Options: 'serial','spark','ray' compute_type='ray' gives better memory usage and parallelization :param return_data: (string) A string pertaining on how the data fetched by toksearch should be structured. Options: 'by_shot','by_signal'. 'by_shot' will return a dictionary with shots as keys, with each record namespace stored in each entry. 'by_signal' will return a dictionary organized with the union of all record attrs as keys, and a dictionary organized by shot numbers under it. default: 'by_shot' NOTE: When fetching 'by_signal' Datasets will concatinated together over all valid shots. :param warn: (bool) If flag is true, the user will be warned if they are about to pull back more than 50% of their available memory and can respond accordingly. This is a safety precaution when pulling back large datasets that may cause you to run out of memory. (default: True). :param use_dask: (bool) If flag is True then created datasets will be loaded using dask. Loading with dasks reduces the amount of RAM used by saving the data to disk and only loading into memory by chunks. (default: False) :param load_data: (bool) If this flag is False, then data will be transferred to disk under OMFIT current working directory, but the data will not be loaded into memory (RAM) and thus the OMFITtree will not be updated. This is to be used when fetching data too large to fit into memory. (default True). :param **compute_kwargs: keyword arguments to be passed into the toksearch compute functions ''' def __save_kw__(self): return self.args def __init__(self, shots, signals, server='saga', datasets=None, aligners=None, functions=None, where=None, keep=None, compute_type='spark', return_data='by_shot', warn=True, use_dask=False, load_data=True, queue_args={}, **compute_kwargs): SortedDict.__init__(self) self.kw = compute_kwargs self.kw['serverPicker'] = server self.kw['shots'] = shots self.kw['signals'] = signals self.kw['datasets'] = datasets self.kw['aligners'] = aligners self.kw['compute_type'] = compute_type self.kw['functions'] = functions self.kw['where'] = where self.kw['keep'] = keep self.kw['use_dask'] = use_dask self.kw['return_data'] = return_data.lower() self.kw['load_data'] = load_data self.kw['queue_args'] = queue_args.copy() self.kw['warn'] = warn self.args = self.kw.copy() self.kw['mem_requested'] = 60 self.dynaLoad = True
[docs] @dynaLoad def load(self): kw = self.kw.copy() objs_as_dict, memory_usage, serial_mem, ret_code = toksearchQueryRemote(**kw) self.dynaLoad = False if not kw['load_data']: return objs_as_dict if ret_code == 1: # pulling too much data ans = NOTIFY(serial_mem) self.dynaLoad = True elif ret_code == 3 and kw['warn']: # will warn if self.warn == True pull_data = WARN(serial_mem) if pull_data == 'Yes': mem_needed = predictTokUse(serial_mem) kw['mem_requested'] = min(mem_needed, SERVER_MEM[kw['serverPicker']]) kw['warn'] = False objs_as_dict, memory_usage, serial_mem, ret_code = toksearchQueryRemote(**kw) if ret_code == 1: ans = NOTIFY(serial_mem) self.dynaLoad = True else: self.update(objs_as_dict) else: self.dynaLoad = True else: self.update(objs_as_dict)
@dynaLoad def __call__(self, *args, **kw): if args or kw: return OMFITtksValue(self, *args, **kw)
############################################ if '__main__' == __name__: test_classes_main_header() if SERVER['saga']['workDir'] is None: printe(_using_toksearch_outside_framework_warning) filter_func = "not record['errors']" OMFIT['ts']=OMFITtoksearch([133221], {'ip':TKS_PtDataSignal('ip')}, where=filter_func) #successful creation of OMFIT object (no connections tested)