#!/usr/bin/env python
#    Download Raw Data from a TART telescope
#    The telescope should be in RAW mode if you want it to capture raw data every minute.
#
#    Tim Molteno 2017-2019 - tim@elec.ac.nz

import argparse
import json
import logging
import os
import hashlib
import shutil
import urllib.request, urllib.error, urllib.parse
import time

from tart_tools.api_handler import AuthorizedAPIhandler

logger = logging.getLogger()

def sha256_checksum(filename, block_size=65536):
    sha256 = hashlib.sha256()
    with open(filename, 'rb') as f:
        for block in iter(lambda: f.read(block_size), b''):
            sha256.update(block)
    return sha256.hexdigest()

def download_file(url, checksum=0, file_path=None):
    logger.info("Download_file({}, {}) -> {}".format(url, checksum, file_path))
    
    # Download the file from `url` and save it locally under `file_path`:
    with urllib.request.urlopen(url) as response, open(file_path, 'wb') as out_file:
        shutil.copyfileobj(response, out_file)
    
    if checksum:
        downloaded_checksum = sha256_checksum(file_path)
        if (downloaded_checksum != checksum):
            logger.info("Removing file: Checksum failed\n{}\n{}".format(checksum, downloaded_checksum))
            os.remove(file_path)

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Download data from the telescope', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument('--api', required=False, default='https://tart.elec.ac.nz/signal', help="Telescope API server URL.")
    parser.add_argument('--pw', default='password', required=True, type=str, help='API password')
    parser.add_argument('--dir', type=str, default=".", help='local directory to download')
    parser.add_argument('--raw', action='store_true', help='Download Raw Data in HDF format')
    parser.add_argument('--vis', action='store_true', help='Download Visibility Data in HDF format')

    ARGS = parser.parse_args()

    logger.setLevel(logging.DEBUG)
    # create console handler and set level to debug
    ch = logging.StreamHandler()
    ch.setLevel(logging.INFO)
    # create formatter
    formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
    # add formatter to ch
    ch.setFormatter(formatter)
    # add ch to logger
    logger.addHandler(ch)

    api = AuthorizedAPIhandler(ARGS.api, ARGS.pw)

    logger.info("Downloading Raw Data from {}".format(ARGS.api))
    os.makedirs(ARGS.dir, exist_ok=True)
    
    tart_endpoint = ARGS.api + "/"

    if not (ARGS.raw or ARGS.vis):
        raise RuntimeError("Either --raw or --vis must be specified")
    
    while True:
        resp_vis = []
        resp_raw = []
        if (ARGS.vis):
            resp_vis = api.get('vis/data')
        if (ARGS.raw):
            resp_raw = api.get('raw/data')

        try:
            for entry in resp_raw+resp_vis:
                if 'filename' in entry:

                    data_url = urllib.parse.urljoin(tart_endpoint, entry['filename'])
                    
                    file_name = data_url.split('/')[-1]
                    file_path = os.path.join(ARGS.dir, file_name)

                    if os.path.isfile(file_path):
                        if (sha256_checksum(file_path) == entry['checksum']):
                            logger.info('Skipping {}'.format(file_path))
                        else:
                            logger.info('Corrupted File: {}'.format(file_path))
                            os.remove(file_path)
                            download_file(data_url, entry['checksum'], file_path)
                    else:
                        download_file(data_url, entry['checksum'], file_path)

        except Exception as e:
            logger.exception(e)
        finally:
            logger.info("Pausing.")
            time.sleep(2)

