#!python

'''
     Usage:  
        j2sqlgen --config <configfile> --source <schemafile> --table <tablename>
        j2sqlgen --config <configfile> --source <schemafile> --tables <comma_delim_tbl_list>
        j2sqlgen --config <configfile> --source <schemafile> --list_tables

'''

import os, sys
import json
from snap import common
import docopt


BIGQUERY_STRING_TYPES = ['STRING']

TABLE_DDL_TEMPLATE = '''
CREATE TABLE {schema}.{tablename} (
{table_body}
);
'''

TABLE_LEVEL_SETTINGS = [
    'autocreate_pk',
    'pk_name',
    'pk_type', 
    'autocreated_pk_default',
    'varchar_length',
    'column_type_map'
    'column_name_map'
]

COLUMN_LEVEL_SETTINGS = [
    'varchar_length'
    'column_type_map'
    'column_name_map'
]


def is_string_type(type_name):
    if type_name in BIGQUERY_STRING_TYPES:
        return True
    return False




class DuplicateColumn(Exception):
    def __init__(self, column_name):
        Exception.__init__(self, 'TableSpec already contains column "%s".' % column_name)


class ColumnSpec(object):
    def __init__(self, column_name, column_type, is_nullable=True, is_pk=False, **kwargs):
        self.name = column_name
        self.datatype = column_type
        self.nullable = is_nullable
        self.length = kwargs.get('length', '') # can be null; usually only applies to string types
        self.is_primary_key = is_pk
        self.default_value = kwargs.get('default')
            
    def sql(self):
        null_clause = ''
        if not self.nullable:
            null_clause = 'NOT NULL'
        
        if self.length:
            length_clause = '(%s)' % self.length
        else:
            length_clause = ''

        if self.default_value is None:
            field_line = '{name} {datatype}{length} {null}'.format(name=self.name,
                                                                    datatype=self.datatype,
                                                                    length=length_clause,
                                                                    null=null_clause)
        else:
            print('generating field line for %s with default clause...' % self.name)
            field_line = '{name} {datatype}{length} {null} DEFAULT {default}'.format(name=self.name,
                                                                    datatype=self.datatype,
                                                                    length=length_clause,
                                                                    null=null_clause,
                                                                    default=self.default_value)                                                                                                            
        return field_line.lstrip().rstrip()                                                                   


class TableSpec(object):
    def __init__(self, table_name, schema_name):
        self.name = table_name
        self.schema = schema_name
        self.columns = []


    def has_pk(self):
        for columnspec in self.columns:
            if columnspec.is_primary_key:
                return True
        return False


    def generate_pk_column(self, table_creation_context):
        return ColumnSpec(table_creation_context.pk_name,
                          table_creation_context.pk_type,
                          False,
                          True,
                          default=table_creation_context.pk_value)


    def insert_pk(self, columnspec):
        self.columns.insert(0, columnspec)


    def add_columnspec(self, columnspec):
        self.columns.append(columnspec)


    def add_column(self, column_name, column_type, creation_context, **kwargs):
        column_type = creation_context.get_target_column_type(self.name, column_name, column_type)
        column_name = creation_context.get_target_column_name(self.name, column_name)

        if self.get_column(column_name):
            raise DuplicateColumn(column_name)

        nullable = kwargs.get('is_nullable', True)
        is_pk = False        
        self.columns.append(ColumnSpec(column_name, column_type, nullable, is_pk, **kwargs))


    def get_column(self, column_name):
        for c in self.columns:
            if c.name == column_name:
                return c
        return None


    def remove_column(self, column_name):
        c = self.get_column(column_name)
        if c:
            self.columns.remove(c)
            return True
        return False


class TableCreationContext(object):
    def __init__(self, yaml_config):

        defaults = yaml_config['defaults']
        self.pk_type = None
        self.pk_name = None
        self.pk_value = defaults.get('pk_value')
        self.autocreate_pk = defaults['autocreate_pk_if_missing']
        if self.autocreate_pk == True:
            self.pk_type = defaults['pk_type']
            self.pk_name = defaults['pk_name']

        self.default_varchar_length = defaults['varchar_length']             
        self.type_map = defaults['column_type_map']

        self.overrides = yaml_config['tables']


    def map_column_type(self, tablename, source_typename, target_typename):
        self.type_map[source_typename] = target_typename


    def get_target_column_type(self, table_name, column_name, source_column_type):        
        return self.type_map.get(source_column_type, source_column_type)


    def get_target_column_name(self, table_name, column_name):
        name = column_name
        if not self.overrides:
            return name

        if self.overrides.get(table_name):
            if self.overrides[table_name].get('column_name_map'):
                name = self.overrides[table_name]['column_name_map'].get(column_name, column_name)

        return name


    def get_varchar_length(self, table_name, column_name):

        varchar_length = self.default_varchar_length
        if not self.overrides:
            return varchar_length

        if self.overrides.get(table_name):
            varchar_length = self.overrides[table_name].get('varchar_length', varchar_length)
            
            if self.overrides[table_name].get('column_settings'):
                if self.overrides[table_name]['column_settings'].get(column_name):
                    varchar_length = self.overrides[table_name]['column_settings'][column_name].get('length', varchar_length)

        return varchar_length


def create_tablespec_from_json_config(tablename, json_config, default_schema, creation_context):
    tspec = TableSpec(tablename, json_config.get('schema_name', default_schema))
    for column_config in json_config['columns']:
        column_name = column_config['column_name']

        column_type = creation_context.get_target_column_type(tablename,
                                                              column_name,
                                                              column_config['column_type'])
        settings = {}

        if is_string_type(column_config['column_type']):            
            settings['length'] = creation_context.get_varchar_length(tablename, column_name)

        tspec.add_column(column_name, 
                         column_type,
                         creation_context,
                         **settings)

    if not tspec.has_pk() and creation_context.autocreate_pk:
        primary_key_col = tspec.generate_pk_column(creation_context)
        tspec.insert_pk(primary_key_col)

    return tspec


def generate_sql(tablespec):
    
    field_lines = []
    for column in tablespec.columns:
        field_lines.append(column.sql())
        
    return TABLE_DDL_TEMPLATE.format(schema=tablespec.schema,
                                     tablename=tablespec.name,
                                     table_body=',\n'.join(field_lines))

    
def find_table_config(tablename, json_dbschema):
    for entry in json_dbschema['tables']:
        if entry['table_name'] == tablename:
            return entry
    return None


def main(args):

    configfile = args['<configfile>']
    yaml_config = common.read_config_file(configfile)
    
    json_dbschema = None
    schema_filename = args['<schemafile>']
    with open(schema_filename) as f:
        json_dbschema = json.loads(f.read())

    tablenames = []
    project_schema = json_dbschema['schema_name']
    for entry in json_dbschema['tables']:
        tablenames.append(entry['table_name'])

    if args.get('--list_tables'):
        print('\n'.join(tablenames))
        return

    table_creation_context = TableCreationContext(yaml_config)    

    if args.get('--table'):
        table_name = args['<tablename>']
        tablecfg = find_table_config(table_name, json_dbschema)
        if not tablecfg:
            print('No table "%s" defined in schema file %s.' % (table_name, schema_filename))
            return
        tablespec = create_tablespec_from_json_config(table_name,
                                                      tablecfg,
                                                      project_schema,
                                                      table_creation_context)

        print(generate_sql(tablespec))
        return       

if __name__ == '__main__':
    args = docopt.docopt(__doc__)
    main(args)
