#!/usr/bin/env python
from __future__ import print_function
import sys
import os
import math
import traceback
from time import time

from neo.tests import DB_PREFIX
from neo.tests.benchmark import BenchmarkRunner
from ZODB.FileStorage import FileStorage

class MatrixImportBenchmark(BenchmarkRunner):

    error_log = ''
    _size = None

    def add_options(self, parser):
        _ = parser.add_argument
        _('-d', '--datafs')
        _('-z', '--zeo', action="store_true")
        _('--min-storages', type=int, default=1)
        _('--max-storages', type=int, default=2)
        _('--min-replicas', type=int, default=0)
        _('--max-replicas', type=int, default=1)
        _('--repeat', type=int, default=1)
        _('--threaded', action="store_true")

    def load_options(self, args):
        if args.datafs and not os.path.exists(args.datafs):
            sys.exit('Missing or wrong data.fs argument')
        return dict(
            datafs = args.datafs,
            min_s = args.min_storages,
            max_s = args.max_storages,
            min_r = args.min_replicas,
            max_r = args.max_replicas,
            repeat = args.repeat,
            threaded = args.threaded,
            zeo = args.zeo,
        )

    def start(self):
        # build storage (logarithm) & replicas (linear) lists
        min_s, max_s = self._config.min_s, self._config.max_s
        min_r, max_r = self._config.min_r, self._config.max_r
        min_s2 = int(math.log(min_s, 2))
        max_s2 = int(math.log(max_s, 2))
        storages = [2 ** x for x in range(min_s2, max_s2 + 1)]
        if storages[0] < min_s:
            storages[0] = min_s
        if storages[-1] < max_s:
            storages.append(max_s)
        replicas = range(min_r, max_r + 1)
        results = {}
        def merge_min(a, b):
            for k, vb in b.iteritems():
                try:
                    va = a[k]
                except KeyError:
                    pass
                else:
                    if type(va) is dict:
                        merge_min(va, vb)
                        continue
                    if vb is None or None is not va <= vb:
                        continue
                a[k] = vb
        for x in xrange(self._config.repeat):
            merge_min(results, self.runMatrix(storages, replicas))
        return self.buildReport(storages, replicas, results)

    def runMatrix(self, storages, replicas):
        stats = {}
        if self._config.zeo:
            stats['zeo'] = self.runImport()
        for s in storages:
            stats[s] = z = {}
            for r in replicas:
                if r < s:
                    z[r] = self.runImport(1, s, r, 12*s//(1+r))
        return stats

    def runImport(self, *neo_args):
        datafs = self._config.datafs
        if datafs:
            dfs_storage = FileStorage(file_name=self._config.datafs)
        else:
            datafs = 'PROD1'
            import random, neo.tests.stat_zodb
            dfs_storage = getattr(neo.tests.stat_zodb, datafs)(
                random.Random(0)).as_storage(5000)
        info = "Import of " + datafs
        if neo_args:
            masters, storages, replicas, partitions = neo_args
            info += " with m=%s, s=%s, r=%s, p=%s" % (
                masters, storages, replicas, partitions)
            if self._config.threaded:
                from neo.tests.threaded import NEOCluster
            else:
                from neo.tests.functional import NEOCluster
            zodb = NEOCluster(
                db_list=['%s_matrix_%u' % (DB_PREFIX, i) for i in xrange(storages)],
                clear_databases=True,
                master_count=masters,
                partitions=partitions,
                replicas=replicas,
            )
        else:
            from neo.tests.zeo_cluster import ZEOCluster
            info += " with ZEO"
            zodb = ZEOCluster()
        print(info)
        try:
            zodb.start()
            try:
                storage = zodb.getZODBStorage()
                if neo_args and not self._config.threaded:
                    assert len(zodb.getStorageList()) == storages
                    zodb.expectOudatedCells(number=0)
                start = time()
                storage.copyTransactionsFrom(dfs_storage)
                end = time()
                size = dfs_storage.getSize()
                if self._size is None:
                    self._size = size
                else:
                    assert self._size == size
            finally:
                zodb.stop()
            # Clear DB if no error happened.
            zodb.setupDB()
            return end - start
        except:
            traceback.print_exc()
            self.error_log += "%s:\n%s\n" % (
                info, ''.join(traceback.format_exc()))

    def buildReport(self, storages, replicas, results):
        # draw an array with results
        dfs_size = self._size
        self.add_status('Input size',
            dfs_size and '%-.1f MB' % (dfs_size / 1e6) or 'N/A')
        fmt = '|' + '|'.join(['  %8s  '] * (len(replicas) + 1)) + '|\n'
        sep = '+' + '+'.join(['-' * 12] * (len(replicas) + 1)) + '+\n'
        report = sep
        report += fmt % tuple(['S\R'] + replicas)
        report += sep
        failures = 0
        speedlist = []
        if self._config.zeo:
            result = results['zeo']
            if result is None:
                result = 'FAIL'
                failures += 1
            else:
                result = '%.1f kB/s' % (dfs_size / (result * 1e3))
            self.add_status('ZEO', result)
        for s in storages:
            values = []
            assert s in results
            for r in replicas:
                if r in results[s]:
                    result = results[s][r]
                    if result is None:
                        values.append('FAIL')
                        failures += 1
                    else:
                        result = dfs_size / (result * 1e3)
                        values.append('%8.1f' % result)
                        speedlist.append(result)
                else:
                    values.append('N/A')
            report += fmt % (tuple([s] + values))
            report += sep
        report += self.error_log
        if failures:
            info = '%d failures' % (failures, )
        else:
            info = '%.1f kB/s' % (sum(speedlist) / len(speedlist))
        return info, report

def main(args=None):
    MatrixImportBenchmark().run()

if __name__ == "__main__":
    main()

