from omfit_classes.utils_base import *
from omfit_classes.utils_base import _available_to_user_math, _available_to_user_util, _available_to_user_plot
import numpy as np
import pickle
import functools
import ast
import filecmp
import types
import shutil
import copy
import difflib
import re
import pprint
import warnings
import traceback
from collections.abc import MutableMapping
__all__ = [
'hide_ptrn',
'private_ptrn',
'comment_ptrn',
'comment_ptrn_in_brackets',
'number_ptrn',
'sortHuman',
'get_bases',
'parseBuildLocation',
'parseLocation',
'traverseLocation',
'buildLocation',
'setLocation',
'dirbaseLocation',
'traverse',
'treeLocation',
'recursiveUpdate',
'pretty_diff',
'prune_mask',
'dynaLoad',
'dynaLoadKey',
'dynaSave',
'dynaLoader',
'dynaSaver',
'SortedDict',
'OMFITdataset',
'pickle',
'size_tree_objects',
'sorted_join_lists',
]
# Useful patterns
hide_ptrn = re.compile(r'^__.*__$')
private_ptrn = re.compile(r'^_.*[^_]+$')
comment_ptrn = re.compile(r'^__comment.*__$')
comment_ptrn_in_brackets = re.compile(r'''.*\[['"]__comment.*__['"]\].*''')
number_ptrn = re.compile(r"[-+]?\d*\.?\d+[eEdD][-+\d]+|[-+\d]+\.\d+|\d+")
_special1 = []
[docs]@_available_to_user_util
def sortHuman(inStr):
"""Sort the given list the way that humans expect"""
outStr = str(inStr).lower()
tmp = re.findall(number_ptrn, outStr)
outStr = re.sub(number_ptrn, '\1', outStr)
for kn in tmp:
kn = re.sub(r'[dD]', 'e', kn)
try:
outStr = re.subn('\1', format(float(kn), "+040.16f"), outStr, 1)[0]
except ValueError:
pass
outStr = re.sub(r'-', 'm', outStr)
outStr = re.sub(r'\+', 'p', outStr)
return outStr
def _insort(a, x, caseInsensitive=True):
lo = 0
hi = len(a)
while lo < hi:
mid = (lo + hi) // 2
if caseInsensitive and str(x).lower() < str(a[mid]).lower():
hi = mid
elif not caseInsensitive and str(x) < str(a[mid]):
hi = mid
else:
lo = mid + 1
a.insert(lo, x)
[docs]def get_bases(clss, tp=[]):
"Returns a list of strings describing the dependencies of a class"
if tp == []:
tp = [clss.__name__]
bases = getattr(clss, '__bases__')
if not len(bases):
return tp
else:
for item in bases:
tp.append(item.__name__)
get_bases(item, tp)
return tp
@_available_to_user_util
def different(a, b, precision=0.0):
"""
Evaluates if two objects are different
:param a: first object to compare
:param b: second object to compare
:param precision: relative precision to which objects are compared
:return: integer to indicate equal (0) or different (1)
"""
if isinstance_str(a, ['OMFITexpression', 'OMFITiterableExpression']) and isinstance_str(
b, ['OMFITexpression', 'OMFITiterableExpression']
):
if a.expression != b.expression:
return 1
elif a.__class__ != b.__class__:
return 1
elif hasattr(a, 'filename') and hasattr(b, 'filename'):
if not os.path.exists(a.filename) or not os.path.exists(b.filename):
return 1
elif not filecmp.cmp(a.filename, b.filename):
return 1
elif not filecmp.cmp(a.filename, b.filename, shallow=False):
return 1
elif os.path.split(a.filename)[1] != os.path.split(b.filename)[1]:
return 1
return 0
else:
try:
np.testing.assert_equal(a, b)
return 0
except Exception:
if precision == 0.0:
return 1
try:
np.testing.assert_allclose(a, b, rtol=precision)
return 0
except Exception:
return 1
return 0
[docs]@_available_to_user_util
def sorted_join_lists(a, b, favor_order_of_a=False, case_insensitive=False):
"""
Join two lists in a way that minimizes the distance between them and the merged list
:param a: first list
:param b: second list
:param favor_order_of_a: favor order of list `a` over order of list `b`
:param case_insensitive: merge list in a case-insensitive way
:return: merged list
"""
a = list(map(repr, a))
b = list(map(repr, b))
# use difflib.Differ (which operates on strings) to find out differences
if favor_order_of_a:
keys = list(difflib.Differ(linejunk=None).compare(b, a))
else:
keys = list(difflib.Differ(linejunk=None).compare(a, b))
# NOTE: change of order are done by removing (-) and adding (+) an entry
keys = [ast.literal_eval(k[2:]) for k in keys if k[2:] and k[0] != '?' and (k[0] != '-' or (k[0] == '-' and '+' + k[1:] not in keys))]
# unique keys, keep the ordering, allow caseInsensitive
if case_insensitive:
tmp = [_f.lower() if isinstance(_f, str) else _f for _f in keys]
keys = [_f for k, _f in enumerate(keys) if (_f.lower() if isinstance(_f, str) else _f) not in tmp[:k]]
else:
keys = [_f for k, _f in enumerate(keys) if _f not in keys[:k]]
return keys
[docs]@_available_to_user_util
def parseLocation(inv):
"""
Parse string representation of the dictionary path and return list including root name
This function can parse things like: OMFIT['asd'].attributes[u'aiy' ]["[ 'bla']['asa']"][3][1:5]
:param inv: string representation of the dictionary path
:return: list of dictionary keys including rootname
"""
# look for matching dictionary blocks with matching quotes
inv = inv.strip()
quote = False
starts_at = 0
splits = []
k = 0
char = None
head = ['']
while k < len(inv):
if inv[k] not in ' \t':
new_char = inv[k]
if new_char == '[' and not quote:
if inv[k + 1] in ["'", '"']:
quote = inv[k + 1]
elif inv[k + 1] in 'bru' and inv[k + 2] in ["'", '"']:
quote = inv[k + 2]
else:
quote = True
starts_at = k + 1
if inv[k] == ']' and (quote is True or last_char == quote):
quote = False
splits.append(inv[starts_at:k])
if not quote and inv[k] not in ' \t[]':
splits = []
head = [inv[: k + 1]]
if inv[k] not in ' \t':
last_char = inv[k]
k = k + 1
# if quote was not closed then the parentheses do not match
if quote:
raise SyntaxError('Unbalanced parentheses in ' + inv)
# eval splits
for k, item in enumerate(splits):
try:
splits[k] = ast.literal_eval(item)
except SyntaxError:
if ':' in item:
splits[k] = item
else:
raise
return head + splits
[docs]@_available_to_user_util
def buildLocation(inv):
"""
Assemble list of keys into dictionary path string
:param inv: list of dictionary keys including rootname
:return: string representation of the dictionary path
"""
tmp = inv[0]
for item in inv[1:]:
if isinstance(item, str):
if ':' in item and not re.findall('[a-zA-Z]', item):
tmp += '[' + item + ']'
else:
tmp += '[' + repr(item) + ']'
else:
tmp += '[' + repr(item) + ']'
return tmp
[docs]@_available_to_user_util
def dirbaseLocation(location):
"""
Takes a string or a list of stirings output by parseLocation()
and returns two strings for convenient setting of dictionary locations
>> d, b=dirbaseLocation("OMFIT['dir']['base']")
>> eval(d)[b]
d = OMFIT['dir']
b = 'base'
:param location: string or a list of stirings output by parseLocation()
:return: two string, the first one with the path leading to the leaf, the second with the name of the leaf
"""
if not isinstance(location, list):
location = parseLocation(location)
return buildLocation(location[:-1]), location[-1]
[docs]@_available_to_user_util
def setLocation(location, value, globals=None, locals=None):
"""
Takes a string or a list of stirings output by parseLocation()
and set the leaf to the value provided
:param location: string or a list of stirings output by parseLocation()
:param value: value to set the leaf
:param globals: global namespace for the evaluation of the location
:param locals: local namespace for the evaluation of the location
:return: value
"""
d, b = dirbaseLocation(location)
eval(d, globals, locals)[b] = value
return value
[docs]def parseBuildLocation(inv):
"""
DEPRECATED: use `parseLocation` and `buildLocation` functions instead
Function to handle locations in the OMFIT tree (i.e. python dictionaries)
:param inv: input location
:return:
* if `inv` is a string, then the dictionary path is split and a list is returned (Note that this function strips the root name)
* if it's a list, then the dictionary path is assembled and a string is returned (Note that this function assumes that the root name is missing)
"""
if isinstance(inv, str):
return parseLocation(inv)[1:]
elif isinstance(inv, list):
return buildLocation([''] + inv)
else:
raise ValueError('parseBuildLocation accepts either a string or a list')
[docs]@_available_to_user_util
def traverseLocation(inv):
"""
returns list of locations to reach input location
:param inv: string representation of the dictionary path
:return: list of locations including rootname to reach input location
"""
tmp = parseLocation(inv)
return [buildLocation(tmp[:k]) for k in range(1, len(tmp) + 1)]
[docs]@_available_to_user_util
def traverse(
self,
string='',
level=100,
split=True,
onlyDict=False,
onlyLeaf=False,
skipDynaLoad=False,
noSubmodules=False,
traverse_classes=(MutableMapping,),
):
"""
Returns a string or list of strings describing the path of every entry/subentries in the dictionary
:param string: string to be appended in front of all entries
:param level: maximum depth
:param split: split the output string into a list of strings
:param onlyDict: return only dictionary entries (can be a tuple of classes)
:param onlyLeaf: return only non-dictionary entries (can be a tuple of classes)
:param skipDynaLoad: skip entries that have .dynaLoad==True
:param noSubmodules: controls whether to traverse submodules or not
:param traverse_classes: tuple of classes to traverse
:return: string or list of string
"""
string_in = string
string_out = ''
if isinstance(self, (MutableMapping,)):
keys = list(self.keys())
elif isinstance(self, (list, tuple)):
keys = list(range(len(self)))
else:
keys = []
for kid in keys:
kidName = "[" + repr(kid) + "]"
string = string_in + kidName
# skip also expressions when skipDynaLoad
if skipDynaLoad and isinstance_str(self[kid], ['OMFITexpression', 'OMFITiterableExpression']):
continue
# mention this entry according to `onlyDict` and `onlyLeaf` filters
if (
(not onlyDict and not onlyLeaf)
or (onlyDict is True and isinstance(self[kid], traverse_classes))
or (isinstance(onlyDict, tuple) and isinstance(self[kid], onlyDict))
or (onlyLeaf is True and not isinstance(self[kid], traverse_classes))
or (isinstance(onlyLeaf, tuple) and isinstance(self[kid], onlyLeaf))
):
string_out += string + '\n'
# do not go deeper if skipDynaLoad and the file has not been loaded
try:
if skipDynaLoad and hasattr(self[kid], 'dynaLoad') and self[kid].dynaLoad:
continue
except RecursionError:
printe(f'Error recursing {string}')
continue
# go deeper
if noSubmodules:
from omfit_classes.omfit_base import OMFITmodule
if (
isinstance(self[kid], traverse_classes)
and (not isinstance(onlyDict, tuple) or isinstance(self[kid], onlyDict))
and level > 0
and len(self[kid])
and not (noSubmodules and isinstance(self[kid], OMFITmodule))
):
level -= 1
string_out += traverse(self[kid], string, level, False, onlyDict, onlyLeaf, skipDynaLoad, noSubmodules, traverse_classes)
level += 1
if split:
return string_out.strip().strip('\n').split('\n') if string_out else []
else:
return string_out
[docs]def treeLocation(obj, memo=None):
"""
Identifies location in the OMFIT tree of an OMFIT object
NOTE: Typical users should not need to use this function as part of their modules.
If you find yourself using this function in your modules, it is likely that OMFIT
already provides the functionality that you are looking for in some other way.
We recommend reaching out the OMFIT developers team to see if there is an easy
way to get what you want.
:param obj: object in the OMFIT tree
:param memo: used internally to avoid infinite recursions
:return: string with tree location
"""
if hasattr(obj, '_OMFITcopyOf') and obj._OMFITcopyOf is not None:
obj = obj._OMFITcopyOf()
_nil = []
if memo is None:
memo = {}
y = memo.get(id(obj), _nil)
if y is not _nil:
return y
if not hasattr(obj, '_OMFITparent'):
try:
obj._OMFITkeyName = ''
obj._OMFITparent = None
except AttributeError:
# this is for objects which do not accept ._OMFITparent (e.g. int, None, float,...)
# these are only leafs in the tree and do not need their location in the tree
# to function anyways.
return None
if obj._OMFITparent is None:
# this is when to treat the head node
tmp = [obj._OMFITkeyName]
else:
# this is for all of the middle nodes
tmp = treeLocation(obj._OMFITparent)
if tmp is not None:
tmp.append(tmp[-1] + obj._OMFITkeyName)
memo[id(obj)] = tmp
_keep_alive(obj, memo)
return tmp
def _keep_alive(x, memo):
"""Keeps a reference to the object x in the memo.
Because we remember objects by their id, we have
to assure that possibly temporary objects are kept
alive by referencing them.
We store a reference at the id of the memo, which should
normally not be used unless someone tries to deepcopy
the memo itself...
"""
try:
memo[id(memo)].append(x)
except KeyError:
# aha, this is the first one :-)
memo[id(memo)] = [x]
[docs]@_available_to_user_util
def recursiveUpdate(A, B, overwrite=True, **kw):
"""
Recursive update of dictionary A based on data from dictionary B
:param A: dictionary A
:param B: dictionary B
:param overwrite: force overwrite of duplicates
:return: updated dictionary
"""
# for backward compatibility
if 'overWrite' in kw:
overwrite = kw.pop('overWrite')
if len(kw):
raise TypeError('recursiveUpdate() got an unexpected keyword argument: ' + str(list(kw.keys())))
def f_traverse(A, B):
for kid in list(B.keys()):
if isinstance(B[kid], dict):
if kid not in A:
A[kid] = copy.deepcopy(B[kid])
elif isinstance(A[kid], dict) and isinstance(B[kid], dict):
f_traverse(A[kid], B[kid])
else:
if (overwrite and kid in A) or (kid not in A):
A[kid] = copy.deepcopy(B[kid])
f_traverse(A, B)
return A
[docs]def pretty_diff(d, ptrn={}):
"""
generate "human readable" dictionary output from SortedDict.diff()
"""
for k in list(d[0].keys()):
if isinstance(d[0][k][0], dict):
ptrn[k] = SortedDict()
pretty_diff(d[0][k], ptrn=ptrn[k])
else:
ptrn[k] = d[0][k][0]
return ptrn
[docs]def prune_mask(what, ptrn):
"""
prune dictionary structure based on mask
The mask can be in the form of of a `pretty_diff` dictionary or a list of `traverse` strings
"""
if isinstance(ptrn, dict):
for k in list(what.keys()):
if k not in list(ptrn.keys()):
del what[k]
elif isinstance(what[k], dict) and isinstance(ptrn[k], dict):
prune_mask(what[k], ptrn[k])
return what
elif isinstance(ptrn, (tuple, list)):
ptrn = list(ptrn)
# disregard non-existent paths
for k, item in list(enumerate(ptrn))[::-1]:
try:
eval('what' + item)
except Exception:
ptrn.pop(k)
# expand subtrees
for item in list(ptrn):
if isinstance(eval('what' + item), SortedDict):
ptrn.extend([item + x for x in eval('what' + item).traverse()])
# expand add parents
ptrn = set(ptrn)
for item in list(ptrn):
rootName = []
for level in parseLocation(item)[:-1]:
rootName = rootName + [level]
ptrn.add(buildLocation(rootName))
# do the pruning
for item in what.traverse():
if item not in ptrn:
try:
exec('del what' + item, globals(), locals())
except Exception:
pass
return what
else:
raise Exception('prune_mask: only list/tuple/dict supported')
[docs]def size_tree_objects(location):
"""
Returns file sizes of objects in the dictionary based on the size of their filename attribute
:param location: string of the tree location to be analyzed
:return: dictionary with locations sorted by size
"""
tmp = traverse(eval(location), onlyDict=False, skipDynaLoad=True)
sizes = {}
for item in tmp:
if hasattr(eval(location + item), 'filename'):
try:
obj = eval(location + item)
if not obj.filename:
continue
size = os.stat(obj.filename).st_size
if size not in sizes:
sizes[size] = []
sizes[size].append(location + item)
except Exception as _excp:
printe('Error sizing object %s : %s' % (location + item, repr(_excp)))
return sizes
[docs]def dynaLoad(f):
"""
Decorator which calls `dynaLoader` method
:param f: function to decorate
:return: decorated function
"""
@functools.wraps(f)
def dynamicLoading(self, *args, **kw):
dynaLoader(self, f)
return f(self, *args, **kw)
return dynamicLoading
[docs]def dynaLoadKey(f):
"""
Decorator which calls `dynaLoad` method only if key is not found
:param f: function to decorate
:return: decorated function
"""
@functools.wraps(f)
def dynamicLoading(self, *args, **kw):
if args[0] not in self.keyOrder:
dynaLoader(self, f)
return f(self, *args, **kw)
return dynamicLoading
[docs]def dynaSave(f):
"""
Decorator which calls `dynaSaver` method
:param f: function to decorate
:return: decorated function
"""
def doNothing():
pass
@functools.wraps(f)
def dynamicSaving(self, *args, **kw):
if hasattr_no_dynaLoad(self, 'readOnly') and self.readOnly:
if hasattr_no_dynaLoad(self, 'modifyOriginal'):
if self.modifyOriginal and os.path.exists(self.filename) and os.path.samefile(self.link, self.filename):
return doNothing
elif hasattr_no_dynaLoad(self, '_save_by_copy'):
return self._save_by_copy(**kw)
if dynaSaver(self):
return doNothing
return f(self, *args, **kw)
return dynamicSaving
def _docFromDict(f):
"""
Use the same docstring as for dict
:param f: function to decorate
:return: decorated function
"""
try:
f.__doc__ = getattr(dict, f.__name__).__doc__
except AttributeError:
pass
return f
[docs]def dynaLoader(self, f=None):
"""
Call `load` function if object has `dynaLoad` attribute set to True
After calling `load` function the `dynaLoad` attribute is set to False
"""
if OMFITaux['dynaLoad_switch'] and self.dynaLoad:
self.dynaLoad = False
if f is None or f.__name__ != 'load':
try:
return self.load()
except Exception as _excp_load:
# If an error occurs during loading
# Clear and reset the dynamic load switch to allow re-tries
# note: errors could occur because user stops the process
try:
self.clear()
except Exception as _excp_clear:
# if clear() fails, then its exception should be printed but not raised.
# the user is interested in the original exception raised by load().
printe('The clear() method raised an exception: ' + repr(_excp_clear))
self.dynaLoad = True
raise
[docs]def dynaSaver(self):
"""
This function is meant to be called in the .save() function of objects of the class
`OMFITobject` that support dynamic loading. The idea is that if an object has not
been loaded, then its file representation has not changed and the original file can be resued.
This function returns True/False to say if it was successful at saving.
If True, then the original .save() function can return, otherwise it should go through
saving the data from memory to file.
"""
if self.dynaLoad and hasattr_no_dynaLoad(self, 'link') and hasattr_no_dynaLoad(self, 'filename'):
try:
printd('Dynamic save: ' + self.filename, level=2, topic='save')
if os.path.abspath(self.link) != os.path.abspath(self.filename):
if os.path.exists(self.filename):
if filecmp.cmp(self.link, self.filename, shallow=False):
self.link = self.filename
return True
else:
os.remove(self.filename)
try:
if OMFITaux.setdefault('hardLinks', False):
os.link(self.link, self.filename)
printd('Hard link: %s --> %s' % (self.link, self.filename), level=2, topic='save')
else:
raise Exception('skip')
except Exception:
if os.path.isdir(self.link):
shutil.copytree(self.link, self.filename)
else:
shutil.copy2(self.link, self.filename)
self.link = self.filename
return True
except Exception as _excp:
printe('Error dynamic save: ' + self.filename + '\n' + repr(_excp))
return False
return False
[docs]class SortedDict(dict):
# originally inspired from django/trunk/django/utils/datastructures.py @ 17464
r"""
A dictionary that keeps its keys in the order in which they're inserted
:param data: A dict object or list of (key,value) tuples from which to initialize the new SortedDict object
:param \**kw: Optional keyword arguments given below
kw:
:param caseInsensitive: (bool) If True, allows self['UPPER'] to yield self['upper'].
:param sorted: (bool) If True, keep keys sorted alphabetically, instead of by insertion order.
:param limit: (int) keep only the latest `limit` number of entries (useful for data cashes)
:param dynaLoad: (bool) Not sure what this does
"""
def __new__(cls, *args, **kwargs):
instance = super().__new__(cls, *args, **kwargs)
instance._OMFITkeyName = ''
instance._OMFITparent = None
instance.keyOrder = []
instance.caseInsensitive = kwargs.pop('caseInsensitive', False)
instance.sorted = kwargs.pop('sorted', False)
instance.limit = kwargs.pop('limit', 0)
instance.dynaLoad = False
return instance
def __init__(self, data=None, *args, **kwargs):
self.clear()
self.caseInsensitive = kwargs.pop('caseInsensitive', False)
self.sorted = kwargs.pop('sorted', False)
self.limit = kwargs.pop('limit', 0)
self.dynaLoad = False
if data is None:
data = {}
elif isinstance(data, types.GeneratorType):
# Unfortunately we need to be able to read a generator twice. Once
# to get the data into self with our super().__init__ call and a
# second time to setup keyOrder correctly
data = list(data)
if isinstance(data, dict):
for key in list(data.keys()):
self[key] = data[key]
else:
for key, value in data:
self[key] = value
if self.sorted:
self.sort()
self.dynaLoad = kwargs.pop('dynaLoad', False)
# keep track of what classes have been loaded
from omfit_classes.utils_base import _loaded_classes
_loaded_classes.add(self.__class__.__name__)
def __getattr__(self, attr):
if attr.startswith('_OMFIT') or attr.startswith('OMFIT'):
raise AttributeError('bad attribute `%s`' % attr)
if (
OMFITaux['dynaLoad_switch']
and self.dynaLoad
and attr not in ['__save_kw__', '__tree_repr__', 'modifyOriginal', 'readOnly']
and 'getattr_infiniteloop_block' not in self.__dict__
):
try:
if os.environ['USER'] == 'meneghini':
print('%s dynaloading because %s attribute was requested' % (self.__class__.__name__, attr), file=sys.__stderr__)
traceback.print_stack(file=sys.__stderr__)
self.__dict__['getattr_infiniteloop_block'] = True
dynaLoader(self)
return getattr(self, attr)
finally:
del self.__dict__['getattr_infiniteloop_block']
raise AttributeError('bad attribute `%s`' % attr)
def _setLocation(self, key, value):
if not hasattr(self, '_OMFITparent') or value is not self._OMFITparent:
# check if the parent is the OMFIT tree
inOMFITtree = False
tmp = self
while tmp != None and hasattr(tmp, '_OMFITparent'):
if tmp._OMFITparent is None and tmp._OMFITkeyName == 'OMFIT':
inOMFITtree = True
break
tmp = tmp._OMFITparent
if not (key in self and id(self[key]) == id(value)) and inOMFITtree:
if isinstance_str(value, ['OMFITexpression', 'OMFITiterableExpression']):
value = copy.deepcopy(value)
try:
value._OMFITkeyName = "[" + repr(key) + "]"
value._OMFITparent = self
value._OMFITcopyOf = None # this copy goes directly into the tree, so we can set it to None
except (AttributeError, TypeError):
pass
return value
@_docFromDict
@dynaLoad
def __len__(self):
return super().__len__()
@_docFromDict
@dynaLoad
def __hash__(self):
return ''.join(list(self.keys())).__hash__()
@_docFromDict
@dynaLoad
def __setitem__(self, key, value):
key, value = self._checkSetitem(key, value)
if isinstance_str(key, ['OMFITexpression', 'OMFITiterableExpression']):
raise ValueError('OMFITexpressions are not valid keys for ' + self.__class__.__name__)
keyL = self._keyCaseInsensitive(key)
if keyL != key:
tmp = self.index(keyL)
del self[keyL]
self.keyOrder.insert(tmp, key)
if key not in self.keyOrder:
if hasattr(self, 'sorted') and self.sorted:
_insort(self.keyOrder, key)
else:
self.keyOrder.append(key)
elif self.caseInsensitive:
self.keyOrder[self.index(key)] = key
super().__setitem__(key, self._setLocation(key, value))
# make whatever SortedDict is under a caseInsensitive SortedDict, caseInsensitive itself
if isinstance(self[key], SortedDict) and self.caseInsensitive:
self[key].caseInsensitive = self.caseInsensitive
# if limit>0 limit the number of entries
while self.limit > 0 and len(self.keyOrder) > self.limit:
delkey = self.keyOrder[0]
super().__delitem__(delkey)
self.keyOrder.remove(delkey)
def _checkSetitem(self, key, value):
"""
This method is provided so that subclasses can use it to either:
1) change the key/value tuple as passed to the __setitem__ method
2) raise an error because the key-value pair is not acceptable
:param key: key as passed by the user to the __setitem__ method
:param value: value as passed by the user to the __setitem__ method
:return: updated (key, value) tuple
"""
return key, value
@_docFromDict
@dynaLoad
def __delitem__(self, key):
key = self._keyCaseInsensitive(key)
super().__delitem__(key)
self.keyOrder.remove(key)
# does not need @dynaLoadKey, because functions that call _keyCaseInsensitive already do
def _keyCaseInsensitive(self, key):
if not hasattr(self, 'caseInsensitive'):
self.caseInsensitive = False
if not self.caseInsensitive or not isinstance(key, str):
return key
original_key = key
if original_key in self.keyOrder:
return original_key
original_key_lower = key.lower()
if original_key_lower in self.keyOrder:
return original_key_lower
original_key_upper = key.upper()
if original_key_upper in self.keyOrder:
return original_key_upper
for key in self.keyOrder:
try:
if key.lower() == original_key_lower:
return key
except AttributeError:
pass
return original_key
@_docFromDict
@dynaLoadKey
def __getitem__(self, key):
key = self._keyCaseInsensitive(key)
try:
return super().__getitem__(key)
except KeyError:
# if this instance has a fetch method, then call it and try it __getitem__ again
if hasattr(self, 'fetch'):
self.fetch()
return super().__getitem__(key)
else:
raise
@_docFromDict
@dynaLoadKey
def __contains__(self, key):
return super().__contains__(self._keyCaseInsensitive(key))
def __getstate__(self):
tmp = copy.copy(self.__dict__)
for k in list(tmp.keys()):
if k[:6] == '_OMFIT':
del tmp[k]
return tmp, list(self.values())
def __setstate__(self, tmp):
if isinstance(tmp, dict):
# old way of loading sortedDict
self.__dict__ = tmp
for key in list(self.keys()):
self[key] = self._setLocation(key, self[key])
else:
# new way of loading sortedDict
self.limit = False
self.__dict__.update(tmp[0])
for key, value in zip(self.keyOrder, tmp[1]):
self[key] = self._setLocation(key, value)
[docs] @dynaLoad
def index(self, item):
"""
returns the index of the item
"""
return self.keyOrder.index(self._keyCaseInsensitive(item))
@_docFromDict
@dynaLoad
def __iter__(self):
return iter(self.keyOrder)
[docs] @_docFromDict
@dynaLoad
def pop(self, key, *args):
key = self._keyCaseInsensitive(key)
result = super().pop(key, *args)
try:
self.keyOrder.remove(key)
except ValueError:
# Key wasn't in the dictionary in the first place. No problem.
pass
return result
[docs] @_docFromDict
@dynaLoad
def popitem(self):
result = super().popitem()
self.keyOrder.remove(result[0])
return result
[docs] @_docFromDict
@dynaLoad
def items(self):
return list(zip(self.keyOrder, list(self.values())))
[docs] @_docFromDict
@dynaLoad
def iteritems(self):
for key in self.keyOrder:
yield key, self[key]
[docs] @dynaLoad
def keys(self, filter=None, matching=False):
"""
returns the sorted list of keys in the dictionary
:param filter: regular expression for filtering keys
:param matching: boolean to indicate whether to return the keys that match (or not)
:return: list of keys
"""
if filter is None:
return self.keyOrder[:]
elif not matching:
return [kkk for kkk in self.keyOrder[:] if not re.match(filter, str(kkk))]
else:
return [kkk for kkk in self.keyOrder[:] if re.match(filter, str(kkk))]
[docs] @_docFromDict
@dynaLoad
def iterkeys(self):
return iter(self.keyOrder)
[docs] @_docFromDict
@dynaLoad
def values(self):
return list(map(self.__getitem__, self.keyOrder))
[docs] @_docFromDict
@dynaLoad
def itervalues(self):
for key in self.keyOrder:
yield self[key]
[docs] @_docFromDict
@dynaLoad
def update(self, dict_):
for key, value in list(dict_.items()):
self[key] = self._setLocation(key, value)
[docs] @dynaLoad
def recursiveUpdate(self, other, overwrite=False):
return recursiveUpdate(self, other, overwrite)
[docs] @dynaLoad
def setdefault(self, key, default):
"""
The method setdefault() is similar to get(), but will set dict[key]=default if key is not already in dict
:param key: key to be accessed
:param default: default value if key does not exist
:return: value
"""
if key not in self:
if hasattr(self, 'sorted') and self.sorted:
_insort(self.keyOrder, key)
else:
self.keyOrder.append(key)
self[key] = default
return self[key]
[docs] @_docFromDict
@dynaLoad
def get(self, key, default):
if key not in self:
return default
return self[key]
[docs] @dynaLoad
def value_for_index(self, index):
"""Returns the value of the item at the given zero-based index"""
return self[self.keyOrder[index]]
[docs] @dynaLoad
def insert(self, index, key, value):
"""Inserts the key, value pair before the item with the given index"""
key = self._keyCaseInsensitive(key)
if key in self.keyOrder:
n = self.keyOrder.index(key)
del self.keyOrder[n]
if n < index:
index -= 1
self.keyOrder.insert(index, key)
self[key] = value
[docs] @dynaLoad
def copy(self):
"""Returns a copy of this object"""
obj = self.__class__(self)
obj.keyOrder = self.keyOrder[:]
return obj
@dynaLoad
def __repr__(self):
"""returns the keys in their sorted order"""
return '{%s}' % ', '.join(['%r: %r' % (k, v) for k, v in list(self.items())])
[docs] @_docFromDict
def clear(self):
super().clear()
self.keyOrder = []
[docs] @dynaLoad
def moveUp(self, index):
"""
Shift up in key list the item at a given index
:param index: index to be shifted
:return: None
"""
if index < len(self.keyOrder):
self.keyOrder.insert(index + 1, self.keyOrder.pop(index))
[docs] @dynaLoad
def moveDown(self, index):
"""
Shift down in key list the item at a given index
:param index: index to be shifted
:return: None
"""
if index > 0:
self.keyOrder.insert(index - 1, self.keyOrder.pop(index))
def __repr__(self):
return self.__class__.__name__ + '(' + str(list(self.items())) + ')'
[docs] @dynaLoad
def across(self, what='', sort=False, returnKeys=False):
"""
Aggregate objects across the tree
:param what: string with the regular expression to be cut across
:param sort: sorting of results alphabetically
:param returnKeys: return keys of elements in addition to objects
:return: list of objects or tuple with with objects and keys
>> OMFIT['test']=OMFITtree()
>> for k in range(5):
>> OMFIT['test']['aaa'+str(k)]=OMFITtree()
>> OMFIT['test']['aaa'+str(k)]['var']=k
>> OMFIT['test']['bbb'+str(k)]=-1
>> print(OMFIT['test'].across("['aaa*']['var']"))
"""
location = parseBuildLocation(what)
keys = []
for k in list(self.keys()):
if isinstance(k, str):
if re.match('%r' % location[0], '%r' % k):
keys.append(k)
else:
if re.match('%r' % location[0], '%r' % repr(k)):
keys.append(k)
if len(location) > 1:
what = parseBuildLocation(location[1:])
else:
what = ''
if sort:
index = np.argsort(list(map(float, keys)))
else:
index = list(range(len(keys)))
tmp = []
for k in index:
tmp_ = self[keys[k]]
tmp.append(b2s(eval("tmp_" + what)))
tmp_
if returnKeys:
return tmp, [k for k in np.array(keys)[np.array(index, int)]]
else:
return tmp
[docs] @dynaLoad
def sort(self, key=None, **kw):
r"""
:param key: function that returns a string that is used for sorting or dictionary key whose content is used for sorting
>> tmp=SortedDict()
>> for k in range(5):
>> tmp[k]={}
>> tmp[k]['a']=4-k
>> # by dictionary key
>> tmp.sort(key='a')
>> # or equivalently
>> tmp.sort(key=lambda x:tmp[x]['a'])
:param \**kw: additional keywords passed to the underlying list sort command
:return: sorted keys
"""
if key is None:
self.keyOrder.sort(key=sortHuman, **kw)
elif not callable(key):
self.sort(key=lambda x: self[x][key])
else:
self.keyOrder.sort(key=key, **kw)
return self.keyOrder
[docs] def sort_class(self, class_order=[dict]):
"""
sort items based on their class
:param class_order: list containing order of class
:return: sorted keys
"""
lst = {}
for k in self.keyOrder:
oo = len(class_order)
for o, c in list(enumerate(class_order)):
if isinstance(self[k], c):
oo = o
break
if hasattr(self[k], '__class__'):
for o, c in list(enumerate(class_order)):
if self[k].__class__.__name__ == c.__name__:
oo = o
break
lst.setdefault(oo, []).append(k)
self.keyOrder = []
for k in range(len(class_order) + 1):
if k in lst:
self.keyOrder += lst[k]
return self.keyOrder
[docs] @dynaLoad
def diff(
self,
other,
ignoreComments=False,
ignoreContent=False,
skipClasses=(),
noloadClasses=(),
precision=0.0,
order=True,
favor_my_order=False,
modify_order=False,
quiet=True,
):
"""
Comparison of a SortedDict
:param other: other dictionary to compare to
:param ignoreComments: ignore keys that start and end with "__" (e.g. "__comment__")
:param ignoreContent: ignore content of the objects
:param skipClasses: list of class of objects to ignore
:param noloadClasses: list of class of objects to not load
:param precision: relative precision to which the comparison is carried out
:param order: does the order of the keys matter
:param favor_my_order: favor my order of keys
:param modify_order: update order of input dictionaries based on diff
:param quiet: verbosity of the comparison
:return: comparison dictionary
"""
# todo: should allow taking differences among any type of dictionary, not only sorted dict
# sorted join of self.keys() with other.keys()
keys = sorted_join_lists(self.keys(), other.keys(), favor_my_order, self.caseInsensitive or other.caseInsensitive)
# update order of input dictionaries based on diff
if modify_order:
self.keyOrder = [k for k in keys if k in self.keys()]
other.keyOrder = [k for k in keys if k in other.keys()]
if ignoreComments:
keys = [key for key in keys if not re.match(comment_ptrn, str(key))]
ndiffs = 0.0
switch = SortedDict()
for key in keys:
if not quiet:
printi('Compare: ' + str(key))
if key not in self:
switch[key] = ['added', False]
ndiffs += 1.0
elif key not in other:
switch[key] = ['removed', False]
ndiffs += 1.0
else:
if isinstance(self[key], skipClasses) or isinstance(other[key], skipClasses):
continue
if isinstance(self[key], noloadClasses) or isinstance(other[key], noloadClasses):
if (
isinstance(self[key], SortedDict)
and isinstance(other[key], SortedDict)
and (self[key].dynaLoad or other[key].dynaLoad)
):
switch[key] = ['noLoad', False]
continue
if isinstance(self[key], SortedDict) and isinstance(other[key], SortedDict): # any changes in the subdirs, even order
tmp = self[key].diff(
other[key],
ignoreComments=ignoreComments,
ignoreContent=ignoreContent,
skipClasses=skipClasses,
noloadClasses=noloadClasses,
precision=precision,
order=order,
favor_my_order=favor_my_order,
modify_order=modify_order,
quiet=quiet,
)
if np.sum([len(tmp[0][k]) for k in tmp[0]]) or (
order and len([v for k, v in enumerate(self[key]) if list(other[key].keys())[k] != v])
):
switch[key] = tmp
ndiffs += 1.0
elif not ignoreContent and different(self[key], other[key], precision=precision):
switch[key] = ['changed', False]
ndiffs += 1.0
return [switch, False, keys]
[docs] @dynaLoad
def pretty_diff(
self,
other,
ignoreComments=False,
ignoreContent=False,
skipClasses=(),
noloadClasses=(),
precision=0.0,
order=True,
favor_my_order=False,
quiet=True,
):
"""
Comparison of a SortedDict in human readable format
:param other: other dictionary to compare to
:param ignoreComments: ignore keys that start and end with "__" (e.g. "__comment__")
:param ignoreContent: ignore content of the objects
:param skipClasses: list of class of objects to ignore
:param noloadClasses: list of class of objects to not load
:param precision: relative precision to which the comparison is carried out
:param order: does the order of the keys matter
:param favor_my_order: favor my order of keys
:param quiet: verbosity of the comparison
:return: comparison dictionary in pretty to print format
"""
tmp = self.diff(
other,
ignoreComments=ignoreComments,
ignoreContent=ignoreContent,
skipClasses=skipClasses,
noloadClasses=noloadClasses,
precision=precision,
order=order,
favor_my_order=favor_my_order,
quiet=quiet,
)
return pretty_diff(tmp)
[docs] @dynaLoad
def diffKeys(self, other):
"""
:param other: other dictionary to compare to
:return: floating point to indicate the ratio of keys that are similar
"""
# notice that selfKeys and otherKeys are generated with traverse() and do not include comment_ptrn_in_brackets
selfKeys = set(
[
key.lower().split('(')[0]
for key in traverse(self)
if isinstance(key, str) and not re.match(comment_ptrn_in_brackets, str(key))
]
)
otherKeys = set(
[
key.lower().split('(')[0]
for key in traverse(other)
if isinstance(key, str) and not re.match(comment_ptrn_in_brackets, str(key))
]
)
if len(selfKeys) < len(otherKeys):
keys = selfKeys
else:
keys = otherKeys
if len(keys) < 2:
return 0.0
ndiffs = 0.0
for key in keys:
if len(selfKeys) < len(otherKeys) and key not in otherKeys:
ndiffs += 1.0
elif len(selfKeys) >= len(otherKeys) and key not in selfKeys:
ndiffs += 1.0
lk = len(keys) * 1.0
lE = lk - ndiffs
if lk > 0:
score = lE * 1.0 / lk
else:
score = 1.0
return score
[docs] @dynaLoad
def changeKeysCase(self, case=None, recursive=False):
"""
Change all the keys in the dictionary to be upper/lower case
:param case: 'upper' or 'lower'
:param recursive: apply this recursively
:return: None
"""
if case is None:
return
elif case == 'upper' or case == 'lower':
for kid in list(self.keys()):
tmp = self.pop(kid)
if case == 'upper':
self[kid.upper()] = tmp
elif case == 'lower':
self[kid.lower()] = tmp
for kid in list(self.keys()):
if recursive and isinstance(self[kid], SortedDict):
self[kid].changeKeysCase(case, recursive=True)
[docs] @dynaLoad
def traverse(self, string='', level=100, onlyDict=False, onlyLeaf=False, skipDynaLoad=False):
"""
Equivalent to the `tree` command in UNIX
:param string: prepend this string
:param level: depth
:param onlyDict: only subtrees and not the leafs
:return: list of strings containing the dictionary path to each object
"""
return traverse(self, string, level, split=True, onlyDict=onlyDict, onlyLeaf=onlyLeaf, skipDynaLoad=skipDynaLoad)
[docs] @dynaLoad
def walk(self, function, **kw):
r"""
Walk every member and call a function on the keyword and value
:param function: `function(self,kid,**kw)`
:param \**kw: kw passed to the function
:return: self
"""
for kid in list(self.keys()):
if hasattr(self[kid], 'walk'):
self[kid].walk(function, **kw)
else:
self[kid] = function(self, kid, **kw)
return self
[docs] @dynaLoad
def safe_del(self, key):
"""
Delete key entry only if it is present
:param key: key to be deleted
"""
if key in self:
del self[key]
[docs] @dynaLoad
def flatten(self):
"""
The hierarchical structure of the dictionaries is flattened
:return: SortedDict populated with the flattened content of the dictionary
"""
tmp = SortedDict(caseInsensitive=self.caseInsensitive)
for item in list(self.keys()):
if isinstance(self[item], SortedDict):
tmp.update(self[item].flatten())
else:
tmp[item] = self[item]
return tmp
[docs] @dynaLoad
def setFlat(self, key, value):
"""
recursively searches dictionary for key in order to set a value
raises KeyError if key could not be found, so this method cannot
be used to set new entries in the dictionary
:param key: key to be set
:param value: value to set
"""
if key in self:
self[key] = value
return
else:
for item in list(self.keys()):
if isinstance(self[item], SortedDict):
try:
self[item].setFlat(key, value)
return
except KeyError:
pass
raise KeyError('`%s` could not be found throughout the dictionary' % key)
[docs] @dynaLoad
def check_location(self, location, value=_special1):
"""
check if location exist and equals value (if value is specified)
:param location: location as string
:param value: value for which to return equal
:return: True/False
>> root['SETTINGS'].check_location("['EXPERIMENT']['shot']", 133221)
"""
try:
eval('self' + location)
except KeyError:
return False
if value is not _special1 and evalExpr(eval('self' + location)) != evalExpr(value):
return False
return True
def __popup_menu__(self):
"""
Dummy method to avoid dynamic loading for classes that do not override it
:return: empty list
"""
return []
[docs] def todict(self):
"""
Return a standard Python dictionary representation of the SortedDict
:return: dictionary
"""
flat = self.traverse()
tmp = {}
for item in flat:
value = eval(f'self{item}')
d, b = dirbaseLocation('tmp' + item)
if isinstance(value, dict):
eval(d)[b] = {}
else:
eval(d)[b] = value
return tmp
[docs]class OMFITdataset(object):
"""
Subclassing from this class is like subclassing from the xarray.Dataset class
but without having to deal with the hassle of inheriting from xarrays
(internally this class uses class composition rather than subclassing).
Also this class makes it possible to use the OMFIT dynamic loading capabilities.
All classes that subclass OMFITdataset must define the `.dynaLoad` attribute.
NOTE: Classes that subclass from OMFITdataset will be identified
as an xarray.Dataset when using isinstance(..., xarray.Dataset)
within OMFIT
"""
def __init__(self, data_vars=None, coords=None, attrs=None):
r"""
:param data_vars: see xarray.Dataset
:param coords: see xarray.Dataset
:param attrs: see xarray.Dataset
"""
if self.__class__ == OMFITdataset:
raise Exception('OMFITdataset is a class that can only be used as a subclass')
if not hasattr(self, 'dynaLoad'):
raise Exception(self.__class__.__name__ + ' must set .dynaLoad before calling OMFITdataset.__init__')
with warnings.catch_warnings(record=True) as w:
warnings.filterwarnings('ignore', category=DeprecationWarning)
warnings.filterwarnings('ignore', category=FutureWarning)
from xarray import Dataset
self._dataset = Dataset(data_vars=data_vars, coords=coords, attrs=attrs)
def __getattr__(self, attr):
# doing this callable test first gets pickling and unpickling to work
# it used to get stuck in a recursive loop on "dynaLoad" and just testing for attr=='dynaLoad' did not help
if attr == '__deepcopy__':
raise AttributeError('Raise not implemented so that we do not use the ._dataset method')
with warnings.catch_warnings(record=True) as w:
warnings.filterwarnings('ignore', category=DeprecationWarning)
warnings.filterwarnings('ignore', category=FutureWarning)
from xarray import Dataset
is_callable = callable(getattr(Dataset, attr))
if self.dynaLoad and (not attr.startswith('_') or attr in ['__getitem__', '__setitem__']):
self.load()
self.dynaLoad = False
return getattr(self._dataset, attr)
[docs] @dynaLoad
def to_dataset(self):
"""
Return an xarray.Dataset representation of the data
:return: xarray.Dataset
"""
with warnings.catch_warnings(record=True) as w:
warnings.filterwarnings('ignore', category=DeprecationWarning)
warnings.filterwarnings('ignore', category=FutureWarning)
from xarray import Dataset
return Dataset(self._dataset.data_vars, attrs=self._dataset.attrs)
[docs] @dynaLoad
def from_dataset(self, dataset):
"""
Create from xarray.Dataset representation of the data
"""
with warnings.catch_warnings(record=True) as w:
warnings.filterwarnings('ignore', category=DeprecationWarning)
warnings.filterwarnings('ignore', category=FutureWarning)
from xarray import Dataset
if isinstance(datatset, Dataset):
self._dataset = dataset
else:
printw("Input is not a proper xarray.Dataset!")
@dynaLoad
def __getitem__(self, key):
# getitem is necessary so that the object is subscriptable
return self._dataset.__getitem__(key)
@dynaLoad
def __setitem__(self, key, value):
# setitem is necessary so that the object is assignable
return self._dataset.__setitem__(key, value)
@dynaLoad
def __len__(self):
# len is necessary so that the object is iterable
return self._dataset.__len__()
@dynaLoad
def __iter__(self):
# iter is necessary so that the object is iterable
return self._dataset.__iter__()
@dynaLoad
def __contains__(self, b):
# contains is necessary so that the object is iterable
return self._dataset.__contains__(b)
def __delitem__(self, b):
# delete from data set
del self._dataset[b]
# automatic handle removing of _OMFITxxx attributes when pickling
_dumps = pickle.dumps
@functools.wraps(_dumps)
def _OMFITdumps(x, *args, **kw):
dynaLoadBkp = OMFITaux['dynaLoad_switch']
OMFITaux['dynaLoad_switch'] = False
saveIt = {}
try:
if hasattr(x, '__dict__'):
for k in list(x.__dict__.keys()):
if k[:6] == '_OMFIT':
saveIt[k] = x.__dict__[k]
del x.__dict__[k]
return _dumps(x, *args, **kw)
finally:
for k in list(saveIt.keys()):
x.__dict__[k] = saveIt[k]
OMFITaux['dynaLoad_switch'] = dynaLoadBkp
pickle.dumps = _OMFITdumps
_dump = pickle.dump
@functools.wraps(_dump)
def _OMFITdump(x, *args, **kw):
dynaLoadBkp = OMFITaux['dynaLoad_switch']
OMFITaux['dynaLoad_switch'] = False
saveIt = {}
try:
if hasattr(x, '__dict__'):
for k in list(x.__dict__.keys()):
if k[:6] == '_OMFIT':
saveIt[k] = x.__dict__[k]
del x.__dict__[k]
return _dump(x, *args, **kw)
finally:
for k in list(saveIt.keys()):
x.__dict__[k] = saveIt[k]
OMFITaux['dynaLoad_switch'] = dynaLoadBkp
pickle.dump = _OMFITdump
_load = pickle.load
@functools.wraps(_load)
def _OMFITload(*args, **kw):
dynaLoadBkp = OMFITaux['dynaLoad_switch']
OMFITaux['dynaLoad_switch'] = False
try:
kw.setdefault('encoding', 'latin1')
return _load(*args, **kw)
finally:
OMFITaux['dynaLoad_switch'] = dynaLoadBkp
pickle.load = _OMFITload
_loads = pickle.loads
@functools.wraps(_loads)
def _OMFITloads(*args, **kw):
dynaLoadBkp = OMFITaux['dynaLoad_switch']
OMFITaux['dynaLoad_switch'] = False
try:
kw.setdefault('encoding', 'latin1')
return _loads(*args, **kw)
finally:
OMFITaux['dynaLoad_switch'] = dynaLoadBkp
pickle.loads = _OMFITloads
if __name__ == "__main__":
aa = SortedDict({'foo': 1, 'bar': 2, 'yo': 4})
a = SortedDict({'foo': 1, 'bar': 2, 'yo': 4, 'dd': aa})
bb = SortedDict({'foo': 0, 'foobar': 3, 'yo': 4})
b = SortedDict({'foo': 0, 'foobar': 3, 'yo': 4, 'dd': bb})
diff, switch, keys = a.diff(b)
print(diff)