#!python3.6
import argparse
import pandas as pd
import scanpy as sc
import os
def CStreet_parser():
    cstreet_parser = argparse.ArgumentParser(prog='CStreet', usage='%(prog)s [-h] <-i ExpMatrix1 ExpMatrix2 ExpMatrix3 ...> [-s CellStates1 CellStates2 CellStates3 ...] [-n ProjectName] [-o OutputDir] [options]',
                                            formatter_class=argparse.RawDescriptionHelpFormatter,
                                            description='CStreet is a cell states trajectory inference method for time-series single-cell RNA-seq data.')
    cstreet_parser.add_argument('-v',
                                '--version',
                                action='version',
                                version='CStreet 1.0.2',
                                help='''show version number of CStreet and exit''')

    cstreet_parser.add_argument('-i',
                                '--Input_ExpMatrix',
                                nargs='+',
                                action='store',
                                type=str,
                                help='''This indicates expression matrixes, which contain the time-series expression level as read counts or normalized values in tab-delimited format.  
                                 For example: '-i ExpressionMatrix_t1.txt ExpressionMatrix_t2.txt ExpressionMatrix_t3.txt' indicates the input of 3 timepoint expression matrixes.'''
                                )

    cstreet_parser.add_argument('-T',
                                '--Input_CellinCol',
                                action='store_true',
                                dest='row_cell_flag',
                                help='''This determines whether the expression level of one cell is displayed as a column in the expression matrixes. 
                                For example, '-T' indicates that the gene expression levels of one cell are displayed in a column and the expression levels of one gene across all cells are displayed in a row. '''
                                )

    cstreet_parser.add_argument('-o',
                                '--Output_Dir',
                                action='store',
                                type=str,
                                default="./",
                                help='''The project directory, which is used to save all output files. 
                                DEFAULT: "./".'''
                                )
    
    cstreet_parser.add_argument('-n',
                                '--Output_Name',
                                action='store',
                                type=str,
                                default="CStreet",
                                help='''The project name, which is used to generate output file names as a prefix.
                                DEFAULT: "CStreet"'''
                                )

    cstreet_parser.add_argument('-s',
                                '--Input_CellStates',
                                nargs='*',
                                action='store',
                                type=str,
                                help='''An optional parameter that uses CStreet's built-in dimensionality reduction and clustering methods to perform clustering without knowing the cell states (DEFAULT) or accepts the user’s input. 
                                The input files should contain the cell state information sharing the same cell ID in the expression matrixes in tab-delimited format.'''
                                )

    cstreet_parser.add_argument('--CellClusterParam_PCAn',
                                action='store',
                                type=int,
                                default=10,
                                help='''The number of principal components to use, which is enabled only if the cell state information is not provided. 
                                It can be set from 1 to the minimum dimension size of the expression matrixes. 
                                DEFAULT: 10.'''
                                )
    
    cstreet_parser.add_argument('--CellClusterParam_k',
                                action='store',
                                type=int,
                                default=15,
                                help='''The number of nearest neighbors to be searched, which is enabled only if the cell state information is not provided. 
                                It should be in the range of 2 to 100 in general. 
                                DEFAULT: 15. '''
                                )

    cstreet_parser.add_argument('--CellClusterParam_Resolution',
                                action='store',
                                type=float,
                                default=0.1,
                                help='''The resolution of the Louvain algorithm, which is enabled only if the cell state information is not provided. 
                                A higher resolution means that more and smaller clusters are found. 
                                DEFAULT: 0.1. '''
                                )

    cstreet_parser.add_argument('--Switch_DeadCellFilter',
                                action='store',
                                type=str,
                                choices=["ON","OFF"],
                                default="ON",
                                help='''The switch for the dead cell filter, which filters cell outliers based on the count percent of mitochondrial genes. 
                                DEFAULT: "ON".'''
                                )

    cstreet_parser.add_argument('--Threshold_MitoPercent',
                                action='store',
                                type=float,
                                default=0.2,
                                help='''The maximum count percent of mitochondrial genes needed for a cell to pass filtering, which is enabled only if '--Switch_DeadCellFilter' is "ON". 
                                DEFAULT: 0.2.'''
                                )

    cstreet_parser.add_argument('--Switch_LowCellNumGeneFilter',
                                action='store',
                                type=str,
                                choices=["ON","OFF"],
                                default="ON",
                                help='''The switch for the low cell number gene filter, which retains genes that are expressed in at least a certain number of cells. 
                                DEFAULT: "ON".'''
                                )

    cstreet_parser.add_argument('--Threshold_LowCellNum',
                                action='store',
                                type=int,
                                default=3,
                                help='''The minimum number of cells expressed that is required for a gene to pass filtering, which is enabled only if '-- Switch_LowCellNumGeneFilter' is "ON". 
                                DEFAULT: 3.'''
                                )

    cstreet_parser.add_argument('--Switch_LowGeneCellsFilter',
                                action='store',
                                type=str,
                                choices=["ON","OFF"],
                                default="ON",
                                help='''The switch for the low gene number cell filter, which retains cells with at least a certain number of genes expressed. 
                                DEFAULT: "ON".'''
                                )

    cstreet_parser.add_argument('--Threshold_LowGeneNum',
                                action='store',
                                type=int,
                                default=200,
                                help='''The minimum number of genes expressed that is required for a cell to pass filtering, which is enabled only if '-- Switch_LowGeneCellsFilter' is "ON". 
                                DEFAULT: 200.'''
                                )

    cstreet_parser.add_argument('--Switch_Normalize',
                                action='store',
                                type=str,
                                choices=["ON","OFF"],
                                default="ON",
                                help='''The switch to enable total read count normalization for each cell. 
                                DEFAULT: "NO".'''
                                )

    cstreet_parser.add_argument('--Threshold_NormalizeBase',
                                action='store',
                                type=float,
                                default=1000000,
                                help='''The base of normalization, which is enabled only if '--Switch_Normalize' is "ON". If the DEFAULT is chosen, it is CPM normalization. 
                                DEFAULT: 1e6.'''
                                )

    cstreet_parser.add_argument('--Switch_LogTransform',
                                action='store',
                                type=str,
                                choices=["ON","OFF"],
                                default="ON",
                                help='''The switch to logarithmize the expression matrix. 
                                DEFAULT: "NO".'''
                                )
    cstreet_parser.add_argument('--KNNParam_metric',
                                action='store',
                                type=str,
                                default="euclidean",
                                help='''The distance metric to use for kNN. It can be set to "euclidean" or "correlation".
                                DEFAULT: "euclidean".'''
                                )

    cstreet_parser.add_argument('--WithinTimePointParam_PCAn',
                                action='store',
                                type=int,
                                default=10,
                                help='''The number of principal components to use within a timepoint. It can be set from 1 to the minimum dimension size of the expression matrixes. 
                                DEFAULT: 10.'''
                                )

    cstreet_parser.add_argument('--WithinTimePointParam_k',
                                action='store',
                                type=int,
                                default=15,
                                help='''The number of nearest neighbors to be searched within a timepoint. It should be in the range of 2 to 100 in general. 
                                DEFAULT: 15.'''
                                )

    cstreet_parser.add_argument('--BetweenTimePointParam_PCAn',
                                action='store',
                                type=int,
                                default=10,
                                help='''The number of principal components to use between timepoints. It can be set from 1 to the minimum dimension size of the expression matrixes. 
                                DEFAULT: 10. '''
                                )

    cstreet_parser.add_argument('--BetweenTimePointParam_k',
                                action='store',
                                type=int,
                                default=15,
                                help='''The number of nearest neighbors to be searched between timepoints. It should be in the range of 2 to 100 in general. 
                                DEFAULT: 15.'''
                                )
    cstreet_parser.add_argument('--ProbParam_SamplingSize',
                                action='store',
                                type=int,
                                default=5,
                                help='''The number of repeated sampling trials used to estimate the connection probability. 
                                DEFAULT: 5.
                                '''
                                )

    cstreet_parser.add_argument('--ProbParam_RandomSeed',
                                action='store',
                                type=int,
                                default=0,
                                help='''The size of the resulting figure. For example: '--FigureParam_FigureSize 6 7' means width is 6 and height is 7. 
                                DEFAULT: 6 7 .)'''
                                )

    cstreet_parser.add_argument('--FigureParam_FigureSize',
                                action='store',
                                nargs=2,
                                type=float,
                                default=(6,7),
                                help="Figure size of the result figure. For example: '--FigureParam_FigureSize 6 7' indicates width is 6 and height is 7. DEFAULT: 6 7 "
                                )

    cstreet_parser.add_argument('--FigureParam_LabelBoxWidth',
                                action='store',
                                type=int,
                                default=10,
                                help='''The width of the label box in the result figure. For example, '--FigureParam_LabelBoxWidth 10' means that 10 characters will be shown in the label box of the resulting figure. 
                                DEFAULT: 10.'''
                                )

    cstreet_parser.add_argument('--Threshold_MinProbability',
                                action='store',
                                type=float,
                                default=-1,
                                help='''The minimum probability of edge for each cell state that is displayed, which will only be used for visualization. It can be a number between 0 and 1 or "OTSU". When OTSU is selected, it will be automatically estimated using OTSU's Method (OTSU, 1979).
                                DEFAULT: "OTSU".'''
                                )

    cstreet_parser.add_argument('--Threshold_MaxOutDegree',
                                action='store',
                                type=int,
                                default=10,
                                help='''The maximum number of outdegrees for each cell state that is displayed, which will only be used for visualization. 
                                DEFAULT: 10.'''
                                )

    cstreet_parser.add_argument('--Threshold_MinCellNumofStates',
                                action='store',
                                type=int,
                                default=0,
                                help='''The minimum cell number of each cell state that is displayed, which will only be used for visualization. 
                                DEFAULT: 0.'''
                                )

    return cstreet_parser
    

if __name__ == '__main__':

    cstreet_parser=CStreet_parser()
    args = cstreet_parser.parse_args()
    if args.Input_ExpMatrix is None:
        cstreet_parser.print_help()
        os._exit(0)
    if (args.Input_CellStates is not None) and len(args.Input_CellStates)!=len(args.Input_ExpMatrix):
        raise ValueError("Length of Input_CellStates does not match length of Input_ExpMatrix. ")    

    if len(args.Input_ExpMatrix)<2:
        raise ValueError("2 timepoint data at least. ")

    scdata_list=[]
    cell_name_list=[]
    for path in args.Input_ExpMatrix:
        
        if str(path).split(".")[-1]=="txt":
            data=pd.read_csv(path,header=0, sep="\t",index_col=0)
            # cell_name_list.append(data.index)
            if args.row_cell_flag :
                cell_name_list.append(data.columns)
                scdata_list.append(data.T)
            else:
                cell_name_list.append(data.index)
                scdata_list.append(data)
        elif str(path).split(".")[-1]=="h5ad":
            data=sc.read_h5ad(path).to_df()
            cell_name_list.append(data.index)
            scdata_list.append(data)
        else:
            raise ValueError("Just accept '.txt' or '.h5ad' foramt.")

    cluster_list=[]
    if args.Input_CellStates is not None:
        for i in range(len(args.Input_CellStates)):
            path=args.Input_CellStates[i]
            data=pd.read_csv(path,header=None, sep="\t",index_col=0)
            data.columns=["state"]
            try:
                states=data.loc[cell_name_list[i],"state"].to_list()    
            except KeyError as e:
                raise ValueError("The cell state information must share the same cell ID in the expression matrixes!")

            cluster_list.append(states)



    from cstreet import *
    
    cdata=CStreetData()
    for i in range(len(scdata_list)):
        if i<len(cluster_list):
            cdata.add_new_timepoint_scdata(scdata_list[i],cluster_list[i])
        else:
            cdata.add_new_timepoint_scdata(scdata_list[i])

    #Step0:basic params#
    cdata.params.Output_Dir=args.Output_Dir
    cdata.params.Output_Name=args.Output_Name

    #Step1:cell_cluster# 
    cdata.params.CellClusterParam_PCAn=args.CellClusterParam_PCAn
    cdata.params.CellClusterParam_k=args.CellClusterParam_k
    cdata.params.CellClusterParam_Resolution=args.CellClusterParam_Resolution

    #Step2:gene and cell filter#
    cdata.params.Switch_DeadCellFilter=True if args.Switch_DeadCellFilter=="ON" else False
    cdata.params.Threshold_MitoPercent=args.Threshold_MitoPercent
    cdata.params.Switch_LowCellNumGeneFilter=True if args.Switch_LowCellNumGeneFilter=="ON" else False
    cdata.params.Threshold_LowCellNum=args.Threshold_LowCellNum
    cdata.params.Switch_LowGeneCellsFilter=True if args.Switch_LowGeneCellsFilter=="ON" else False
    cdata.params.Threshold_LowGeneNum=args.Threshold_LowGeneNum

    #Step3:normalize#
    cdata.params.Switch_Normalize=True if args.Switch_Normalize=="ON" else False
    cdata.params.Threshold_NormalizeBase=args.Threshold_NormalizeBase
    cdata.params.Switch_LogTransform=True if args.Switch_LogTransform=="ON" else False

    #Step4:get_graph#
    cdata.params.KNNParam_metric=args.KNNParam_metric
    cdata.params.WithinTimePointParam_PCAn=args.WithinTimePointParam_PCAn
    cdata.params.WithinTimePointParam_k=args.WithinTimePointParam_k
    cdata.params.BetweenTimePointParam_PCAn=args.BetweenTimePointParam_PCAn
    cdata.params.BetweenTimePointParam_k=args.BetweenTimePointParam_k

    #Step5: calculate probability
    cdata.params.ProbParam_SamplingSize=args.ProbParam_SamplingSize
    cdata.params.ProbParam_RandomSeed=args.ProbParam_RandomSeed
    #Step6:plot graph#
    cdata.params.FigureParam_FigureSize=tuple([args.FigureParam_FigureSize[0],args.FigureParam_FigureSize[1]])
    cdata.params.FigureParam_LabelBoxWidth=args.FigureParam_LabelBoxWidth
    cdata.params.Threshold_MaxOutDegree=args.Threshold_MaxOutDegree
    cdata.params.Threshold_MinCellNumofStates=args.Threshold_MinCellNumofStates
    cdata.params.Threshold_MinProbability="OTSU" if args.Threshold_MinProbability == -1 else args.Threshold_MinProbability

    cdata.run_cstreet()




