#!python

'''
     Usage:  
        j2sqlgen --config <configfile> --source <schemafile> [(--schema <db_schema>)] --table <tablename>
        j2sqlgen --config <configfile> --source <schemafile> [(--schema <db_schema>)] --tables <comma_delim_tbl_list>
        j2sqlgen --config <configfile> --source <schemafile> [(--schema <db_schema>)] --all-tables
        j2sqlgen --config <configfile> --source <schemafile> --list_tables

'''

'''
+mdoc+

j2sqlgen (JSON to SQL generator) is designed to generate equivalent tables across different database types.

The (slightly esosteric) use case is: you have a table in (say) MySQL and you want to create the equivalent table
in (say) PostgreSQL. Because the supported column types differ slightly, simply applying the DDL from the 
source DB to the target DB is usually not an option. 

j2sqlgen is driven by a YAML configfile

+mdoc+
'''

import os, sys
import json
from snap import common
import docopt


BIGQUERY_STRING_TYPES = ['STRING']

TABLE_DDL_TEMPLATE = '''
CREATE TABLE {schema}.{tablename} (
{table_body}
) {table_suffix}
'''

TABLE_LEVEL_SETTINGS = [
    'autocreate_pk',
    'pk_name',
    'pk_type', 
    'autocreated_pk_default',
    'varchar_length',
    'column_type_map'
    'column_name_map'
]

COLUMN_LEVEL_SETTINGS = [
    'varchar_length'
    'column_type_map'
    'column_name_map'
]


def is_string_type(type_name):
    if type_name in BIGQUERY_STRING_TYPES:
        return True
    return False


class DuplicateColumn(Exception):
    def __init__(self, column_name):
        Exception.__init__(self, 'TableSpec already contains column "%s".' % column_name)


class ColumnSpec(object):
    def __init__(self, table_name, column_name, column_type, is_nullable=True, is_pk=False, **kwargs):
        self.tablename = table_name
        self.name = column_name
        self.datatype = column_type
        self.nullable = is_nullable
        self.length = kwargs.get('length', '') # can be null; usually only applies to string types
        self.is_primary_key = is_pk
        self.default_value = kwargs.get('default')        
            
    def sql(self, table_creation_context):
        null_clause = ''
        if not self.nullable:
            null_clause = 'NOT NULL'
        
        if self.length:
            length_clause = '(%s)' % self.length
        else:
            length_clause = ''

        src_colname = self.name
        src_coltype = self.datatype
        if self.default_value is None:
            template = '{name} {datatype}{length} {null} {column_suffix}'
            
            field_line = template.format(name=table_creation_context.get_target_column_name(self.tablename,
                                                                                            self.name),
                                         datatype=table_creation_context.get_target_column_type(self.tablename,
                                                                                                self.name,
                                                                                                self.datatype),
                                         length=length_clause,
                                         null=null_clause,
                                         column_suffix=table_creation_context.get_column_suffix(self.tablename, self.name))
        else:
            template = '{name} {datatype}{length} {null} DEFAULT {default} {column_suffix}'
            field_line = template.format(name=table_creation_context.get_target_column_name(self.tablename,
                                                                                            self.name),
                                        datatype=table_creation_context.get_target_column_type(self.tablename,
                                                                                               self.datatype),
                                        length=length_clause,
                                        null=null_clause,
                                        default=self.default_value,
                                        column_suffix=table_creation_context.get_column_suffix(self.tablename, self.name))
        if self.is_primary_key:
            field_line = field_line + ' PRIMARY KEY'                                                                                                            
        return field_line.lstrip().rstrip()                                                    


class TableSpec(object):
    def __init__(self, table_name, schema_name, **kwargs):
        self.name = table_name
        self.schema = schema_name
        self.columns = []


    def has_pk(self):
        for columnspec in self.columns:
            if columnspec.is_primary_key:
                return True
        return False


    def generate_pk_column(self, table_creation_context):
        return ColumnSpec(self.name,
                          table_creation_context.pk_name,
                          table_creation_context.pk_type,
                          False,
                          True,
                          default=table_creation_context.pk_value)


    def insert_pk(self, columnspec):
        self.columns.insert(0, columnspec)


    def add_columnspec(self, columnspec):
        self.columns.append(columnspec)


    def add_column(self, source_column_name, source_column_type, creation_context, **kwargs):
        target_column_type = creation_context.get_target_column_type(self.name, source_column_name, source_column_type)
        target_column_name = creation_context.get_target_column_name(self.name, source_column_name)

        if self.get_column(target_column_name):
            raise DuplicateColumn(target_column_name)

        nullable = kwargs.get('is_nullable', True)
        is_pk = False        
        self.columns.append(ColumnSpec(self.name, target_column_name, target_column_type, nullable, is_pk, **kwargs))


    def get_column(self, column_name):
        for c in self.columns:
            if c.name == column_name:
                return c
        return None


    def remove_column(self, column_name):
        c = self.get_column(column_name)
        if c:
            self.columns.remove(c)
            return True
        return False


class TableCreationContext(object):
    def __init__(self, yaml_config):

        self.defaults = yaml_config['defaults']
        self.pk_type = None
        self.pk_name = None
        self.pk_value = self.defaults.get('pk_value')
        self.default_varchar_length = self.defaults['varchar_length']
        self.type_map = {}         
        self.type_map.update(self.defaults['column_type_map'] or {})

        self.overrides = yaml_config.get('tables', {})
        autocreate_pk = self.defaults.get('autocreate_pk_if_missing', False)
        if autocreate_pk == True:            
            self.pk_type = self.defaults['pk_type']
            self.pk_name = self.defaults['pk_name']

    def get_table_suffix(self, table_name):
        if not self.overrides:
            return self.defaults.get('table_suffix', '')
        if not self.overrides.get(table_name):
            return self.defaults.get('table_suffix', '')
        return self.overrides[table_name].get('table_suffix', '')

    def get_column_suffix(self, table_name, column_name):
        if not self.overrides:
            return self.defaults.get('column_suffix', '')
        if not self.overrides.get(table_name):
            return self.defaults.get('column_suffix', '')

        suffix = self.overrides[table_name].get('column_suffix', '')

        if self.overrides[table_name].get('column_settings') and self.overrides[table_name]['column_settings'].get(column_name):            
            suffix = suffix or self.overrides[table_name]['column_settings'][column_name].get('column_suffix')

        return suffix

    def get_mapped_table_name(self, table_name):
        if not self.overrides:
            return table_name
        if not self.overrides.get(table_name):
            return table_name

        return self.overrides[table_name].get('rename_to', table_name)     


    def should_autocreate_pk(self, table_name): 
        setting = self.defaults['autocreate_pk_if_missing']
        if not self.overrides:
            return setting
        
        table_level_settings = self.overrides.get(table_name)
        
        if not table_level_settings:            
            return setting

        if table_level_settings.get('autocreate_pk_if_missing') is not None:
            setting = table_level_settings['autocreate_pk_if_missing']

        return setting


    def map_column_type(self, tablename, source_typename, target_typename):
        self.type_map[source_typename] = target_typename


    def get_target_column_type(self, table_name, column_name, source_column_type):        
        return self.type_map.get(source_column_type, source_column_type)


    def get_target_column_name(self, table_name, column_name):
        name = column_name
        if not self.overrides:
            return name
        if not self.overrides.get(table_name):
            return name
        if self.overrides[table_name].get('column_name_map'):
            name = self.overrides[table_name]['column_name_map'].get(column_name, column_name)

        return name


    def get_varchar_length(self, table_name, column_name):

        varchar_length = self.default_varchar_length
        if not self.overrides:
            return varchar_length

        if self.overrides.get(table_name):
            varchar_length = self.overrides[table_name].get('varchar_length', varchar_length)
            
            if self.overrides[table_name].get('column_settings'):
                if self.overrides[table_name]['column_settings'].get(column_name):
                    varchar_length = self.overrides[table_name]['column_settings'][column_name].get('length', varchar_length)

        return varchar_length


def create_tablespec_from_json_config(tablename, json_config, dbschema, creation_context, **kwargs):
    source_tablename = tablename
    target_tablename = creation_context.get_mapped_table_name(source_tablename)
    tspec = TableSpec(source_tablename, dbschema)
    
    for column_config in json_config['columns']:
        source_column_name = column_config['column_name']        
        source_column_type = column_config['column_type']

        settings = {}
        if is_string_type(column_config['column_type']):
            settings['length'] = creation_context.get_varchar_length(source_tablename, source_column_name)
        
        tspec.add_column(source_column_name,
                         source_column_type,
                         creation_context,
                         **settings)

    if not tspec.has_pk() and creation_context.should_autocreate_pk(source_tablename):
        primary_key_col = tspec.generate_pk_column(creation_context)
        tspec.insert_pk(primary_key_col)

    return tspec


def generate_sql(tablespec, table_creation_context):
    
    field_lines = []
    for column in tablespec.columns:
        field_lines.append(column.sql(table_creation_context))
        
    ddl_stmt = TABLE_DDL_TEMPLATE.format(schema=tablespec.schema,
                                     tablename=table_creation_context.get_mapped_table_name(tablespec.name),
                                     table_body=',\n'.join(field_lines),
                                     table_suffix=table_creation_context.get_table_suffix(table_creation_context.get_mapped_table_name(tablespec.name)))

    return ddl_stmt.strip() + ';'

    
def find_table_config(tablename, json_dbschema):
    for entry in json_dbschema['tables']:
        if entry['table_name'] == tablename:
            return entry
    return None


def main(args):

    configfile = args['<configfile>']
    yaml_config = common.read_config_file(configfile)
    
    json_dbschema = None
    schema_filename = args['<schemafile>']
    with open(schema_filename) as f:
        json_dbschema = json.loads(f.read())

    tablenames = []
    project_schema = json_dbschema['schema_name']
    for entry in json_dbschema['tables']:
        tablenames.append(entry['table_name'])

    if args.get('--list_tables'):
        print('\n'.join(tablenames))
        return

    
    if args.get('--schema'):
        project_schema = args['<db_schema>']
    table_creation_context = TableCreationContext(yaml_config)    

    if args.get('--table'):
        table_name = args['<tablename>']
        tablecfg = find_table_config(table_name, json_dbschema)
        if not tablecfg:
            print('No table "%s" defined in schema file %s.' % (table_name, schema_filename))
            return
        tablespec = create_tablespec_from_json_config(table_name,
                                                      tablecfg,
                                                      project_schema,
                                                      table_creation_context)

        print(generate_sql(tablespec, table_creation_context))
        return 
    elif args.get('--tables'):
        list_string = args['<comma_delim_tbl_list>']
        selected_table_names = [t.lstrip().rstrip() for t in list_string.split(',')]
        for table_name in selected_table_names:
            tablecfg = find_table_config(table_name, json_dbschema)
            if not tablecfg:
                print('No table "%s" defined in schema file %s.' % (table_name, schema_filename))
                return
            tablespec = create_tablespec_from_json_config(table_name,
                                                      tablecfg,
                                                      project_schema,
                                                      table_creation_context)

            print(generate_sql(tablespec, table_creation_context))
            print('\n')                
    elif args.get('--all-tables'):
        # generate all the tables specified in the metadata file
        for table_name in tablenames:
            tablecfg = find_table_config(table_name, json_dbschema)
            if not tablecfg:
                print('No table "%s" defined in schema file %s.' % (table_name, schema_filename), file=sys.stderr)
                exit(1)
            tablespec = create_tablespec_from_json_config(table_name,
                                                        tablecfg,
                                                        project_schema,
                                                        table_creation_context)

            print(generate_sql(tablespec, table_creation_context))
            print('\n')
        return 


if __name__ == '__main__':
    args = docopt.docopt(__doc__)
    main(args)
