#!/usr/bin/env python
# -*- coding: utf-8 -*-

from __future__ import print_function
import os
import sys
import glob
import argparse
import logging
import time
from multiprocessing import cpu_count
from threading import Thread
from subprocess import call, Popen, PIPE

try:
    # python 2
    from Queue import Queue, Empty
except ImportError:
    # python 3
    from queue import Queue, Empty

# Look for taskd in $PATH instead of task/src/
os.environ["TASKD_USE_PATH"] = "1"

TIMEOUT = .2


def run_test(testqueue, outqueue, threadname):
    start = time.time()
    while True:
        try:
            test = testqueue.get(block=True, timeout=TIMEOUT)
        except Empty:
            break

        log.info("Running test %s", test)

        try:
            p = Popen(os.path.abspath(test), stdout=PIPE, stderr=PIPE,
                      env=os.environ)
            out, err = p.communicate()
        except Exception as e:
            log.exception(e)
            # Premature end
            break

        if sys.version_info > (3,):
            out, err = out.decode('utf-8'), err.decode('utf-8')

        output = ("# {0}\n".format(os.path.basename(test)), out, err)
        log.debug("Collected output %s", output)
        outqueue.put(output)

        testqueue.task_done()

    log.warning("Finished %s thread after %s seconds",
                threadname, round(time.time() - start, 3))


class TestRunner(object):
    def __init__(self):
        self.threads = []
        self.tap = open(cmd_args.tapfile, 'w')

        self._parallelq = Queue()
        self._serialq = Queue()
        self._outputq = Queue()

    def _find_tests(self):
        for test in glob.glob("*.t") + glob.glob("*.t.exe"):
            if os.access(test, os.X_OK):
                # Executables only
                if self._is_parallelizable(test):
                    log.debug("Treating as parallel: %s", test)
                    self._parallelq.put(test)
                else:
                    log.debug("Treating as serial: %s", test)
                    self._serialq.put(test)
            else:
                log.debug("Ignored test %s as it is not executable", test)

        log.info("Parallel tests: %s", self._parallelq.qsize())
        log.info("Serial tests: %s", self._serialq.qsize())

    def _prepare_threads(self):
        # Serial thread
        self.threads.append(
            Thread(target=run_test, args=(self._serialq, self._outputq, "Serial"))
        )
        # Parallel threads
        self.threads.extend([
            Thread(target=run_test, args=(self._parallelq, self._outputq, "Parallel"))
            for i in range(cpu_count())
        ])
        log.info("Spawned %s threads to run tests", len(self.threads))

    def _start_threads(self):
        for thread in self.threads:
            # Threads die when main thread dies
            log.debug("Starting thread %s", thread)
            thread.daemon = True
            thread.start()

    def _print_timestamp_to_tap(self):
        now = time.time()
        timestamp = "# {0} ==> {1}\n".format(
            now,
            time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(now)),
        )

        log.debug("Adding timestamp %s to TAP file", timestamp)
        self.tap.write(timestamp)

    def _is_parallelizable(self, test):
        if cmd_args.serial:
            return False

        # This is a pretty weird way to do it, and not realiable.
        # We are dealing with some binary tests though.
        with open(test, 'rb') as fh:
            header = fh.read(100).split(b"\n")
            if len(header) >= 2 and \
                ((b"/usr/bin/env python" in header[0]) or \
                (header[1][-14:] == b"bash_tap_tw.sh")):
                return True
            else:
                return False

    def _get_remaining_tests(self):
        return self._parallelq.qsize() + self._serialq.qsize()

    def is_running(self):
        for thread in self.threads:
            if thread.is_alive():
                return True

        return False

    def start(self):
        self._find_tests()
        self._prepare_threads()

        self._print_timestamp_to_tap()

        finished = 0
        total = self._get_remaining_tests()

        self._start_threads()

        while self.is_running() or not self._outputq.empty():
            try:
                outputs = self._outputq.get(block=True, timeout=TIMEOUT)
            except Empty:
                continue

            log.debug("Outputting to TAP: %s", outputs)

            for output in outputs:
                self.tap.write(output)

                if cmd_args.verbose:
                    sys.stdout.write(output)

            self._outputq.task_done()
            finished += 1

            log.warning("Finished %s out of %s tests", finished, total)

        self._print_timestamp_to_tap()

        if not self._parallelq.empty() or not self._serialq.empty():
            raise RuntimeError(
                "Something went wrong, not all tests were ran. {0} "
                "remaining.".format(self._get_remaining_tests()))

    def show_report(self):
        self.tap.flush()

        log.debug("Calling 'problems --summary' for report")
        return call([os.path.abspath("problems"), "--summary", cmd_args.tapfile])


def parse_args():
    parser = argparse.ArgumentParser(description="Run Taskwarrior tests")
    parser.add_argument('--verbose', '-v', action="store_true",
                        help="Also send TAP output to stdout")
    parser.add_argument('--logging', '-l', action="count",
                        default=0,
                        help="Logging level. -lll is the highest level")
    parser.add_argument('--serial', action="store_true",
                        help="Do not run tests in parallel")
    parser.add_argument('--tapfile', default="all.log",
                        help="File to use for TAP output")
    return parser.parse_args()


def main():
    runner = TestRunner()
    runner.start()

    # If we're producing summary report, propagate the return code
    if not cmd_args.verbose:
        return runner.show_report()
    else:
        return 0


if __name__ == "__main__":
    cmd_args = parse_args()

    if cmd_args.logging == 1:
        level = logging.WARN
    elif cmd_args.logging == 2:
        level = logging.INFO
    elif cmd_args.logging >= 3:
        level = logging.DEBUG
    else:
        level = logging.ERROR

    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(message)s",
        level=level,
    )
    log = logging.getLogger(__name__)

    log.debug("Parsed commandline arguments: %s", cmd_args)

    try:
        sys.exit(main())
    except Exception as e:
        log.exception(e)
        sys.exit(1)

# vim: ai sts=4 et sw=4
