#!/usr/bin/python
# -*- coding: utf-8 -*-
#
# Copyright (C) 2009-2012 Red Hat, Inc.
#
# Authors:
# Thomas Woerner <twoerner@redhat.com>
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 2 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
#

from __future__ import print_function
import sys
import os, os.path
from copy import copy

from firewall.config import *
from firewall.core.io.firewalld_conf import firewalld_conf
from firewall.core.io.zone import Zone, zone_reader, zone_writer
from optparse import Option, OptionError, OptionParser, Values, \
    SUPPRESS_HELP, BadOptionError, OptionGroup
from firewall.functions import getPortID, getPortRange, getServiceName, \
    checkIP, checkInterface

# check for root user
if os.getuid() != 0:
    print(_("You need to be root to run %s.") % sys.argv[0])
    sys.exit(-1)

def usage():
    print("Usage: %s -h | --help" % sys.argv[0])

def error(text):
    print("%s %s" % (_("Error:"), text))

def warning(text):
    print("%s %s" % (_("Warning:"), text))

def __fail(msg=None):
    if msg:
        error(msg)
    sys.exit(2)

# system-config-firewall: fw_parser

def _check_port(option, opt, value):
    failure = False
    try:
        (ports, protocol) = value.split(":")
    except:
        failure = True
    else:
        range = getPortRange(ports.strip())
        if range < 0:
            failure = True
        elif range == None:
            raise OptionError(_("port range %s is not unique.") % value, opt)
        elif len(range) == 2 and range[0] >= range[1]:
            raise OptionError(_("%s is not a valid range (start port >= end "
                                "port).") % value, opt)
    if not failure:
        protocol = protocol.strip()
        if protocol not in [ "tcp", "udp" ]:
            raise OptionError(_("%s is not a valid protocol.") % protocol, opt)
    if failure:
        raise OptionError(_("invalid port definition %s.") % value, opt)
    return (ports.strip(), protocol)

def _check_forward_port(option, opt, value):
    result = { }
    error = None
    splits = value.split(":", 1)
    while len(splits) > 0:
        key_val = splits[0].split("=")
        if len(key_val) != 2:
            error = _("Invalid argument %s") % splits[0]
            break
        (key, val) = key_val
        if (key == "if" and checkInterface(val)) or \
                (key == "proto" and val in [ "tcp", "udp" ]) or \
                (key == "toaddr" and checkIP(val)):
            result[key] = val
        elif (key == "port" or key == "toport") and getPortRange(val) > 0:
            result[key] = val
        else:
            error = _("Invalid argument %s") % splits[0]
            break
        if len(splits) > 1:
            if splits[1].count("=") == 1:
                # last element
                splits = [ splits[1] ]
            else:
                splits = splits[1].split(":", 1)
        else:
            # finish
            splits.pop()

    if error:
        dict = { "option": opt, "value": value, "error": error }
        raise OptionError(_("option %(option)s: invalid forward_port "
                                 "'%(value)s': %(error)s.") % dict, opt)

    error = False
    for key in [ "if", "port", "proto" ]:
        if key not in result.keys():
            error = True
    if not "toport" in result.keys() and not "toaddr" in result.keys():
        error = True
    if error:
        dict = { "option": opt, "value": value }
        raise OptionError(_("option %(option)s: invalid forward_port "
                                 "'%(value)s'.") % dict, opt)

    return result

def _check_interface(option, opt, value):
    if not checkInterface(value):
        raise OptionError(_("invalid interface '%s'.") % value, opt)
    return value

def _append_unique(option, opt, value, parser, *args, **kwargs):
    vals = getattr(parser.values, option.dest)
    if vals and value in vals:
        return
    parser.values.ensure_value(option.dest, []).append(value)

class _Option(Option):
    TYPES = Option.TYPES + ("port", "rulesfile", "service", "forward_port",
                            "icmp_type", "interface")
    TYPE_CHECKER = copy(Option.TYPE_CHECKER)
    TYPE_CHECKER["port"] = _check_port
    TYPE_CHECKER["forward_port"] = _check_forward_port
    TYPE_CHECKER["interface"] = _check_interface

def _addStandardOptions(parser):
    parser.add_option("--enabled",
                      action="store_true", dest="enabled", default=True,
                      help=_("Enable firewall (default)"))
    parser.add_option("--disabled",
                      action="store_false", dest="enabled",
                      help=_("Disable firewall"))
#    parser.add_option("--update",
#                      action="store_false", dest="update",
#                      help=_("Ignored option, was used to update the firewall"))
    parser.add_option("--addmodule",
                      action="callback", dest="add_module", type="string",
                      metavar=_("<module>"),  callback=_append_unique,
                      help=_("Ignored option, was used to enable an iptables module"))
    parser.add_option("--removemodule",
                      action="callback", dest="remove_module", type="string",
                      metavar=_("<module>"), callback=_append_unique,
                      help=_("Ignored option, was used to disable an iptables module"))
    parser.add_option("-s", "--service",
                      action="callback", dest="services", type="service",
                      default=[ ],
                      metavar=_("<service>"), callback=_append_unique,
                      help=_("Enable a service in the default zone (example: ssh)"))
    parser.add_option("--remove-service",
                      action="callback", dest="remove_services", type="service",
                      default=[ ],
                      metavar=_("<service>"), callback=_append_unique,
                      help=_("Disable a service in the default zone (example: ssh)"))
    parser.add_option("-p", "--port",
                      action="callback", dest="ports", type="port",
                      metavar=_("<port>[-<port>]:<protocol>"),
                      callback=_append_unique,
                      help=_("Enable a port in the default zone "
                             "(example: ssh:tcp)"))
    parser.add_option("-t", "--trust",
                      action="callback", dest="trust", type="interface",
                      metavar=_("<interface>"), callback=_append_unique,
                      help=_("Bind an interface to the trusted zone"))
    parser.add_option("-m", "--masq",
                      action="callback", dest="masq", type="interface",
                      metavar=_("<interface>"), callback=_append_unique,
                      help=_("Enables masquerading in the default zone, interface argument is ignored. This is IPv4 only."))
    parser.add_option("--custom-rules",
                      action="callback", dest="custom_rules", type="rulesfile",
                      metavar=_("[<type>:][<table>:]<filename>"),
                      callback=_append_unique,
                      help=_("Ignored option. Was used to add custom rules to the firewall (Example: ipv4:filter:/etc/sysconfig/ipv4_filter_addon)"))
    parser.add_option("--forward-port",
                      action="callback", dest="forward_port",
                      type="forward_port",
                      metavar=_("if=<interface>:port=<port>:proto=<protocol>"
                                "[:toport=<destination port>]"
                                "[:toaddr=<destination address>]"),
                      callback=_append_unique,
                      help=_("Forward the port with protocol for the interface to either another local destination port (no destination address given) or to an other destination address with an optional destination port. This will be added to the default zone. This is IPv4 only."))
    parser.add_option("--block-icmp",
                      action="callback", dest="block_icmp", type="icmp_type",
                      default=[ ],
                      callback=_append_unique,
                      metavar=_("<icmp type>"),
                      help=_("Block this ICMP type in the default zone. The default is to accept all ICMP types."))

def _parse_args(parser, args, options=None):
    try:
        (_options, _args) = parser.parse_args(args, options)
    except Exception as error:
        parser.error(error)
        return None

    if len(_args) != 0:
        for arg in _args:
            parser.error(_("no such option: %s") % arg)
    if parser._fw_exit:
        sys.exit(2)
    if not hasattr(_options, "filename"):
        _options.filename = None
    if not hasattr(_options, "converted"):
        _options.converted = False
    return _options

class _OptionParser(OptionParser):
    # overload print_help: rhpl._ returns UTF-8
    def print_help(self, file=None):
        if file is None:
            file = sys.stdout

        file.write(_("This tool tries to convert system-config-firewall/lokkit options as much as possible to firewalld, but there are limitations for example with custom rules, modules and masquerading.") + "\n\n")
        str = self.format_help()
        if isinstance(str, unicode):
            encoding = self._get_encoding(file)
            str = str.encode(encoding, "replace")
        file.write(str)
        file.write("\n" + _("If no options are given, the configuration from '%s' be migrated.") % (CONFIG) + "\n")
        self.exit()
    def print_usage(self, file=None):
        pass
    def exit(self, status=0, msg=None):
        if msg:
            print(msg, file=sys.stderr)
        sys.exit(status)
    def error(self, msg):
        if self._fw_source:
            text = "%s: %s" % (self._fw_source, msg)
        else:
            text = str(msg)
        self.exit(2, msg=text)
    def _match_long_opt(self, opt):
        if opt in self._long_opt:
            return opt
        raise BadOptionError(opt)
    def _process_long_opt(self, rargs, values):
        # allow to ignore errors in the ui
        try:
            self.__process_long_opt(rargs, values)
        except Exception as msg:
            self.error(msg)
    def _process_short_opts(self, rargs, values):
        # allow to ignore errors in the ui
        try:
            OptionParser._process_short_opts(self, rargs, values)
        except Exception as msg:
            self.error(msg)
    def __process_long_opt(self, rargs, values):
        arg = rargs.pop(0)

        # Value explicitly attached to arg?  Pretend it's the next
        # argument.
        if "=" in arg:
            (opt, next_arg) = arg.split("=", 1)
            had_explicit_value = True
        else:
            opt = arg
            had_explicit_value = False

        opt = self._match_long_opt(opt)
        option = self._long_opt[opt]
        if option.takes_value():
            nargs = option.nargs
            if len(rargs)+int(had_explicit_value) < nargs:
                if nargs == 1:
                    self.error(_("%s option requires an argument") % opt)
                else:
                    dict = { "option": opt, "count": nargs }
                    self.error(_("%(option)s option requires %(count)s "
                                 "arguments") % dict)
            elif nargs == 1 and had_explicit_value:
                value = next_arg
            elif nargs == 1:
                value = rargs.pop(0)
            elif had_explicit_value:
                value = tuple([ next_arg ] + rargs[0:nargs-1])
                del rargs[0:nargs-1]
            else:
                value = tuple(rargs[0:nargs])
                del rargs[0:nargs]

        elif had_explicit_value:
            self.error(_("%s option does not take a value") % opt)

        else:
            value = None

        option.process(opt, value, values, self)

def _gen_parser(source=None):
    parser = _OptionParser(option_class=_Option)
    parser._fw_source = source
    parser._fw_exit = False
    return parser

def parseSysconfigArgs(args, options=None, source=None):
    parser = _gen_parser(source)
    _addStandardOptions(parser)
    return _parse_args(parser, args, options)

# system-config-firewall: fw_sysconfig
CONFIG = '/etc/sysconfig/system-config-firewall'

def read_sysconfig_args():
    filename = None
    if os.path.exists(CONFIG) and os.path.isfile(CONFIG):
        filename = CONFIG
    try:
        fd = open(filename, 'r')
    except:
        return None
    argv = [ ]
    for line in fd.xreadlines():
        if not line:
            break
        line = line.strip()
        if len(line) < 1 or line[0] == '#':
            continue
        argv.append(line)
    fd.close()
    return (argv, filename)

def parse_sysconfig_args(args, merge_config=None, filename=None):
    config = parseSysconfigArgs(args, options=merge_config, source=filename)
    if not config:
        return None
    config.filename = filename
    return config

def read_sysconfig_config(merge_config=None):
    args = read_sysconfig_args() # returns: (args, filename) or None
    if not args:
        return merge_config
    return parse_sysconfig_args(args[0], merge_config, args[1])

if len(sys.argv) > 1:
    # Parse the cmdline args and setup the initial firewall state
    conf = parse_sysconfig_args(None)
    if not conf:
        error(_("Problem parsing arguments."))
        sys.exit(1)
else:
    # open system-config-firewall config
    conf = read_sysconfig_config()
    if not conf:
        error(_("Opening of '%s' failed, exiting." % CONFIG))
        sys.exit(1)


if conf.enabled == False:
    os.system("systemctl disable firewalld.service")
    os.system("systemctl mask firewalld.service")
else:
    os.system("systemctl unmask firewalld.service")
    os.system("systemctl enable firewalld.service")


# open firewalld config file to get default zone

default_zone = "public" # default zone in case of missing config file
trusted_zone = "trusted"

_firewalld_conf = firewalld_conf(FIREWALLD_CONF)
try:
    _firewalld_conf.read()
except Exception as msg:
    # ignore read error, use default zone
    pass
else:
    default_zone = _firewalld_conf.get("DefaultZone")

obj = None
for path in [ ETC_FIREWALLD_ZONES, FIREWALLD_ZONES ]:
    filename = "%s.xml" % default_zone
    if os.path.exists("%s/%s" %(path, filename)):
        print(_("Opening default zone '%s'" % default_zone))
        obj = zone_reader(filename, path)
        break

if not obj:
    error(_("Unable to open default zone '%s', exiting.") % default_zone)
    # create new zone?
    sys.exit(1)

trusted_obj = None
if default_zone != trusted_zone:
    for path in [ ETC_FIREWALLD_ZONES, FIREWALLD_ZONES ]:
        filename = "%s.xml" % trusted_zone
        if os.path.exists("%s/%s" %(path, filename)):
            trusted_obj = zone_reader(filename, path)
            break
    if conf.trust and not trusted_obj:
        error(_("Unable to open zone '%s', exiting.") % trusted_zone)
        sys.exit(1)
else:
    trusted_obj = obj

changed = False
changed_trusted = False

# fields that can not get converted into a zone, need NM work

if conf.trust:
    if trusted_obj:
        for dev in conf.trust:
            warning(_("The device '%s' will be bound to the %s zone.") % \
                          (dev, trusted_zone))
            trusted_obj.interfaces.append(dev)
            changed_trusted = True

# no custom rules
if conf.custom_rules and len(conf.custom_rules) > 0:
    for custom in conf.custom_rules:
        warning(_("Ignoring custom-rule file '%s'") % ":".join(custom))

# no modules
if conf.add_module and len(conf.add_module) > 0:
    for module in conf.add_module:
        warning(_("Ignoring addmodule '%s'") % module)
if conf.remove_module and len(conf.remove_module) > 0:
    for module in conf.remove_module:
        warning(_("Ignoring removemodule '%s'") % module)

if conf.masq:
    for dev in conf.masq:
        if obj.masquerade != True:
            warning(_("Device '%s' was masqueraded, enabling masquerade for the default zone.") % dev)
            obj.masquerade = True
            changed = True

if conf.ports and len(conf.ports) > 0:
    for item in conf.ports:
        if item not in obj.ports:
            print(_("Adding port '%s/%s' to default zone.") % \
                      (item[0], item[1]))
            obj.ports.append(item)
            changed = True

if conf.remove_services:
    for service in conf.remove_services:
        if service in obj.services:
            print(_("Removing service '%s' from default zone.") % service)
            obj.services.remove(service)
            changed = True

if conf.services:
    for service in conf.services:
        if service not in obj.services:
            print(_("Adding service '%s' to default zone.") % service)
            obj.services.append(service)
            changed = True

if conf.block_icmp:
    for icmp in conf.block_icmp:
        if icmp not in obj.icmp_blocks:
            print(_("Adding icmpblock '%s' to default zone.") % icmp)
            obj.icmp_blocks.append(icmp)
            changed = True

if conf.forward_port:
    for fwd in conf.forward_port:
        # ignore interface, should belong to default zone
        entry = (fwd.get("port", ""), fwd.get("proto", ""),
                 fwd.get("toport", ""), fwd.get("toaddr", ""))
        if entry not in obj.forward_ports:
            print(_("Adding forward port %s:%s:%s:%s to default zone.") % \
                      (entry[0], entry[1], entry[2], entry[3]))
            obj.forward_ports.append(entry)
            changed = True

if changed:
    zone_writer(obj, ETC_FIREWALLD_ZONES)
else:
    print(_("No changes to default zone needed."))

if changed_trusted:
    zone_writer(trusted_obj, ETC_FIREWALLD_ZONES)
    print(_("Changed trusted zone configuration."))
    print("\n")
    warning(_("If one of the trusted interfaces is used for a connection with NetworkManager or if there is an ifcfg file for this interface, the zone will be changed to the zone defined in the configuration as soon as it gets activated. To change the zone of a connection use <command>nm-connection-editor</command> and set the zone to trusted, for an ifcfg file, use an editor and add \"ZONE=trusted\". If the zone is not defined in the ifcfg file, the firewalld default zone will be used."))

sys.exit(0)
