#!/usr/bin/env python
# -*- coding: utf-8 -*-
# vim: ai ts=4 sts=4 et sw=4
# clamavmirror: ClamAV Signature Mirroring Tool
# Copyright (C) 2015 Andrew Colin Kissa <andrew@topdog.za.net>
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this file,
# You can obtain one at http://mozilla.org/MPL/2.0/.
"""ClamAV Signature Mirroring Tool

Why
---

The existing clamdownloader.pl script does not have any error
correction it simply bails out if a downloaded file is not
valid and is unable to retry different mirrors if one fails.

This script will retry if a download fails with an http code
that is not 404, it will connect to another mirror if retries
fail or file not found or if the downloaded file is invalid.

It has options to set the locations for the working and
mirror directory as well as user/group ownership for the
downloaded files. It uses locking to prevent multiple
instances from running at the same time.

Requirements
------------

UrlGrabber module - http://urlgrabber.baseurl.org/
DNS-Python module - http://www.dnspython.org/

Usage
-----

$ ./clamavmirror.py -h
Usage: clamavmirror.py [options]

Options:
  -h, --help            show this help message and exit
  -a HOSTNAME, --hostname=HOSTNAME
                        ClamAV source server hostname
  -r TXTRECORD, --text-record=TXTRECORD
                        ClamAV Updates TXT record
  -w WORKDIR, --work-directory=WORKDIR
                        Working directory
  -d MIRRORDIR, --mirror-directory=MIRRORDIR
                        The mirror directory
  -u USER, --user=USER  Change file owner to this user
  -g GROUP, --group=GROUP
                        Change file group to this group
  -l LOCKDIR, --locks-directory=LOCKDIR
                        Lock files directory

Example Usage
-------------

./extras/scripts/clamavmirror.py -w ~/tmp/clamavtmp/ \
-d ~/tmp/clamavmirror/ -u andrew -g staff -a db.za.clamav.net \
-l ~/Downloads/
"""
import os
import pwd
import grp
import sys
import time
import fcntl
import hashlib
import logging

import urlgrabber.grabber

from shutil import move
from random import shuffle
from optparse import OptionParser
from subprocess import PIPE, Popen

from urlgrabber.mirror import MirrorGroup
from urlgrabber.progress import text_progress_meter
from urlgrabber.grabber import URLGrabber, URLGrabError
from dns.resolver import query, Resolver, NXDOMAIN, NoAnswer


def setup_logging(loglevel):
    """Setup logging"""
    try:
        # pylint: disable-msg=W0212
        level = logging._levelNames.get(loglevel, None)
        if level is None:
            level = int(loglevel)
        if level < 1:
            raise ValueError()

        formatter = logging.Formatter('%(asctime)s %(message)s')
        handler = logging.StreamHandler(sys.stdout)
        handler.setFormatter(formatter)
        dbobj = logging.getLogger('urlgrabber')
        dbobj.addHandler(handler)
        dbobj.setLevel(level)
    except (KeyError, ImportError, ValueError):
        dbobj = None
    urlgrabber.grabber.set_logger(dbobj)


def get_file_md5(filename):
    """Get a file's MD5"""
    if os.path.exists(filename):
        blocksize = 65536
        hasher = hashlib.md5()
        with open(filename, 'rb') as afile:
            buf = afile.read(blocksize)
            while len(buf) > 0:
                hasher.update(buf)
                buf = afile.read(blocksize)
        return hasher.hexdigest()
    else:
        return ''


def get_md5(string):
    """Get a string's MD5"""
    hasher = hashlib.md5()
    hasher.update(string)
    return hasher.hexdigest()


def chunk_report(bytes_so_far, total_size):
    """Display progress"""
    percent = float(bytes_so_far) / total_size
    percent = round(percent * 100, 2)
    sys.stdout.write("[x] Downloaded %d of %d bytes (%0.2f%%)\r" %
                     (bytes_so_far, total_size, percent))
    if bytes_so_far >= total_size:
        sys.stdout.write('\n')


def chunk_read(response, handle, chunk_size=8192, report_hook=None):
    """Read chunks"""
    total_size = response.info().getheader('Content-Length').strip()
    total_size = int(total_size)
    bytes_so_far = 0
    while 1:
        chunk = response.read(chunk_size)
        handle.write(chunk)
        bytes_so_far += len(chunk)
        if not chunk:
            handle.close()
            break
        if report_hook:
            report_hook(bytes_so_far, total_size)
    return bytes_so_far


def error(msg):
    """print to stderr"""
    print >> sys.stderr, msg


def info(msg):
    """print to stdout"""
    print >> sys.stdout, msg


def deploy_signature(source, dest, user=None, group=None):
    """Deploy a signature fole"""
    move(source, dest)
    os.chmod(dest, 0644)
    if user and group:
        try:
            uid = pwd.getpwnam(user).pw_uid
            gid = grp.getgrnam(group).gr_gid
            os.chown(dest, uid, gid)
        except (KeyError, OSError):
            pass


def create_file(name, content):
    "Generic to write file"
    with open(name, 'w') as writefile:
        writefile.write(content)


def get_name_servers():
    """Get the name servers for clamav.net"""
    answers = query('clamav.net', 'NS')
    ips = []
    for hostname in [rdata.to_text() for rdata in answers]:
        try:
            ipanswers = query(hostname, 'A')
            ips.extend(['nameserver %s' % rdata.address
                        for rdata in ipanswers])
        except NXDOMAIN:
            pass
    shuffle(ips)
    return ips


def get_ip_addresses(hostname):
    """Return ip addresses from hostname"""
    try:
        resolver = Resolver(filename=get_name_servers())
        answers = resolver.query(hostname, 'A')
        return [rdata.address for rdata in answers]
    except NXDOMAIN:
        return []


def get_txt_record(hostname):
    """Get the text record"""
    try:
        resolver = Resolver(filename=get_name_servers())
        answers = resolver.query(hostname, 'TXT')
        return answers[0].strings[0]
    except (IndexError, NXDOMAIN):
        return ''


def get_local_version(sigdir, sig):
    """Get the local version of a signature"""
    version = None
    filename = os.path.join(sigdir, '%s.cvd' % sig)
    if os.path.exists(filename):
        cmd = ['sigtool', '-i', filename]
        sigtool = Popen(cmd, stdout=PIPE, stderr=PIPE)
        while True:
            line = sigtool.stdout.readline()
            if line and line.startswith('Version:'):
                version = line.split()[1]
                break
            if not line:
                break
        sigtool.wait()
    return version


def verify_sigfile(sigdir, sig):
    """Verify a signature file"""
    cmd = ['sigtool', '-i', '%s/%s.cvd' % (sigdir, sig)]
    sigtool = Popen(cmd, stdout=PIPE, stderr=PIPE)
    ret_val = sigtool.wait()
    return ret_val == 0


def failure_cb(obj):
    """Failure callback"""
    error("[-] \033[91mRequest(s) for: %s failed, num of"
          " tries: %s, retries: %s\033[0m" %
          (obj.url, obj.tries, obj.retry))


# pylint: disable-msg=W0613
def check_download(obj, *args, **kwargs):
    """Verify a download"""
    version = args[0]
    workdir = args[1]
    signame = args[2]
    if version:
        local_version = get_local_version(workdir, signame)
        if not verify_sigfile(workdir, signame) or version != local_version:
            error("[-] \033[91mFailed to verify signature: %s from: %s\033[0m"
                  % (signame, obj.url))
            raise URLGrabError(-1)


def download_sig(opts, ips, sig, version=None):
    """Download signature for IP list"""
    code = None
    downloaded = False
    try:
        shuffle(ips)
        if opts.verbose:
            setup_logging('INFO')
        sendargs = (version, opts.workdir, sig)
        useagent = 'ClamAV/0.98.7 (OS: linux-gnu, ARCH: x86_64, CPU: x86_64)'
        grabber = URLGrabber(progress_obj=text_progress_meter(), timeout=2,
                             reget=None, user_agent=useagent,
                             checkfunc=(check_download, sendargs, {}),
                             failure_callback=failure_cb)
        mirrorgrp = MirrorGroup(grabber,
                                ['http://%s' % mirror for mirror in ips])
        if version:
            path = '/%s.cvd' % sig
            filename = os.path.join(opts.workdir, '%s.cvd' % sig)
        else:
            path = '/%s.cdiff' % sig
            filename = os.path.join(opts.workdir, '%s.cdiff' % sig)

        mirrorgrp.urlgrab(path, filename=filename)
        code = 200
        downloaded = os.path.exists(filename)
    except URLGrabError, err:
        code = 404
        error('[Errno %i]: %s' % (err.errno, err.strerror))
    return downloaded, code


def get_addrs(hostname):
    """get addrs"""
    count = 1
    addrs = None
    for passno in range(1, 6):
        count = passno
        info("[+] \033[92mResolving hostname:\033[0m %s pass: %d" %
             (hostname, passno))
        try:
            addrs = get_ip_addresses(hostname)
            if addrs:
                info("=> Resolved to: %s" % ','.join(addrs))
                break
            else:
                info("=> Resolution failed, sleeping 5 secs")
                time.sleep(5)
        except NoAnswer:
            pass
    if not addrs:
        error("=> Resolving hostname: %s failed after %d tries" %
              (hostname, count))
        sys.exit(2)
    return addrs


def get_record(opts):
    """Get record"""
    count = 1
    for passno in range(1, 5):
        count = passno
        info("[+] \033[92mQuerying TXT record:\033[0m %s pass: %s" %
             (opts.txtrecord, passno))
        record = get_txt_record(opts.txtrecord)
        if record:
            info("=> Query returned: %s" % record)
            break
        else:
            info("=> Txt record query failed, sleeping 5 secs")
            time.sleep(5)
    if not record:
        error("=> Txt record query failed after %d tries" % count)
        sys.exit(3)
    return record


def copy_sig(sig, opts, isdiff):
    """Deploy a sig"""
    info("[+] \033[92mDeploying signature:\033[0m %s" % sig)
    if isdiff:
        sourcefile = os.path.join(opts.workdir, '%s.cdiff' % sig)
        destfile = os.path.join(opts.mirrordir, '%s.cdiff' % sig)
    else:
        sourcefile = os.path.join(opts.workdir, '%s.cvd' % sig)
        destfile = os.path.join(opts.mirrordir, '%s.cvd' % sig)
    deploy_signature(sourcefile, destfile, opts.user, opts.group)
    info("=> Deployed signature: %s" % sig)


def update_sig(options, addrs, sign, vers):
    """update signature"""
    info("[+] \033[92mChecking signature version:\033[0m %s" % sign)
    localver = get_local_version(options.mirrordir, sign)
    remotever = vers[sign]
    if localver is None or (localver and int(localver) < int(remotever)):
        info("=> Update required local: %s => remote: %s" %
             (localver, remotever))
        for passno in range(1, 6):
            info("=> Downloading signature: %s pass: %d" % (sign, passno))
            status, code = download_sig(options, addrs, sign, remotever)
            if status:
                info("=> Downloaded signature: %s" % sign)
                copy_sig(sign, options, 0)
                break
            else:
                if code == 404:
                    error("=> \033[91mSignature:\033[0m %s not found,"
                          " will not retry" % sign)
                    break
                error("=> \033[91mDownload failed:\033[0m %s "
                      "pass: %d, sleeping 5sec" % (sign, passno))
                time.sleep(5)
    else:
        info("=> No update required L: %s => R: %s" % (localver, remotever))


def update_diff(opts, addrs, sig):
    """Update diff"""
    for passno in range(1, 6):
        info("[+] \033[92mDownloading cdiff:\033[0m %s pass: %d" %
             (sig, passno))
        status, code = download_sig(opts, addrs, sig)
        if status:
            info("=> Downloaded cdiff: %s" % sig)
            copy_sig(sig, opts, 1)
            break
        else:
            if code == 404:
                error("=> \033[91mSignature:\033[0m %s not found,"
                      " will not retry" % sig)
                break
            error("=> \033[91mDownload failed:\033[0m %s pass: %d,"
                  " sleeping 5sec" % (sig, passno))
            time.sleep(5)


def create_dns_file(opts, record):
    """Create the DNS record file"""
    info("[+] \033[92mUpdating dns.txt file\033[0m")
    filename = os.path.join(opts.mirrordir, 'dns.txt')
    localmd5 = get_file_md5(filename)
    remotemd5 = get_md5(record)
    if localmd5 != remotemd5:
        create_file(filename, record)
        info("=> dns.txt file updated")
    else:
        info("=> No update required L: %s => R: %s" % (localmd5, remotemd5))


def main(options):
    """The main functions"""
    addrs = get_addrs(options.hostname)
    addrs.extend(get_addrs('db.us.big.clamav.net.'))
    addrs = list(set(addrs))
    record = get_record(options)
    # pylint: disable-msg=W0621
    _, mainv, dailyv, _, _, _, safebrowsingv, bytecodev = record.split(':')
    versions = {'main': mainv, 'daily': dailyv,
                'safebrowsing': safebrowsingv,
                'bytecode': bytecodev}
    for signature_type in ['main', 'daily', 'bytecode', 'safebrowsing']:
        if signature_type in ['daily', 'bytecode', 'safebrowsing']:
            # download diffs
            localver = get_local_version(options.mirrordir, signature_type)
            remotever = locals()['%sv' % signature_type]
            if localver is not None:
                for num in range(int(localver), int(remotever) + 1):
                    sig_diff = '%s-%d' % (signature_type, num)
                    filename = os.path.join(options.mirrordir,
                                            '%s.cdiff' % sig_diff)
                    if not os.path.exists(filename):
                        update_diff(options, addrs, sig_diff)
        update_sig(options, addrs, signature_type, versions)
    create_dns_file(options, record)
    sys.exit(0)


if __name__ == '__main__':
    PARSER = OptionParser()
    PARSER.add_option('-a', '--hostname',
                      help='ClamAV source server hostname',
                      dest='hostname',
                      type='str',
                      default='db.de.clamav.net')
    PARSER.add_option('-r', '--text-record',
                      help='ClamAV Updates TXT record',
                      dest='txtrecord',
                      type='str',
                      default='current.cvd.clamav.net')
    PARSER.add_option('-w', '--work-directory',
                      help='Working directory',
                      dest='workdir',
                      type='str',
                      default='/var/spool/clamav-mirror')
    PARSER.add_option('-d', '--mirror-directory',
                      help='The mirror directory',
                      dest='mirrordir',
                      type='str',
                      default='/srv/www/clamav')
    PARSER.add_option('-u', '--user',
                      help='Change file owner to this user',
                      dest='user',
                      type='str',
                      default='nginx')
    PARSER.add_option('-g', '--group',
                      help='Change file group to this group',
                      dest='group',
                      type='str',
                      default='nginx')
    PARSER.add_option('-l', '--locks-directory',
                      help='Lock files directory',
                      dest='lockdir',
                      type='str',
                      default='/var/lock/subsys')
    PARSER.add_option('-v', '--verbose',
                      help='Display verbose output',
                      dest='verbose',
                      action='store_true',
                      default=False)
    OPTIONS, _ = PARSER.parse_args()
    try:
        LOCKFILE = os.path.join(OPTIONS.lockdir, 'clamavmirror')
        with open(LOCKFILE, 'w+') as lock:
            fcntl.lockf(lock, fcntl.LOCK_EX | fcntl.LOCK_NB)
            main(OPTIONS)
    except IOError:
        info("=> Another instance is already running")
        sys.exit(254)
