#!/usr/bin/python
"""Ask the antitheft server which software version I should be running."""
# Copyright (C) 2007-8 One Laptop Per Child Association, Inc.
# Licensed under the terms of the GNU GPL v2 or later; see COPYING for details.
# Written by C. Scott Ananian <cscott@laptop.org>
from __future__ import with_statement
from __future__ import division
from urllib import urlencode
from urllib2 import urlopen
from binascii import hexlify
import bitfrost.util.json as json
from bitfrost.contents.utils import open_envel, UnifiedContents
from bitfrost.update import perform_update, inhibit_suspend, check_signature
from bitfrost.leases.keys import OATS_KEYS
from bitfrost.leases.core import find_lease
from bitfrost.leases.crypto import date_cmp
from random import SystemRandom
from subprocess import call, check_call
import os, os.path, sys, shutil
import hashlib
import re
import time, calendar
from datetime import datetime, timedelta


DEFAULT_ANTITHEFT_SERVER = 'antitheft.laptop.org'
ANTITHEFT_SERVER_FILE='/security/oats-server'

LAST_ATTEMPT_FILE = '/security/update-attempt'
LAST_QUERY_FILE = '/security/update-query'
INTERVAL_FILE = '/security/update-interval'
STREAM_FILE = '/security/update-stream'
DEFAULT_STREAM_FILE = '/etc/olpc-update/update-stream'
LEASE_FILE='/security/lease.sig'

def read_ofw(dev_path):
    """Read a node in the OFW device tree.  Return None if node not present."""
    real_path = os.path.join('/ofw',dev_path)
    if not os.path.exists(real_path): return None
    with open(real_path) as f:
        return f.read().rstrip("\n\0")

def current_version():
    try:
        shortname = os.path.basename(os.readlink('/versions/running'))
        return UnifiedContents('/versions/contents/%s' % shortname) \
               .contents_hash()
    except:
        return None

def check_credential(data, credential, valid_keys, serialnum):
    # sig01: sha256 keyid data\n
    #  3 2 2    6  1  64 1 N   1
    seen_good_sig = False
    for sig in open_envel('sig', 1, credential):
        # check signature should never raise an error, and at least one
        # should return true.
        if check_signature(json.write(data), sig, valid_keys, serialnum):
            seen_good_sig = True
    if not seen_good_sig:
        raise RuntimeError("No signatures match our keys")
    return True

def timestr_to_secs(timestr):
    t = list(time.strptime(timestr, '%Y%m%dT%H%M%SZ'))
    t[8] = 0 # no dst
    secs = calendar.timegm(tuple(t))
    return secs

def maybe_update_lease(new_lease_data, sn, uuid, lease_expires, report=None):
    """
    Consider new_lease_data as a lease candidate to replace the existing one.
    Does not use the new lease if it fails security checks, or if the existing
    lease has an expiry date further in the future.
    """

    if report is not None:
        report(0, 'Considering new lease from server')

    # cryptographically verify new lease
    minimized = find_lease(sn, uuid, new_lease_data)
    new_expiration = lease_or_sig_expiry(minimized)

    if lease_expires is None:
        if report is not None:
            report(0, 'Existing lease broken or missing, installing new one')
        update_lease(minimized, report=report)
        return

    if report is not None:
        report(0, 'Current lease expires %s, new lease expires %s' % (lease_expires, new_expiration))

    if date_cmp(new_expiration, lease_expires) <= 0:
        if report is not None:
            report(0, 'Staying with existing lease')
        return

    update_lease(minimized, report=report)

def update_lease(new_lease, report=None):
    """
    Atomically put a new lease in place at /security/lease.sig
    """

    if report is not None:
        report(0, 'Switching to new lease')

    try:
        os.unlink(LEASE_FILE + ".new")
    except:
        pass

    new_fd = open(LEASE_FILE+".new", "w")
    new_fd.write(new_lease)
    new_fd.close()

    try:
        shutil.copy(LEASE_FILE, LEASE_FILE + ".bak")
    except:
        # don't abort lease update if backup failed
        pass

    os.rename(LEASE_FILE + ".new", LEASE_FILE)


def check_stolen_hash(hash, nonce, uuid):
    return hash == hashlib.sha256(uuid + ":" + nonce + ":STOLEN").hexdigest()

@inhibit_suspend
def query(urls, valid_keys, lease_expires, delay=1, stream=None, report=None):
    nonce = hexlify(os.urandom(16))
    serialnum = read_ofw('mfg-data/SN')
    uuid      = read_ofw('mfg-data/U#')
    isxo     = True
    if serialnum is None:
	serialnum = 'SHF00000000' # for testing.
        isxo = False
    if uuid is None:
        uuid = 'A000B000-C000-D000-E000-F000G000H000'
        isxo = False

    params = { 'serialnum': serialnum,
               'nonce': nonce }
    vhash = current_version()
    if vhash is not None:
        params['version'] = vhash
    try:
        params['stream'] = open(DEFAULT_STREAM_FILE).read().strip()
    except: pass # no default stream for this build.
    try:
        params['stream'] = open(STREAM_FILE).read().strip()
    except: pass # no stream set.
    if stream is not None:
        params['stream'] = stream # manually force an update stream
    # add free disk space in /versions/pristine
    try:
        st = os.statvfs('/versions/pristine/')
        params['freespace'] = str(st.f_bfree * st.f_bsize // 1024)
    except: pass # not a critical parameter; be safe.

    resp_map = None

    # query all urls
    for url in urls:
        report(0, 'Querying %s' % url)
        try:
            # XXX: should use a handler to handle hashcash stuff here.
            resp = urlopen(url, urlencode(params))
            if resp.code == 200:
                data, credential = open_envel('oatc-signed-resp',1,json.read(resp.read()))
                check_credential(data, credential, valid_keys, serialnum)
                resp_map = open_envel('oatc-resp',1,data)
                break
            else:
                report(0, 'Bad HTTP status code: %d' % resp.code)
        except:
            report(0, 'HTTP query failed')

    if resp_map == None:
        report(0, 'Could not contact any OAT server')
        return None
      
    if 'nonce' not in resp_map or nonce != resp_map['nonce']:
        raise RuntimeError('bad nonce in reply')

    if isxo and 'stolen' in resp_map \
            and check_stolen_hash(resp_map['stolen'], uuid, nonce):
        # this machine has been reported stolen - get rid of
        # any leases and poweroff
        if os.path.exists(LEASE_FILE):
            os.unlink(LEASE_FILE)
        call(['/usr/bin/halt'])
        exit()

    if 'time' in resp_map:
        report(0, 'Server time: ' + resp_map['time'])
        serverepoch = timestr_to_secs(resp_map['time'])

        now = time.time()
        # The main goal is to reset the clock in case it is really off
        # due to RTC battery or tinkering.
        # We only set the time if it's off by more than a day.
        # This prevents messing with NTP.
        if abs(now - serverepoch) > 3600 * 24:
            newtimestr = time.strftime('%Y-%m-%d %H:%M:%S +0000', time.gmtime(serverepoch))
            report(0, "Setting hwclock to " + newtimestr)
            check_call(['/usr/sbin/hwclock',
                        '--set', '--date', newtimestr])
            check_call(['/usr/sbin/hwclock',
                        '--hctosys'])

    if 'lease' in resp_map:
        try:
            maybe_update_lease(resp_map['lease'], serialnum, uuid, lease_expires, report=report)
        except Exception, e:
            # don't abort a sw update if the leasing doesn't work out
            if report is not None:
                report(0, 'Encountered %s when attempting lease update' % type(e))
            pass

    if 'update' in resp_map:
        return resp_map['update']

    if 'delegate' in resp_map:
        new_url, new_key = resp_map['delegate']
        time.sleep(delay) # limit worst case loop traffic.
        return query([ new_url ], [ new_key ], lease_expires, delay=delay*2, stream=stream, report=report)
    return None # i guess that means we shouldn't upgrade

def _clamp(n):
    """clamp number to the interval [0,1]."""
    return max(0, min(n, 1))

def randomly_do_nothing(lease_expires, report=None):
    """look at times of LAST_ATTEMPT_FILE, LAST_QUERY_FILE and compare to
    INTERVAL_FILE to see if we should actually do a check."""
    def filetime(f, now):
        """Return the mtime of the given file, as a datetime object."""
        try:
            t = os.stat(f).st_mtime # in floating-point seconds
            return t if t <= now else None # file modified in future!
        except OSError:
            return None # couldn't find file.
    now = time.time() # current time (in fp seconds)
    last_attempt = filetime(LAST_ATTEMPT_FILE, now)
    last_query = filetime(LAST_QUERY_FILE, now)
    # if we get an error, assume the last attempt was 15 minutes ago.
    if last_attempt is None: last_attempt = now - (15*60)
    # if we get an error, assume that the last query was successful.
    if last_query is None: last_query = last_attempt
    # interval is how many attempts per month should be attempted.
    try:
        interval = int(open(INTERVAL_FILE).read().strip())
    except:
        interval = 30 # check once a day as a fallback.
    if report is not None:
        report(0, 'Last attempt: %d sec ago; last query: %d sec ago; interval: %d' % (now-last_attempt, now-last_query, interval))
    # normalize to attempts per second.
    interval = interval / (30*24*60*60.)
    basic_prob = _clamp(     interval  * (now - last_attempt))

    if lease_expires is not None:
        serverepoch = timestr_to_secs(resp_map['time'])                                            
        lease_mult = 1.
        # if lease has expired, increase probability 4x
        if now >= expiry:
            lease_mult = 4.

        basic_prob = _clamp(basic_prob * lease_mult)

    # if it's been too long (10x expected time) since the last successful
    # query, increase the query probability by up to 10x.
    upper_prob = _clamp((10.*interval) * (now - last_attempt))
    slider     = _clamp(     interval  * (now - last_query) / 10.)
    adj_prob = basic_prob + slider*max(0, upper_prob - basic_prob)

    # roll the dice!
    if report is not None:
        report(0, 'Chances of a query: %3.1f%%' % (adj_prob * 100))
    r = SystemRandom().random()
    return r > adj_prob   # if r is higher, do nothing.

def lease_or_sig_expiry(lease):
    # read lease quickly - 1MB buf
    #fh = open(lpath)
    #lease = fh.read(1024 * 1024)
    #fh.close()

    # rough parsing - act01+sig01 has a different
    # layout from act01+sig02. The only consistent thing
    # is that only the datestamp looks like a datestamp.
    tokens = lease.split(' ')
    expiry = '00000000T000000Z'
    for t in tokens:
        # does it look like a timestamp?
        m = re.match('\d{8}T\d{6}Z$', t)
        if m and date_cmp(expiry, t) > 0:
            expiry = t
    
    return expiry

def touch(f, ignore_errors=False):
    try:
        if not os.path.exists(f):
            open(f,'w').close()
        os.utime(f, None) # set mtime to current time
    except OSError:
        if not ignore_errors: raise

def get_current_lease_expiry():
    try:
        fd = open("/security/lease.sig", "r")
        current_lease = fd.read()
        fd.close()
        cur_expiration = lease_or_sig_expiry(current_lease)
        return cur_expiration
    except:
        return None

def main():
    from optparse import OptionParser
    from bitfrost.update import VERSION_INFO
    parser = OptionParser(usage="""
 %prog [options]
 %prog --help""")
    parser.add_option('-a','--auto', action='store_true', dest='auto',
                      default=False, help="use randomness to ensure that we don't check too often.")
    parser.add_option('-s','--sleep', action='store', type='float',
                      dest='sleep', default=0, metavar='MINUTES',
                      help='sleep for a random period up to the specified limit before starting; this helps avoid synchronized queries.')
    parser.add_option('-v',action='count',dest='verbosity',default=-1,
                      help='display verbose progress information.')
    parser.add_option('-f','--force', action='store_true',dest='force',
                      default=False,help='Force update even if low priority.')
    parser.add_option('--stream', action='store',dest='stream',metavar='STREAM',
                      default=None,help='Force update to the given stream.')
    parser.add_option('--version',action='store_true',dest='version',
                      default=False,
                      help="display version and license information.")
    (options, args) = parser.parse_args()
    if options.version:
        print VERSION_INFO
    if os.getuid() != 0:
        parser.error('Must be run as root.')
    def report(lvl, msg):
        if options.verbosity >= lvl:
            print >>sys.stderr, msg

    lease_expires = get_current_lease_expiry()

    if options.sleep > 0:
        # sleep some number of minutes to prevent all queries from being
        # synchronized
        report(0, 'Sleeping for a bit.')
        time.sleep(SystemRandom().uniform(0, options.sleep*60))

    try:
        if options.auto:
            if randomly_do_nothing(lease_expires, report=report):
                report(0, 'Not time for next query yet.')
                return # don't check.
            
    finally:
        touch(LAST_ATTEMPT_FILE, ignore_errors=True)

    oats_server = DEFAULT_ANTITHEFT_SERVER
    try:
        fh = open(ANTITHEFT_SERVER_FILE)
        fc = fh.read(4096)
        fh.close()
        if fc:
            oats_server = fc.strip()
    except:
        pass

    oats_urls = [ 'http://schoolserver/antitheft/1/',
                  'http://' + oats_server + '/antitheft/1/' ]

    resp = query(oats_urls, OATS_KEYS, lease_expires, stream=options.stream, report=report )
    touch(LAST_QUERY_FILE, ignore_errors=True)
    if resp is None: return # no upgrade information in the response.

    vhash, check_frequency, priority, hints = resp
    report(0, 'Requested update %s with priority %s' % (vhash, priority))
    try: # try to record the requested check_frequency.
        open(INTERVAL_FILE, 'w').write(str(check_frequency))
    except: pass
    if vhash == current_version():
        report(0, 'Already up to date.')
        return # don't need to upgrade
    if priority == 'low':
        if options.force:
            report(0, 'Forcing low priority update.')
        else:
            report(0, 'Skipping low priority update; use --force to override.')
            return # don't need to upgrade

    # invoke updater.
    # XXX: once we've found that we want to update, we should try more often?
    report(0, 'Performing update: %s %s %s' % (vhash, priority, hints))
    perform_update(vhash, priority, hints, options.verbosity)

if __name__ == '__main__': main ()
