#!/usr/bin/python3 -sP
# AWS EC2 HibInit Agent. This agent does several things upon startup:
# 1. It checks for sufficient swap space to allow hibernation and fails if not enough space
# 2. If there's no swap file or the existing swap file isn't of a sufficient size,
#      it creates it and launches a background thread to touch all of its blocks
#      and makes sure that EBS volumes are pre-warmed.
# 3. It updates the offset of the swap file in the kernel using SNAPSHOT_SET_SWAP_AREA ioctl.
# 4. It updates the resume offset and resume device in grub file.
#
# This file is compatible both with Python 2 and Python 3
import argparse
import array
import ctypes as ctypes
import fcntl
import mmap
import os
import struct
import sys
import syslog
import math
import requests
import signal
from subprocess import check_call, check_output

try:
    from ConfigParser import ConfigParser, NoSectionError, NoOptionError
except:
    from configparser import ConfigParser, NoSectionError, NoOptionError

# Space reserved for swap headers
SWAP_RESERVED_SIZE = 16384
log_to_syslog = True
SWAP_FILE = '/swap'

DEFAULT_STATE_DIR = '/var/lib/hibinit-agent'
HIB_ENABLED_FILE = "hibernation-enabled"
IMDS_BASEURL = 'http://169.254.169.254'
IMDS_API_TOKEN_PATH = 'latest/api/token'
IMDS_SPOT_ACTION_PATH = 'latest/meta-data/hibernation/configured'
MAX_SWAP_SIZE_OFFSET_ALLOWED = 100 * 1024

# Sigterm_handler Behaviour Modifiers
CRITICAL_PROCESS_IN_PROGRESS = False
SHUTDOWN_REQUESTED = False


def log(message):
    if log_to_syslog:
        syslog.syslog(message)


def sigterm_handler(signum, frame):
    if CRITICAL_PROCESS_IN_PROGRESS:
        global SHUTDOWN_REQUESTED
        SHUTDOWN_REQUESTED = True
    else:
        cmd = "swapon -s | cut -f1,1 | grep -w {name}"
        cmd = cmd.format(name=SWAP_FILE)
        swapoff = "swapoff {filename}"
        cmd = check_output(cmd, shell=True)
        if cmd.strip() == SWAP_FILE:
            swapoff = swapoff.format(filename=SWAP_FILE)
            check_call(swapoff, shell=True)
        if os.path.isfile(SWAP_FILE) and os.access(SWAP_FILE, os.R_OK):
            os.remove(SWAP_FILE)
        exit(0)


def begin_critical_process():
    global CRITICAL_PROCESS_IN_PROGRESS
    CRITICAL_PROCESS_IN_PROGRESS = True


def end_critical_process():
    global CRITICAL_PROCESS_IN_PROGRESS
    CRITICAL_PROCESS_IN_PROGRESS = False
    if SHUTDOWN_REQUESTED:
        exit(0)


def fallocate(fl, size):
    try:
        _libc = ctypes.CDLL('libc.so.6')
        _fallocate = _libc.fallocate
        _fallocate.argtypes = [ctypes.c_int, ctypes.c_int, ctypes.c_ulong, ctypes.c_ulong]

        # (FD, mode, offset, len)
        res = _fallocate(fl.fileno(), 0, 0, size)
        if res != 0:
            raise Exception("Failed to perform fallocate(). Result: %d" % res)
    except Exception as e:
        log("Failed to call fallocate(), will use resize. Err: %s" % str(e))
        fl.seek(size-1)
        fl.write(chr(0))


def get_file_block_number(filename):
    with open(filename, 'r') as handle:
        buf = array.array('L', [0])
        # from linux/fs.h
        FIBMAP = 0x01
        result = fcntl.ioctl(handle.fileno(), FIBMAP, buf)
    if result < 0:
        raise Exception("Failed to get the file offset. Error=%d" % result)
    return buf[0]


def get_rootfs_size():
    stat=os.statvfs('/')
    return math.ceil(float(stat.f_bsize * stat.f_blocks)/(1024*1024*1024))


# This is only for grub2
def find_grub_mount():
    path_list = ['/etc/grub2-efi.cfg',
                 '/boot/grub2/grub.cfg',
                 '/boot/grub2-efi/grub.cfg',
                 '/etc/grub2.cfg']
    for ln in path_list:
        if os.path.isfile(ln) and os.access(ln, os.R_OK):
            cmd = 'stat -L -c %m ' + ln
            mount = check_output(cmd, shell=True, universal_newlines=True)
            return mount.strip()
    return None


def patch_grub_config(swap_device, offset):
    log("Updating GRUB to use the device %s with offset %d for resume" % (swap_device, offset))
    cmd = "grubby --update-kernel=ALL --args='no_console_suspend={console} resume_offset={offset} resume={swap_device}'"
    cmd = cmd.format(console='1', offset=offset, swap_device=swap_device)
    mount_point = find_grub_mount()
    if mount_point is None:
        log("Grub configuration is not updated. Grub cannot found under /boot or /etc. Please run manual grub update with resume parameters")
        return
    fsfreeze = "sync && mountpoint -q {mount} && trap '' HUP INT QUIT TERM && fsfreeze -f {mount} && fsfreeze -u {mount}"
    fsfreeze = fsfreeze.format(mount=mount_point)
    mkconfig = "grub2-mkconfig -o {path}"
    mkconfig = mkconfig.format(path="/boot/grub2/grub.cfg")
    check_output(mkconfig, shell=True)
    check_call(cmd, shell=True)
    check_call(fsfreeze, shell=True)

    # Some of grubby versions need a restart after changing in kernel parameters
    # echo offset and swap device helps the customer to use the agent immediately
    # after rpm installation. 
    if os.path.exists("sys/power/resume"):
       echoResumeDeviceCmd = "echo {swap_device} > /sys/power/resume"
       echoResumeDeviceCmd = echoResumeDeviceCmd.format(swap_device=swap_device)
       log("sys/power/resume exist and would be updated")
       check_output(echoResumeDeviceCmd, shell=True)

    if os.path.exists("/sys/power/resume_offset"):
       echoResumeOffsetCmd = "echo {offset} > /sys/power/resume_offset"
       echoResumeOffsetCmd = echoResumeOffsetCmd.format(offset=offset)
       log("sys/power/resume_offset exist and would be updated")
       check_output(echoResumeOffsetCmd, shell=True)
    log("GRUB configuration is updated")


def update_kernel_swap_offset(swapon, swapoff, filename, grub_update, btrfs_enabled):
    swapon = swapon.format(swapfile=filename)
    log("Running: %s" % swapon)
    check_call(swapon, shell=True)
    log("Updating the kernel offset for the swapfile: %s" % filename)

    statbuf = os.stat(filename)
    dev = statbuf.st_dev
    if btrfs_enabled:
        getOffset = "btrfs inspect-internal map-swapfile -r {swapfile}".format(swapfile=filename)
        offset = int(check_output(getOffset, shell=True).decode(sys.getfilesystemencoding()))
    else:
        offset = get_file_block_number(filename)

    if grub_update:
        dev_str = find_device_for_file(filename)
        patch_grub_config(dev_str, offset)
    else:
        log("Skipping GRUB configuration update")

    log("Setting swap device to %d with offset %d" % (dev, offset))

    if not btrfs_enabled:
        # Set the kernel swap offset, see https://www.kernel.org/doc/Documentation/power/userland-swsusp.txt
        # From linux/suspend_ioctls.h
        SNAPSHOT_SET_SWAP_AREA = 0x400C330D
        buf = struct.pack('LI', offset, dev)
        with open('/dev/snapshot', 'r') as snap:
            fcntl.ioctl(snap, SNAPSHOT_SET_SWAP_AREA, buf)
        log("Done updating the swap offset. Turning swapoff")

    swapoff = swapoff.format(swapfile=filename)
    log("Running: %s" % swapoff)
    check_call(swapoff, shell=True)
    check_call("trap '' HUP INT QUIT TERM && dracut -a resume -f", shell=True)


def find_device_for_file(filename):
    # Find the mount point for the swap file ('df -P /swap')
    df_out = check_output(['df', '-P', filename]).decode(sys.getfilesystemencoding())
    dev_str = df_out.split("\n")[1].split()[0]
    return dev_str


class SwapInitializer(object):
    def __init__(self, filename, swap_size, touch_swap, btrfs_enabled, mkswap, swapoff, swapon):
        self.filename = filename
        self.swap_size = swap_size
        self.mkswap = mkswap
        self.swapoff = swapoff
        self.swapon = swapon
        self.touch_swap = touch_swap
        self.btrfs_enabled = btrfs_enabled
        self.block_size = 1024

    def do_allocate(self):
        log("Allocating %d bytes in %s" % (self.swap_size, self.filename))
        with open(self.filename, 'w+') as fl:
            fallocate(fl, self.swap_size)
        os.chmod(self.filename, 0o600)

    def setup_btrfs(self):
        no_cow = "truncate -s 0 {swapfile} && chattr +C {swapfile}".format(swapfile=self.filename)
        log("Setting /swap to No Copy On Write")
        check_call(no_cow, shell=True)

    def init_swap(self):
        """
            Initialize the swap using direct IO to avoid polluting the page cache
        """
        try:
            cur_swap_size = os.stat(self.filename).st_size
            if cur_swap_size >= self.swap_size:
                log("Swap file size (%d bytes) is already large enough" % cur_swap_size)
                if self.init_mkswap():
                    return
        except OSError:
            try:
                os.unlink(self.filename)
            except:
                pass

        if self.btrfs_enabled:
            self.setup_btrfs()

        self.do_allocate()
        if not self.touch_swap:
            log("Swap pre-heating is skipped, the swap blocks won't be touched during "
                "to ensure they are ready")
            self.init_mkswap()
            return

        written = 0
        log("Opening %s for direct IO" % self.filename)
        fd = os.open(self.filename, os.O_RDWR | os.O_DIRECT | os.O_SYNC | os.O_DSYNC)
        if fd < 0:
            raise Exception("Failed to initialize the swap. Err: %s" % os.strerror(os.errno))

        filler_block = None
        try:
            # Create a filler block that is correctly aligned for direct IO
            filler_block = mmap.mmap(-1, 1024 * 1024)
            # We're using 'b' to avoid optimizations that might happen for zero-filled pages
            filler_block.write(b'b' * 1024 * 1024)

            log("Touching all blocks in %s" % self.filename)

            while written < self.swap_size:
                res = os.write(fd, filler_block)
                if res <= 0:
                    raise Exception("Failed to touch a block. Err: %s" % os.strerror(os.errno))
                written += res
        finally:
            os.close(fd)
            if filler_block:
                filler_block.close()
        log("Swap file %s is ready" % self.filename)
        self.init_mkswap()


    def init_mkswap(self):
        # Do mkswap
        try:
            mkswap = self.mkswap.format(swapfile=self.filename)
            log("Running: %s" % mkswap)
            check_call(mkswap, shell=True)
            return True
        except Exception as e:
            log("Failed to initialize swap, reason: %s" % str(e))
        return False


class BackgroundInitializerRunner(object):
    def __init__(self, swapper, swapon, swapoff, filename, update_grub, btrfs_enabled):
        self.swapper = swapper
        self.thread = None
        self.error = None
        self.swapon = swapon
        self.swapoff = swapoff
        self.filename = filename
        self.update_grub = update_grub
        self.btrfs_enabled = btrfs_enabled

    def init_background_process(self):
        try:
            pid = os.fork()
            if pid > 0:
            # Exit parent process
                sys.exit(0)
        except OSError:
            print >> sys.stderr, "fork failed: %d" % (OSError)
            sys.exit(1)

        # Configure the child processes environment
        os.chdir("/")
        os.setsid()
        os.umask(0o22)

        # Configures to handle kills gracefully
        signal.signal(signal.SIGTERM, sigterm_handler)

    def init_hibernate_setup(self):
        try:
            if self.swapper is not None:
                self.swapper.init_swap()
            begin_critical_process()
            update_kernel_swap_offset(self.swapon, self.swapoff, self.filename, self.update_grub, self.btrfs_enabled)
            end_critical_process()
        except Exception as ex:
            log("Failed to initialize swap when using async, reason: %s" % str(ex))
            self.error = ex


def identify_file_system(swapfile, file_system):
    # Walk the parent directories of the swapfile to find on which
    # filesystem it's mounted
    swap_place = swapfile
    dev = None
    while not dev:
        swap_place, _ = os.path.split(swap_place)
        try:
            dev = find_device_for_file(swap_place)
        except:
            pass
            if swap_place == '/':
                raise Exception("Failed to find the filesystem type of /")

    with open("/proc/mounts") as fl:
        lines = fl.read().split("\n")
        for ln in lines:
            if dev in ln and file_system in ln:
                return True
    return False


class Config(object):
    def __init__(self, config, args):
        def get(section, name):
            try:
                return config.get(section, name)
            except NoSectionError:
                return None
            except NoOptionError:
                return None

        def get_int(section, name):
            v = get(section, name)
            if v is None:
                return None
            return int(v)

        self.log_to_syslog = self.merge(
            self.to_bool(get('core', 'log-to-syslog')), self.to_bool(args.log_to_syslog), True)

        self.mkswap = self.merge(get('swap', 'mkswap'), args.mkswap, 'mkswap {swapfile}')
        self.swapon = self.merge(get('swap', 'swapon'), args.swapon, 'swapon {swapfile}')
        self.swapoff = self.merge(get('swap', 'swapoff'), args.swapoff, 'swapoff {swapfile}')
        self.touch_swap = self.merge(
            self.to_bool(get('core', 'touch-swap')), self.to_bool(args.touch_swap),
            identify_file_system(SWAP_FILE, "xfs"))
        self.btrfs_enabled = self.merge(
            self.to_bool(get('core', 'btrfs-enabled')), self.to_bool(args.btrfs_enabled),
            identify_file_system(SWAP_FILE, "btrfs"))

        self.grub_update = self.merge(
            self.to_bool(get('core', 'grub-update')), self.to_bool(args.grub_update), True)

        self.swap_percentage = self.merge(
            get_int('swap', 'percentage-of-ram'), args.swap_ram_percentage, 100)
        self.swap_mb = self.merge(
            get_int('swap', 'target-size-mb'), args.swap_target_size_mb, 4000)

        self.state_dir = get('core', 'state-dir')
        if self.state_dir is None:
            self.state_dir = DEFAULT_STATE_DIR

    def merge(self, cf_value, arg_value, def_val):
        if arg_value is not None:
            return arg_value
        if cf_value is not None:
            return cf_value
        return def_val

    def to_bool(self, bool_str):
        """Parse the string and return the boolean value encoded or raise an exception"""
        if bool_str is None:
            return None
        if bool_str.lower() in ['true', 't', '1']:
            return True
        elif bool_str.lower() in ['false', 'f', '0']:
            return False
        # if here we couldn't parse it
        raise ValueError("%s is not recognized as a boolean value" % bool_str)

    def __str__(self):
        return str(self.__dict__)


def get_imds_token(seconds=21600):
    """ Get a token to access instance metadata. """
    log("Requesting new IMDSv2 token.")
    request_header = {'X-aws-ec2-metadata-token-ttl-seconds': '{}'.format(seconds)}
    token_url = '{}/{}'.format(IMDS_BASEURL, IMDS_API_TOKEN_PATH)
    response = requests.put(token_url, headers=request_header)
    response.close()
    if response.status_code != 200:
        return None

    return response.text


def create_state_dir(state_dir):
    """ Create agent run dir if it doesn't exist."""
    if not os.path.isdir(state_dir):
        os.makedirs(state_dir)


def hibernation_enabled(state_dir):
    """Returns a boolean indicating whether hibernation is enabled or not.

    Hibernation can't be enabled/disabled after the instance launch. If we find
    hibernation to be enabled, we create a semephore file so that we don't
    have to probe IMDS again. That is useful when a instance is rebooted
    after/if the IMDS http endpoint has been disabled.
    """
    hib_sem_file = os.path.join(state_dir, HIB_ENABLED_FILE)
    if os.path.isfile(hib_sem_file):
        log("Found {!r}, configuring hibernation".format(hib_sem_file))
        return True

    imds_token = get_imds_token()
    if imds_token is None:
        # IMDS http endpoint is disabled
        return False

    request_header = {'X-aws-ec2-metadata-token': imds_token}
    response = requests.get("{}/{}".format(IMDS_BASEURL, IMDS_SPOT_ACTION_PATH),
                 headers=request_header)
    response.close()
    if response.status_code != 200 or response.text.lower() == "false":
        return False

    log("Hibernation Configured Flag found")
    os.mknod(hib_sem_file)
    return True


def main():

    # Parse arguments
    parser = argparse.ArgumentParser(description="An EC2 background process that creates a setup for instance hibernation "
                                                 "at instance launch and also registers ACPI sleep event/actions")
    parser.add_argument('-c', '--config', help='Configuration file to use', type=str)
    parser.add_argument("-syslog", "--log-to-syslog", help='Log to syslog', type=str)
    parser.add_argument("-touch", "--touch-swap", help='Do swap initialization', type=str)
    parser.add_argument("-btrfs", "--btrfs_enabled", help='Sets no copy on write on swap file', dest="btrfs_enabled", type=str)
    parser.add_argument("-grub", "--grub-update", help='Update GRUB config with resume offset', type=str)
    parser.add_argument("-p", "--swap-ram-percentage", help='The target swap size as a percentage of RAM', type=int)
    parser.add_argument("-s", "--swap-target-size-mb", help='The target swap size in megabytes', type=int)
    parser.add_argument('--mkswap', help='The command line utility to set up swap', type=str)
    parser.add_argument('--swapon', help='The command line utility to turn on swap', type=str)
    parser.add_argument('--swapoff', help='The command line utility to turn off swap', type=str)

    args = parser.parse_args()

    config_file = ConfigParser()
    if args.config:
        config_file.read(args.config)

    config = Config(config_file, args)
    global log_to_syslog
    log_to_syslog = config.log_to_syslog

    log("Effective config: %s" % config)
    create_state_dir(config.state_dir)

    # Let's first check if we even need to run
    if not hibernation_enabled(config.state_dir):
        log("Instance Launch has not enabled Hibernation Configured Flag. hibinit-agent exiting!!")
        exit(0)
    # Validate if disk space>total RAM
    ram_bytes = os.sysconf('SC_PAGE_SIZE') * os.sysconf('SC_PHYS_PAGES')
    if get_rootfs_size()<=(math.ceil(float(ram_bytes)/(1024*1024*1024))):
        log("Insufficient disk space. Cannot create setup for hibernation. Please allocate a larger root device")
        exit(1)

    target_swap_size = config.swap_mb * 1024 * 1024
    swap_percentage_size = ram_bytes * config.swap_percentage // 100
    if swap_percentage_size > target_swap_size:
        target_swap_size = int(swap_percentage_size)
    log("Will check if swap is at least: %d megabytes" % (target_swap_size // (1024*1024)))

    # Validate if swap file exists
    cur_swap = 0
    if os.path.isfile(SWAP_FILE) and os.access(SWAP_FILE, os.R_OK):
        cur_swap = os.path.getsize(SWAP_FILE)

    bi = BackgroundInitializerRunner(None, config.swapon, config.swapoff, SWAP_FILE, config.grub_update, config.btrfs_enabled)
    if cur_swap > target_swap_size - SWAP_RESERVED_SIZE + MAX_SWAP_SIZE_OFFSET_ALLOWED:
        log("Swap already exists! (have %d, need %d), deleting existing swap file %s since current swap is "
              "sufficiently large and wasting disk space" %
            (cur_swap, target_swap_size, SWAP_FILE))
        os.remove(SWAP_FILE)
    elif cur_swap >= target_swap_size - SWAP_RESERVED_SIZE:
        log("There's sufficient swap available (have %d, need %d)" %
            (cur_swap, target_swap_size))
        bi.init_background_process()
        bi.init_hibernate_setup()
        exit()
    # Validate if instance was launched from pre-created image and swap size>=total RAM, if not re-create the swap
    elif cur_swap > 0 and (cur_swap < target_swap_size - SWAP_RESERVED_SIZE):
        log("Swap already exists! (have %d, need %d), deleting existing swap file %s" %
            (cur_swap, target_swap_size, SWAP_FILE))
        os.remove(SWAP_FILE)

    log("Create swap and initialize it")
    # We need to create swap, but first validate that we have enough free space
    swap_dev = os.path.dirname(SWAP_FILE)
    st = os.statvfs(swap_dev)
    free_bytes = st.f_bavail * st.f_frsize
    # Rounding off to swap_size+10mb for swap headers
    free_space_needed = target_swap_size + 10 * 1024 * 1024
    if free_space_needed >= free_bytes:
        log("There's not enough space (%d present, %d needed) on the device: %s" % (
                free_bytes, free_space_needed, swap_dev))
        exit(1)
    log("There's enough space (%d present, %d needed) on the device: %s" % (
        free_bytes, free_space_needed, swap_dev))

    sw = SwapInitializer(SWAP_FILE, target_swap_size, config.touch_swap, config.btrfs_enabled,
                             config.mkswap, config.swapoff, config.swapon)
    bi.swapper = sw
    bi.init_background_process()

    log("kicking child process to initiate the setup")
    bi.init_hibernate_setup()

if __name__ == '__main__':
    main()
