#!/usr/bin/python3

import os
import numpy as np
import nibabel as nib
from argparse import ArgumentParser
import vnmrjpy as vj

"""
Niview : image viewer for VnmrJ data
====================================
"""

class Niview():
    """Lightweight viewer for Vnmrj data files in various formats

    Can read and show:
        Nifit1 : .nii / .nii.gz
        Vnmrj image file :  .fdf / .img
        Vnmrj raw data file:  .fid

    Intended to work from the command line.
    """
    def __init__(self, nifti=None, img=None,to_anatomical=True,\
                fid=None, showkspace=False, ch=None, addch=True,\
                absolute=True, magn=False, phase=False,\
                imag=False, real=False):  # fid options


        if nifti is not None:
            self.img = nib.load(nifti)
            self.data = self.img.get_fdata()
            self.affine = self.img.affine

        elif img is not None:

            varr = vj.core.read_fdf(img)
            #if to_anatomical:
            #    varr.to_anatomical()

            self.data = varr.data
            self.affine = varr.nifti_header.get_qform()
        elif fid is not None:
            # reading fid gets more options:
            varr = vj.core.read_fid(fid)
            varr.to_kspace()
            if to_anatomical:
                varr.to_anatomical()
            self.affine = varr.nifti_header.get_qform()

            # show kspace
            if showkspace == True:
            
                kspace = varr.data

                if ch == None:
                    # concatenate channels along time dimension
                    rcvrlst = [kspace[...,i] for i in range(kspace.shape[4])]
                    data = np.concatenate(rcvrlst,axis=3)
                else:
                    data = kspace[...,ch]

                if absolute == True:
                    self.data = np.absolute(data)

            # show image space
            else:
                data5d = varr.to_imagespace().data

                # plot absolute values
                if absolute == True:
                    if ch == None:
                        if addch ==True:
                            self.data = vj.core.recon.ssos(varr.data)
                        elif addch == False:
                            # concatenate channels along time dimension
                            rcvrlst = [np.absolute(data5d[...,i]) for i \
                                                    in range(data5d.shape[4])]
                            self.data = np.concatenate(rcvrlst,axis=3)
                    else:
                        raise(Exception('not implemented'))
                # plot phase values
                elif phase == True:
                    if ch == None:
                        # concatenate channels along time dimension
                        rcvrlst = [np.arctan2(np.imag(data5d[...,i]),\
                            np.real(data5d[...,i])) for i in \
                                    range(data5d.shape[4])]
                        self.data = np.concatenate(rcvrlst,axis=3)
                    else:
                        raise(Exception('not implemented'))

        else:
            print('Possibly wrong input. Quitting ...')
            return
        #finding dimension with 1 length

    def interactiveplot(self):
        """Using nibabel built-in viewer


        dim 1 is flipped because natural y axis is up->down
        """
        # default to cubic affine, this should be close
        try:
            affine = self.affine 
        except:
            print('Backing to default cubic affine')
            affine = vj.util.make_cubic_affine(self.data)
        v = nib.viewers.OrthoSlicer3D(self.data, affine=affine)
        print('Plotting...')
        v.show()

if __name__ == '__main__':

    description="""Command-line script to process and plot .fid, .fdf, .nii
                for quick checks."""

    parser = ArgumentParser(description=description)
    parser.add_argument('nifti')
    parser.add_argument('--phase',action='store_true')
    parser.add_argument('--magn',action='store_true')
    parser.add_argument('--kspace',action='store_true')
    args = parser.parse_args()
    if str(args.nifti).endswith('.nii') or \
        str(args.nifti).endswith('.nii.gz'):
        print('Reading .nii file ...')
  
        nv = Niview(nifti=str(args.nifti),to_scanner=False)
        nv.interactiveplot()

    elif str(args.nifti).endswith('.img') and \
        os.path.isdir(str(args.nifti)):
    
        print('Reading .fdf images in .img direcory ...')
        nv = Niview(img=str(args.nifti))
        nv.interactiveplot()

    elif str(args.nifti).endswith('.fid') and\
        os.path.isdir(str(args.nifti)):
        print('Reading .fid file in .fid direcory ...')
        if args.kspace:
            if args.phase:
                nv = Niview(fid=str(args.nifti),phase=args.phase,\
                    absolute=False,showkspace=True)
                nv.interactiveplot()
            else:
                nv = Niview(fid=str(args.nifti),absolute=True,showkspace=True)
                nv.interactiveplot()
        else:
            if args.phase:
                nv = Niview(fid=str(args.nifti),phase=args.phase,\
                    absolute=False)
                nv.interactiveplot()
            else:
                nv = Niview(fid=str(args.nifti),absolute=True)
                nv.interactiveplot()
    else:
        print('Wrong input: not a .nii/.fid/.img file or dir')
