#!/usr/bin/python
#
# Module that removes channels from an installed satellite
#
#
# Copyright (c) 2008--2010 Red Hat, Inc.
#
# This software is licensed to you under the GNU General Public License,
# version 2 (GPLv2). There is NO WARRANTY for this software, express or
# implied, including the implied warranties of MERCHANTABILITY or FITNESS
# FOR A PARTICULAR PURPOSE. You should have received a copy of GPLv2
# along with this software; if not, see
# http://www.gnu.org/licenses/old-licenses/gpl-2.0.txt.
# 
# Red Hat trademarks are not licensed under GPLv2. No permission is
# granted to use or replicate Red Hat trademarks that are incorporated
# in this software or its documentation. 
#

import sys
import string
import os
import shutil
from optparse import Option, OptionParser

try:
    from rhn import rhnLockfile    #new place for rhnLockfile
except:
    from spacewalk.common import rhnLockfile #old place for rhnLockfile

from spacewalk.satellite_tools.progress_bar import ProgressBar
from spacewalk.common import CFG, initCFG, initLOG, log_debug, log_error
from spacewalk.server import rhnSQL

options_table = [
    Option("-v", "--verbose",       action="count", 
           help="Increase verbosity"),
    Option("-l", "--list",          action="store_true",
           help="List defined channels and exit"),
    Option("-c", "--channel",       action="append", 
           help="Delete this channel (can be present multiple times)"),
    Option("-u", "--unsubscribe",   action="store_true",
           help="Unsubscribe systems registered to the specified channels"),
    Option(      "--justdb",        action="store_true",
                 help="Delete only from the database, do not remove files from disk"),
    Option(      "--force",         action="store_true",
                 help="Remove the channel packages from any other channels too (Not Recommended)"),
    Option("-p", "--skip-packages", action="store_true",
           help="Do not remove package metadata or packages from the filesystem (Not Recommended)"),
    Option(      "--skip-channels", action="store_true",
           help="Remove only packages from channel not the channel itself."),
]

LOCK = None

def main():
    global LOCK

    if os.getuid() != 0:
        sys.stderr.write('ERROR: must be root to execute\n')
        sys.exit(8)

    try:
        LOCK = rhnLockfile.Lockfile('/var/run/satellite-sync.pid')
    except rhnLockfile.LockfileLockedException:
        print("ERROR: Either another instance of spacewalk-remove-channel is running, or a satellite-sync is currently taking place.")
        sys.exit(-1)


    global options_table
    parser = OptionParser(option_list=options_table)

    (options, args) = parser.parse_args()

    if args:
        for arg in args:
            sys.stderr.write("Not a valid option ('%s'), try --help\n" % arg)
        return -1

    if not (options.channel or options.list):
        sys.stderr.write("Nothing to do\n")
        sys.exit(0)

    initCFG('server')
    initLOG("stdout", options.verbose or 0)

    rhnSQL.initDB(CFG.DEFAULT_DB)

    dict_label, dict_parents = __listChannels()
    if options.list:
        keys = dict_parents.keys()
        keys.sort()
        for c in keys:
            print c
            for sub in dict_parents[c]:
                print "\t" + sub
        return


    # Verify if the channel is valid
    channels = {}
    for channel in options.channel:
        channels[channel] = None

    child_test_fail = False
    for channel in channels.keys():
        if not dict_label.has_key(channel):
            print "Unknown channel %s" % channel
            return -1
        if not options.skip_channels:
            # Sanity check: verify subchannels are deleted as well if base
            # channels are selected
            if not dict_parents.has_key(channel):
                continue
            # this channel is a parent channel?
            children = []
            for subch in dict_parents[channel]:
                if not channels.has_key(subch):
                    child_test_fail = True
                    children.append(subch)
            if children:
                print "Error: cannot remove channel %s: subchannel(s) exist: " %(
                    channel)
                for child in children:
                    print "\t\t\t" + child

    if child_test_fail:
        return -1

    if not options.skip_channels:
       if __serverCheck(channels.keys(), options.unsubscribe):
           return -1

       if __kickstartCheck(channels.keys()):
           return -1


    try:
        delete_channels(channels.keys(), force=options.force,
                        justdb=options.justdb, skip_packages=options.skip_packages,
                        skip_channels=options.skip_channels)
    except:
        rhnSQL.rollback()
        raise
    rhnSQL.commit()
    releaseLOCK()
    return 0

def __serverCheck(labels, unsubscribe):
    sql = """
        select distinct S.org_id, S.id, S.name
        from rhnChannel c inner join 
             rhnServerChannel sc on c.id = sc.channel_id inner join
             rhnServer s on s.id = sc.server_id
        where c.label in (%s)
    """
    params, bind_params = _bind_many(labels)
    bind_params = string.join(bind_params, ', ')
    h = rhnSQL.prepare(sql % (bind_params))
    apply(h.execute, (), params)
    list = h.fetchall_dict()
    if not list:
        return 0


    if unsubscribe:
        return  __unsubscribeServers(labels)

    print("\nCurrently there are systems subscribed to one or more of the specified channels.")
    print("If you would like to automatically unsubscribe these systems, simply use the --unsubscribe flag.\n")
    print("The following systems were found to be subscribed:")

    print 'org_id'.ljust(8),
    print 'id'.ljust(14),
    print('name')
    print ("-"*32)
    for map in list:
        print str(map['org_id']).ljust(8),
        print str(map['id']).ljust(14),
        print(map['name'])

    return len(list) 


def __unsubscribeServers(labels):
    sql = """
        select distinct sc.server_id as server_id, C.id as channel_id, c.parent_channel, c.label
        from rhnChannel c inner join 
             rhnServerChannel sc on c.id = sc.channel_id
        where c.label in (%s) order by C.parent_channel
    """
    params, bind_params = _bind_many(labels)
    bind_params = string.join(bind_params, ', ')
    h = rhnSQL.prepare(sql % (bind_params))
    apply(h.execute, (), params)
    list = h.fetchall_dict()

    channel_counts = {}
    for i in list:
        if channel_counts.has_key(i['label']):
            channel_counts[i['label']] = channel_counts[i['label']] + 1
        else:
            channel_counts[i['label']] = 1
    print("\nThe following channels will have their systems unsubscribed:")
    channel_list = channel_counts.keys()
    channel_list.sort()
    for i in channel_list:
        print(i.ljust(40)),
        print(str(channel_counts[i]).ljust(8))

    pb = ProgressBar(prompt='Unsubscribing:    ', endTag=' - complete',
                     finalSize=len(list), finalBarLength=40, stream=sys.stdout)
    pb.printAll(1)

    for i in list:
        __unbsubscribeServer(i['server_id'], i['channel_id'])
        pb.addTo(1)
        pb.printIncrement()
    pb.printComplete()


def __unbsubscribeServer(server_id, channel_id):
    sql = """
         call rhn_channel.unsubscribe_server(:server_id, :channel_id)
    """
    h = rhnSQL.prepare(sql)
    h.execute(server_id=server_id, channel_id=channel_id)


def __kickstartCheck(labels):
    sql = """
        select K.org_id, K.label
        from rhnKSData K inner join 
             rhnKickstartDefaults KD on KD.kickstart_id = K.id inner join
             rhnKickstartableTree KT on KT.id = KD.kstree_id inner join
             rhnChannel c on c.id = KT.channel_id
        where c.label in (%s)
    """
    params, bind_params = _bind_many(labels)
    bind_params = string.join(bind_params, ', ')
    h = rhnSQL.prepare(sql % (bind_params))
    apply(h.execute, (), params)
    list = h.fetchall_dict()

    if not list:
        return 0

    print("The following kickstarts are associated with one of the specified channels. " +
          "Please remove these or change their associated base channel.\n")
    print('org_id'.ljust(8)),
    print('label')
    print("-"*20) 
    for map in list:
        print str(map['org_id']).ljust(8),
        print(map['label'])

    return len(list)


def __listChannels():
    sql = """
        select c1.label, c2.label parent_channel 
        from rhnChannel c1 left outer join rhnChannel c2 on c1.parent_channel = c2.id
        where C1.org_id is null or C2.org_id is null
        order by c2.label desc, c1.label asc
    """
    h = rhnSQL.prepare(sql)
    h.execute()
    labels = {}
    parents = {}
    while 1:
        row = h.fetchone_dict()
        if not row:
            break
        parent_channel = row['parent_channel']
        labels[row['label']] = parent_channel
        if not parent_channel:
            parents[row['label']] = []

        if parent_channel:
            parents[parent_channel].append(row['label'])

    return labels, parents


def delete_channels(channelLabels, force=0, justdb=0, skip_packages=0, skip_channels=0):
    # Get the package ids
    if not channelLabels:
        return 

    rpms_ids = list_packages(channelLabels, force=force, sources=0)
    rpms_paths = _get_package_paths(rpms_ids, sources=0)
    srpms_ids = list_packages(channelLabels, force=force, sources=1)
    srpms_paths = _get_package_paths(srpms_ids, sources=1)

    if not skip_packages:
        _delete_srpms(srpms_ids)
        _delete_rpms(rpms_ids)

    if not justdb and not skip_packages:
        _delete_files(rpms_paths + srpms_paths)

    # Get the channel ids
    h = rhnSQL.prepare("""
        select id, parent_channel 
        from rhnChannel 
        where label = :label 
        order by parent_channel""")
    channel_ids = []
    for label in channelLabels:
        h.execute(label=label)
        row = h.fetchone_dict()
        if not row:
            break
        channel_id = row['id']
        if row['parent_channel']:
            # Subchannel, we have to remove it first
            channel_ids.insert(0, channel_id)
        else:
            channel_ids.append(channel_id)

    if not channel_ids:
        return

    indirect_tables = [
        ['rhnKickstartableTree', 'channel_id', 'rhnKSTreeFile', 'kstree_id'],
    ]
    query = """
        delete from %(table_2)s where %(link_field)s in (
            select id 
              from %(table_1)s
             where %(channel_field)s = :channel_id
        )
    """
    for e in indirect_tables:
        args = {
            'table_1'       : e[0],
            'channel_field' : e[1],
            'table_2'       : e[2],
            'link_field'    : e[3],
        }
        h = rhnSQL.prepare(query % args)
        h.executemany(channel_id=channel_ids)

    tables = [
        ['rhnErrataFileChannel', 'channel_id'],
        ['rhnErrataNotificationQueue', 'channel_id'],
        ['rhnChannelErrata', 'channel_id'],
        ['rhnChannelPackage', 'channel_id'],
        ['rhnRegTokenChannels', 'channel_id'],
        ['rhnServerProfile', 'base_channel'],
        ['rhnKickstartableTree', 'channel_id'],
    ]

    if not skip_channels:
        tables.extend([
            ['rhnChannelFamilyMembers', 'channel_id'],
            ['rhnDistChannelMap', 'channel_id'],
            ['rhnReleaseChannelMap', 'channel_id'],
            ['rhnChannel', 'id'],
        ])

    query = "delete from %s where %s = :channel_id"
    for table, field in tables:
        log_debug(3, "Processing table %s" % table)
        h = rhnSQL.prepare(query % (table, field))
        h.executemany(channel_id=channel_ids)
    if not justdb:
        __deleteRepoData(channelLabels)


def __deleteRepoData(labels):
    dir = '/var/cache/' + CFG.repomd_path_prefix
    for label in labels:
        if os.path.isdir(dir + '/' + label):
            shutil.rmtree(dir + '/' + label)


def list_packages(channelLabels, sources=0, force=0):
    "List the source ids for the channels"
    if sources:
        packages = "srpms"
    else:
        packages = "rpms"
    log_debug(3, "Listing %s" % packages)
    if not channelLabels:
        return []

    params, bind_params = _bind_many(channelLabels)
    bind_params = string.join(bind_params, ', ')

    if sources:
        templ = _templ_srpms()
    else:
        templ = _templ_rpms()

    if force:
        query = templ % ("", bind_params)
    else:
        if CFG.DB_BACKEND == 'oracle':
            minus_op = 'MINUS' # ORA syntax
        else:
            minus_op = 'EXCEPT' # ANSI syntax
        query = """
            %s
            %s
            %s
        """ % (
                templ % ("", bind_params),
                minus_op,
                templ % ("not", bind_params),
            )
    h = rhnSQL.prepare(query)
    apply(h.execute, (), params)
    return map(lambda x: x['id'], h.fetchall_dict() or [])

def _templ_rpms():
    "Returns a template for querying rpms"
    log_debug(4, "Generating template for querying rpms")
    return """\
        select cp.package_id id
        from rhnChannel c, rhnChannelPackage cp
        where c.label %s in (%s)
        and cp.channel_id = c.id"""

def _templ_srpms():
    "Returns a template for querying srpms"
    log_debug(4, "Generating template for querying srpms")
    return """\
        select  ps.id id
        from    rhnPackage p,
                rhnPackageSource ps,
                rhnChannelPackage cp,
                rhnChannel c
        where   c.label %s in (%s)
            and c.id = cp.channel_id
            and cp.package_id = p.id
            and p.source_rpm_id = ps.source_rpm_id
            and ((p.org_id is null and ps.org_id is null) or
                p.org_id = ps.org_id)"""

def _delete_srpms(srcPackageIds):
    """Blow away rhnPackageSource and rhnFile entries.
    """
    if not srcPackageIds:
        return
    # nuke the rhnPackageSource entry
    h = rhnSQL.prepare("""
        delete
        from rhnPackageSource
        where id = :id
    """)
    count = h.executemany(id=srcPackageIds)
    if not count:
        count = 0
    log_debug(2, "Successfully deleted %s/%s source package ids" % (
        count, len(srcPackageIds)))

def _delete_rpms(packageIds):
    if not packageIds:
        return
    group = 300 
    toDel = packageIds[:]
    print "Deleting package metadata (" + str(len(toDel)) + "):"
    pb = ProgressBar(prompt='Removing:         ', endTag=' - complete',
                     finalSize=len(packageIds), finalBarLength=40, stream=sys.stdout)
    pb.printAll(1)

    while len(toDel) > 0:
        _delete_rpm_group(toDel[:group])
        del toDel[:group]
        pb.addTo(group)
        pb.printIncrement()
    pb.printComplete()

def _delete_rpm_group(packageIds):

    references = [
        'rhnChannelPackage', 
        'rhnErrataPackage', 
        'rhnErrataPackageTMP', 
        'rhnPackageChangelogRec', 
        'rhnPackageConflicts', 
        'rhnPackageFile', 
        'rhnPackageObsoletes', 
        'rhnPackageProvides', 
        'rhnPackageRequires', 
        'rhnServerNeededCache',
    ]
    deleteStatement = "delete from %s where package_id = :package_id"
    for table in references:
        h = rhnSQL.prepare(deleteStatement % table)
        count = h.executemany(package_id=packageIds)
        log_debug(3, "Deleted from %s: %d rows" % (table, count))
    deleteStatement = "delete from rhnPackage where id = :package_id"
    h = rhnSQL.prepare(deleteStatement)
    count = h.executemany(package_id=packageIds)
    if count:
        log_debug(2, "DELETED package id %s" % str(packageIds))
    else:
        log_error("No such package id %s" % str(packageIds))
    rhnSQL.commit()

def _delete_files(relpaths):
    for relpath in relpaths:
        path = os.path.join(CFG.MOUNT_POINT, relpath)
        if not os.path.exists(path):
            log_debug(1, "Not removing %s: no such file" % path)
            continue
        try:
            os.unlink(path)
        except OSError:
            log_debug(1,  "Error unlinking %s;" % path)

def _bind_many(l):
    h = {}
    lr = []
    for i in range(len(l)):
        key = 'p_%s' % i
        h[key] = l[i]
        lr.append(':' + key)
    return h, lr

def _get_package_paths(package_ids, sources=0):
    if sources:
        table = "rhnPackageSource"
    else:
        table = "rhnPackage"
    h = rhnSQL.prepare("select path from %s where id = :package_id" % table)
    pdict = {}
    for package_id in package_ids:
        h.execute(package_id=package_id)
        row = h.fetchone_dict()
        if not row:
            continue
        if not row['path']:
            continue
        pdict[row['path']] = None

    return pdict.keys()



def releaseLOCK():
    global LOCK
    if LOCK:
        LOCK.release()



if __name__ == '__main__':
    try:
        sys.exit(main() or 0)
    except KeyboardInterrupt:
        sys.stderr.write("\nUser interrupted process.\n")
        releaseLOCK()
        sys.exit(0)
    except SystemExit:
        # Normal exit
        raise
    except Exception, e:
        releaseLOCK()
        sys.stderr.write("\nERROR: unhandled exception occurred: (%s).\n" % e)
        import traceback
        traceback.print_exc()
        sys.exit(-1)

