#!/usr/bin/env -S python3 -u
# -*- coding: utf-8 -*-

# Copyright © 2016, cyberang3l
# Copyright © 2023, IOhannes m zmölnig, IEM

# This file is part of virtshaus
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with iaem-downloader. If not, see <http://www.gnu.org/licenses/>.
#
# cyberang3l's original version of this script had no license attachend
# to it. the original version of the script can be found at
# https://gist.github.com/cyberang3l/f4c8b1ab6fc48374fbae9553d89e5eed

import time
import os

import logging

log = logging.getLogger(os.path.basename(__file__))
logging.basicConfig()


# ----------------------------------------------------------------------
def start_vms(libvirt_conn, vms):
    running_vms = []
    for vm_name in vms:
        try:
            wait_time = vms[vm_name]
        except Exception:
            wait_time = 0

        try:
            vm = libvirt_conn.lookupByName(vm_name)
        except libvirt.libvirtError as e:
            # Error code 42 = Domain not found
            if e.get_error_code() == 42:
                log.error("VM '{}' is not a valid domain".format(vm_name))
            else:
                log.exception("Unknown error")
            continue

        if vm.isActive():
            log.info("VM '{}' already started".format(vm_name))
        while not vm.isActive():
            log.info("Starting VM '{}'".format(vm_name))
            vm.create()
            time.sleep(1)
            if vm.isActive():
                log.info(
                    "Waiting for {} seconds before trying to start the next VM".format(
                        wait_time,
                    )
                )
                time.sleep(wait_time)
        if vm.isActive():
            running_vms.append(vm_name)
    if running_vms:
        log.info("The following VMs have been started: %s" % ", ".join(running_vms))
    else:
        log.warning("No VMs have been started of '%s'" % ", ".join(vms))


# ----------------------------------------------------------------------
def stop_vms(libvirt_conn, vms, timeout):
    stopped_vms = []
    for vm_name in vms:
        try:
            vm = libvirt_conn.lookupByName(vm_name)
        except libvirt.libvirtError as e:
            # Error code 42 = Domain not found
            if e.get_error_code() == 42:
                log.error("VM '{}' is not a valid domain".format(vm_name))
            else:
                log.exception("Unknown error")
            continue

        if vm.isActive():
            print("Stopping VM '{}'".format(vm_name))
        else:
            print("VM '{}' is already stopped".format(vm_name))

        now = time.time()
        while vm.isActive():
            try:
                vm.shutdown()
                time.sleep(1)
                if time.time() - now >= timeout:
                    log.error(
                        "Timeout was reached and VM '{}' hasn't stopped yet. Destroying...".format(
                            vm_name
                        )
                    )
                    vm.destroy()
            except libvirt.libvirtError as e:
                # Error code 55 = Not valid operation: domain is not running
                if e.get_error_code() != 55:
                    log.exception("Unable to stop VM '{}'".format(vm_name))
                break

        if not vm.isActive():
            stopped_vms.append(vm_name)

    if stopped_vms:
        log.info("The following VMs have been stopped: %s" % ", ".join(stopped_vms))
    else:
        log.warning("No VMs have been stopped of '%s'" % ", ".join(vms))


# ----------------------------------------------------------------------
def status_vms(libvirt_conn, vms=[]):
    if not vms:
        domNames = {_: False for _ in libvirt_conn.listDefinedDomains()}
        for domID in libvirt_conn.listDomainsID():
            dom = libvirt_conn.lookupByID(domID)
            domNames[dom.name()] = True
        vms = sorted(domNames)

    print("Status for all VMs (active and +inactive domain names):")
    print("-----------------------------")
    found = 0
    for domName in vms:
        try:
            vm = libvirt_conn.lookupByName(domName)
            print(
                "\t{}{}".format(
                    " " if vm.isActive() else "+",
                    vm.name(),
                )
            )
            found += 1
        except Exception:
            pass
    if not found:
        print("\tNone")
    print("-----------------------------")


def parseCmdlineArgs():
    import argparse

    defaults = {
        "connection": "qemu:///system",
        "timeout": 240,
    }
    parser = argparse.ArgumentParser()
    parser.set_defaults(**defaults)

    def add_stopargs(parser):
        parser.add_argument("--timeout", type=float, help="timeout for stopping VMs")

    parser.add_argument(
        "--connection",
        type=str,
        help="libvirt connection (DEFAULT: {connection})".format(**defaults),
    )

    parser.add_argument(
        "--config",
        action="append",
        help="configfiles or directories with configfiles (ending in .conf)",
    )

    group = parser.add_argument_group("printout")
    group.add_argument(
        "-v",
        "--verbose",
        action="count",
        default=0,
        help="raise verbosity (can be given multiple times",
    )
    group.add_argument(
        "-q",
        "--quiet",
        action="count",
        default=0,
        help="lower verbosity (can be given multiple times",
    )
    group.add_argument("--logfile", type=str, help="Logfile to write to", default=None)

    subparsers = parser.add_subparsers(
        title="commands",
        dest="command",
        required=True,
    )
    subparser = subparsers.add_parser(
        "start",
        help="start VMs",
    )
    subparser = subparsers.add_parser(
        "stop",
        help="stop VMs",
    )
    add_stopargs(subparser)

    subparser = subparsers.add_parser(
        "restart",
        help="restart VMs",
    )
    add_stopargs(subparser)

    subparser = subparsers.add_parser(
        "status",
        help="print status of VMs",
    )

    args = parser.parse_args()

    # handle some args centrally
    args.timeout = args.timeout or defaults.get("timeout", 0)

    # logging
    loglevel = max(
        0,
        min(logging.FATAL + 1, logging.INFO + (args.quiet - args.verbose) * 10),
    )
    logging.getLogger().setLevel(loglevel)

    del args.quiet
    del args.verbose

    if args.logfile:
        fh = None
        try:
            fh = logging.FileHandler(args.logfile, "w", encoding="utf-8")
        except OSError:
            pass
        if fh:
            logging.getLogger().addHandler(fh)
    del args.logfile

    # config-files

    configfiles = []
    for filename in args.config or ["/etc/virtshaus/vmboot.conf"]:
        if os.path.exists(filename):
            configfiles.append(filename)
        else:
            log.fatal("config-file '%s' does not exist" % (filename,))
    args.config = configfiles

    return args


def readConfigs(configfiles):
    vm_start_list = {}
    for conf in configfiles:
        with open(conf) as config:
            for line in config:
                line = line.strip()
                if not line or line.startswith("#"):
                    continue
                line_list = line.split()
                try:
                    wait = int(line_list[1])
                except IndexError:
                    wait = 0
                vm_start_list[line_list[0]] = wait
    return vm_start_list


def main(args):
    start_list = readConfigs(args.config)
    stop_list = reversed(start_list)

    conn = libvirt.open(args.connection)

    try:
        if args.command == "start":
            start_vms(conn, start_list)
        elif args.command == "stop":
            stop_vms(conn, stop_list, args.timeout)
        elif args.command == "restart":
            stop_vms(conn, stop_list, args.timeout)
            start_vms(conn, start_list)
        elif args.command == "status":
            status_vms(conn, start_list)
    except KeyboardInterrupt as e:
        raise e
    finally:
        conn.close()


if __name__ == "__main__":
    args = parseCmdlineArgs()
    import libvirt

    main(args)
