from . import unix_os as os
import sys
import stat
import copy
import shutil
import time
import uuid
import zipfile
import ast
import platform
import warnings
import functools
import builtins
import queue as Queue
import pickle
from io import StringIO
import locale
# Keep track of what classes have been loaded
# This is used internally for identifying what classes are
# used by different OMFIT modules/scripts/workflows/regressions
_loaded_classes = set()
# set default locale to be en_US.UTF-8
for _item in ['LC_ALL', 'LC_CTYPE']:
os.environ[_item] = 'en_US.UTF-8'
# monkey patch open so to enforce UNIX style newline also on windows
[docs]def open(file, mode='r', buffering=-1, encoding=None, errors=None, newline=None, closefd=True, opener=None):
if mode == 'w':
newline = '\n'
return builtins.open(
file=file, mode=mode, buffering=buffering, encoding=encoding, errors=errors, newline=newline, closefd=closefd, opener=opener
)
open.__doc__ = builtins.open.__doc__
if os.name == 'nt':
import wexpect as pexpect
EOF = pexpect.EOF
else:
import pexpect
EOF = pexpect.exceptions.EOF
if 'omfit_tree' in sys.modules:
print('Loading base utility functions...')
[docs]def deprecated(func):
"""This is a decorator which can be used to mark functions
as deprecated. It will result in a warning being emitted
when the function is used."""
@functools.wraps(func)
def new_func(*args, **kwargs):
with warnings.catch_warnings():
warnings.simplefilter('always', DeprecationWarning)
warnings.warn("Call to deprecated function `{}`".format(func.__name__), category=DeprecationWarning)
return func(*args, **kwargs)
return new_func
[docs]def b2s(obj):
if isinstance(obj, bytes):
for error_handling in ['strict', 'replace', 'ignore']:
try:
return obj.decode("utf-8", errors=error_handling)
except Exception:
pass
raise RuntimeError('A bytes object was passed to b2s, but not handled')
import numpy as np
if isinstance(obj, np.ndarray) and (
obj.dtype.name.startswith('bytes')
or (obj.dtype.name.startswith('object') and np.all(map(lambda x: isinstance(x, (bytes, str)), obj.flat)))
):
return np.reshape(np.array(list(map(b2s, obj.flat))), obj.shape)
else:
return obj
# This is the pickle protocol used for long term storage and deployment of OMFIT objects
pickle.OMFIT_PROTOCOL = max([4, pickle.HIGHEST_PROTOCOL])
# Windows defines USERNAME instead of the Unix USER environmental variable
if os.name == 'nt' and 'USER' not in os.environ and 'USERNAME' in os.environ:
os.environ['USER'] = os.environ['USERNAME']
# Debug slow imports
if float(os.environ.get('OMFIT_TIME_IMPORTS', '0')) > 0:
_orig_import = builtins.__import__
def new_import(*args, **kw):
time = _orig_import('time')
t0 = time.time()
tmp = _orig_import(*args, **kw)
if args and (time.time() - t0) > float(os.environ['OMFIT_TIME_IMPORTS']):
sys.__stderr__.write('Import of `%s` has taken time: %3.3f s\n' % (args[0], time.time() - t0))
return tmp
builtins.__import__ = new_import
# remove SSH_ASKPASS environmental varaible to avoid graphical SSH prompt
if 'SSH_ASKPASS' in os.environ:
del os.environ['SSH_ASKPASS']
# set the environmental variable HOME based on HOMEPATH (for windows installations)
os.environ.setdefault('HOME', os.path.expanduser('~'))
# the OMFITsrc variable stores the directory where OMFIT is running from
# note that this is defined from the omfit/omfit_classes folder
OMFITsrc = os.path.abspath(os.path.dirname(__file__) + os.sep + '..')
# the OMFITsettingsDir variable stores the default users settins, bookmarks, open sessions, ...
OMFITsettingsDir = os.environ['HOME'] + os.sep + '.OMFIT'
firstOMFITexecution = False
if not os.path.exists(OMFITsettingsDir):
os.makedirs(OMFITsettingsDir)
firstOMFITexecution = True
os.chmod(OMFITsettingsDir, 0o700)
# The OMFITtmpDir variable stores the temporary directory where working
# directories from multiple instances of OMFIT coexist
# It is important that OMFITtmpDir is defined in a directory that is not shared among multiple nodes of a cluster
_tmp = os.path.abspath(os.environ.get('OMFIT_TMPDIR', os.sep + 'tmp'))
OMFITtmpDir = os.sep.join([_tmp, os.environ['USER'], 'OMFIT'])
if not os.path.exists(OMFITtmpDir):
try:
os.makedirs(OMFITtmpDir)
except OSError:
if os.path.exists(OMFITtmpDir):
pass # to handle simultaneous start of OMFIT sessions
else:
OMFITtmpDir = os.environ['HOME'] + os.sep + 'tmp' + os.sep + 'OMFIT'
if not os.path.exists(OMFITtmpDir):
try:
os.makedirs(OMFITtmpDir)
except OSError:
if os.path.exists(OMFITtmpDir):
pass # to handle simultaneous start of OMFIT sessions
else:
raise
# create an OMFITsessionsDir directory for each workstation
OMFITsessionsDir = OMFITtmpDir + '_local' + os.sep + 'sessions'
if not os.path.exists(OMFITsessionsDir):
try:
os.makedirs(OMFITsessionsDir)
except OSError:
if os.path.exists(OMFITsessionsDir):
pass # to handle simultaneous start of OMFIT sessions
else:
raise
# create an OMFITbinsDir directory for each workstation
OMFITbinsDir = OMFITtmpDir + '_local' + os.sep + 'bins'
if not os.path.exists(OMFITbinsDir):
try:
os.makedirs(OMFITbinsDir)
except OSError:
if os.path.exists(OMFITbinsDir):
pass # to handle simultaneous start of OMFIT sessions
else:
raise
# the OMFITcontrolmastersDir stores the sockets for SSH controlmaster functionality
OMFITcontrolmastersDir = OMFITtmpDir + '_local' + os.sep + 'controlmasters'
if not os.path.exists(OMFITcontrolmastersDir):
try:
os.makedirs(OMFITcontrolmastersDir)
except OSError:
if os.path.exists(OMFITcontrolmastersDir):
pass # to handle simultaneous start of OMFIT sessions
else:
raise
os.chmod(OMFITcontrolmastersDir, 0o700)
[docs]def safe_eval_environment_variable(var, default):
'''
Safely evaluate environmental variable
:param var: string with environmental variable to evaluate
:param default: default value for the environmental variable
'''
try:
return eval(os.environ.get(var, repr(default)))
except Exception:
return os.environ.get(var, repr(default))
# If we are runnnig the whole OMFIT framework, and it's a public installation
# then by default users' Python modules paths are rejected
# Users can set OMFIT_CLEAN_PYTHON_ENVIRONMENT=0 to disable all clearing
# or at least set OMFIT_CLEAN_PYTHON_ENVIRONMENT=1 to disable the warning related to such clearing
if 'omfit_classes.startup_framework' in sys.modules and safe_eval_environment_variable('OMFIT_CLEAN_PYTHON_ENVIRONMENT', True):
_unacceptable_paths = ['/usr/local', os.environ['HOME'] + '/.local', os.environ['HOME'] + '/Library'] + os.environ.get(
'PYTHONPATH', ''
).split(':')
_unacceptable_paths = [_up for _up in _unacceptable_paths if _up]
_invalid_paths = []
for _path in sys.path:
for _up in _unacceptable_paths:
if (
_path.startswith(os.path.abspath(_up))
and _path in sys.path
and _path not in _invalid_paths
and os.path.exists(_path)
and os.path.abspath(_path) != OMFITsrc
and not os.path.abspath(_path).startswith(sys.executable.split('bin')[0])
):
_invalid_paths.append(_path)
_invalid_paths = sorted(_invalid_paths)
if len(_invalid_paths):
if os.path.exists(os.sep.join([OMFITsrc, '..', 'public'])) or 'OMFIT_CLEAN_PYTHON_ENVIRONMENT' in os.environ:
if 'OMFIT_CLEAN_PYTHON_ENVIRONMENT' not in os.environ:
print('=' * 80)
print('Warning: The following user-defined paths have been removed from your Python environment:')
for _path in _invalid_paths:
print(' %s' % _path)
print('To use your original Python environment set OMFIT_CLEAN_PYTHON_ENVIRONMENT=0')
print('=' * 80)
for _path in _invalid_paths:
sys.path.remove(_path)
else:
print('=' * 80)
print('Warning: The following user-defined paths are in your Python environment:')
for _path in _invalid_paths:
print(' %s' % _path)
print('To use a clean Python environment set OMFIT_CLEAN_PYTHON_ENVIRONMENT=1')
print('To suppress this warning message set OMFIT_CLEAN_PYTHON_ENVIRONMENT=0')
print('=' * 80)
# Use OMAS as OMFIT-source git submoule
if "OMAS_ROOT" not in os.environ or not os.path.exists(os.environ['OMAS_ROOT']):
local_omas = os.path.abspath(OMFITsrc + os.sep + '..' + os.sep + "omas")
if local_omas in sys.path:
sys.path.remove(local_omas)
sys.path.insert(0, local_omas)
msg_submodule = 'using git submodule bundled OMAS installation'
if "OMAS_ROOT" not in os.environ:
print(f"$OMAS_ROOT not found: {msg_submodule}")
else:
print(f"$OMAS_ROOT: {os.environ['OMAS_ROOT']} does not exist: {msg_submodule}")
else:
print(f"$OMAS_ROOT: {os.environ['OMAS_ROOT']}")
if os.environ['OMAS_ROOT'] in sys.path:
sys.path.remove(os.environ['OMAS_ROOT'])
sys.path.insert(0, os.environ['OMAS_ROOT'])
local_nimpy = os.path.abspath(OMFITsrc + os.sep + '..' + os.sep + "nimpy")
# Keep track of original environment versions before they get modified within OMFIT
for k in ['PATH', 'LD_LIBRARY_PATH', 'DYLD_LIBRARY_PATH']:
if f'ORIGINAL_{k}' not in os.environ and k in os.environ:
os.environ[f'ORIGINAL_{k}'] = os.environ[k]
# Add directory of python executable to PATH
if os.path.split(sys.executable)[0] not in os.environ['PATH']:
os.environ['PATH'] = os.path.split(sys.executable)[0] + os.path.pathsep + os.environ['PATH']
# warnings
import warnings
[docs]def warning_on_one_line(message, category, filename, lineno, file=None, line=None):
# Suppress DeprecationWarning and FutureWarning for imports to avoid lots of red-text from 3rd party packages
# when running Python 2 with `-3` option to warn about Python 3.x incompatibilities
if category in [DeprecationWarning, FutureWarning] and OMFITsrc not in filename and OMFITtmpDir not in filename:
return ''
message = str(message)
where = os.path.split(filename)[1] + '@' + str(lineno)
text = []
for k, line in enumerate(message.strip('\n').split('\n')):
if k == 0:
text.append(category.__name__ + ': ' + line + '\n')
else:
text.append(' ' * (len(category.__name__) + 2) + line + '\n')
text = ''.join(text).strip('\n')
# do not let Python 2 int division slip
if 'classic int division' in text or "'U' mode is deprecated" in text:
raise Exception(text)
return text + ' (' + where + ')\n'
warnings.formatwarning = warning_on_one_line
warnings.filterwarnings('always', message='classic int division')
if os.environ['USER'] in ['meneghini', 'smithsp', 'eldond']:
warnings.filterwarnings('always', category=DeprecationWarning)
warnings.filterwarnings('always', category=FutureWarning)
warnings.filterwarnings('ignore', message='invalid escape sequence')
if os.environ['USER'] in [
'meneghini',
'smithsp',
'eldond',
'bpatel2',
'logannc',
'thomek',
'avdeevag',
'jmcclena',
'slendebroekt',
'haskeysr',
]:
warnings.filterwarnings('error', category=DeprecationWarning, message='Using a DataArray object to construct a variable is ambiguous.*')
else:
warnings.filterwarnings(
'always', category=DeprecationWarning, message='Using a DataArray object to construct a variable is ambiguous.*'
)
if os.environ['USER'] in ['eldond']:
behavior = 'always' # Behavior for not serious but not 100% trivial things; choose 'ignore', 'once', or 'always'
warnings.filterwarnings('error') # TODO: take this out when done?
warnings.filterwarnings('ignore', category=ResourceWarning)
warnings.filterwarnings('ignore', message='OMFIT is unable to create script backup copies') # I don't care
warnings.filterwarnings('ignore', category=UserWarning, message='.*Font family .* not found.*') # Fallback is fine
warnings.filterwarnings('ignore', category=UserWarning, message='No `boutdata`.*')
warnings.filterwarnings(behavior, category=UserWarning, message='Tight layout not applied.*') # Not serious
warnings.filterwarnings(behavior, category=DeprecationWarning, message='the imp module is deprecated in favour of importlib.*')
warnings.filterwarnings(behavior, category=DeprecationWarning, message='PILLOW_VERSION is deprecated.*')
# This is not serious, but it is annoying
from numpy import VisibleDeprecationWarning
warnings.filterwarnings('once', category=VisibleDeprecationWarning)
# From importing old? xarray versions. Can't be solved but by changing/upgrading xarray, or ignoring.
warnings.filterwarnings("ignore", category=DeprecationWarning, message='Using or importing the ABCs*') # xarray
# TODO: update OMFIT's xarray requirement to a newer version
warnings.filterwarnings(behavior, category=SyntaxWarning) # Needed for MDSplus with py3.8: "is not" with literal
# These should never be elevated to errors:
warnings.filterwarnings('always', category=UserWarning, message='Unable to import omfit_plot.*')
# base imports
import ast
import socket
import re
import subprocess
import distutils
import time
import platform
import difflib
import tempfile
import datetime
import inspect
import functools
import glob
from collections import OrderedDict
from pprint import pprint
from omfit_classes.exceptions_omfit import *
special1 = []
# store file location for later OMFITobject files garbage collection
_allOMFITobjects = {}
# The content of this file is loaded by OMFITsetup and as such
# it should not depend on the add-on Python package that OMFIT
# requires (Numpy, Matplotlib, Tk, ...)
# ---------------------
# OMFITaux: dictionary is used to keep track of auxiliary informations of OMFIT
# ---------------------
class _OMFITauxiliary(dict):
def __new__(cls, *args, **kwargs):
instance = super().__new__(cls, *args, **kwargs)
instance.defaults = {}
return instance
def __init__(self):
self.defaults['lastUserError'] = ['']
self.defaults['lastReportedUserError'] = ['']
self.defaults['lastBrowsedDirectory'] = ''
self.defaults['lastBrowsed'] = {}
self.defaults['GUI'] = None
self.defaults['rootGUI'] = None
self.defaults['treeGUI'] = None
self.defaults['console'] = None
self.defaults['virtualKeys'] = False
self.defaults['hardLinks'] = False
self.defaults['quickPlot'] = {}
self.defaults['prun_process'] = []
self.defaults['prun_nprocs'] = []
self.defaults['pythonRunWindows'] = []
self.defaults['haltWindow'] = None
self.defaults['MDSserverReachable'] = {}
self.defaults['RDBserverReachable'] = {}
self.defaults['batch_js'] = {}
self.defaults['sysinfo'] = {}
self.defaults['sshTunnel'] = {}
self.defaults['lastActivity'] = time.time()
self.defaults['noCopyToCWD'] = False
self.defaults['lastRunModule'] = ''
self.defaults['moduleSkeletonCache'] = None
self.defaults['debug'] = 0
self.defaults['dynaLoad_switch'] = True
self.update(copy.deepcopy(self.defaults))
def __getitem__(self, key):
# if 'lastBrowsedDirectory' does not exisits recurse directory backwards to find valid directory root
if key == 'lastBrowsedDirectory' and key in self:
tmp = super().__getitem__(key)
tmp = os.path.abspath(os.path.expandvars(os.path.expanduser(tmp)))
for k in range(len(tmp)):
if os.path.exists(tmp):
return tmp
tmp = os.path.split(tmp)[0]
return os.environ['HOME']
return super().__getitem__(key)
OMFITaux = _OMFITauxiliary()
[docs]def hasattr_no_dynaLoad(object, attribute):
"""
same as `hasattr` function but does not trigger dynamic loading
"""
try:
dynaLoadBkp = OMFITaux['dynaLoad_switch']
OMFITaux['dynaLoad_switch'] = False
return hasattr(object, attribute)
finally:
OMFITaux['dynaLoad_switch'] = dynaLoadBkp
# ---------------------
# Decorators @_available_to_user are used to define which functions should appear in the OMFIT documentation
# ---------------------
OMFITaux['OMFITmath_functions'] = []
def _available_to_user_math(f):
OMFITaux['OMFITmath_functions'].append(f)
return f
OMFITaux['OMFITutil_functions'] = []
def _available_to_user_util(f):
OMFITaux['OMFITutil_functions'].append(f)
return f
OMFITaux['OMFITplot_functions'] = []
def _available_to_user_plot(f):
OMFITaux['OMFITplot_functions'].append(f)
return f
# ---------------------
# evaluate expressions
# ---------------------
[docs]def isinstance_str(inv, cls):
"""
checks if an object is of a certain type by looking at the class name (not the class object)
This is useful to circumvent the need to load import Python modules.
:param inv: object of which to check the class
:param cls: string or list of string with the name of the class(es) to be checked
:return: True/False
"""
if isinstance(cls, str):
cls = [cls]
if hasattr(inv, '__class__') and hasattr(inv.__class__, '__name__') and inv.__class__.__name__ in cls:
return True
return False
[docs]def evalExpr(inv):
"""
Return the object that dynamic expressions return when evaluated
This allows OMFITexpression('None') is None to work as one would expect.
Epxressions that are invalid they will raise an OMFITexception when evaluated
:param inv: input object
:return:
* If inv was a dynamic expression, returns the object that dynamic expressions return when evaluated
* Else returns the input object
"""
if isinstance_str(inv, 'OMFITexpressionError'):
raise OMFITexception('Invalid expression')
elif isinstance_str(inv, ['OMFITexpression', 'OMFITiterableExpression']) and hasattr(inv, '_value_'):
tmp = inv._value_()
if isinstance_str(tmp, 'OMFITexpressionError'):
raise OMFITexception('Invalid expression:\n' + inv.error)
return tmp
else:
return inv
[docs]def freezeExpr(me, remove_OMFITexpressionError=False):
"""
Traverse a dictionary and evaluate OMFIT dynamic expressions in it
NOTE: This function operates in place
:param me: input dictionary
:param remove_OMFITexpressionError: remove entries that evaluate as OMFITexpressionError
:return: updated dictionary
"""
for kid in list(me.keys()):
if isinstance_str(me[kid], ['OMFITexpression', 'OMFITiterableExpression']):
try:
me[kid] = evalExpr(me[kid])
except Exception:
del me[kid]
continue
elif isinstance_str(me[kid], 'OMFITexpressionError'):
if remove_OMFITexpressionError:
del me[kid]
continue
else:
raise OMFITexception('Invalid expression:\n' + me[kid].error)
if isinstance(me[kid], dict):
freezeExpr(me[kid], remove_OMFITexpressionError=remove_OMFITexpressionError)
# ---------------------
# checktypes
# ---------------------
[docs]def is_none(inv):
"""
This is a convenience function to evaluate if a object or an expression is None
Use of this function is preferred over testing if an expression is None
by using the == function. This is because np arrays evaluate == on a per item base
:param inv: input object
:return: True/False
"""
if inv is None:
return True
elif isinstance_str(inv, ['OMFITexpression', 'OMFITiterableExpression']):
return evalExpr(inv) is None
return False
[docs]def is_bool(value):
"""
Convenience function check if value is boolean
:param value: value to check
:return: True/False
"""
return value in [True, False]
[docs]def is_int(value):
"""
Convenience function check if value is integer
:param value: value to check
:return: True/False
"""
import numpy as np
return isinstance(value, (int, np.integer))
[docs]def is_float(value):
"""
Convenience function check if value is float
:param value: value to check
:return: True/False
"""
import numpy as np
return isinstance(value, (float, np.floating))
[docs]def is_numeric(value):
"""
Convenience function check if value is numeric
:param value: value to check
:return: True/False
"""
try:
0 + value
return True
except TypeError:
return False
[docs]def is_number_string(my_string):
"""
Determines whether a string may be parsed as a number
:param my_string: string
:return: bool
"""
try:
float(my_string)
except ValueError:
return False
else:
return True
[docs]def is_alphanumeric(value):
"""
Convenience function check if value is alphanumeric
:param value: value to check
:return: True/False
"""
if isinstance(value, str):
return True
try:
0 + value
return True
except TypeError:
return False
[docs]def is_array(value):
"""
Convenience function check if value is list/tuple/array
:param value: value to check
:return: True/False
"""
import numpy as np
return isinstance(value, (list, tuple, np.ndarray))
[docs]def is_string(value):
"""
Convenience function check if value is string
:param value: value to check
:return: True/False
"""
return isinstance(value, str)
[docs]def is_email(value):
if isinstance(value, str):
return re.findall('i?[\w\-\.]+@[\w\-\.]+\.[\w\-\.]+', value)
[docs]def is_int_array(val):
"""
Convenience function check if value is a list/tuple/array of integers
:param value: value to check
:return: True/False
"""
import numpy as np
if is_array(val):
try:
tmp = np.atleast_1d(val).astype(int)
except TypeError:
return False
if np.all(np.atleast_1d(val) == tmp):
return True
return False
# ---------------------
# Printing
# ---------------------
streams_q = Queue.Queue()
[docs]class qRedirector(object):
'''A class for redirecting stdout and stderr to this Text widget'''
def __init__(self, tag='STDOUT'):
self.tag = tag
[docs] def write(self, string):
streams_q.put((string, self.tag), block=False, timeout=0)
class _Streams(dict):
tags = {
'STDOUT': ('black', 'RESET'),
'STDERR': ('red3', 'RED'),
'DEBUG': ('gold4', 'YELLOW'),
'PROGRAM_OUT': ('blue', 'BLUE'),
'PROGRAM_ERR': ('purple', 'MAGENTA'),
'INFO': ('forest green', 'GREEN'),
'WARNING': ('DarkOrange2', 'YELLOW'),
'HIST': ('dark slate gray', 'CYAN'),
'HELP': ('PaleGreen4', 'CYAN'),
}
def __new__(cls):
return dict.__new__(cls)
def __init__(self):
self.stderr = sys.stderr
self.stdout = sys.stdout
self.setDefaults()
self.bkpStack = []
def setDefaults(self):
for k in self.tags:
if 'ERR' in k:
self[k] = self.stderr
else:
self[k] = self.stdout
def backup(self):
self.bkpStack.append(list(self.items()))
def restore(self):
for k, v in self.bkpStack.pop():
self[k] = v
sys.stdout = self['STDOUT']
sys.stderr = self['STDERR']
_streams = _Streams()
[docs]class console_color:
[docs] @staticmethod
def BLACK(x=''):
return '\033[30m' + str(x) + '\033[0m'
[docs] @staticmethod
def RED(x=''):
return '\033[31m' + str(x) + '\033[0m'
[docs] @staticmethod
def GREEN(x=''):
return '\033[32m' + str(x) + '\033[0m'
[docs] @staticmethod
def YELLOW(x=''):
return '\033[33m' + str(x) + '\033[0m'
[docs] @staticmethod
def BLUE(x=''):
return '\033[34m' + str(x) + '\033[0m'
[docs] @staticmethod
def MAGENTA(x=''):
return '\033[35m' + str(x) + '\033[0m'
[docs] @staticmethod
def CYAN(x=''):
return '\033[36m' + str(x) + '\033[0m'
[docs] @staticmethod
def WHITE(x=''):
return '\033[37m' + str(x) + '\033[0m'
[docs] @staticmethod
def UNDERLINE(x=''):
return '\033[4m' + str(x) + '\033[0m'
[docs] @staticmethod
def RESET(x=''):
return '\033[0m' + str(x) + '\033[0m'
[docs]def tag_print(*objects, **kw):
"""
Works like the print function, but used to print to GUI (if GUI is available).
The coloring of the GUI print is determined by the `tag` parameter.
:param \*objects: string/objects to be printed
:param sep: separator (default: ' ')
:param sep: new line character (default: '\\\\n')
:param tag: one of the following:
* 'STDOUT'
* 'STDERR'
* 'DEBUG'
* 'PROGRAM_OUT'
* 'PROGRAM_ERR'
* 'INFO'
* 'WARNING'
* 'HIST'
* 'HELP'
"""
tag = kw.get('tag', '')
tag_override = os.environ.get('OMFIT_TAG_PRINT_STREAM_OVERRIDE', '')
if tag_override in _streams:
tag = tag_override
if sys.stdout is not sys.__stdout__ and tag in _streams: # <--- check if the stdout is redirected and tag is recognized
return print(*objects, sep=kw.pop('sep', ' '), end=kw.pop('end', '\n'), file=_streams[tag])
else:
file = sys.__stderr__ if 'ERR' in tag else sys.__stdout__
# colored terminal
if sys.stdout.isatty():
tmp = StringIO()
print(*objects, sep=kw.pop('sep', ' '), end=kw.pop('end', '\n'), file=tmp)
return print(getattr(console_color, _streams.tags[tag][1])(tmp.getvalue()), sep='', end='', file=file)
# standard terminal
return print(*objects, sep=kw.pop('sep', ' '), end=kw.pop('end', '\n'), file=file)
[docs]def printi(*objects, **kw):
"""
Function to print with `INFO` style
:param \*objects: what to print
:param \**kw: keywords passed to the `print` function
:return: return from `print` function
"""
kw['tag'] = 'INFO'
if int(os.environ.get('OMFIT_VISUAL_CUES', '0')):
objects = ['^' + '\n^'.join(str(x).splitlines()) for x in objects]
return tag_print(*objects, **kw)
[docs]def pprinti(*objects, **kw):
"""
Function to pretty-print with `INFO` style
:param \*objects: what to print
:param \**kw: keywords passed to the `print` function
:return: return from `pprint` function
"""
kw['stream'] = _streams['INFO']
return pprint(*objects, **kw)
[docs]def printe(*objects, **kw):
"""
Function to print with `ERROR` style
:param \*objects: what to print
:param \**kw: keywords passed to the `print` function
:return: return from `print` function
"""
kw['tag'] = 'STDERR'
if int(os.environ.get('OMFIT_VISUAL_CUES', '0')):
objects = ['!' + '\n!'.join(str(x).splitlines()) for x in objects]
return tag_print(*objects, **kw)
[docs]def pprinte(*objects, **kw):
"""
Function to pretty-print with `STDERR` style
:param \*objects: what to print
:param \**kw: keywords passed to the `pprint` function
:return: return from `pprint` function
"""
kw['stream'] = _streams['STDERR']
return pprint(*objects, **kw)
[docs]def printw(*objects, **kw):
"""
Function to print with `WARNING` style
:param \*objects: what to print
:param \**kw: keywords passed to the `print` function
:return: return from `print` function
"""
kw['tag'] = 'WARNING'
if int(os.environ.get('OMFIT_VISUAL_CUES', '0')):
objects = ['@' + '\n@'.join(str(x).splitlines()) for x in objects]
return tag_print(*objects, **kw)
[docs]def pprintw(*objects, **kw):
"""
Function to pretty-print with `WARNING` style
:param \*objects: what to print
:param \**kw: keywords passed to the `pprint` function
:return: return from `pprint` function
"""
kw['stream'] = _streams['WARNING']
return pprint(*objects, **kw)
OMFITaux['debug_logs'] = _debug_logs = {}
[docs]def printd(*objects, **kw):
"""
Function to print with `DEBUG` style.
Printing is done based on environmental variable OMFIT_DEBUG
which can either be a string with an integer (to indicating a debug level)
or a string with a debug topic as defined in OMFITaux['debug_logs']
:param \*objects: what to print
:param level: minimum value of debug for which printing will occur
:param \**kw: keywords passed to the `print` function
:return: return from `print` function
"""
# log debug history
debug_topic = kw.pop('topic', 'uncategorized')
if debug_topic not in _debug_logs:
_debug_logs[debug_topic] = []
_tmp_stream = StringIO()
if int(os.environ.get('OMFIT_VISUAL_CUES', '0')):
objects = ['$' + '\n$'.join(str(x).splitlines()) for x in objects]
print(*objects, sep=kw.get('sep', ' '), end=kw.get('end', '\n'), file=_tmp_stream)
_debug_logs[debug_topic].append(str(time.time()) + ": " + _tmp_stream.getvalue())
_debug_logs[debug_topic] = _debug_logs[debug_topic][-100:]
kw['tag'] = 'DEBUG'
doPrint = False
try:
# print by level
debug_level = int(os.environ.get('OMFIT_DEBUG', '0')) # this will raise an ValueError if OMFIT_DEBUG is a string
if debug_level >= kw.pop('level', 1) or debug_level < 0:
doPrint = True
except ValueError:
# print by topic
if os.environ.get('OMFIT_DEBUG', '') == debug_topic:
doPrint = True
finally:
if doPrint:
terminal_debug = safe_eval_environment_variable('OMFIT_TERMINAL_DEBUG', False)
if terminal_debug > 0:
printt(*objects, **kw)
if (terminal_debug % 2) == 0: # Even numbers print to the OMFIT console
return tag_print(*objects, **kw)
[docs]def printt(*objects, **kw):
"""
Function to force print to terminal instead of GUI
:param \*objects: what to print
:param err: print to standard error
:param \**kw: keywords passed to the `print` function
:return: return from `print` function
"""
if int(os.environ.get('OMFIT_VISUAL_CUES', '0')):
objects = ['%' + '\n%'.join(str(x).splitlines()) for x in objects]
try:
file = sys.__stderr__ if kw.pop('err', False) else sys.__stdout__
return print(*objects, sep=kw.pop('sep', ' '), end=kw.pop('end', '\n'), file=file)
except IOError:
pass
[docs]class quiet_environment(object):
"""
This environment quiets all output (stdout and stderr)
>> print('hello A')
>> with quiet_environment() as f:
>> print('hello B')
>> print('hello C')
>> tmp=f.stdout
>> print('hello D')
>> print(tmp)
"""
def __enter__(self):
self.streams_bkp = {}
self.streams_bkp.update(_streams)
sys.stdout = StringIO()
sys.stderr = StringIO()
for k in _streams:
if 'ERR' in k:
_streams[k] = sys.stderr
else:
_streams[k] = sys.stdout
return self
@property
def stdout(self):
return sys.stdout.getvalue()
@property
def stderr(self):
return sys.stderr.getvalue()
def __exit__(self, type, value, traceback):
_streams.update(self.streams_bkp)
sys.stdout = _streams['STDOUT']
sys.stderr = _streams['STDERR']
[docs]def size_of_dir(folder):
"""
function returns the folder size as a number
:param folder: directory path
:return: size in bytes
"""
size = 0
if isinstance(folder, str) and os.path.exists(folder) and os.path.isdir(folder):
for path, dirs, files in os.walk(folder):
for f in files:
try:
size += os.path.getsize(os.path.join(path, f))
except OSError:
pass
return size
[docs]def sizeof_fmt(filename, separator='', format=None, unit=None):
"""
function returns a string with nicely formatted filesize
:param filename: string with path to the file or integer representing size in bytes
:param separator: string between file size and units
:param format: default None, format for the number
:param unit: default None, unit of measure
:return: string with file size
"""
if unit in ['b', 'B']:
unit = 'bytes'
def _size(num):
for u in ['bytes', 'kB', 'MB', 'GB']:
if u == unit or (unit is None and num < 1024.0):
break
num /= 1024.0
else:
u = 'TB'
if isinstance(format, str):
f = format
elif isinstance(format, dict) and u in list(format.keys()):
f = format[u]
else:
f = '%3.1f'
return f % num + separator + u
if is_int(filename) and filename >= 0:
return _size(filename)
if isinstance(filename, str) and os.path.exists(filename):
return _size(os.path.getsize(filename))
else:
return 'N/A'
[docs]def encode_ascii_ignore(string):
"""
This function provides fail-proof conversion of str to ascii
Note: not ASCII characters are ignored
:param string: str string
:return: ascii scring
"""
import numpy as np
string = np.array([ord(x) for x in string])
return ''.join([chr(x) for x in string[np.where(string < 128)[0]]])
[docs]def is_binary_file(filename):
"""
Detect if a file is binary or ASCII
:param filename: path to the file
:return: True if binary file else False
"""
textchars = bytearray({7, 8, 9, 10, 12, 13, 27} | set(range(0x20, 0x100)) - {0x7F})
with open(filename, 'rb') as f:
return bool(f.read().translate(None, textchars))
[docs]def wrap(text, width):
r"""
A word-wrap function that preserves existing line breaks
and most spaces in the text. Expects that existing line
breaks are posix newlines (\n).
:param text: text to be wrapped
:param width: maximum line width
"""
from functools import reduce
return reduce(
lambda line, word, width=width: '%s%s%s'
% (line, ' \n'[(len(line) - line.rfind('\n') - 1 + len(word.split('\n', 1)[0]) >= width)], word),
text.split(' '),
)
[docs]@_available_to_user_util
def ascii_progress_bar(
n,
a=0,
b=100,
mess='',
newline=False,
clean=False,
width=20,
fill='#',
void='-',
style=' [{sfill}{svoid}] {perc:3.2f}% {mess}',
tag='INFO',
quiet=None,
):
"""
Displays an ASCII progress bar
:param n: current value OR iterable
:param a: default 0, start value (ignored if n is an iterable)
:param b: default 100, end value (ignored if n is an iterable)
:param mess: default blank, message to be displayed
:param newline: default False, use newlines rather than carriage returns
:param clean: default False, clean out progress bar when end is reached
:param width: default 20, width in characters of the progress bar
:param fill: default '#', filled progress bar character
:param void: default '-', empty progress bar character
:param style: default ' [{sfill}{svoid}] {perc:3.2f}% {mess}' full format string
:param tag: default 'HELP', see tag_print()
:param quiet: do not print; default is quiet=None; if quiet is None, attempt to pick up value from
bool(eval(os.environ['OMFIT_PROGRESS_BAR_QUIET']))
Example::
for n in ascii_progress_bar(np.linspace(12.34, 56.78, 4), mess=lambda x:f'{x:3.3}'):
OMFITx.Refresh() # will slow things down
"""
def ascii_progress_bar_base(n, a, b, mess, newline, clean, width, fill, void, style, tag, quiet, dtime=0):
# handle manual iteration
if not (a < b and a <= n and n <= b):
return
perc = 100.0 * (n + 1 - a) / (b + 1 - a)
nfill = int(round(width * perc / 100.0))
sfill = fill * nfill
svoid = void * (width - nfill)
if newline:
buff = ''
else:
buff = '\r'
buff += style.format(**locals())
if newline or (n == b and not clean):
buff += '\n'
elif n == b and clean:
buff = '\r' + ' ' * len(buff) + '\r'
if not quiet:
tag_print(buff, tag=tag, end='')
return n
def ascii_progress_bar_iterable(n, **kw):
data = list(n) # we need to know the length of the data in case `n` was a generator
kw['a'] = 0
kw['b'] = len(data) - 1
if callable(kw['mess']):
messages = list(map(kw['mess'], data))
else:
messages = kw['mess']
if isinstance(kw['mess'], str):
messages = [kw['mess']] * len(data)
elif isinstance(kw['mess'], np.ndarray):
messages = list(messages)
if not kw['newline']:
messages = copy.copy(messages)
if not len(messages):
messages.append('')
messages[-1] = ''
for n, d in enumerate(data):
kw['mess'] = messages[n]
if '{dtime' in style:
t0 = time.time()
yield d
kw['dtime'] = time.time() - t0
ascii_progress_bar_base(n, **kw)
else:
ascii_progress_bar_base(n, **kw)
yield d
return
import numpy as np
import os
if quiet is None:
quiet = bool(eval(os.environ.get('OMFIT_PROGRESS_BAR_QUIET', '0')))
if np.iterable(n):
return ascii_progress_bar_iterable(
n, a=a, b=b, mess=mess, newline=newline, clean=clean, width=width, fill=fill, void=void, style=style, tag=tag, quiet=quiet
)
else:
return ascii_progress_bar_base(
n, a=a, b=b, mess=mess, newline=newline, clean=clean, width=width, fill=fill, void=void, style=style, tag=tag, quiet=quiet
)
[docs]def load_dynamic(module, path):
"""
Load and initialize a module implemented as a dynamically loadable shared library and return its module object.
If the module was already initialized, it will be initialized again.
Re-initialization involves copying the __dict__ attribute of the cached instance of the module over the value used in the module cached in sys.modules.
Note: using shared libraries is highly system dependent, and not all systems support it.
:param name: name used to construct the name of the initialization function: an external C function called initname() in the shared library is called
:param pathname: path to the shared library.
"""
import importlib
spec = importlib.util.spec_from_file_location(module, path)
mod = importlib.util.module_from_spec(spec)
spec.loader.exec_module(mod)
return mod
# ---------------------
# processes
# ---------------------
[docs]def is_running(process):
"""
This function retuns True or False depending on whether a process is running or not
This relies on grep of the `ps axw` command
:param process: string with process name or process ID
:return: False if process is not running, otherwise line of `ps` command
"""
process = str(process)
s = subprocess.Popen([system_executable('ps'), "axw"], stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
std_out, std_err = list(map(b2s, s.communicate()))
for item in std_out.split('\n'):
if re.search(' ' + process + ' ', item) or re.search(r'^' + process + ' ', item):
return item
return False
[docs]def kill_subprocesses(process=None):
"""
kill all of the sub-processes of a given process
:param process: process of which sub-processes will be killed
"""
if process is None:
process = os.getpid()
process = str(process)
if int(process) == 1:
raise Exception('Cowardly refusing to kill INIT process')
if os.name == 'nt': # Windows
print('Killing ', process, os.getpid())
os.system(f'taskkill /PID {process} /T /F')
else:
s = subprocess.Popen(
[system_executable('pgrep'), "-P", process], stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE
)
std_out, std_err = list(map(b2s, s.communicate()))
output = [_f for _f in [x.strip() for x in std_out.split('\n')] if _f]
printi('killing sub-processes: ' + ', '.join(output))
for item in output:
ps = is_running(item)
if ps:
printi(ps)
os.kill(int(item), 9)
[docs]def memuse(as_bytes=False, **kw):
"""
return memory usage by current process
:param as_bytes: return memory as number of bytes
:param \**kw: keywords to be passed to `sizeof_fmt()`
:return: formatted string with usage expressed in kB, MB, GB, TB
"""
try:
import psutil
process = psutil.Process(os.getpid())
memory = int(process.memory_info().rss)
except ImportError:
memory = -1
if as_bytes:
return memory
memuse = sizeof_fmt(memory, **kw)
return memuse
# maximum length of line in shell
def _arg_max():
s = subprocess.Popen([system_executable('getconf'), 'ARG_MAX'], stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
std_out, std_err = list(map(b2s, s.communicate()))
ARG_MAX = int(std_out.strip())
return ARG_MAX
try:
ARG_MAX = _arg_max()
except Exception:
# http://www.in-ulm.de/~mascheck/various/argmax/
ARG_MAX = 4096
# hardwire to 4096
ARG_MAX = 4096
_executables_cache = {}
[docs]def system_executable(executable, return_details=False, force_path=None):
"""
function that returns the full path of the executable
:param executable: executable to return the full path of
:param return_details: return additional info for some commands (e.g. rsync)
:return: string with full path of the executable
"""
if force_path is not None:
return force_path
if executable not in _executables_cache or os.sep not in _executables_cache[executable]:
_executables_cache[executable] = {}
if os.name == 'nt':
cygwin_executable = os.path.join(os.environ['CYGWIN'], 'bin', executable + '.exe')
if os.path.exists(cygwin_executable):
_executables_cache[executable]['path'] = cygwin_executable
else:
_executables_cache[executable]['path'] = executable
else:
import distutils.spawn
_executables_cache[executable]['path'] = distutils.spawn.find_executable(executable)
if executable == 'rsync' and _executables_cache[executable]['path']:
kwarg = dict(stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE, close_fds=True)
if os.name == 'nt':
kwarg['creationflags'] = subprocess.CREATE_NEW_PROCESS_GROUP
else:
kwarg['preexec_fn'] = os.setsid
s = subprocess.Popen([_executables_cache['rsync']['path'], '--version'], shell=False, **kwarg)
std_out, std_err = list(map(b2s, s.communicate()))
_executables_cache['rsync']['version'] = re.findall('version [0-9.]+', std_out.strip().split('\n')[0])[0].split()[1]
if float('.'.join(_executables_cache['rsync']['version'].split('.')[:2])) >= 3.1:
_executables_cache['rsync']['progress'] = '--info=progress2'
else:
_executables_cache['rsync']['progress'] = '--progress'
if return_details:
return _executables_cache[executable]
else:
return _executables_cache[executable]['path']
[docs]def python_environment():
"""
returns string with module names that have __version__ attribute (similar to what `pip freeze` would do)
"""
python_environment = {}
for item in sorted(sys.modules.keys()):
if not item.startswith('_'):
try:
has_version_info = hasattr(sys.modules[item], '__version__')
except Exception:
continue
else:
if has_version_info:
version = str(sys.modules[item].__version__)
version = re.sub('Revision:*', '', version)
version = re.sub('\$', '', version)
version = version.strip()
if version:
python_environment[item] = version
# remove subpackage version that matches parent package version
for item in list(python_environment.keys()):
if '.' in item and item.split('.')[0] in python_environment:
if python_environment[item] == python_environment[item.split('.')[0]]:
del python_environment[item]
# sorted list of lists
return [[item, python_environment[item]] for item in sorted(list(python_environment.keys()), key=lambda x: x.lower())]
# ---------------------
# identification
# ---------------------
@_available_to_user_util
def is_localhost(server, looseCheck=True):
"""
Checks if `server` is the localhost.
If `server` is None or an empty string then returns True.
`server` can be a string in the format username@server:port
:param server: server string
:param looseCheck: loosely check if `server` string is contained in the localhost names
:return: True/False
"""
import numpy as np
if server == None or not len(server) or server in ['localhost', '127.0.0.1']:
return 'localhost'
server = str(server)
server = parse_server(server)[2]
if server == 'localhost':
return 'localhost'
if re.match('127\.0+\.0+\.0*1', server):
return 'localhost'
if not len(server):
return 'localhost'
else:
if set(splitDomain(server)).intersection(_localhost_names):
return 'localhost'
elif looseCheck:
for servlet in splitDomain(server):
if np.any([item.lower().startswith(servlet.lower()) for item in _localhost_names if item]):
return 'localhost'
for item in _localhost_names:
if not item:
continue
if np.any([servlet.lower().startswith(item.lower()) for servlet in splitDomain(server)]):
return 'localhost'
return False
@_available_to_user_util
def is_server(serverA, serverB):
"""
Checks if `serverA` and `serverB` are the same server
:param serverA: server string
:param serverB: server string
:return: True/False
"""
if is_localhost(serverA) and is_localhost(serverB):
return True
elif str(serverB) in str(serverA) or str(serverA) in str(serverB):
return True
else:
return False
[docs]@_available_to_user_util
def is_institution(instA, instB):
instA = tolist(evalExpr(instA), [None])
instB = tolist(evalExpr(instB), [None])
for A in instA:
for B in instB:
if A.upper() == B.upper():
return True
return False
# ---------------------
# Network
# ---------------------
[docs]def splitDomain(namesIn):
if isinstance(namesIn, set):
namesIn = list(namesIn)
elif not isinstance(namesIn, list):
namesIn = [namesIn]
names = copy.copy(namesIn)
for item in names:
if isinstance(item, str) and re.findall('[a-zA-Z]', item):
tmp = item.split('.')
if len(tmp) > 1:
names.append(tmp[0])
return names
def _localhost():
os.environ.setdefault('HOST', '')
def returnHostname():
if 'NERSC_HOST' in os.environ:
return os.environ['NERSC_HOST']
try:
# issues with .get local hostname
socket.getaddrinfo(socket.gethostname(), 22)
return socket.gethostname()
except Exception:
return '127.0.0.1'
def find_synonims(myself):
etc_hosts = []
if os.path.exists('/etc/hosts'):
with open('/etc/hosts', 'r') as f:
for line in f.readlines():
tmp = line.strip().split('#')[0].split()
if len(tmp):
etc_hosts.append(tmp)
synonims = set(splitDomain(myself))
for name in list(synonims):
for equivalent_names in etc_hosts:
equivalent_names = splitDomain(equivalent_names)
if name in equivalent_names:
synonims = synonims.union(set(equivalent_names))
return synonims
myself = [returnHostname(), '127.0.0.1', 'localhost', '']
myself.extend([k[4][0] for k in socket.getaddrinfo(myself[0], 22)])
if 'HOST' in os.environ and os.environ['HOST']:
myself.append(os.environ['HOST'])
# running find_sysnonims twice allows to deduce more synonims based on the ip addresses in /etc/hosts
myself = find_synonims(find_synonims(myself))
return myself
_localhost_names = _localhost()
[docs]def ping(host, timeout=2):
"""
Function for pinging remote server
:param host: host name to be pinged
:param timeout: timeout time
:return: boolean indicating if ping was successful
"""
# -W option is waiting time in seconds
# -c packets count
with open(os.devnull, 'w') as nul_f:
child = subprocess.Popen([system_executable('ping'), "-c1", '-W' + str(timeout), host], stderr=nul_f, stdout=subprocess.PIPE)
while child.poll() is None:
pass
return not bool(child.poll())
[docs]def parse_server(server, default_username=os.environ['USER']):
"""
parses strings in the form of username:password@server.somewhere.com:port
:param server: input string
:param default_username: what username to return if username@server is not specified
:return: tuple of four strings with user,password,server,port
"""
server = evalExpr(server)
if server is None:
server = ''
if '@' in server:
us = server.split('@')
username = '@'.join(us[:-1])
server = us[-1]
else:
username = default_username
if ':' in username:
username, password = username.split(':', 1)
else:
password = ''
if ':' in server:
server, port = server.split(':')
port = str(int(port)) # to make sure the port is an integer
else:
port = '22'
return username.strip(), password.strip(), server.strip(), port.strip()
[docs]def assemble_server(username='', password='', server='', port=''):
"""
returns assembled server string in the form username:password@server.somewhere.com:port
:param username: username
:param password: password
:param server: server
:param port: port
:returns: assembled server string
"""
out = ''
if username:
out = username
if password:
out = out + ':' + password
out = out + '@'
if server:
out = out + server
else:
out = out + 'localhost'
if port:
out = out + ':' + str(port)
return out
[docs]def test_connection(username, server, port, timeout=1.0, ntries=1, ssh_path=None):
"""
Function to test if connection is available on a given host and port number
:param server: server to connect to
:param port: TCP port to connect to
:param timeout: wait for `timeout` seconds for connection to become available
:param ntries: number of retries
* ntries==1 means: check if connection is up
* ntries>1 means: wait for connection to come up
:return: boolean indicating if connection is available
"""
connected = None
# This is needed because apparently OMFIT['MainSettings'] doesn't exist outside the framework
try:
is_local = is_localhost(server)
except KeyError:
is_local = False
# if there is a controlmaster present check if it is live
if username is not None and not is_local:
connected = controlmaster(username, server, port, None, check=True, ssh_path=ssh_path)
connection = (str(server), int(port))
if ntries == 1:
printd('Check if connection ' + str(connection) + ' is up...', topic='connection')
else:
printd('Waiting for connection ' + str(connection) + ' to become available...', topic='connection')
timeout = float(timeout)
ntries = max([int(ntries), 1])
dt = timeout / ntries
t0 = time.time()
t = 0
while connected is None or (not connected and t <= timeout):
t1 = time.time()
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
s.settimeout(dt)
try:
connected = s.connect_ex(connection) == 0
s.close()
except Exception:
pass
if connected:
printd('Connection ' + str(connection) + ' is open!', topic='connection')
else:
if ntries == 1:
printd('Connection ' + str(connection) + ' is not open', topic='connection')
break
else:
printd('Re-checking connection ' + str(connection), topic='connection')
sleep(max([dt - (time.time() - t1), 0])) # make sure to wait at least a total of `dt` seconds before the next retry
t = time.time() - t0
if not connected and ntries > 1 and t > timeout:
printd('Connection ' + str(connection) + ' is not open: wait timed out', topic='connection')
return connected
[docs]def getSSHlink(sshName='OMFITssh', ssh_path=None):
sshlink = OMFITbinsDir + os.sep + sshName
if not os.path.exists(sshlink):
if os.path.islink(sshlink):
try:
os.remove(sshlink) # remove a broken link
except IOError:
pass
try:
os.symlink(system_executable('ssh', force_path=ssh_path), sshlink)
except OSError: # avoid issues with parallel processes creating the symlink at the same time
pass
return sshlink
[docs]def sshOptions(sshName='OMFITssh', ssh_path=None):
sshlink = getSSHlink(sshName, ssh_path)
xauth_exe = system_executable('xauth')
if xauth_exe is None:
xauth_opt = ''
printd('No xauth found', topic='connection')
else:
xauth_opt = f' -o XAuthLocation={xauth_exe}'
return (
sshlink + ' -o StrictHostKeyChecking=no'
' -o CheckHostIP=no'
' -o Compression=yes'
' -o ForwardAgent=yes'
' -o NoHostAuthenticationForLocalhost=yes'
' -o PermitLocalCommand=no'
' -o ConnectTimeout=10'
' -o ServerAliveCountMax=5'
' -o ServerAliveInterval=3'
' -o TCPKeepAlive=no' + xauth_opt
)
_nprocs_per_controlmaster = 5
_pexpect_cmd = {} # this is necessary to prevent garbage collection
[docs]def ssh_connect_with_password(cmd, credential, test_connected, timeout=10, check_every=0.1):
if not OMFITaux['GUI'] and os.environ.get('OMFIT_SETUP', '0') == '1':
# this is used by `omfit -s` and may be done for non GUI workflows
return os.system(cmd)
if test_connected():
return True
# only the first process every _nprocs_per_controlmaster processes
# setup the connection while the others wait for timeout seconds
if 'prun_process' in OMFITaux and len(OMFITaux['prun_process']):
if not len(OMFITaux['prun_process']):
index = 0
elif len(OMFITaux['prun_process']) == 1:
index = OMFITaux['prun_process'][-1]
else:
import numpy as np
index = np.prod((np.array(OMFITaux['prun_nprocs']) * np.array(OMFITaux['prun_process']))[:-1]) + OMFITaux['prun_process'][-1]
setter = (index % _nprocs_per_controlmaster) == 0
if setter:
# sleep time is necessary to stagger ssh connections
time.sleep(index // _nprocs_per_controlmaster)
else:
maxk = k = int(timeout / float(check_every))
while k > 0: # wait 10 seconds for controlmaster of this group to be setup
if test_connected():
return True
time.sleep(check_every)
k -= 1
# execute command
printd('ssh_connect_with_password: ' + cmd, topic='connection')
p = pexpect.spawn(cmd, env={v: os.environ[v] if v != 'SSH_ASKPASS' else False for v in os.environ}, timeout=None)
_pexpect_cmd[str(credential), str(cmd)] = p
# check if connected
connected = False
printd('Waiting for connection to ' + credential, topic='connection')
maxk = k = int(timeout / float(check_every))
npwd = 0
try:
while k > 0: # wait 10 seconds
if p.expect([pexpect.TIMEOUT, "(?i)Verification code:"], timeout=check_every):
npwd += 1
# do not ask for password more than 3 times
if npwd >= 4:
break
printi(b2s(p.before + p.after).strip())
# if there is a GUI, then OMFITaux['GUI'] is not None, and we can use AskPassGUI
if OMFITaux.get('GUI', None):
from utils import AskPassGUI
pwd, otp = AskPassGUI(credential, OTP=True).decrypt()
else:
pwd, otp = AskPass(credential, OTP=True).decrypt()
p.sendline(otp)
p.expect([pexpect.TIMEOUT, "(?i)password:", "(?i)password\s*\+\s*OTP:"], timeout=10)
p.sendline(pwd)
k = maxk
elif p.expect([pexpect.TIMEOUT, "(?i)password:", "(?i)password\s*\+\s*OTP:"], timeout=check_every):
npwd += 1
# do not ask for password more than 3 times
if npwd >= 4:
break
printi(b2s(p.before + p.after).strip())
# if there is a GUI, then OMFITaux['GUI'] is not None, and we can use AskPassGUI
if OMFITaux.get('GUI', None):
from utils import AskPassGUI
pwd_otp = ''.join(AskPassGUI(credential, force_ask=(npwd > 2), OTP=True).decrypt())
else:
pwd_otp = ''.join(AskPass(credential, force_ask=(npwd > 2), OTP=True).decrypt())
p.sendline(pwd_otp)
k = maxk
# we need to wait a minimum amount of time before declaring victory
# this is to allow the password request to show up on the terminal
elif (maxk - k) * check_every > 2.0 and test_connected():
connected = True
break
k -= 1
except EOF as _excp:
connected = test_connected()
if not connected:
tmp = re.findall('before \(last 100 chars\):.*', str(_excp))
if tmp:
raise EOF(':'.join(tmp[0].split(':')[1:]).strip('\'" \r\n'))
else:
raise
if connected:
printd('Connected to ' + credential, topic='connection')
else:
printd('Failed to connect to ' + credential, topic='connection')
return connected
[docs]def controlmaster(username, server, port, credential, check=False, ssh_path=None):
"""
Setup controlmaster ssh socket
:param username: username of ssh connection
:param server: host of ssh connection
:param port: port of ssh connection
:param credential: credential file to use for connection
:param check: check that control-master exists and is open
:return: [True/False] if `check=True` else string to be included to ssh command to use controlmaster
"""
printd(username, server, port, credential, check, topic='connection')
# use 10 connections per controlmaster to reflect default `maxsessions` setting of sshd servers
if not len(OMFITaux['prun_process']):
index = 0
elif len(OMFITaux['prun_process']) == 1:
index = OMFITaux['prun_process'][-1]
else:
import numpy as np
index = np.prod((np.array(OMFITaux['prun_nprocs']) * np.array(OMFITaux['prun_process']))[:-1]) + OMFITaux['prun_process'][-1]
filename = (
OMFITcontrolmastersDir
+ os.sep
+ str(omfit_numeric_hash(str(assemble_server(username, '', server, port)), 8))
+ '__%d' % (index // _nprocs_per_controlmaster)
)
# test if controlmaster exists and is open
if os.path.exists(filename):
tmp = system_executable('ssh', force_path=ssh_path) + ' -o ControlPath=%s -O check %s' % (filename, filename)
child = subprocess.Popen(tmp, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
for k in range(100):
if child.poll() is not None:
break
sleep(0.1)
if k == 10:
printi('Waiting for server response on controlmaster %s' % filename)
if child.poll() is None or child.poll() == 255 or not os.path.exists(filename):
printd('Controlmaster exists but is closed: ' + filename, topic='connection')
if not check and os.path.exists(filename):
try:
os.remove(filename)
except FileNotFoundError:
pass
valid = False
if child.poll() is None:
child.terminate()
else:
printd('Controlmaster exists and is open: ' + filename, topic='connection')
valid = True
child.stdout.close()
child.stderr.close()
else:
printd('Controlmaster does not exist: ' + filename, topic='connection')
valid = False
if check:
printd('Controlmaster checked, and now returning', topic='connection')
return valid
if valid:
printd('Using controlmaster: ' + filename, topic='connection')
else:
# create controlmaster
if os.name == 'nt':
ssh_command = sshOptions('OMFITsshControlmaster', ssh_path=ssh_path)
cmd = ssh_command + ' -N -t -t -M -Y -S ' + filename + ' ' + username + '@' + server + ' -p ' + str(port)
else:
ssh_command = sshOptions('OMFITsshControlmaster', ssh_path=ssh_path)
cmd = (
ssh_command
+ ' -o ControlPersist=24h -o ControlMaster=auto -N -t -t -M -Y -S '
+ filename
+ ' '
+ username
+ '@'
+ server
+ ' -p '
+ str(port)
)
printd('controlmaster: ' + cmd, topic='connection')
connected = ssh_connect_with_password(
cmd, credential, lambda: controlmaster(username, server, port, credential, True, ssh_path=ssh_path), timeout=10, check_every=1
)
if connected:
printd('Created controlmaster: ' + filename, topic='connection')
else:
raise OMFITexception("%s\n\n\nCould not connect with:\n%s" % (credential, cmd))
return ' -S ' + filename + ' '
[docs]def setup_ssh_tunnel(server, tunnel, forceTunnel=False, forceRemote=False, allowEmptyServerUsername=False, ssh_path=None):
"""
This function sets up the remote ssh tunnel (if necessary) to connect to the server
and returns the username,server,port triplet onto which to connect.
:param server: string with remote server
:param tunnel: string with via tunnel (multi-hop tunneling with comma separated list of hops)
:param forceTunnel: force tunneling even if server is directly reachable (this is useful when servers check provenance of an IP)
:param forceRemote: force remote connection even if server is localhost
:param allowEmptyServerUsername: allow empty server username
:return: username, server, port
"""
if not isinstance(evalExpr(server), str):
raise OMFITexception('server must be a string')
if not isinstance(evalExpr(tunnel), str):
raise OMFITexception('tunnel must be a string')
# save original entries
server0 = server
tunnel0 = tunnel
# see if server is localhost
local = is_localhost(server0)
if forceRemote or forceTunnel:
if local and not forceTunnel:
return (os.environ['USER'], 'localhost', '22')
local = False
# resolve nested multi-hop tunnels
if ',' in tunnel0:
hops = tunnel.split(',')
tunnel0 = tunnel = hops[0]
for server in hops[1:]:
u, s, p = setup_ssh_tunnel(
server, tunnel, forceTunnel=False, forceRemote=False, allowEmptyServerUsername=False, ssh_path=ssh_path
)
tunnel0 = tunnel = assemble_server(u, '', s, p)
# calculate tunnel port and parse server/tunnel strings
port_l = str(omfit_numeric_hash((str(server0) + str(tunnel0))) % (65535 - 49152) + 49152)
username, password, server, port = parse_server(str(server0), default_username='')
username_t, password_t, server_t, port_t = parse_server(str(tunnel), default_username='')
# buffer localshost given as hostname
if local:
if len(server) and server0 + '-' + tunnel0 not in OMFITaux['sshTunnel'] and server != 'localhost':
printi(server + ' is the local workstation')
OMFITaux['sshTunnel'][server0 + '-' + tunnel0] = (username, '', port)
return OMFITaux['sshTunnel'][server0 + '-' + tunnel0]
# buffered connections
if server0 + '-' + tunnel0 in OMFITaux['sshTunnel']:
printd('Found buffered connection %s@%s:%s' % tuple(OMFITaux['sshTunnel'][server0 + '-' + tunnel0][:]), topic='connection')
# since connection should be buffered we expect a quick response
connected = test_connection(
OMFITaux['sshTunnel'][server0 + '-' + tunnel0][0],
OMFITaux['sshTunnel'][server0 + '-' + tunnel0][1],
int(OMFITaux['sshTunnel'][server0 + '-' + tunnel0][2]),
timeout=0.1,
ntries=1,
ssh_path=ssh_path,
)
if connected:
return OMFITaux['sshTunnel'][server0 + '-' + tunnel0]
else:
del OMFITaux['sshTunnel'][server0 + '-' + tunnel0]
# start by seeing if the tunnel was already established
# checking for tunnel on localhost is faster than checking remote...
if len(server_t):
connected = test_connection(username, '127.0.0.1', int(port_l), timeout=0.1, ntries=1, ssh_path=ssh_path)
if connected:
printi('Found existing SSH tunnel ' + username + '@127.0.0.1:' + port_l + ' = ' + tunnel0 + ' --> ' + server0)
OMFITaux['sshTunnel'][server0 + '-' + tunnel0] = (username, '127.0.0.1', port_l)
return OMFITaux['sshTunnel'][server0 + '-' + tunnel0]
# check that username was specified for server connection
if not allowEmptyServerUsername and not username:
raise OMFITexception(
f'Start OMFIT and specify username explicitly under `File > Preferences > Remote Servers` for connection to server `{server0}`'
)
# see if server is directly reachable
# remote: here where we need to give it some time, and we cannot query too much for risk of getting blacklisted
if not forceTunnel:
try:
if not controlmaster(username, server, port, server0, check=True, ssh_path=ssh_path):
try:
controlmaster(username, server, port, server0, ssh_path=ssh_path)
except OMFITexception:
connected = False
connected = controlmaster(username, server, port, server0, check=True, ssh_path=ssh_path)
except EOF: # do not raise pexpect.EOF, we need to keep going because some servers (eg. lac.epfl.ch) keep the door open but will drop the packet if not contacted via the tunnel
connected = False
if connected:
printd('Server is directly reachable at %s@%s:%s' % (username, server, port), topic='connection')
OMFITaux['sshTunnel'][server0 + '-' + tunnel0] = (username, server, port)
return OMFITaux['sshTunnel'][server0 + '-' + tunnel0]
# make or re-use tunnel
if len(server_t):
# check that username was specified for tunnel connection
if not username_t:
raise OMFITexception('Specify username explicitly for connection to tunnel `%s`' % str(tunnel0))
connected = False
if len(OMFITaux['prun_process']):
import numpy as np
n = (
np.random.rand() * 10
) # give it max 10 seconds to check if other processes have setup the connection (since localhost, we can check often)
connected = test_connection(username, '127.0.0.1', int(port_l), timeout=n, ntries=10 * np.ceil(n), ssh_path=ssh_path)
if connected:
printi('Found existing SSH tunnel ' + username + '@127.0.0.1:' + port_l + ' = ' + tunnel0 + ' --> ' + server0)
else:
# setup the tunnel
printi(
'Starting SSH tunnel '
+ username
+ '@127.0.0.1:'
+ port_l
+ ' = '
+ username_t
+ '@'
+ server_t
+ ':'
+ port_t
+ ' --> '
+ username
+ '@'
+ server
+ ':'
+ port
)
# The connection to the tunnel server gets setup by the following controlmaster call
cmd = (
sshOptions('OMFITsshTunnel', ssh_path=ssh_path)
+ controlmaster(username_t, server_t, port_t, tunnel0, ssh_path=ssh_path)
+ " -t -t -Y -N -L "
+ port_l
+ ":"
+ server
+ ":"
+ port
+ " "
+ username_t
+ "@"
+ server_t
+ " -p "
+ port_t
)
printd('tunneling command: ' + cmd, topic='connection')
# Establish the actual tunnel
connected = ssh_connect_with_password(
cmd,
server0,
lambda username=username, port_l=port_l: test_connection(
username, '127.0.0.1', int(port_l), timeout=0.1, ntries=1, ssh_path=ssh_path
),
)
if connected:
printi('Created SSH tunnel ' + username + '@127.0.0.1:' + port_l + ' = ' + tunnel0 + ' --> ' + server0)
if connected:
OMFITaux['sshTunnel'][server0 + '-' + tunnel0] = (username, '127.0.0.1', port_l)
return OMFITaux['sshTunnel'][server0 + '-' + tunnel0]
else:
if not internet_on():
raise OMFITexception('Failed to connect since network appears to be down')
raise OMFITexception('Failed to create tunnel ' + username + '@127.0.0.1:' + port_l + ' = ' + tunnel0 + ' --> ' + server0)
if len(tunnel0):
raise OMFITexception('Failed to connect to `%s` via `%s`' % (server0, tunnel0))
OMFITaux['sshTunnel'][server0 + '-' + tunnel0] = (username, server, port)
return OMFITaux['sshTunnel'][server0 + '-' + tunnel0]
[docs]def setup_socks(tunnel, ssh_path=None):
"""
Specifies a local "dynamic" application-level port forwarding.
Whenever a connection is made to a defined port on the local side, the connection is forwarded over the secure channel, and the application protocol is then used to determine where to connect to from the remote machine.
The SOCKS4 and SOCKS5 protocols are supported, and ssh will act as a SOCKS server.
:param tunnel: tunnel
:return: local sock connection (username, localhost, port)
"""
# save original entries
tunnel0 = tunnel
# buffered connections
if tunnel0 in OMFITaux['sshTunnel']:
printd('Found buffered connection %s@%s:%s' % tuple(OMFITaux['sshTunnel'][tunnel0][:]), topic='connection')
# since connection should be buffered we expect a quick response
connected = test_connection(
OMFITaux['sshTunnel'][tunnel0][0],
OMFITaux['sshTunnel'][tunnel0][1],
int(OMFITaux['sshTunnel'][tunnel0][2]),
timeout=0.1,
ntries=1,
ssh_path=ssh_path,
)
if connected:
return OMFITaux['sshTunnel'][tunnel0]
else:
del OMFITaux['sshTunnel'][tunnel0]
# calculate tunnel port and parse server/tunnel strings
port_l = str(omfit_numeric_hash(tunnel0) % (65535 - 49152) + 49152)
username, password, tunnel, port = parse_server(str(tunnel0), default_username='')
# Check if existing connection exists
connected = test_connection(username, '127.0.0.1', int(port_l), timeout=0.1, ntries=1, ssh_path=ssh_path)
if connected:
printi('Found existing SOCKS tunnel ' + username + '@127.0.0.1:' + port_l + ' --> ' + username + '@' + tunnel + ':' + port)
OMFITaux['sshTunnel'][tunnel0] = (username, '127.0.0.1', port_l)
# Establish the actual tunnel
cmd = sshOptions('OMFITsshTunnel', ssh_path=ssh_path) + " -D " + str(port_l) + " " + username + "@" + tunnel + " -p " + port
printd('sock command: ' + cmd, topic='connection')
connected = ssh_connect_with_password(
cmd,
tunnel0,
lambda username=username, port_l=port_l: test_connection(username, '127.0.0.1', port_l, timeout=0.1, ntries=1, ssh_path=ssh_path),
)
if connected:
OMFITaux['sshTunnel'][tunnel0] = (username, '127.0.0.1', port_l)
return OMFITaux['sshTunnel'][tunnel0]
else:
raise Exception('Failed to create tunnel ' + username + '@127.0.0.1:' + port_l + ' --> ' + username + '@' + tunnel + ':' + port)
[docs]def internet_on(website='http://www.bing.com', timeout=5):
"""
Check internet connection by returning the ability to connect to a given website
:param website: website to test (by default 'http://www.bing.com' since google is not accessible from China)
:param timeout: timeout in seconds
:return: ability to connect or not
"""
import requests
if bool(len(OMFITaux['prun_process'])):
printd('prun: skipping internet connection check.', topic='connection')
return True
try:
printd('Internet connection check...', topic='connection')
response = requests.get(website, verify=False, timeout=timeout)
response.raise_for_status() # Raises a HTTPError if the status is 4xx, 5xxx
return True
except (requests.exceptions.ConnectionError, requests.exceptions.Timeout, requests.exceptions.HTTPError):
return False
[docs]def get_ip():
"""
https://stackoverflow.com/a/28950776/6605826
:return: current system IP address
"""
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
try:
# doesn't even have to be reachable
s.connect(('10.255.255.255', 1))
ip_addr = s.getsockname()[0]
except Exception:
ip_addr = '127.0.0.1'
finally:
s.close()
return ip_addr
[docs]class AskPass:
"""
Class that asks for password and one-time-password secret if credential file does not exists
"""
explain = (
"NOTE: Your saved credentials will be encrypted\n"
+ " with your ~/.ssh/id_rsa private key\n"
+ "and stored under "
+ os.path.abspath(OMFITsettingsDir + os.sep + 'credentials')
)
def __init__(self, credential, force_ask=False, OTP=False, **kw):
r"""
:param credential: credential filename to look for under the OMFITsettingsDir+os.sep+'credentials' directory
:param force_ask: force asking for the password (if file exists pwd is shown for amendment)
:param OTP: ask for one-time-password secret; if 'raw' do not pass through pyotp.now()
:param \**kw: extra arguments used in self.ask()
"""
self.credential = credential
pwd, otp = decrypt_credential(self.credential)
store = True
if not len(pwd + otp) and os.path.exists(credential_filename(self.credential)):
store = False
self.store = store
self.OTP = OTP
if len(pwd + otp) and not force_ask:
return
elif len(OMFITaux['prun_process']):
raise RuntimeError('PWD+OTP for %s is not valid. Parallel run cannot continue.' % credential)
self.ask(pwd, otp, OTP, store, **kw)
[docs] def ask(self, pwd, otp, OTP, store, **kw):
"""
Ask user for the password and one-time-password secret
:param pwd: password
:param otp: one time password
:param OTP: ask for one-time-password secret; if 'raw' do not pass through pyotp.now()
:param store: save pwd + otp to credential file
"""
print(self.explain)
OTP_text = []
if OTP and OTP != 'raw':
OTP_text = ['OTP secret (32 digits)']
else:
OTP_text = ['OTP']
defaults = OrderedDict()
username, password, server, port = parse_server(self.credential)
print(f'\nCredentials for: {username}@{server}:{port}\n')
defaults['Password'] = pwd
if OTP_text:
defaults[OTP_text[0]] = otp
defaults['Save'] = store
defaults_txt = {}
for field in defaults:
if defaults[field] or field == 'Save':
if field in ['Password'] + OTP_text:
defaults_txt[field] = f'[{"*" * len(defaults[field])}]'
else:
defaults_txt[field] = f'[{defaults[field]}]'
else:
defaults_txt[field] = ''
response = {}
for k, field in enumerate(defaults.keys()):
response[field] = input(f'{field} {defaults_txt[field]}: ')
if not len(response[field]):
response[field] = str(defaults[field])
self.store = bool(eval(response['Save']))
self.pwd = response['Password']
self.otp = ''
if OTP_text:
self.otp = response[OTP_text[0]]
# if we decide not to store the password, the credential file is still generated but it will contain an empty password and OTP
if self.store:
print('STORE')
self.encrypt(self.pwd, self.otp)
else:
print('EMPTY')
self.encrypt('', '')
[docs] def destroy(self):
pass
[docs] def encrypt(self, pwd, otp):
"""
save user password and one-time-password secret to credential file
"""
encrypt_credential(self.credential, pwd, otp)
self.destroy()
[docs] def decrypt(self):
"""
:return: password with addition of six digits one-time-password
"""
if self.store:
pwd, secret = decrypt_credential(self.credential)
else:
pwd, secret = self.pwd, self.otp
if self.OTP == 'raw':
return pwd, secret
elif len(secret):
import pyotp
otp = pyotp.TOTP(secret).now()
return pwd, otp
else:
return pwd, ''
# ---------------------
# cryptography
# ---------------------
def _crypto(in_string, encrypt=True, keys=None):
"""
:param in_string: string to be encrypted
:param encrypt: if True, encrypt, else decrypt
:param keys: filename or list of filenames of RSA private keys to be used for encryption (by default `~/.ssh/id_rsa`)
:return: encrypted/decrypted string
"""
def _exe_GUI(cmd, key):
printd('Executing command: `{}`'.format(cmd), topic='Crypto')
p = pexpect.spawn(cmd, timeout=20)
pass_str = 'Enter pass phrase for %s:' % key
pass_str_re = pass_str.replace(key, '.*')
index = p.expect([pass_str_re, EOF])
if index == 0:
key = re.match('Enter pass phrase for (.*):', b2s(p.after)).groups()[0]
if 'rootGUI' in OMFITaux and OMFITaux['rootGUI']:
# Get password
from utils_widgets import password_gui
password = password_gui(title=pass_str, key=key)
else:
from utils_widgets import stored_passwords
import getpass
if key not in stored_passwords:
stored_passwords[key] = getpass.getpass(pass_str)
password = stored_passwords[key]
printd('Entering password', topic='Crypto')
p.sendline(password.strip())
p.expect(EOF)
if keys is None:
keys = os.path.join(os.environ['HOME'], '.ssh', 'id_rsa')
keys = tolist(keys, [])
# private keys
for key in keys:
if not os.path.exists(key):
raise ValueError(f'Cannot find private RSA key: `{key}`!\nGenerate one with `ssh-keygen -b 4096 -m PEM`')
# filenames
import tempfile
tmpdir = tempfile._get_default_tempdir()
filename_in = os.path.join(tmpdir, next(tempfile._get_candidate_names()))
filename_out = os.path.join(tmpdir, next(tempfile._get_candidate_names()))
filename_cert = {}
try:
# input file
with open(filename_in, mode='w') as f:
pass
os.chmod(filename_in, 0o600)
if encrypt:
with open(filename_in, mode='w+') as f:
f.write(in_string)
else:
with open(filename_in, mode='wb+') as f:
f.write(in_string)
# output file
with open(filename_out, mode='w') as f:
pass
os.chmod(filename_out, 0o600)
if encrypt:
for key in keys:
# certificate
filename_cert[key] = os.path.join(tmpdir, next(tempfile._get_candidate_names()))
cmd = "openssl req -key {key} -new -x509 -nodes -subj '/' -out {certificate}".format(
key=key, certificate=filename_cert[key]
)
_exe_GUI(cmd, key)
# encrypt
cmd = "openssl smime -encrypt -aes256 -in {filename_in} -out {filename_out} -outform PEM ".format(
filename_in=filename_in, filename_out=filename_out
)
cmd += ' '.join(filename_cert.values())
printd('Executing command: `{}`'.format(cmd), topic='Crypto')
os.system(cmd)
else:
# decrypt
with open(filename_in, 'rb') as f:
header = b'-----BEGIN PKCS7-----'
# look for SMIME signature
if header in f.readline()[: len(header)]:
cmd = 'openssl smime -decrypt -inform PEM'
# otherwise fallback on RSAUTL for backward compatibility
else:
cmd = 'openssl rsautl -decrypt'
cmd = (cmd + ' -in {filename_in} -out {filename_out} -inkey {key}').format(
key=key, filename_in=filename_in, filename_out=filename_out
)
_exe_GUI(cmd, key)
# read output file
with open(filename_out, mode='rb') as f:
tmp = f.read()
if not encrypt:
tmp = b2s(tmp)
finally:
# clean up temporary files
if os.path.exists(filename_in):
os.remove(filename_in)
if os.path.exists(filename_out):
os.remove(filename_out)
for key in filename_cert:
if os.path.exists(filename_cert[key]):
os.remove(filename_cert[key])
return tmp
[docs]@_available_to_user_util
def encrypt(in_string, keys=None):
"""
Encrypts a string with user RSA private key
:param in_string: input string to be encrypted with RSA private key
:param keys: private keys to encrypt with. if None, default to RSA key.
:return: encrypted string
"""
return _crypto(in_string, encrypt=True, keys=keys)
_tested_PEM = []
[docs]@_available_to_user_util
def decrypt(in_string):
"""
Decrypts a string with user RSA private key
:param in_string: input string to be decrypted with RSA private key
:return: decrypted string
"""
if not len(_tested_PEM):
try:
_tested_PEM.append(True)
# the first time encryption is used, make sure ssh keys support encryption
assert decrypt(encrypt('OMFIT')) == 'OMFIT', 'Your SSH keys do not support PEM! Regenerate with `ssh-keygen -b 4096 -m PEM`'
except AssertionError:
_tested_PEM[:] = []
raise
return _crypto(in_string, encrypt=False)
[docs]def credential_filename(server):
"""
generates credential filename given username@server:port
:param server: username@server:port
:return: credential filename
"""
tmp = list(parse_server(server))
tmp[1] = ''
credential = assemble_server(*tmp)
if not os.path.exists(OMFITsettingsDir + os.sep + 'credentials'):
os.makedirs(OMFITsettingsDir + os.sep + 'credentials')
return OMFITsettingsDir + os.sep + 'credentials' + os.sep + credential
[docs]def encrypt_credential(credential, password, otp, keys=None):
"""
generate encrypted credential file
:param credential: credential string
:param password: user password
:param otp: user one-time-password secret
:param keys: private keys to encrypt with. if None, default to RSA key.
"""
if password == 'd3dpub' or (password.strip() == '' and otp.strip() == ''):
printd(f'Refusing to store credential {credential}, since password is {password} and otp is {otp}', topic='Crypto')
return
filename = credential_filename(credential)
with open(filename, 'wb') as f:
f.write(encrypt('\n'.join([password, otp]), keys=keys))
os.chmod(filename, 0o600)
[docs]def decrypt_credential(credential):
"""
read and decrypt credential file
:param credential: credential string
:return: password and one-time-password secret
"""
filename = credential_filename(credential)
if not os.path.exists(filename):
return '', ''
with open(filename, 'rb') as f:
tmp = b2s(decrypt(f.read()))
tmp = tmp.split('\n')
otp = ''
pwd = tmp[0]
if len(tmp) > 1 and len(tmp[1]):
otp = tmp[1]
return pwd.strip(), otp.strip()
[docs]def reset_credential(credential):
"""
delete credential file
:param credential: credential string
"""
filename = credential_filename(credential)
if os.path.exists(filename):
os.remove(filename)
printi('Removed credential file for %s' % credential)
[docs]@_available_to_user_util
def password_encrypt(data, password, encoding="utf-8", bufferSize=64 * 1024):
'''
Encrypt using AES cryptography
:param data: data in clear to be encrypted
:param password: password
:param encoding: encoding to be used to translate from string to bytes
:param bufferSize: buffer size
:return: encrypted data (bytes)
'''
import pyAesCrypt
import io
# from string to bytes
if encoding:
data = data.encode(encoding)
# input plaintext binary stream
fIn = io.BytesIO(data)
# initialize ciphertext binary stream
fCiph = io.BytesIO()
# initialize decrypted binary stream
fDec = io.BytesIO()
# encrypt stream
pyAesCrypt.encryptStream(fIn, fCiph, password, bufferSize)
# encrypted data
return fCiph.getvalue()
[docs]@_available_to_user_util
def password_decrypt(data, password, encoding="utf-8", bufferSize=64 * 1024):
'''
Deccrypt using AES cryptography
:param data: encrypted data (bytes) to be decrypted
:param password: password
:param encoding: encoding to be used to translate from bytes to string
:param bufferSize: buffer size
:return: data in clear
'''
import pyAesCrypt
import io
# initialize decrypted binary stream
fDec = io.BytesIO()
# initialize ciphertext binary stream
fCiph = io.BytesIO()
fCiph.write(data)
# get ciphertext length
ctlen = len(fCiph.getvalue())
# go back to the start of the ciphertext stream
fCiph.seek(0)
# decrypt stream
pyAesCrypt.decryptStream(fCiph, fDec, password, bufferSize, ctlen)
# decrypted data
data = fDec.getvalue()
# turn into string
if encoding:
data = data.decode(encoding)
return data
# ---------------------
# PDF and DMP
# ---------------------
[docs]def PDF_add_file(pdf, file, name=None, delete_file=False):
"""
Embed file in PDF
:param pdf: PDF filename
:param file: filename or file extension
:param name: name of attachment in PDF. Uses file filename if is None.
:param delete_file: remove file after embedding
:return: full path to PDF file
"""
try:
import PyPDF2
except ImportError:
printw('Could not embed data in PDF because Python `PyPDF2` package is missing')
return
if file.startswith('.') and os.sep not in file:
file = re.sub('.pdf$', '', pdf) + file
unmeta = PyPDF2.PdfFileReader(pdf, "rb")
meta = PyPDF2.PdfFileWriter()
meta.appendPagesFromReader(unmeta)
with open(file, "rb") as fp:
meta.addAttachment(os.path.split(file)[1], fp.read())
with open(pdf, "wb") as fp:
meta.write(fp)
# delete file if requested
if delete_file:
os.remove(file)
# return path to PDF
return os.path.abspath(pdf)
[docs]def PDF_get_file(pdf, file, name='.*'):
"""
Extract file from PDF
:param pdf: PDF filename
:param file: filename or file extension
:param name: regular expression with name(s) of attachment in PDF to get
:return: full path to file
"""
try:
import PyPDF2
except ImportError:
printw('Could not de-embed data in PDF because Python `PyPDF2` package is missing')
return
if file.startswith('.') and os.sep not in file:
file = re.sub('.pdf$', '', pdf) + file
unmeta = PyPDF2.PdfFileReader(pdf, "rb")
embedded = unmeta.trailer['/Root']['/Names']['/EmbeddedFiles']['/Names']
list_files = []
done = False
for obj in range(len(embedded)):
if not isinstance(embedded[obj], dict):
continue
filename = embedded[obj]['/F']
list_files.append(filename)
if re.match(name, filename):
data = embedded[obj]['/EF']['/F']._data
with open(file, 'wb') as f:
f.write(data)
done = True
break
# raise error if file is not found
if not len(list_files):
raise RuntimeError('The PDF does not contain attachments')
elif not done:
raise RuntimeError('The PDF does not contain `%s`.\nPossible options are: %s' % (name, list_files))
# return path to file
return os.path.abspath(file)
[docs]def PDF_set_DMP(pdf, dmp='.h5', delete_dmp=False):
"""
Embed DMP file in PDF
:param pdf: PDF filename
:param dmp: DMP filename or extension
:param delete_dmp: remove DMP file after embedding
:return: full path to PDF file
"""
if dmp.startswith('.') and os.sep not in dmp:
dmp = re.sub('.pdf$', '', pdf) + dmp
return PDF_add_file(pdf, dmp, name='DMP_' + os.path.splitext(os.path.split(dmp)[1])[0], delete_file=delete_dmp)
[docs]def PDF_get_DMP(pdf, dmp='.h5'):
"""
Extract DMP file from PDF
:param pdf: PDF filename
:param dmp: filename or file extension
:return: full path to DMP file
"""
if dmp.startswith('.') and os.sep not in dmp:
dmp = re.sub('.pdf$', '', pdf) + dmp
return PDF_get_file(pdf, dmp, name='^DMP_.*')
# ---------------------
# Python introspection
# ---------------------
[docs]def python_imports(namespace, submodules=True, onlyWithVersion=False):
"""
function that lists the Python modules that have been imported in a namespace
:param namespace: usually this function should be called with namespace=globals()
:param submodules: list only top level modules or also the sub-modules
:param onlyWithVersion: list only (sub-)modules with __version__ attribute
"""
import types
mods = []
for name, val in list(namespace.items()):
if isinstance(val, types.ModuleType):
version = 'N/A'
if hasattr(val, '__version__'):
version = val.__version__
elif onlyWithVersion:
continue
if submodules:
mods.append([val.__name__, version])
else:
if not '.' in val.__name__:
mods.append([val.__name__, version])
return mods
[docs]def function_arguments(f, discard=None, asString=False):
"""
Returns the arguments that a function takes
:param f: function to inspect
:param discard: list of function arguments to discard
:param asString: concatenate arguments to a string
:return: tuple of four elements
* list of compulsory function arguments
* dictionary of function arguments that have defaults
* True/False if the function allows variable arguments
* True/False if the function allows keywords
"""
import inspect
the_argspec = inspect.getfullargspec(f)
the_keywords = the_argspec.varkw
args = []
kws = OrderedDict()
string = ''
for k, arg in enumerate(the_argspec.args):
if (discard is not None) and (arg in tolist(discard)):
continue
d = ''
if the_argspec.defaults is not None:
if (-len(the_argspec.args) + k) >= -len(the_argspec.defaults):
d = the_argspec.defaults[-len(the_argspec.args) + k]
kws[arg] = d
string += arg + '=' + repr(d) + ',\n'
else:
args.append(arg)
string += arg + ',\n'
else:
args.append(arg)
string += arg + ',\n'
if the_argspec.varargs:
string += '*[],\n'
if the_keywords:
string += '**{},\n'
string = string.strip()
if asString:
return string
else:
return args, kws, the_argspec.varargs is not None, the_keywords is not None
[docs]def args_as_kw(f, args, kw):
"""
Move positional arguments to kw arguments
:param f: function
:param args: positional arguments
:param kw: keywords arguments
:return: tuple with positional arguments moved to keyword arguments
"""
a, k, astar, kstar = function_arguments(f)
n = 0
for name, value in zip(a + list(k.keys()), args):
if name not in kw:
kw[name] = value
n += 1
return args[n:], kw
[docs]def only_valid_kw(f, kw0=None, **kw1):
"""
Function used to return only entries of a dictionary
that would be accepted by a function and avoid
TypeError: ... got an unexpected keyword argument ...
:param f: function
:param kw0: dictionary with potential function arguments
:param \**kw1: keyword dictionary with potential function arguments
>> f(**only_valid_kw(f, kw))
"""
if kw0 is None:
kw0 = {}
else:
kw = copy.copy(kw0)
kw.update(kw1)
a, k, astar, kstar = function_arguments(f)
if kstar:
return kw
kw_out = {}
for item, value in kw.items():
if item in a or item in k:
kw_out[item] = value
return kw_out
[docs]def functions_classes_in_script(filename):
"""
Parses a Python script and returns list of functions and classes
that are declared there (does not execute the script)
:params filename: filename of the Python script to parse
:returns: dictionary with lists of 'functions' and 'classes'
"""
with open(filename, 'r') as file:
ast_tree = ast.parse(file.read(), filename=filename)
objects = {'functions': [], 'classes': []}
for item in ast_tree.body:
if isinstance(item, ast.FunctionDef):
objects['functions'].append(item.name)
if isinstance(item, ast.ClassDef):
objects['classes'].append(item.name)
return objects
[docs]def dump_function_usage(post_operator=None, pre_operator=None):
"""
Decorator function used to collect arguments passed to a function
>> def printer(d):
>> print(d)
>> @dump_function_usage(printer)
>> def hello(a=1):
>> print('hello')
:param post_operator: function called after the decorated function
(a dictionary with the function name, arguments, and keyword arguments gets passed)
:param pre_operator: function called before the decorated function
(a dictionary with the function name, arguments, and keyword arguments gets passed)
"""
def dumpArgsDeco(func):
def wrapper(*func_args, **func_kwargs):
arg_names = func.__code__.co_varnames[: func.__code__.co_argcount]
defaults = func.__defaults__ or ()
args = func_args[: len(arg_names)]
args = args + defaults[len(defaults) - (func.__code__.co_argcount - len(args)) :]
params = list(zip(arg_names, args))
args = func_args[len(arg_names) :]
if args:
params.append(('args', args))
if func_kwargs:
params.append(('kwargs', func_kwargs))
func_name = func.__name__
func_name = func.__name__
params.append(('func', func_name))
# printd(func_name + ' (' + ', '.join('%s = %r' % p for p in params) + ' )',topic=func_name)
if pre_operator is not None:
pre_operator(dict(params))
tmp = func(*func_args, **func_kwargs)
if post_operator is not None:
post_operator(dict(params))
return tmp
return wrapper
return dumpArgsDeco
[docs]def function_to_tree(funct, self_ref):
"""
Converts a function to an OMFITpythonTask instance that can be saved in the tree
:param funct: function
The function you want to export
:param self_ref: object
Reference to the object that would be called self within the script.
Its location in the tree will be looked up and used to replace 'self' in the code.
This is used to add a line defining the variable self within the new OMFITpythonTask's source. If the function
doesn't use self, then it just has to be something that won't throw an exception, since it won't be used
(e.g. self_ref=OMFIT should work if you're not using self)
:return: An OMFITpythonTask instance
"""
from inspect import cleandoc
from omfit_classes.sortedDict import treeLocation
from omfit_classes.omfit_python import OMFITpythonTask
src = inspect.getsource(funct)
def_part = cleandoc(src.split('):')[0]) + ')'
body = cleandoc('):'.join(src.split('):')[1:]))
lines = body.split('\n')
while lines[-1].strip().startswith('return') or (not len(lines[-1].strip())):
lines = lines[:-1]
body = '\n'.join(lines)
self_def = 'self = {}\n\n'.format(treeLocation(self_ref)[-1])
name = def_part.split('(')[0].split('def ')[1]
args = '('.join([arg.lstrip('\n ') for arg in def_part.split('(')[1:]])
# Remove leading "self" argument and trailing ", " or ",", if present
cut_start = ['self, ', 'self,', 'self']
for cs in cut_start:
if args.startswith(cs):
args = args[len(cs) :]
break
args = '\n'.join([line for line in args.split('\n') if len(line.strip())])
args = re.sub('\*\*(?P<name>[A-Z,a-z,0-9,_]*)', '\g<name>={}', args)
return OMFITpythonTask(
'{}.py'.format(name),
fromString='# TEMPORARY FILE: EDITS WILL NOT BE SAVED!\n'
'# This file was generated by omfit_classes.utils_base.function_to_tree()\n\n'
'defaultVars(' + args + '\n' + self_def + body,
)
OMFIT_backward_compatibility_mapper = {}
# ---------------------
# lists and dicts
# ---------------------
[docs]@_available_to_user_util
def tolist(data, empty_lists=None):
"""
makes sure that the returned item is in the format of a list
:param data: input data
:param empty_lists: list of values that if found will be filtered out from the returned list
:return: list format of the input data
"""
import numpy as np
data = evalExpr(data)
if isinstance(data, str):
data = [data]
if empty_lists:
if not np.iterable(data):
if data in empty_lists:
return []
else:
data0 = data
data = []
for k in data0:
if k not in empty_lists:
data.append(k)
if isinstance(data, np.ndarray):
return np.atleast_1d(data).tolist()
elif isinstance(data, dict):
return [data]
try:
return list(data)
except TypeError:
return [data]
[docs]def common_in_list(input_list):
"""
Finds which list element is most common (most useful for a list of strings or mixed strings & numbers)
:param input_list: list with hashable elements (no nested lists)
:return: The list element with the most occurrences. In a tie, one of the winning elements is picked arbitrarily.
"""
return max(set(input_list), key=input_list.count)
[docs]def keyword_arguments(dictionary):
"""
Returns the string that can be used to generate a dictionary from keyword arguments
eg. keyword_arguments({'a':1,'b':'hello'}) --> 'a'=1,'b'='hello'
:param dictionary: input dictionary
:return: keyword arguments string
"""
return ','.join([k[0] + '=' + repr(k[1]) for k in zip(list(dictionary.keys()), list(dictionary.values()))])
[docs]@_available_to_user_util
def select_dicts_dict(dictionary, **selection):
"""
Select keys from a dictionary of dictionaries. This is useful to select data from a dictionary that uses a hash
as the key for it's children dictionaries, and the hash is based on the content of the children.
eg:
>> parent={}
>> parent['child1']={'a':1,'b':1}
>> parent['child2']={'a':1,'b':2}
>> select_dicts_dict(parent, b=1) #returns: ['child1']
>> select_dicts_dict(parent, a=1) #returns: ['child1', 'child2']
:param dictionary: parent dictionary
:param \**selection: keywords to select on
:return: list of children whose content matches the selection
"""
ret = []
for item in list(dictionary.keys()):
found = True
for sel in list(selection.keys()):
if not (sel in dictionary[item] and selection[sel] == dictionary[item][sel]):
found = False
if found:
ret.append(item)
return ret
[docs]@_available_to_user_util
def bestChoice(options, query, returnScores=False):
"""
This fuction returns the best heuristic choice from a list of options
:param options: dictionary or list of strings
:param query: string to look for
:param returnScores: whether to return the similarity scores for each of the options
:return: the tuple with best choice from the options provided and its matching score, or match scores if returnScores=True
"""
scores = []
for item in options:
m = difflib.SequenceMatcher(None, item.lower(), query.lower())
scores.append(m.ratio())
i = sorted(list(range(len(scores))), key=scores.__getitem__)
if returnScores:
return scores
elif isinstance(options, dict):
return list(options.values())[i[-1]], scores[i[-1]]
else:
return options[i[-1]], scores[i[-1]]
[docs]def flip_values_and_keys(dictionary, modify_original=False, add_key_to_value_first=False):
"""
Flips values and keys of a dictionary
People sometime search the help for swap_keys, switch_keys, or flip_keys to find this function.
:param dictionary: dict
input dictionary to be processed
:param modify_original: bool
whether the original dictionary should be modified
:param add_key_to_value_first: bool
Append the original key to the value (which will become the new key).
The new dictionary will look like: {'value (key)': key},
where key and value were the original key and value.
This will force the new key to be a string.
:return: dict
flipped dictionary
"""
keys = list(dictionary.keys())
values = list(dictionary.values())
if modify_original:
dictionary.clear()
else:
try:
dictionary = dictionary.__class__()
except TypeError: # OMFITjson and other file classes won't init without a filename
dictionary = OrderedDict()
for k, v in zip(keys, values):
if add_key_to_value_first:
dictionary['{} ({})'.format(v, k)] = k
else:
dictionary[v] = k
return dictionary
[docs]def dir2dict(startpath, dir_dict=OrderedDict):
"""
python dictionary hierarchy based on filesystem hierarchy
:param startpath: filesystem path
:return: python dictionary
"""
def set_leaf(tree, branches, leaf):
"""Set a terminal element to *leaf* within nested dictionaries.
*branches* defines the path through dictionaries.
Example:
>>> t = {}
>>> set_leaf(t, ['b1','b2','b3'], 'new_leaf')
>>> print(t)
{'b1': {'b2': {'b3': 'new_leaf'}}}
"""
if len(branches) == 1:
tree[branches[0]] = leaf
return
if branches[0] not in tree:
tree[branches[0]] = dir_dict()
set_leaf(tree[branches[0]], branches[1:], leaf)
tree = dir_dict()
for root, dirs, files in os.walk(startpath):
dirs.sort(key=lambda x: x.lower() + x)
files.sort(key=lambda x: x.lower() + x)
branches = [startpath]
if root != startpath:
branches.extend(os.path.relpath(root, startpath).split('/'))
set_leaf(tree, branches, dir_dict([(d, dir_dict()) for d in dirs] + [(f, None) for f in files]))
return tree
# ---------------------
# git
# ---------------------
OMFITreposDir = None
[docs]def parse_git_message(message, commit=None, tag_commits=[]):
ctype_color = 'RoyalBlue3'
message = re.sub(r'<<<>>>(.*)<<<>>>', r'\n<<<>>>\1<<<>>>', message)
summary = message.split('\n')[0]
ctype = re.findall(r'<<<>>>.*<<<>>>', message, re.MULTILINE | re.DOTALL)
def ctype_colors(in_c):
in_c = in_c.lower()
if 'bug' in in_c or 'fix' in in_c:
return 'firebrick1'
if 'document' in in_c or 'regression' in in_c:
return 'orange'
if 'module' in in_c:
return 'purple'
return 'RoyalBlue3'
if commit in tag_commits:
ctype = 'New release'
ctype_color = 'forest green'
elif len(ctype):
message = message.replace(ctype[0], '')
ctype = ctype[0].replace('\n', ' ').replace('<<<>>>', '').strip('><')
ctype_color = ctype_colors(ctype)
elif ':' in summary:
summ_split = summary.split(':', 3)
message = summ_split[-1] + ('\n' * 3)
summary = message.strip()
ctype_color = ctype_colors(':'.join(summ_split[:-1]))
ctype = ':'.join(summ_split[:-1]).lstrip('!')
else:
ctype = 'Update'
message = message.split('\n')
return ctype, summary, message, ctype_color
[docs]class OMFITgit(object):
"""
An OMFIT interface to a git repo, which does NOT leave dangling open files
:param git_dir: the root of the git repo
:param n: number of visible commits
"""
def __init__(self, git_dir, n=25):
if not os.path.exists(git_dir):
raise OSError(f'{git_dir} does not exist')
if not os.path.exists(git_dir + os.sep + '.git') and not git_dir.endswith('.git'):
raise OSError(f'{git_dir} is not a git repository!')
self.work_repo = None
self.git_cmd = system_executable('git')
os.environ.get(
'GIT_SSH_COMMAND', sshOptions('OMFITsshGit')
) # +' -o ControlMaster=auto -o ControlPath='+OMFITcontrolmastersDir+os.sep+'%r@%h:%p -o ControlPersist=24h'
self.git_dir = os.path.abspath(git_dir)
self._n_visible = n
self.install_OMFIT_hooks()
[docs] def install_OMFIT_hooks(self, quiet=True):
import filecmp
if self.is_OMFIT_source and os.access(os.sep.join([self.git_dir, '.git', 'hooks']), os.W_OK):
for source in glob.glob(os.sep.join([self.git_dir, 'install', 'git-hooks', '*'])):
target = os.sep.join([self.git_dir, '.git', 'hooks', os.path.split(source)[1]])
if not os.path.exists(target) or not filecmp.cmp(source, target, shallow=True):
try:
shutil.copy2(source, target)
printw("Installed '%s' git hook" % os.path.split(target)[1])
except Exception as e:
printe("Could not install '%s' git hook: %s" % (os.path.split(target)[1], repr(e)))
elif not quiet:
raise Exception('Could not install .git hooks')
[docs] def is_OMFIT_source(self):
if os.path.exists(os.sep.join([self.git_dir, 'omfit', 'omfit_tree.py'])):
return True
return False
def __call__(self, command, verbose=False, returns='out+err', subdir=''):
"""
Return the results of `git <command> [args]` as a string
:param command: (str or iterable) The git command (or iterable of commands) to carry out
:param returns: list with choice of `['out','err','code']` or 'out+err'
:param subdir: subdirectory in the git repo where to execute the command
:return: stdout, stderr, return code depending on `returns` list
"""
return_dict = {}
if verbose:
printi('%s %s' % ('git', command))
if 'DYLD_LIBRARY_PATH' in os.environ:
print('Removed DYLD_LIBRARY_PATH=%s' % os.environ['DYLD_LIBRARY_PATH'])
del os.environ['DYLD_LIBRARY_PATH']
composite_command = '&&'.join([' %s %s' % (self.git_cmd, x) for x in tolist(command)])
if verbose:
print(composite_command)
command = 'cd %s && %s' % (self.git_dir + '/' + subdir, composite_command)
if os.name == 'nt':
if len(command) > 8191:
raise OMFITexception('Command for git is too long for windows command line')
process = subprocess.Popen(command, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
return_dict['out'], return_dict['err'] = list(map(b2s, process.communicate()))
return_dict['code'] = process.returncode
return_dict['out'] = return_dict['out'].rstrip()
return_dict['err'] = return_dict['err'].rstrip()
if verbose and (return_dict['out'] + return_dict['err']).rstrip():
print((return_dict['out'] + return_dict['err']).rstrip())
if returns == 'out+err':
return return_dict['out'] + return_dict['err']
elif len(returns) == 1:
return return_dict[returns[0]]
else:
return {ret: return_dict[ret] for ret in returns}
def __getattr__(self, attr):
if attr in ['status', 'log', 'tag', 'describe', 'branch']:
return self(attr)
if attr == 'head':
return self.get_hash('HEAD')
raise AttributeError('%s: Unknown attribute' % attr)
[docs] def tag_hashes(self):
"""
Get the full hashes for all of the tags
:return: A list of hashes, corresponding to self.tag.splitlines()
"""
return self.get_hash(self.tag.splitlines()).splitlines()
[docs] def get_hash(self, ref):
"""
Get the full hash for the given reference (tag, branch, or commit)
:param ref: A commitish reference (tag, commit hash, or branch) or iterable of references
"""
return self(['log -1 --pretty="%H" ' + str(x) + ' --' for x in tolist(ref)], returns=['out'])
[docs] def get_commit_message(self, commit):
"""
Given the commit hash, return the commit message
:param commit: (str) The hash of the commit
"""
return self('log -1 --pretty="%B" ' + commit, returns=['out'])
[docs] def get_commit_date(self, commit):
"""
Given the commit hash, return the commit date (in form similar to `time.time`)
:param commit: (str) The hash of the commit
:return: An int
"""
return int(self('log -1 --pretty="%at" ' + str(commit), returns=['out']))
[docs] def get_visible_commits(self, order=['author-date-order', 'date-order', 'topo-order'][0]):
"""
Get the commits that don't have Hidden, Hide, or Minor comment tags
:return: A tuple of the commits, their messages, their authors, their dates, and the tag_hashes
"""
sep = '#<-:-:->#'
recent = self(
'log -' + str(int(self._n_visible)) + ' --no-merges --' + order + ' --pretty="%H; %aN; %at%n%B%n' + sep + '" HEAD',
returns=['out'],
).split('\n' + sep + '\n')[:-1]
commits = []
messages = []
authors = []
dates = []
tag_commits = self.tag_hashes()
for line in recent:
split_line = line.strip().splitlines()
commit, author, date = split_line[0].split(';')
message_orig = '\n'.join(split_line[1:])
ctype, summary, message, ctype_color = parse_git_message(message_orig, commit, tag_commits)
if message[0] == summary:
title = summary.strip()
else:
title = ''
if ctype.lower() not in ['hide', 'hidden', 'minor'] and (commit in tag_commits or not title.startswith("Merge ")):
commits.append(commit)
messages.append(message_orig)
authors.append(author)
dates.append(date)
return commits, messages, authors, dates, tag_commits
[docs] def get_tag_version(self, tag_family):
"""
Get a version similar to the results of self.describe, but restricted to tags containing a specific string. This
is for finding the tagged version of a module: one might use ``repo.get_tag_version('cake_')`` to get the
version of the CAKE module (like 'cake_00.01.684e4d226a' or 'cake_00.01')
:param tag_family: A substring defining a family of tags. It is assumed that splitting this substring out of the
git tags will leave behind version numbers for that family.
:return: A string with the most recent tag in the family, followed by a commit short hash if there have been
commits since the tag was defined.
"""
printd('Getting version for tag_family = {}...'.format(tag_family))
return self('describe --match "{}[0-9]*" --tags'.format(tag_family))
[docs] def active_branch(self):
"""
Get the active branch, the (shortened) hash of HEAD, and the standard string containing those
where the standard string is "Commit <shortened hash> on branch <active_branch>"
:return: active_branch, active_branch_hash, standard_string
"""
t0 = time.time()
active_branch = 'DETACHED_HEAD'
for branch in self('branch').splitlines():
if branch.startswith('*') and 'HEAD detached' not in branch:
active_branch = branch.split('*')[1].strip()
break
active_branch_commit = self.get_hash('HEAD')
tags = list([x for x in self('tag --contains', returns=['out']).split() if x not in ['HEAD', active_branch_commit]])
if len(tags):
active_branch_commit = ','.join(tags)
else:
active_branch_commit = active_branch_commit[:10]
repo_str = '%s on branch %s' % (self('describe'), active_branch)
return active_branch, active_branch_commit, repo_str
[docs] def get_module_params(self, path='modules', key='ID', quiet=True, modules=None):
"""
Return a dictionary whose keys are the modules available in this repo
:param path: The path (relative to the root of the repo) where the modules are stored
:param key: attribute used for the returned dictorary (eg. `ID` (default) or `path`)
:param quiet: whether to print the modules loaded (if `None` uses progress bar)
:param modules: list of modules to track (full path, minus OMFITsave.txt)
:return: A dictionary whose keys are the modules available in this repo,
the values are dictionaries whose keys are author, date, commit,
path, version.
A return example::
{
'EFITtime':
{
'edited_by': 'Sterling Smith',
'commit': '64a213fb03e14a154567f8eb7b260f10acbe48f3',
'date': '1456297324', #time.time format
'path': '/Users/meneghini/Coding/atom/OMFIT-source/modules/EFITtime/OMFITsave.txt',
'version': u'Module to execute time-dependent EFIT equilibrium reconstructions'
}
}
"""
def sortByModule(command):
byMod = {}
for filename in self(command + ' ' + path).split():
tmp = list([_f for _f in filename[len(path) :].split(os.sep) if _f])
if len(tmp):
byMod.setdefault(tmp[0], [])
byMod[tmp[0]].append(self.git_dir + os.sep + filename)
return byMod
tracked_files = sortByModule('ls-tree --name-only -r HEAD')
untracked_files = sortByModule('ls-files --other --exclude-standard')
modified_files = sortByModule('ls-files -m --exclude-standard')
if modules is None:
modules = [
os.path.split(x)[0] for x in list(map(os.path.abspath, glob.glob(os.sep.join([self.git_dir, path, '*', 'OMFITsave.txt']))))
]
else:
modules = [re.sub(os.sep + 'OMFITsave.txt$', '', x) for x in modules]
result = {}
commands = []
mod_keys = sorted(modules, key=lambda x: x.lower())
for k, mod in enumerate(mod_keys):
commands.append('log -1 --pretty="%H; %aN; %at" HEAD -- ' + mod + '; echo ' + mod)
mod_tmp = [_f for _f in self(commands, returns=['out']).splitlines() if _f]
mod_out = {}
details = untracked_details = '? ; %s ; %d' % (os.environ['USER'], int(time.time()))
for line in mod_tmp:
if line.count(';'):
details = line
elif not line.count(';'):
mod_out[line] = details
details = untracked_details
for mod in ascii_progress_bar(mod_keys, mess=mod_keys, quiet=quiet is not None):
try:
mod_name = os.path.split(mod)[1]
tmp = {}
tmp['ID'] = mod_name
tmp['path'] = os.sep.join([mod, 'OMFITsave.txt'])
tmp['untracked'] = untracked_files.get(mod_name, '')
tmp['modified'] = modified_files.get(mod_name, '')
tmp['date'] = time.time()
tmp['edited_by'] = os.environ['USER']
tmp['commit'] = ''
tmp['description'] = ''
if len(tmp['modified']) or os.path.abspath(mod + os.sep + 'OMFITsave.txt') in tmp['untracked']:
pass
else:
for param, val in zip(['commit', 'edited_by', 'date'], mod_out[mod].split(';')):
tmp[param] = val.strip()
if quiet is False:
if os.path.abspath(mod + os.sep + 'OMFITsave.txt') in tmp['untracked']:
printi(tmp['ID'].ljust(20) + "(untracked)")
elif len(tmp['modified']):
printi(tmp['ID'].ljust(20) + "(modified)")
else:
printi(tmp['ID'])
tmp['date'] = convertDateFormat(int(tmp['date']))
# get additional module informations that is in the modules settings
from omfit_classes.omfit_base import OMFITmodule
info = OMFITmodule.info(tmp['path'])
info.pop('date', None)
info.pop('edited_by', None)
info.pop('commit', None)
tmp.update(info)
result[tmp[key]] = tmp
except Exception as _excp:
printe('git error for module %s: %s' % (mod_name, repr(_excp)))
raise
return result
[docs] def get_remotes(self):
'''returns dictionary with remotes as keys and their properties'''
remotes = {}
remote_keys = self('remote').split()
for i, remote in enumerate(remote_keys):
remotes[remote] = {}
remotes[remote]['url'] = self('ls-remote --get-url ' + remote)
return remotes
[docs] def get_branches(self, remote=''):
'''returns dictionary with branches as keys and their properties'''
branches = {}
if not remote:
for branch in [_f for _f in self('branch').strip().split('\n') if _f]:
current = False
if branch.startswith('*'):
current = True
branch = branch.strip('* ')
branches[branch] = {}
branches[branch]['remote'] = self('config branch.%s.remote' % branch).strip()
branches[branch]['current'] = current
else:
returns = self('ls-remote --heads %s' % remote, returns=['out', 'err'])
for branch in [_f for _f in returns['out'].strip('\n ').split('\n') if _f]:
branch = '/'.join(branch.split()[1].split('/')[2:])
branches[branch] = {}
branches[branch]['remote'] = remote
return branches
[docs] def clone(self):
"""
Clone of the repository in the OMFITworking environment `OMFITtmpDir+os.sep+'repos'`
and maintain remotes information. Note: `original_git_repository` is the remote that points
to the original repository that was cloned
:return: OMFITgit object pointing to cloned repository
"""
if not os.path.exists(OMFITreposDir):
os.makedirs(OMFITreposDir)
repository_directory = OMFITreposDir + os.sep + os.path.split(os.path.abspath(self.git_dir))[1] + "_" + omfit_hash(self.git_dir, 10)
if not os.path.exists(repository_directory):
tmp = self.active_branch()
self('clone -n file://$PWD %s' % repository_directory, verbose=True)
self.work_repo = OMFITgit(repository_directory, n=self._n_visible)
self.work_repo('remote rename %s %s' % ('origin', 'original_git_repository'), verbose=True)
# copy remotes (origin points to original repo)
remotes = self.get_remotes()
for remote in list(remotes.keys()):
if 'gafusion/OMFIT-source.git' in remotes[remote]['url'] and 'gafusion' not in remotes:
remotes['gafusion'] = remotes[remote]
for remote in list(remotes.keys()):
self.work_repo('remote add %s %s' % (remote, remotes[remote]['url']), verbose=True)
else:
if not self.work_repo:
self.work_repo = OMFITgit(repository_directory, n=self._n_visible)
out = self.work_repo('status')
if 'fatal: Not a git repository' in out:
shutil.rmtree(repository_directory)
self.work_repo = self.clone()
printd('work repo already exists: %s' % self.work_repo.git_dir, topic='OMFITgit')
return self.work_repo
[docs] def switch_branch(self, branch, remote=''):
"""
Switch to branch
:param branch: branch
:param remote: optional remote repository
"""
# get local branches
branches = self.get_branches()
# get remote branches
if remote:
remote_branches = self.get_branches(remote=remote)
# fetch remote/branch
if branch in remote_branches:
self('fetch %s %s' % (remote, branch), verbose=True)
# checkout new branch and make it point to the remote
if branch not in branches and remote and branch not in remote_branches:
self('checkout -b %s' % (branch), verbose=True)
out = self('branch %s -u %s/%s' % (branch, remote, branch), verbose=True)
if 'error: the requested upstream branch ' in out and ' does not exist' in out:
self('push --set-upstream %s %s' % (remote, branch), verbose=True)
# if there is a local copy already
elif branch in branches:
if branches[branch]['remote'] == remote or not branches[branch]['remote'] or not remote:
self('checkout %s' % (branch), verbose=True)
else:
self('branch %s -u %s/%s' % (branch, remote, branch), verbose=True)
if remote and branches[branch]['remote']:
self('merge --no-edit %s/%s' % (remote, branch), verbose=True)
# if there is a remote branch
else:
self('checkout --track %s/%s' % (remote, branch), verbose=True)
[docs] def branch_containing_commit(self, commit):
"""
Returns a list with the names of the branches that contain the specified commit
:param commit: commit to search for
:return: list of strings (remote is separated by `/`)
"""
tmp = self('branch -a --contains %s' % commit)
tmp = [x.strip(' *') for x in tmp.strip().split('\n')]
lst = []
for item in list(tmp):
if 'remotes/origin/HEAD' in item:
continue
if item.startswith('remotes/'):
lst.append('/'.join(item.split('/')[1:]))
else:
lst.append(item)
return sorted(set(lst))
[docs] def switch_branch_GUI(self, branch='', remote='', parent=None, title=None, only_existing_branches=False):
import tkinter as tk
from tkinter import ttk
from utils_tk import tk_center
top = tk.Toplevel()
top.withdraw()
if title is None:
title = 'Switch branch'
top.wm_title(title)
out = [None, None]
if parent is None:
parent = OMFITaux['rootGUI']
top.wm_transient(parent)
def onRemote(event=None):
top.update_idletasks()
branchSelectorOptions = list(self.get_branches(remoteSelector.get()).keys())
branchSelectorOptions = sorted(branchSelectorOptions, key=lambda x: x.lower())
branchSelector.configure(values=tuple(branchSelectorOptions))
def onReturn(event=None):
top.update_idletasks()
self.switch_branch(branchSelector.get(), remoteSelector.get())
out[:] = [branchSelector.get(), remoteSelector.get()]
top.destroy()
def onEscape(event=None):
top.destroy()
frm = ttk.Frame(top)
frm.pack(side=tk.TOP, padx=5, pady=2, fill=tk.X, expand=tk.NO)
ttk.Label(frm, text='git repository: ' + self.git_dir, justify=tk.LEFT).pack(side=tk.LEFT, anchor=tk.W)
frm = ttk.Frame(top)
frm.pack(side=tk.TOP, padx=5, pady=2, fill=tk.X, expand=tk.NO)
ttk.Label(frm, text='Remote: ').pack(side=tk.LEFT)
remoteSelector = ttk.Combobox(frm, width=50)
remoteSelector.pack(side=tk.LEFT, fill=tk.X, expand=tk.YES)
remoteSelector.bind('<<ComboboxSelected>>', func=onRemote)
remoteSelectorOptions = sorted(self.get_remotes(), key=lambda x: x.lower())
remoteSelector.configure(values=tuple(remoteSelectorOptions))
remoteSelector.configure(state='readonly')
frm = ttk.Frame(top)
frm.pack(side=tk.TOP, padx=5, pady=2, fill=tk.X, expand=tk.NO)
ttk.Label(frm, text='Branch: ').pack(side=tk.LEFT)
branchSelector = ttk.Combobox(frm)
branchSelector.pack(side=tk.LEFT, fill=tk.X, expand=tk.YES)
if only_existing_branches:
branchSelector.configure(state='readonly')
top.bind('<Return>', onReturn)
top.bind('<Escape>', onEscape)
remoteSelector.set(remote)
branchSelector.set(branch)
onRemote()
top.protocol("WM_DELETE_WINDOW", top.destroy)
top.resizable(False, False)
tk_center(top, parent)
top.deiconify()
top.update_idletasks()
top.wait_window(top)
if out[0] == out[1] == None:
return None
else:
return out[0], out[1]
[docs] def remote_branches_details(self, remote='origin', skip=['unstable', 'master', 'daily_.*', 'gh-pages']):
"""
returns a list of strings with the details of the branches on a given remote
This function is useful to identify stale branches.
:param remote: remote repository
:param skip: branches to skip
:return: list of strings
"""
skip = [remote + '/' + item for item in skip]
lst = []
for remote_branch in self('branch -r').split():
if remote_branch.startswith(remote + '/') and not np.any([re.findall('^' + item + '$', remote_branch) for item in skip]):
lst.append(self('show --format="%ci %cr %an" ' + remote_branch).split('\n')[0] + ' ' + remote_branch)
return list(reversed(sorted(lst)))
# ---------------------
# miscellaneous
# ---------------------
[docs]def omfit_hash(string, length=-1):
"""
Hash a string using SHA1 and truncate hash at given length
Use this function instead of Python hash(string) since with
Python 3 the seed used for hashing changes between Python sessions
:param string: input string to be hashed
:param length: lenght of the hash (max 40)
:return: SHA1 hash of the string in hexadecimal representation
"""
import hashlib
m = hashlib.sha1()
m.update(string.encode("UTF-8"))
return m.hexdigest()[:length]
[docs]def omfit_numeric_hash(string, length=-1):
"""
Hash a string using SHA1 and truncate integer at given length
Use this function instead of Python hash(string) since with
Python 3 the seed used for hashing changes between Python sessions
:param string: input string to be hashed
:param length: lenght of the hash (max 47)
:return: SHA1 hash of the string in integer representation
"""
return int(str(int(omfit_hash(string), 16))[:length])
[docs]def find_library(libname, default=None):
"""
This function returns the path of the matching library name
:param libname: name of the library to look for (without `lib` and extension)
:param default: what to return if library is not found
:return:
"""
from ctypes.util import find_library as _find_library
lib_tmp = 'lib' + libname + '.so'
if platform.system() == 'Darwin':
import ctypes.macholib.dyld
dlib = ctypes.macholib.dyld.DEFAULT_LIBRARY_FALLBACK
added = False
if sys.prefix + '/lib' not in dlib:
added = True
dlib.insert(0, sys.prefix + '/lib')
lib = _find_library(lib_tmp)
if added:
ctypes.macholib.dyld.DEFAULT_LIBRARY_FALLBACK.remove(sys.prefix + '/lib')
if 'DYLD_LIBRARY_PATH' in os.environ:
del os.environ['DYLD_LIBRARY_PATH']
else:
lib = _find_library(libname)
for k in os.environ.get('LD_LIBRARY_PATH', '').split(':'):
if os.path.exists(k + os.sep + lib_tmp):
lib = k + os.sep + lib_tmp
break
if lib is None:
if default is not None:
lib = default
else:
raise Exception(lib_tmp + ' not found in LD_LIBRARY_PATH')
return lib
[docs]def find_file(reg_exp_filename, path):
"""
find all filenames matching regular expression under a path
:param reg_exp_filename: regular expression for the file to match
:param path: folder where to look
:return: list of filenames matching regular expression with full path
"""
match = []
for root, dirs, names in os.walk(path):
for name in names:
if reg_exp_filename == name or re.match(reg_exp_filename, name):
match.append(os.path.join(root, name))
return match
def _isdebugging():
"""
Function that returns True/False if function is run in debugging mode
:return: True/False
"""
for frame in inspect.stack():
if frame[1].endswith("pydevd.py"):
return True
return False
OMFIT_is_in_debug_mode = _isdebugging()
[docs]def sleep(sleeptime):
"""
Non blocking sleep
:param sleeptime: time to sleep in seconds
:return: None
"""
if sleeptime <= 0:
return
if 'rootGUI' in OMFITaux and OMFITaux['rootGUI'] is not None and not len(OMFITaux['prun_process']):
import tkinter as tk
finished = tk.BooleanVar()
finished.set(False)
def goOn():
OMFITaux['rootGUI'].update_idletasks()
finished.set(True)
OMFITaux['rootGUI'].update_idletasks()
OMFITaux['rootGUI'].after(int(sleeptime * 1000), goOn)
OMFITaux['rootGUI'].wait_variable(finished)
else:
time.sleep(sleeptime)
[docs]def now(format_out='%d %b %Y %H:%M', timezone=None):
"""
:param format_out: https://docs.python.org/3/library/datetime.html#strftime-and-strptime-behavior
:param timezone: [string] look at /usr/share/zoneinfo for available options
CST6CDT Europe/? Hongkong MST Portugal WET
Africa/? Canada/? Factory Iceland MST7MDT ROC
America/? Chile/? GB Indian/? Mexico/? ROK
Antarctica/? Cuba GB-Eire Iran NZ Singapore
Arctic/? EET GMT Israel NZ-CHAT Turkey
Asia/? EST GMT+0 Jamaica Navajo UCT
Atlantic/? EST5EDT GMT-0 Japan PRC US/?
Australia/? Egypt GMT0 Kwajalein PST8PDT UTC
Brazil/? Eire Greenwich Libya Pacific/? Universal
CET Etc/? HST MET Poland W-SU
Zulu
:return: formatted datetime string
if format_out is None, return datetime object
"""
if timezone is not None:
from dateutil import tz
resolved_timezone = tz.gettz(timezone)
if resolved_timezone is None:
raise ValueError('Timezone `%s` not recognized! see /usr/share/zoneinfo/' % timezone)
timezone = resolved_timezone
dt = datetime.datetime.now(timezone)
if format_out is None:
return dt
return dt.strftime(format_out)
[docs]def update_dir(root_src_dir, root_dst_dir):
"""
Go through the source directory, create any directories that do not already exist in destination directory,
and move files from source to the destination directory
Any pre-existing files will be removed first (via os.remove) before being replace by the corresponding source file.
Any files or directories that already exist in the destination but not in the source will remain untouched
:param root_src_dir: Source directory
:param root_dst_dir: Destination directory
"""
for src_dir, dirs, files in os.walk(root_src_dir):
dst_dir = src_dir.replace(root_src_dir, root_dst_dir, 1)
if not os.path.exists(dst_dir):
os.makedirs(dst_dir)
for file_ in files:
src_file = os.path.join(src_dir, file_)
dst_file = os.path.join(dst_dir, file_)
if os.path.exists(dst_file):
os.remove(dst_file)
shutil.move(src_file, dst_dir)
[docs]def permissions(path):
"""
:path: file path
:return: file permissions as a (user, group, other) (read, write, execute) string, such as: rwxr-xr-x
"""
s = os.lstat(path)
p = ['-'] * 9
signs = 'rwx'
for i, flag in enumerate(
[stat.S_IRUSR, stat.S_IWUSR, stat.S_IXUSR, stat.S_IRGRP, stat.S_IWGRP, stat.S_IXGRP, stat.S_IROTH, stat.S_IWOTH, stat.S_IXOTH]
):
if s.st_mode & flag:
p[i] = signs[i % 3]
return ''.join(p)
[docs]def zipfolder(foldername, filename, compression=zipfile.ZIP_STORED, allowZip64=True):
"""
compress folder to zip archive
:param foldername: folder to compress
:param filename: zip filename to use
:param compression: compression algorythm
:param allowZip64: use 64bit extension to handle files >4GB
"""
filename = os.path.abspath(filename)
if not os.path.exists(os.path.dirname(filename)):
os.makedirs(os.path.dirname(filename))
os.chdir(foldername)
foldername = './'
empty_dirs = []
with zipfile.ZipFile(filename, 'w', compression=compression, allowZip64=allowZip64) as zip:
for dirpath, dirs, files in os.walk(foldername):
empty_dirs.extend([dir for dir in dirs if not len(os.listdir(os.sep.join([dirpath, dir])))])
for name in files:
zip.write(os.path.join(dirpath, name))
# include empty folders
for dir in empty_dirs:
zif = zipfile.ZipInfo(os.sep.join([dirpath, dir]) + "/")
zip.writestr(zif, "")
empty_dirs = []
[docs]def omfit_patch(obj, fun):
"""
Patch a standard module/class with a new function/method.
Moves original attribute to _original_<name> ONLY ONCE! If done
blindly you will go recursive when reloading modules
"""
import types
name = fun.__name__.lstrip('_')
ismod = isinstance(obj, types.ModuleType)
if hasattr(obj, name) and not hasattr(obj, '_original_' + name):
orig = getattr(obj, name)
if ismod:
setattr(obj, '_original_' + name, orig) # save copy of original function
else:
setattr(obj, '_original_' + name, types.MethodType(orig, obj)) # save copy of original method
if ismod:
setattr(obj, name, fun) # replace with modified method
else:
setattr(obj, name, types.MethodType(fun, obj)) # replace with modified method
[docs]def funny_random_name_generator(use_mood=False, digits=2):
"""
Makes up a random name with no spaces in it. Funnier than timestamps.
:param use_mood: bool
Use a mood instead of a color
:param digits: int
Number of digits in the random number (default: 2)
:return: string
The default format is [color]_[animal]_[container]_[two digit number]
Example: "blueviolet_kangaroo_prison_26"
Colors come from matplotlib's list.
Alternative formats selected by keywords:
[mood]_[animal]_[container]_[two digit number]
Example: "menacing_guppy_pen_85"
"""
from matplotlib import colors as mcolors
import random as rand
import numpy as np
colors = dict(mcolors.BASE_COLORS, **mcolors.CSS4_COLORS)
# Sort colors by hue, saturation, value and name.
by_hsv = sorted((tuple(mcolors.rgb_to_hsv(mcolors.to_rgba(color)[:3])), name) for name, color in colors.items())
sorted_names = [name for hsv, name in by_hsv if len(name) > 1]
color = rand.choice(sorted_names)
moods = ['angry', 'happy', 'hangry', 'hungry', 'sad', 'baleful', 'menacing', 'joyful', 'benevolent', 'vindictive']
mood = rand.choice(moods)
animals = [
'alligator',
'bear',
'beaver',
'cat',
'capricorn',
'dog',
'doggo',
'dolphin',
'emu',
'fish',
'goat',
'gorilla',
'guppy',
'horse',
'iguana',
'ibex',
'jackal',
'kangaroo',
'koala',
'kitty',
'lemur',
'monkey',
'markhor',
'newt',
'octopus',
'penguin',
'quail',
'rat',
'snake',
'turtle',
'tortoise',
'vulture',
'walrus',
'yak',
'zebra',
]
animal = rand.choice(animals)
containers = ['box', 'carton', 'vault', 'bag', 'purse', 'bowl', 'prison', 'bottle', 'cup', 'refrigerator', 'pen', 'paddock', 'barn']
container = rand.choice(containers)
number = rand.choice(range(int(10**digits)))
name_format = '{}_{}_{}_{:0' + str(int(np.ceil(digits))) + 'd}'
name = name_format.format(mood if use_mood else color, animal, container, number).replace(' ', '-')
return name
[docs]def developer_mode(module_link=None, **kw):
"""
Calls the OMFITmodule.convert_to_developer_mode() method for every top-level module in OMFIT
Intended for convenient access in a command box script that can be called as soon as the project is reloaded.
Accepts keywords related to OMFITmodule.convert_to_developer_mode() and passes them to that function.
:param module_link: OMFITmodule or OMFIT instance
Refence to the module to be converted or to top-level OMFIT
If None, defaults to OMFIT (affecting all top-level modules)
"""
from omfit_classes.omfit_base import OMFIT, OMFITmodule, relativeLocations
if module_link is None:
module_link = OMFIT
MainScratch = OMFIT['scratch']
MainScratch.setdefault('developerMode_convert_submodules', True)
MainScratch.setdefault('developerMode_load_new_settings', True)
MainScratch.setdefault('developerMode_operation', 'DEVEL')
dirs = OMFITmodule.directories(checkIsWriteable=(MainScratch['developerMode_operation'] == 'DEVEL'))
MainScratch.setdefault('developerMode_modules_directory', dirs[0])
MainScratch.setdefault('developerMode_quiet', True)
# Fill in defaults that don't depend on other settings
keywords = dict(processSubmodules=True, loadNewSettings=True, operation='DEVEL', quiet=True)
# Replace defaults with MainScratch or supplied kw if available
translations = dict(
processSubmodules='developerMode_convert_submodules',
loadNewSettings='developerMode_load_new_settings',
operation='developerMode_operation',
modulesDir='developerMode_modules_directory',
quiet='developerMode_quiet',
)
for k, v in translations.items():
if v in MainScratch:
keywords[k] = copy.copy(MainScratch[v])
if k in kw:
keywords[k] = kw[k]
# Finish assigning defaults for keywords whose default values depend on other keywords
if 'modulesDir' not in keywords:
dirs = OMFITmodule.directories(checkIsWriteable=(keywords['operation'] == 'DEVEL'))
keywords['modulesDir'] = dirs[0]
def announce_and_convert(the_module_, **kw2):
options = {'Developer mode': 'DEVEL', 'Standard mode': 'FREEZE', 'Standard mode (RELOAD)': 'RELOAD'}
printi(
flip_values_and_keys(options)[MainScratch['developerMode_operation']]
+ ' '
+ relativeLocations(the_module_)['rootName']
+ ['', ' and all sub-modules'][MainScratch['developerMode_convert_submodules']]
)
the_module_.convert_to_developer_mode(**keywords)
if module_link is OMFIT:
for mname in OMFIT.moduleDict(level=0).keys():
the_module = eval(f"OMFIT{mname}")
announce_and_convert(the_module, **keywords)
# the_module.convert_to_developer_mode(**keywords)
else:
announce_and_convert(module_link, **keywords)
# module_link.convert_to_developer_mode(**keywords)
# ---------------------
# versions
# ---------------------
[docs]def sanitize_version_number(version):
"""Removes common non-numerical characters from version numbers obtained from git tags, such as '_rc', etc."""
if version.startswith('.'):
version = '-1' + version
# Replace alpha, beta, release candidate *-a, *-b *-c endings with -3, -2, -1
version = re.sub(r'([0-9]+)-?c', r'\1.-1', version)
version = re.sub(r'([0-9]+)-?b', r'\1.-2', version)
version = re.sub(r'([0-9]+)-?a', r'\1.-3', version)
# More release candidate things
version = re.sub(r'([0-9\-]+)_?rc([0-9\-]+)', r'\1\.-1\.\2', version)
# Get rid of remaining non-numerics except for .-*
version = re.sub(r'[^0-9\.\*\-]', '.', version)
# Remove any -[char]
version = re.sub(r'-[a-zA-Z\.]', '.', version)
# Suppress repeated '.'
while '..' in version:
version = version.replace('..', '.')
return version
[docs]def compare_version(version1, version2):
"""
Compares two version numbers and determines which one, if any, is greater.
This function can handle wildcards (eg. 1.1.*)
Most non-numeric characters are removed, but some are given special treatment.
a, b, c represent alpha, beta, and candidate versions and are replaced by numbers -3, -2, -1.
So 4.0.1-a turns into 4.0.1.-3, 4.0.1-b turns into 4.0.1.-2, and then -3 < -2
so the beta will be recognized as newer than the alpha version.
rc# is recognized as a release candidate that is older than the version without the rc
So 4.0.1_rc1 turns into 4.0.1.-1.1 which is older than 4.0.1 because 4.0.1 implies 4.0.1.0.0.
Also 4.0.1_rc2 is newer than 4.0.1_rc1.
:param version1: str
First version to compare
:param version2: str
Second version to compare
:return: int
1 if version1 > version2
-1 if version1 < version2
0 if version1 == version2
0 if wildcards allow version ranges to overlay. E.g. 4.* vs. 4.1.5 returns 0 (equal)
"""
version1 = sanitize_version_number(version1)
version2 = sanitize_version_number(version2)
# Handle version wildcards
if '*' in version1 or '*' in version2:
version1 = version1.split('.')
version2 = version2.split('.')
start_asterix = False
for k in range(max([len(version1), len(version2)])):
if (k < len(version1) and version1[k] == '*') or (k < len(version2) and version2[k] == '*'):
start_asterix = True
if start_asterix:
if k < len(version1):
version1[k] = '*'
else:
version1.append('*')
if k < len(version2):
version2[k] = '*'
else:
version2.append('*')
version1 = '.'.join(version1)
version2 = '.'.join(version2)
def version_int(x):
if x in ['', ' ']:
return 0
elif x in '*':
return x
else:
return int(x)
def normalize(v):
return [version_int(x) for x in re.sub(r'(\.0+)*$', '', v).split('.')]
n1 = normalize(version1)
n2 = normalize(version2)
dn1 = len(n1) - len(n2)
if dn1 < 0:
n1 += [0] * -dn1
elif dn1 > 0:
n2 += [0] * dn1
return (n1 > n2) - (n1 < n2)
[docs]def find_latest_version(versions):
"""
Given a list of strings with version numbers like 1.2.12, 1.2, 1.20.5, 1.2.3.4.5, etc., find the maximum version
number. Test with: print(repo.get_tag_version('v'))
:param versions: List of strings like ['1.1', '1.2', '1.12', '1.1.13']
:return: A string from the list of versions corresponding to the maximum version number.
"""
printd('Finding latest version among {}...'.format(versions))
versions = tolist(copy.deepcopy(versions))
# Pre-sanitization makes find_latest_version faster: 0.15 vs 0.18 s.
versions = [sanitize_version_number(version) for version in versions]
v = versions.pop()
while len(versions):
vt = versions.pop()
if compare_version(v, vt) < 0:
v = vt
printd('Found latest version = {}'.format(v))
return v
[docs]def check_installed_packages(requirements=OMFITsrc + '/../install/requirements.txt'):
"""
Check version of required OMFIT packages
:param requirements: path to the requirements.txt file
:return: summary dictionary
"""
wrong_version_text = 'python_version >' if sys.version_info < (3, 0) else 'python_version <'
with open(requirements, 'r') as f:
lines = [_f for _f in [x.split('#')[0].strip() for x in f.readlines()] if _f and wrong_version_text not in _f]
lines = [line.split(';')[0].strip() for line in lines]
# map package names with their python import name
import yaml
dependency_file = OMFITsrc + '/../install/omfit_dependencies.yaml'
with open(dependency_file, 'r') as f:
deps_doc = yaml.load(f, Loader=yaml.Loader)
package_mapper = {}
for k, item in enumerate(deps_doc['dependencies']):
name = item['name']
if item.get('pip', None) and item['pip'].get('name', None):
name = item['pip']['name']
if 'import_name' in item and item['import_name']:
package_mapper[name] = item['import_name']
summary = {}
bad = {}
for line in lines:
try:
package, tmp, required_version = re.match('([\w\d-]+)([=\>\<,]+)(.*)', line).groups()
required_version = tmp + required_version
except AttributeError:
printe('Error getting package version from `%s`' % line)
continue
package = package_mapper.get(package, package)
summary[package] = {}
summary[package]['required_version'] = required_version
try:
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=DeprecationWarning)
exec('import ' + package)
except ImportError as e:
summary[package]['installed_version'] = None
summary[package]['state'] = None
bad[package] = summary[package]
continue
if not (hasattr(eval(package), '__version__') or hasattr(eval(package), 'version')):
summary[package]['installed_version'] = ''
summary[package]['state'] = True
else:
if hasattr(eval(package), '__version__'):
installed_version = eval(package + '.__version__')
elif hasattr(eval(package), 'version'):
installed_version = eval(package + '.version')
if ' ' in installed_version:
installed_version = installed_version.split(' ')[0]
summary[package]['installed_version'] = installed_version
state = version_conditions_checker(installed_version, required_version)
summary[package]['state'] = state
if not state:
bad[package] = summary[package]
return {'summary': summary, 'bad': bad}
[docs]def version_conditions_checker(version, conditions):
"""
Check that a given version passes all version conditions
:param version: version to check
:param conditions: conditions to be met (multiple conditions are separated by comma)
:return: True if all conditions are satisfied, otherwise False
"""
condition = conditions
if ',' in conditions:
bit = True
for condition in conditions.split(','):
bit = bit and version_conditions_checker(version, condition)
return bit
equality, required_version = re.match('([!=\>\<]+)(.*)', condition).groups()
state = compare_version(version, required_version)
if equality in ['==', '='] and state != 0:
return False
elif equality == '!=' and state == 0:
return False
elif equality == '<=' and state > 0:
return False
elif equality == '>=' and state < 0:
return False
elif equality == '>' and state <= 0:
return False
elif equality == '<' and state >= 0:
return False
return True
[docs]def summarize_installed_packages(required=True, optional=True, verbose=True):
"""
:param required: report on required packages
:param optional: report on optional packages
:param verbose: print to console
:return: status code and text
"""
from warnings import filterwarnings
filterwarnings('ignore')
pline = ' {note:1} {package:22} : {installed_version:15} {required_version:15}'
phead = {'note': '', 'package': 'package', 'installed_version': 'install', 'required_version': 'compare'}
plist = []
if optional:
plist += [('OPTIONAL', 'optional.txt')]
if required:
plist += [('REQUIRED', 'requirements.txt')]
_print = {' ': printi, '-': printw, '?': printw, 'x': printe}
rc = 0
txt = []
for ptype, pfile in plist:
txt.append('\nStatus of ' + ptype + ' Python packages:\n')
if verbose:
print(txt[-1])
txt.append(pline.format(**phead))
if verbose:
print(txt[-1])
res = check_installed_packages(os.path.join(OMFITsrc, '..', 'install', pfile))['summary']
for p in sorted(res, key=lambda k: k.lower()):
if res[p]['state']:
note = ' '
elif res[p]['state'] is False:
if res[p]['installed_version'] == '':
note = '?'
else:
note = '-'
elif ptype == 'REQUIRED':
note = 'x'
rc += 1
else:
note = '-'
if res[p]['installed_version'] is None:
res[p]['installed_version'] = ''
txt.append(pline.format(note=note, package=p, **res[p]))
if verbose:
_print[note](txt[-1])
if verbose:
print('')
return rc, '\n'.join(txt)
[docs]def installed_packages_summary_as_text(installed_packages):
"""
Return a formatted string of the dictionary returned by check_installed_packages()
:param installed_packages: dictionary generated by check_installed_packages()
:return: text representation of the check_installed_packages() dictionary
"""
text = []
for package in sorted(installed_packages['summary']):
text.append(
'{package:22} {installed_version:15} {required_version:15}{bad}'.format(
package=package.ljust(max([len(x) + 1 for x in list(installed_packages['summary'].keys())])),
bad=['', ' (!)'][package in installed_packages['bad']],
**installed_packages['summary'][package],
)
)
return '\n'.join(text)
# The current session temporary directory
OMFITsessionDir = OMFITtmpDir + os.sep + 'OMFIT_' + now("%Y-%m-%d_%H_%M_%S_%f")
[docs]def purge_omfit_temporary_files():
"""
Utility function to purge OMFIT temporary files
"""
from omfit_classes.utils_base import OMFITtmpDir, size_of_dir, sizeof_fmt
from omfit_classes.startup_framework import OMFITglobaltmpDir
def ask(question):
return input('\n' + question + ' y/[n] ') in ['y', 'Y']
def print_single_folder(folder, size=None, sub=False):
if size is None:
size = size_of_dir(folder) if os.path.exists(folder) else 0
_print = printi if size < 1024**3 else printw
_print(' \_' if sub else ' \___', '{size:12} {folder}'.format(size=sizeof_fmt(size, separator=' '), folder=folder))
def print_multi_folder(folder):
empties = 0
print_single_folder(folder)
for path, dirs, files in os.walk(folder):
for subdir in dirs:
subfolder = os.path.join(path, subdir)
subsize = size_of_dir(subfolder)
if subsize > 0:
print_single_folder(subfolder, sub=True)
else:
empties += 1
break
if empties:
print_single_folder('%d empty subfolders' % empties, size=0, sub=True)
def rm_dir(d):
if os.path.exists(d):
try:
shutil.rmtree(d)
printe('Deleted:', d)
except Exception as e:
printe('Error:', repr(e))
# main folders
folders = [OMFITtmpDir]
if OMFITglobaltmpDir != OMFITtmpDir:
folders.append(OMFITglobaltmpDir)
# initial summary
print('\nList and size of OMFIT temporary folders:\n')
for folder in folders:
print_multi_folder(folder)
# specific logic
if ask('Delete OMFIT_TMP in its entirety?'):
rm_dir(OMFITtmpDir)
else:
if ask('Delete OMFIT_TMP/OMFIT_2* session subfolders?'):
for path, dirs, files in os.walk(OMFITtmpDir):
for subdir in dirs:
if subdir.startswith('OMFIT_2'):
rm_dir(os.path.join(path, subdir))
break
if OMFITglobaltmpDir != OMFITtmpDir:
if ask('Delete OMFIT_GLOBAL_TMP in its entirety?'):
rm_dir(OMFITglobaltmpDir)
else:
subfolder = os.path.join(OMFITglobaltmpDir, 'runs')
if os.path.exists(subfolder) and ask('Delete OMFIT_GLOBAL_TMP/runs simulation subfolder?'):
rm_dir(subfolder)
# final summary
print('\nResult:\n')
for folder in folders:
print_multi_folder(folder)
print()
# quiet cleanup
if os.path.exists(OMFITsessionDir):
shutil.rmtree(OMFITsessionDir)