try:
    # framework is running
    from .startup_choice import *
except ImportError as _excp:
    # class is imported by itself
    if (
        'attempted relative import with no known parent package' in str(_excp)
        or 'No module named \'omfit_classes\'' in str(_excp)
        or "No module named '__main__.startup_choice'" in str(_excp)
    ):
        from startup_choice import *
    else:
        raise
from omfit_classes.omfit_ascii import *
from omfit_classes.omfit_gato import OMFITdskgato
import numpy as np
__all__ = ['OMFITtoqProfiles', 'OMFITdskeqdata']
[docs]class OMFITtoqProfiles(SortedDict, OMFITascii):
    """TOQ profiles data files"""
    def __init__(self, filename, **kw):
        OMFITascii.__init__(self, filename, **kw)
        SortedDict.__init__(self)
        self.dynaLoad = True
[docs]    @dynaLoad
    def load(self):
        with open(self.filename, 'r') as f:
            lines = f.readlines()
        lines = lines[2:]
        for k, line in enumerate(lines):
            if re.findall('[a-z]', lines[k]):
                lines[k] = ''
                continue
            lines[k] = lines[k].rstrip('\n')
        lines = [_f for _f in lines if _f]
        #                       0, 1, 2, 3, 4, 5, 6, 7, 9,10,11,12,13,14,15,16,17,18,19,20,21
        for nn, n in enumerate(
            [
                [0, 8, 8, 8, 8, 8, 8, 8, 7, 7, 7, 8, 7, 8, 8, 8, 8, 8, 7, 7, 8],
                [0, 8, 8, 8, 8, 8, 9, 8, 7, 7, 8, 7, 8, 8, 8, 8, 8, 7, 8, 7, 8],
            ]
        ):
            try:
                tmp = [x[0] + x[1] for x in zip(lines[0::2], lines[1::2])]
                for k, line in enumerate(tmp):
                    line = '  ' + line
                    line = [line[np.sum(n[: i + 1]) : np.sum(n[: i + 2])] for i in range(len(n) - 1)]
                    for k1 in range(len(line)):
                        # print(k1,'='+line[k1]+'=')
                        if '*' in line[k1]:
                            line[k1] = 'nan'
                    line = list(map(float, line))
                    tmp[k] = line
                break
            except Exception:
                if nn == 1:
                    raise
                else:
                    pass
        tmp = np.array(tmp).T
        names = ['ped', 'center', 'top', 'rho_0.95', 'rho_0.90']
        self.clear()
        for k1 in range(5):
            self[names[k1]] = SortedDict()
        for k, item in enumerate(
            [
                'PsiN',
                'a-r(cm)',
                'p(MPa)',
                'Te(keV)',
                'n(13cm-3)',
                'pp(MPa/W)',
                'j.B/B0',
                'beta%',
                'betan',
                'nu*',
                'rho',
                'width/a',
                'q',
                's',
                'alpha',
                'al_cyl',
                'eps',
                'betap',
                'rhot_dt',
                'Ti(kev)',
            ]
        ):
            self[item] = tmp[k, 5:]
            for k1 in range(5):
                self[names[k1]][item] = tmp[k, k1]
        for item in self:
            if isinstance(self[item], np.ndarray):
                self[item] = self[item][::-1] 
[docs]    @dynaLoad
    def plot(self):
        ls = ['o', '.', 's', 'd', 'x']
        for k, item in enumerate([k for k in list(self.keys()) if k != 'PsiN' and isinstance(self[k], np.ndarray)]):
            if k == 0:
                ax = pyplot.subplot(4, 5, k + 1)
            else:
                pyplot.subplot(4, 5, k + 1, sharex=ax)
            pyplot.plot(self['PsiN'], self[item])
            for k1, point_item in enumerate([k2 for k2 in list(self.keys()) if isinstance(self[k2], dict)]):
                pyplot.plot(
                    self[point_item]['PsiN'], self[point_item][item], ls[k1], color=pyplot.gca().lines[-1].get_color(), label=point_item
                )
            title_inside(item, y=0.8)
        autofmt_sharex()
        pyplot.xlim([0, 1])  
[docs]class OMFITdskeqdata(OMFITdskgato):
    r"""
    OMFIT class used to interface to equilibria files generated by TOQ (dskeqdata files)
    :param filename: filename passed to OMFITobject class
    :param \**kw: keyword dictionary passed to OMFITobject class
    """
    def __init__(self, filename='', **kw):
        OMFITascii.__init__(self, filename, **kw)
        SortedDict.__init__(self)
        self.dynaLoad = True
[docs]    @dynaLoad
    def load(self):
        # Parser based on readeqdata.f in TOQ
        with open(self.filename, 'r') as f:
            lines = f.read()
        header, lines = lines.split('&end')
        self['__header__'] = header + '&end'
        lines = re.sub('\n', ' ', lines).split()
        self['NSURF'] = int(lines.pop(0))
        self['NTHT'] = int(lines.pop(0))
        self['PSI'] = np.array([float(lines.pop(0)) for k in range(self['NSURF'])]) / 1e8
        self['PRES'] = np.array([float(lines.pop(0)) for k in range(self['NSURF'])]) / 10.0
        self['PPRIME'] = np.array([float(lines.pop(0)) for k in range(self['NSURF'])]) / 1e-7
        self['FPOL'] = np.array([float(lines.pop(0)) for k in range(self['NSURF'])]) / 1e6
        self['FFPRIM'] = np.array([float(lines.pop(0)) for k in range(self['NSURF'])]) / 1e4
        self['QPSI'] = np.array([float(lines.pop(0)) for k in range(self['NSURF'])])
        self['GROT'] = np.array([float(lines.pop(0)) for k in range(self['NSURF'])])
        self['GROTP'] = np.array([float(lines.pop(0)) for k in range(self['NSURF'])])
        self['RBBBS'] = np.array([float(lines.pop(0)) for k in range(self['NTHT'])]) / 100.0
        self['ZBBBS'] = np.array([float(lines.pop(0)) for k in range(self['NTHT'])]) / 100.0
        # detect symmetric equilibria
        self['SYMMETRIC'] = False
        if self['ZBBBS'][0] == self['ZBBBS'][-1] and self['RBBBS'][0] != self['RBBBS'][-1]:
            self['RBBBS'] = np.hstack((self['RBBBS'], self['RBBBS'][self['NTHT'] - 1 :: -1]))
            self['ZBBBS'] = np.hstack((self['ZBBBS'], -self['ZBBBS'][self['NTHT'] - 1 :: -1]))
            self['SYMMETRIC'] = True
        self['CD'] = np.array([float(lines.pop(0)) for k in range(self['NTHT'])])
        self['BCENTR'] = float(re.findall(r'b vac\(rzero\)\(g\)\s+=\s+[\w\.\+\-]+', self['__header__'])[0].split('=')[-1]) / 4000.0
        self['RCENTR'] = float(re.findall(r'rzero\s+=\s+[\w\.\+\-]+', self['__header__'])[0].split('=')[-1]) / 100.0
        self['RMAXIS'] = float(re.findall(r'r axis \(cm\)\s+=\s+[\w\.\+\-]+', self['__header__'])[0].split('=')[-1]) / 100.0
        self['ZMAXIS'] = 0.0
        self.add_derived() 
[docs]    @dynaSave
    def save(self):
        lines = [self['__header__']]
        lines.append('%d' % self['NSURF'])
        lines.append('%d' % self['NTHT'])
        for item in [
            ('PSI', 1e8),
            ('PRES', 10.0),
            ('PPRIME', 1e-7),
            ('FPOL', 1e6),
            ('FFPRIM', 1e4),
            'QPSI',
            'GROT',
            'GROTP',
            ('RBBBS', 100.0),
            ('ZBBBS', 100.0),
            'CD',
        ]:
            mul = 1.0
            if isinstance(item, tuple):
                item, mul = item
            if self['SYMMETRIC'] and item in ['RBBBS', 'ZBBBS']:
                data = copy.deepcopy(self[item][: int(len(self[item]) // 2)]).tolist()
            else:
                data = copy.deepcopy(self[item]).tolist()
            while len(data):
                lines.append(' '.join(['%16.16f' % (data.pop(0) * mul) for k in range(min([4, len(data)]))]))
        with open(self.filename, 'w') as f:
            f.write('\n'.join(lines))