#!/usr/bin/python
#
# Copyright (C) 2009 ,2010 Red Hat, Inc.
# Authors:
# Thomas Woerner <twoerner@redhat.com>
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 2 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
#

import sys
import getopt
import dbus

DATADIR = '/usr/share/firewalld'
sys.path.append(DATADIR)

import firewall_client
from firewall_error import *

def usage():
    print "Usage: %s [<options>] <mode> <action>" % sys.argv[0]
    print "    -h --help"
    print "  Status:"
    print "    --reload"
    print "    --restart"
    print "    --status"
    print "  Modes:"
    print "    --enable"
    print "    --disable"
    print "    --query"
    print "    --list=<action>"
    print "  Actions:"
    print "    --service=<service>"
    print "    --port=<port>[-<port>]:<protocol>"
    print "    --trusted=<interface>"
    print "    --masquerade=<interface>"
    print "    --forward-port=if=<interface>:port=<port>:proto=<protocol>[:toport=<destination port>][:toaddr=<destination address>]"
    print "    --icmp-block=<icmp type>"
    print "    --custom=table=<table>:chain=<chain>:src=<addr>:src_port=<port>:dst=<addr>:dst_port=<port>:proto=<protocol>:iface_in=<interface>:iface_out=<interface>:physdev_in=<bridge>:physdev_out=<bridge>:target=<target>"
    print "  Enable Options:"
    print "    --timeout=<seconds>"
    print

try:
    (opts, args) = \
        getopt.getopt(sys.argv[1:], "h", 
                      [ "help", "timeout=", "reload", "restart", "status", 
                        # modes (exactly one of those)
                        "enable", "disable", "query", "list=",
                        # actions (exactly one of those)
                        "panic",
                        "service=", "port=", "trusted=", "masquerade=",
                        "forward-port=", "icmp-block=", "custom="
                        ])
except Exception, msg:
    print msg
    usage()
    sys.exit(1)

timeout = 0
mode = None
action = None
value = None
status = 0

def __fail(msg=None):
    if msg:
        print msg
    usage()
    sys.exit(2)

for (opt, val) in opts:
    if opt in ["-h", "--help"]:
        usage()
        sys.exit(0)

    elif opt in [ "--reload", "--restart", "--status" ]:
        if mode:
            __fail()
        mode = opt[2:]

    # timeout
    elif opt == "--timeout":
        try:
            timeout = int(val)
        except Exception, msg:
            usage()
            sys.exit(2)
        if timeout < 1:
            __fail("Timeout not valid")

    # mode
    elif opt in [ "--enable", "--disable", "--query" ]:
        if mode:
            __fail()
        mode = opt[2:]
    elif opt == "--list":
        if mode or action:
            __fail()
        mode = opt[2:]
        action = val

    # action
    elif opt in [ "--panic", "--service", "--port", "--trusted", "--masquerade",
                  "--forward-port", "--icmp-block", "--custom" ]:
        if action:
            __fail()
        action = opt[2:]
        
        if opt != "--panic":
            if value:
                __fail()
            value = val

if not mode:
    __fail("No mode.")
if mode not in [  "reload", "restart", "status" ]:
    if not action:
        __fail("No action.")
    if action != "panic" and mode != "list" and not value:
        __fail("No value.")
if timeout != 0:
    if mode != "enable":
        __fail("Timeout only valid in enable mode.")
    if action == "panic":
        __fail("No timeout for panic.")

try:
    fw = firewall_client.Firewall_Client()
except dbus.DBusException, e:
    if e._dbus_error_name == 'org.freedesktop.DBus.Error.ServiceUnknown':
        print "FirewallD is probably not running."
    sys.exit(1)

if mode == "status":
    status = fw.status()
    sys.exit(status)
elif mode == "reload":
    if not fw.reload():
        sys.exit(1)
    sys.exit(0)
elif mode == "restart":
    if not fw.restart():
        sys.exit(1)
    sys.exit(0)

# panic
if action == "panic":
    if mode == "enable":
        if not fw.enablePanicMode():
            sys.exit(1)
    elif mode == "disable":
        if not fw.disablePanicMode():
            sys.exit(1)
    elif mode == "query":
        if not fw.queryPanicMode():
            sys.exit(1)
# service
elif action == "service":
    if mode == "list":
        l = fw.getServices()
        if len(l) > 0:
            print " ".join(l)
        sys.exit(0)

    if mode == "enable":
        fw.enableService(value, timeout)
    elif mode == "disable":
        fw.disableService(value)
    elif mode == "query":
        if not fw.queryService(value):
            sys.exit(-1)
# port
elif action == "port":
    if mode == "list":
        l = fw.getPorts()
        if len(l) > 0:
            print " ".join(["%s:%s" % port for port in l])
        sys.exit(0)

    try:
        (port, proto) = value.split(":")
    except Exception, msg:
        __fail(msg)

    if mode == "enable":
        fw.enablePort(port, proto, timeout)
    elif mode == "disable":
        fw.disablePort(port, proto)
    elif mode == "query":
        if not fw.queryPort(port, proto):
            sys.exit(-1)

# trusted
elif action == "trusted":
    if mode == "list":
        l = fw.getTrusted()
        if len(l) > 0:
            print " ".join(l)
        sys.exit(0)

    if mode == "enable":
        fw.enableTrusted(value, timeout)
    elif mode == "disable":
        fw.disableTrusted(value)
    elif mode == "query":
        if not fw.queryTrusted(value):
            sys.exit(-1)

# masquerade
elif action == "masquerade":
    if mode == "list":
        l = fw.getMasquerades()
        if len(l) > 0:
            print " ".join(l)
        sys.exit(0)

    if mode == "enable":
        fw.enableMasquerade(value, timeout)
    elif mode == "disable":
        fw.disableMasquerade(value)
    elif mode == "query":
        if not fw.queryMasquerade(value):
            sys.exit(-1)

# forward port
elif action == "forward-port":
    if mode == "list":
        l = fw.getForwardPorts()
        if len(l) > 0:
            print "\n".join(["if=%s:port=%s:proto=%s:toport=%s:toaddr=%s" % (interface, port, protocol, toport, toaddr) for (interface, port, protocol, toport, toaddr) in l])
        sys.exit(0)

    try:
        toport = ""
        toaddr = ""
        args = value.split(":")
        if len(args) < 4 or len(args) > 6:
            __fail()
        if not (args[0].startswith("if=") or args[1].startswith("port=") or \
                    args[2].startswith("proto=")):
            __fail()
        for arg in args:
            (opt,val) = arg.split("=")
            if opt == "if":
                interface = val
            elif opt == "port":
                port = val
            elif opt == "proto":
                protocol = val
            elif opt == "toport":
                toport = val
            elif opt == "toaddr":
                toaddr = val
            else:
                raise ValueError, "invalid forward port arg %s" % (arg)
    except Exception, msg:
        __fail(msg)

    if mode == "enable":
        fw.enableForwardPort(interface, port, protocol, toport, toaddr, timeout)
    elif mode == "disable":
        fw.disableForwardPort(interface, port, protocol, toport, toaddr)
    elif mode == "query":
        if not fw.queryForwardPort(interface, port, protocol, toport,
                                       toaddr):
            sys.exit(-1)

# block icmp
elif action == "icmp-block":
    if mode == "list":
        l = fw.getIcmpBlocks()
        if len(l) > 0:
            print " ".join(l)
        sys.exit(0)

    if mode == "enable":
        fw.enableIcmpBlock(value, timeout)
    elif mode == "disable":
        fw.disableIcmpBlock(value)
    elif mode == "query":
        if not fw.queryIcmpBlock(value):
            sys.exit(-1)

# custom
elif action == "custom":
    if mode == "list":
        names = [ "table", "chain", "src", "src_port", "dst", "dst_port",
                  "protocol", "iface_in", "iface_out", "physdev_in",
                  "physdev_out", "target" ]
        items = fw.getCustoms()
        for item in items:
            args = [ ]
            for i in xrange(len(item)):
                if item[i]:
                    args.append("%s=%s" % (names[i], item[i]))
            print ", ".join(args)
        sys.exit(0)

    try:
        table = "filter"
        chain = "INPUT"
        src = src_port = dst = dst_port = protocol = iface_in = iface_out = ""
        physdev_in = physdev_out = ""
        target = "ACCEPT"

        args = value.split(":")
        for arg in args:
            (opt,val) = arg.split("=")
            if opt == "table":
                table = val
            elif opt == "chain":
                chain = val
            elif opt == "src":
                src = val
            elif opt == "src_port":
                src_port = val
            elif opt == "dst":
                dst = val
            elif opt == "dst_port":
                dst_port = val
            elif opt == "protocol":
                protocol = val
            elif opt == "iface_in":
                iface_in = val
            elif opt == "iface_out":
                iface_out = val
            elif opt == "physdev_in":
                physdev_in = val
            elif opt == "physdev_out":
                physdev_out = val
            elif opt == "target":
                target = val
            else:
                raise ValueError, "invalid custom arg %s" % (arg)
    except Exception, msg:
        __fail(msg)

    if mode == "enable":
        status = fw.enableCustom(table, chain, src, src_port, dst, dst_port,
                                 protocol, iface_in, iface_out, physdev_in,
                                 physdev_out, target, timeout)
    elif mode == "disable":
        status = fw.disableCustom(table, chain, src, src_port, dst, dst_port,
                                  protocol, iface_in, iface_out, physdev_in, 
                                  physdev_out, target)
    elif mode == "query":
        if not fw.queryCustom(table, chain, src, src_port, dst, dst_port,
                              protocol, iface_in, iface_out, physdev_in,
                              physdev_out, target):
            sys.exit(-1)

if status != 0:
    if status == ALREADY_ENABLED:
        print "already enabled"
    elif status == NOT_ENABLED:
        print "not enabled"
    elif status == ENABLE_FAILED:
        print "enable failed"
    elif status == DISABLE_FAILED:
        print "disable failed"
    elif status == NO_IPV6_NAT:
        print "no ipv6 nat"
    elif status == INVALID_ACTION:
        print "invalid action"
    elif status == INVALID_SERVICE:
        print "invalid service"
    elif status == INVALID_PORT:
        print "invalid port"
    elif status == INVALID_PROTOCOL:
        print "invalid protocol"
    elif status == INVALID_INTERFACE:
        print "invalid interface"
    elif status == INVALID_ADDR:
        print "invalid address"
    elif status == INVALID_FORWARD:
        print "invalid forward"
    elif status == INVALID_ICMP_TYPE:
        print "invalid icmp type"
    elif status == INVALID_TABLE:
        print "invalid table"
    elif status == INVALID_CHAIN:
        print "invalid chain"
    elif status == INVALID_TARGET:
        print "invlaid target"
    elif status == MISSING_TABLE:
        print "missing table"
    elif status == MISSING_CHAIN:
        print "missing chain"
    elif status == MISSING_PORT:
        print "missing port"
    elif status == MISSING_PROTOCOL:
        print "missing protocol"
    elif status == MISSING_ADDR:
        print "missing address"
    else:
        print "unknown error"

sys.exit(status)
