#!python
import os
import logging
import argparse
import time

import torch

def main(args):
    data = torch.load(args.check, map_location=torch.device('cpu'))
    model = data['state_dict']
    model = {k[6:] : v for k, v in model.items()}
    torch.save(model, 'model.pt')

def file_path(path : str) -> str:
    """
    Type check of argparse
    """
    if os.path.isfile(path):
        return path
    raise FileNotFoundError(path)

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Extract the model from a given checkpoint')
    
    parser.add_argument('--verbose', '-v', dest='verbose', action='store_true')
    parser.add_argument('check', type=file_path, help='Checkpoint path')

    args = parser.parse_args()
    if args.verbose is True:
        logging.basicConfig(level=logging.INFO)
    start = time.time()
    logging.info('Start.')
    out = main(args)
    end = time.time()
    logging.info('Finished. Time: {}'.format(end-start))
