#!/usr/bin/python3

import getopt, glob, os, os.path, re, shutil, socket, sys
import subprocess, random, time, tempfile, traceback, shlex

ssh_id_file_keys = None
perform_snapshots = True
safekeep_args = ''
test_reps=2

class TestFailure(Exception):
    pass

def mkssh(cmd='ssh'):
    args = [cmd]
    ssh_id_file = ssh_id_file_keys[0]
    if ssh_id_file:
        args.append('-i')
        args.append(shlex.quote(ssh_id_file))
    return ' '.join(args)

def rpipe(cmd, user, host, desc, cmd2, user2, host2):
    if host:
        cmd = '%s %s@%s %s' % (mkssh(), user, host, shlex.quote(cmd))
    if host2:
        cmd2 = '%s %s@%s %s' % (mkssh(), user2, host2, shlex.quote(cmd2))
    print(cmd, '|', cmd2)
    pout = os.popen(cmd, 'r')
    pin = os.popen(cmd2, 'w')
    for line in pout:
        pin.write(line)
    st1 = pout.close()
    st2 = pin.close()
    if desc and st1:
        raise TestFailure('Failed to %s on host %s' % (desc, host))
    if desc and st2:
        raise TestFailure('Failed to %s on host %s' % (desc, host2))
    if st1: return st1
    return st2

def rcmd(cmd, user, host, desc=None):
    if host:
        cmd = '%s %s@%s %s' % (mkssh(), user, host, shlex.quote(cmd))
    print(cmd)
    st = os.system(cmd)
    if desc and st:
        raise TestFailure('Failed to %s on host %s' % (desc, host))
    return st

def rcmdin(cmd, content, user, host, desc=None):
    if host:
        cmd = '%s %s@%s %s' % (mkssh(), user, host, shlex.quote(cmd))
    print(cmd)
    pipe = os.popen(cmd, 'w')
    pipe.write(content)
    st = pipe.close()
    if st and desc:
        raise TestFailure('Failed to %s for %s@%s' % (desc, user, host))
    return st

def rcmdout(cmd, user, host, desc=None):
    if host:
        cmd = '%s %s@%s %s' % (mkssh(), user, host, shlex.quote(cmd))
    print(cmd)
    out = os.popen(cmd, 'r')
    try:
        return out.read()
    finally:
        st = out.close()
        if desc and st:
            raise TestFailure('Failed to %s on host %s' % (desc, host))
        if st:
            return None

def readfile(file, user=None, host=None):
    if host:
        cmd = 'cat %s' % file
        return rcmdout(cmd, user, host, 'read file %s' % file)
    else:
        fin = open(file)
        try:
            return fin.read()
        finally:
            fin.close()

def writefile(file, content, perm=None, mode='w', user=None, host=None):
    """
    Creates a file with the given name, content, and mode.
    The file may be passed in by name, as an open file object, 
    or as  file descriptor. If passed as a file object or file descriptor,
    it will be closed after the content will be written to it.
    """
    if host:
        if mode.startswith('a'):
            cmd = 'cat >> %s' % shlex.quote(file)
        else:
            cmd = 'cat > %s' % shlex.quote(file)
        rcmdin(cmd, content, user, host, 'write file %s' % file)
        if perm != None:
            cmd = 'chmod %s %s' % (perm, file)
            rcmd(cmd, user, host, 'set file permissions for ' + file)
    else:
        fout = open(file, mode)
        try:
            fout.write(content)
        finally:
            fout.close()

        if perm != None and isinstance(file, (str,)): 
            os.chmod(file, perm)

def augmentfile(file, user=None, host=None):
    content = ""
    for i in range(random.randint(5, 20)):
        content = '%sSome random nr: %s\n' % (content, random.random())
    writefile(file, content, mode='a', user=user, host=host)

def localTest(tmproot):
    params = {'tmproot': tmproot, 'args': safekeep_args}
    os.mkdir(os.path.join(tmproot, 'backup.d'))
    os.mkdir(os.path.join(tmproot, 'client'))
    os.mkdir(os.path.join(tmproot, 'client', 'data'))
    os.mkdir(os.path.join(tmproot, 'client', 'home'))
    os.mkdir(os.path.join(tmproot, 'client', 'misc'))
    os.mkdir(os.path.join(tmproot, 'server'))
    CONFIG = """
        # Default values should do for now
    """
    writefile(os.path.join(tmproot, 'safekeep.conf'), CONFIG % params)
    BACKUP = """
       <backup id="%(tmproot)s">
          <repo path="%(tmproot)s/server" />
          <data>
             <exclude glob="**/*.nobackup" />
             <include path="%(tmproot)s/client" />
          </data>
        </backup>
    """
    writefile(os.path.join(tmproot, 'backup.d', 'test.backup'), BACKUP % params)
    for i in range(test_reps):
        FILES = (
            'data/fileA.out', 
            'data/fileB.nobackup', 
            'data/fileC.out', 
            'home/unit.out', 
            'home/office.out', 
            'home/street.out', 
            'misc/file.nobackup', 
        )
        for file in FILES:
            augmentfile(os.path.join(tmproot, 'client', file))
        cmd = r"cd %(tmproot)s/client; find -type f -a \! -name '*.nobackup' | sort | xargs md5sum > %(tmproot)s/md5sums.client"
        if os.system(cmd % params):
            raise TestFailure("Can't compute the source MD5 sums")
        cmd = "safekeep --server %(args)s --conf '%(tmproot)s/safekeep.conf'"
        if os.system(cmd % params):
            raise TestFailure("Can't backup files")
        os.mkdir(os.path.join(tmproot, 'restore'))
        cmd = "rdiff-backup -r now %(tmproot)s/server %(tmproot)s/restore"
        if os.system(cmd % params):
            raise TestFailure("Can't restore files")
        cmd = "cd %(tmproot)s/restore%(tmproot)s/client; find -type f | sort | xargs md5sum > %(tmproot)s/md5sums.restore"
        if os.system(cmd % params):
            raise TestFailure("Can't compute the source MD5 sums")
        cmd = "diff -u %(tmproot)s/md5sums.client %(tmproot)s/md5sums.restore"
        if os.system(cmd % params):
            raise TestFailure("The MD5 sums differ")
        dir = os.path.join(tmproot, 'restore')
        os.system("chmod +w " + dir)
        shutil.rmtree(dir)
        time.sleep(1) 

def takeOver(host, role, ver):
    cmd = 'ssh -o PasswordAuthentication=no root@%s true' % (host)
    if os.system(cmd): 
        print('The box %s does not appear to have been initialized' % (host))
        print('for functioning as %s in a SafeKeep remote test.' % (role))
        print('If you want to use it in such a role, please enter the root')
        print('password for the box so that the test can take over it.')
        print('WARNING: the test is potentially dangereous, all data on')
        print('         the box may be lost! Please make SURE this is')
        print('         the box you REALLY intend to use for the test.')
        print('To cancel the test, type Ctrl-C')
        # add the key blindly to make sure we have it, and we are asked the passwd only once
        cmd = 'umask 077; test -d .ssh || mkdir .ssh; cat >> .ssh/authorized_keys'
        rcmdin(cmd, ssh_id_file_keys[1], 'root', host)

    cmd = 'cat /etc/redhat-release'
    release = rcmdout(cmd, 'root', host, 'fetch the release')
    res = re.match(r'\s*(.*) release (.*) \(.*\)\s*', release)
    if not res:
        raise TestFailure('Unknown release format for: %s' % release)
    distro = res.group(1)
    if distro in ('CentOS', ):
        distro = 'centos'
    elif distro in ('Fedora', 'Fedora Core'):
        distro = 'fedora'
    else:
        raise TestFailure('Unknown distro %s in release %s' % (distro, release))

    lattica_repo = """
        [safekeep-test-lattica-devel]
        name=Lattica - Development Tree for Fedora Core $releasever $basearch
        baseurl=http://lattica.com/repos/lattica/devel/%s/$releasever
        gpgkey=http://lattica.com/keys/RPM-GPG-KEY-lattica-devel
        enabled=1
        gpgcheck=0
    """ % (distro)
    lattica_repo = '\n'.join([line.strip() for line in lattica_repo.splitlines()])
    cmd = 'echo %s > /etc/yum.repos.d/safekeep-test-lattica-devel.repo' % (shlex.quote(lattica_repo).strip())
    rcmd(cmd, 'root', host, 'install Lattica Repo')
    cmd = """
        rm -rf /var/cache/yum/safekeep-test-lattica-devel/; 
        if rpm -q safekeep-%(role)s-%(verrel)s; then 
            yum update -y safekeep-%(role)s-%(verrel)s; 
        else 
            yum install -y safekeep-%(role)s-%(verrel)s;
        fi;
        rpm -q safekeep-%(role)s | grep %(verrel)s
    """  % { 'role': role, 'verrel': ver + '-1' }
    rcmd(cmd, 'root', host, 'install safekeep')

def createKey(user, host, keyname, comment):
    key = '.ssh/%s' % keyname
    cmd = 'test -f %s || ssh-keygen -q -b 1024 -t dsa -N \'\' -C \'%s\' -f %s' % (key, comment, key)
    rcmd(cmd, user, host, 'create %s on server' % keyname)
    return key

#TODO: this is just old code, we need to change it
#      so that it builds the RPM on the host it will run on
def packageAndDeployRPM(tar, user, host):
    pkgroot='/tmp/safekeep-test-XXX'   #TODO: replace XXX with curr timestamp

    rpmmacros = """
%%packager          safekeep Tester Package Builder
%%_topdir           %s
%%_debug_package    %%{nil}

%%_signature         gpg
""" % (pkgroot)

    writefile(os.path.join(pkgroot, '.rpmmacros'), rpmmacros)

    # and the directories needed by rpmbuild
    for dir in ('BUILD', 'RPMS/noarch', 'SOURCES', 'SPECS', 'SRPMS'):
        os.makedirs(os.path.join(pkgroot, dir))

    cmd = '%s %s %s@%s:%s' % \
          (mkssh('scp'), tar, user, host, pkgroot)
    print(cmd)
    if os.system(cmd):
        raise TestFailure('Failed to copy %s to the RPM host %s' % (tar, host)) 

    cmd = 'HOME=%s rpmbuild -ta %s' % (pkgroot, tar)
    print(cmd)
    if os.system(cmd):
        raise TestFailure('Failed to build the .rpm')

    ver = tar[len('safekeep-'):-len('.tar.gz')]
    binrpm_list = glob.glob(os.path.join(pkgroot, 'RPMS/noarch', 'safekeep-*-' + ver + '-*.noarch.rpm'))

    rel = binrpm_list[0]
    rel = rel[rel.rindex('-')+1:rel.rindex('.')]
    rel = rel[:rel.rindex('.')]

    for binrpm in binrpm_list:
        if not os.path.isfile(binrpm):
            raise TestFailure('Failed to find binary rpm: %s' % binrpm)

def packageAndDeployDEB(tar, user, host):
    # TODO: we need to implement this
    pass

def packageAndUpload(user, rpmhost, debhost):
    cmd = 'make tar'
    if os.system(cmd):
        raise TestFailure('Failed to build the tar snapshot')
    files = os.listdir('.')
    mytar = None
    for file in files:
        if file.startswith('safekeep-') and file.endswith('.tar.gz'):
            if file > mytar:
                mytar = file # last one wins
    if not mytar:
        raise TestFailure('Failed to determine the tar name')

    if rpmhost:
        packageAndDeployRPM(mytar, user, rpmhost)

    if debhost:
        packageAndDeployDEB(mytar, user, debhost)

    cmd = 'rm %s' % (mytar)
    print(cmd)
    if os.system(cmd):
        raise TestFailure('Failed to nuke the tar')

    ver = mytar[len('safekeep-'):-len('.tar.gz')]
    return ver

def remoteTest(tmproot, client, server):
    # build and upload 
    ver = packageAndUpload('root', 'ulysses', None)

    takeOver(client, 'client', ver)
    takeOver(server, 'server', ver)

    # setup the server
    cmd = 'chsh -s /bin/bash safekeep --server'
    rcmd(cmd, 'root', server, 'Allow access to the safekeep account')
    client_addr = socket.gethostbyname(client)
    if client_addr != client:
        cmd = '(grep -iv %s /etc/hosts; echo "%s   %s") > /etc/hosts.new; mv -f /etc/hosts.new /etc/hosts' % (client, client_addr, client)
        rcmd(cmd, 'root', server, 'install the client name in /etc/hosts')

    cmd = 'cd ~safekeep; ' + \
          'cp /root/.ssh/authorized_keys .ssh/authorized_keys; ' + \
          'chown -R safekeep.safekeep .ssh'
    rcmd(cmd, 'root', server, 'install key for the safekeep user')
    key_id   = createKey('safekeep', server, 'id_dsa', 'SafeKeep server ID')
    key_ctrl = createKey('safekeep', server, 'safekeep-server-ctrl-key', 'SafeKeep server control key')
    key_data = createKey('safekeep', server, 'id-safekeep-data-key', 'SafeKeep server data key')

    cmd = 'if ! rpm -q postgresql-server; then p=postgresql-server; fi; ' + \
          'if ! rpm -q mysql-server; then p="$p mysql-server"; fi; ' + \
          'if [ "x$p" != "x" ]; then yum install -y $p; fi; ' + \
          'if service postgresql status | grep -v running; then service postgresql start; fi; ' + \
          'if service mysqld status | grep -v running; then service mysqld start; fi; '
    rcmd(cmd, 'root', server, 'ensure RDBMSes are installed and running')
   
    known_hosts = os.path.join(os.environ['HOME'], '.ssh', 'known_hosts')
    cmd = 'if test -f %s; then grep -i %s %s; fi' % (known_hosts, client, known_hosts)
    fingerprint = rcmdout(cmd, None, None, 'read the client fingerprint')
    if fingerprint:
        cmd = '(if test -f .ssh/known_hosts; then grep -iv %s .ssh/known_hosts; fi; cat) > .ssh/known_hosts.new; mv -f .ssh/known_hosts.new .ssh/known_hosts' % (client, )
        rcmdin(cmd, fingerprint, 'safekeep', server, 'deploy client fingerprint on server')

    snap_conf = ''
    if perform_snapshots:
        snap_conf = '<snapshot device="/dev/mapper/VolGroup00-LogVol00" size="50M"/>'
    conf = """
        <backup id="test-client">
          <host name="%s" key-data="%s" />
          <repo path="/var/lib/safekeep/client/data" retention="5h"/>
          <setup>
              %s
          </setup>
          <data>
             <exclude glob="**/*.nobackup" />
             <exclude path="/etc/cups/certs" />
             <exclude path="/etc/lvm" />
             <exclude path="/etc/mtab" />
             <include path="/etc" />
             <include path="/srv" />
          </data>
        </backup>
        """ % (client, key_data, snap_conf)

    writefile('/etc/safekeep/backup.d/test-client.backup', conf, '664', 'w', 'root', server)
    cmd = 'rm -rf client; mkdir -p client/data'
    rcmd(cmd, 'safekeep', server, 'create data repo') 

    # setup the client
    cmd = 'umask 077; test -d .ssh || mkdir .ssh; if test -f .ssh/authorized_keys; then cat .ssh/authorized_keys; fi'
    clkeys_txt = rcmdout(cmd, 'root', client, 'fetch client keys')
    key_id_enc = readfile('.ssh/id_dsa.pub', 'safekeep', server)
    if key_id_enc not in clkeys_txt.splitlines():
        writefile('.ssh/authorized_keys', key_id_enc, None, 'a', 'root', client)

    cmd = 'safekeep %s --keys --deploy' % safekeep_args
    rcmd(cmd, 'safekeep', server, 'deploy keys to client')

    # run the test
    for i in range(test_reps):
        cmd = 'rm -f md5sum.client md5sum.server'
        rcmd(cmd, 'safekeep', server, 'cleanup MD5 sums')
        FILES = (
            '/etc/a-simple-file', 
            '/etc/another-file.nobackup', 
            '/srv/safekeep-test1.log', 
            '/srv/safekeep-test2.nobackup',
        )
        for file in FILES:
            augmentfile(file, 'root', client)
        cmd1 = "cd /; (find /etc -type f; find /srv -type f) | " + \
               "grep -v '\.nobackup$' | grep -v '^/etc/cups/certs/' | " + \
               "grep -v '^/etc/mtab' | grep -v '^/etc/lvm/' | " + \
               "sort | sed s/.// | xargs md5sum"
        cmd2 = "cat > /var/lib/safekeep/md5sum.client"
        rpipe(cmd1, 'root', client, 'do MD5 sums on client', cmd2, 'safekeep', server)
        cmd = 'safekeep --server %s' % safekeep_args
        rcmd(cmd, 'root', server, 'backup data')
        cmd = 'cd /var/lib/safekeep/client/data; find -type f | grep -v rdiff-backup-data/ | sort | sed s/..// | xargs md5sum > /var/lib/safekeep/md5sum.server'
        rcmd(cmd, 'safekeep', server, 'do MD5 sums on server')
        cmd = 'diff -q /var/lib/safekeep/md5sum.client /var/lib/safekeep/md5sum.server'
        rcmd(cmd, 'safekeep', server, 'compare MD5 sums')
        
        # TODO: list available backups
        # TODO: compare against retention policy, fail if different
        cmd = 'if lvscan | grep _snap_; then false; fi'
        rcmd(cmd, 'root', client, 'test that no snapshots are left behind')
        cmd = 'if grep "safekeep-.*-rbind" /proc/mounts; then false; fi'
        rcmd(cmd, 'root', client, 'test that no --rbind mounts are left behind')
        # TODO: check iff appropriate DB dumps deleted
        time.sleep(1) 

def guess_ssh_id_file_keys():
    if os.environ.get('SSH_AUTH_SOCK'):
        out = os.popen('ssh-add -L', 'r')
        ids = out.read()
        if out.close() or not ids:
            print('Failed to get keys from ssh agent', file=sys.stderr)
        else:
            return (None, ids)
    for filename in ('identity', 'id_dsa', 'id_rsa'):
        file = os.path.join(os.environ['HOME'], '.ssh', filename)
        if os.path.isfile(file): 
            return (file, readfile(file))
    return None

def usage():
    print('usage: %s [options]' % (sys.argv[0]))
    print()
    print('options:')
    print('-l,--local            Run the local version of the test (default)')
    print('-r,--remote           Run the client/server version of the test')
    print('-n,--nocleanup        Do not erase the temporary root for the test')
    print('-i,--identity IDFILE  SSH identity file to use (remote test only)')
    print('-c,--client=ADDR      The client address (remote test only)')
    print('-s,--server=ADDR      The server address (remote test only)')
    print('--reps=N              The number of iterations for the test (default 2)')
    print('--args=ARGS           Extra arguments to pass to SafeKeep (default to none)')
    print('-h,--help             Print this help message and exit')

def main():
    global ssh_id_file_keys
    try:
        opts, args = getopt.getopt(sys.argv[1:], 
                                   'lrni:c:s:', 
                                   ['local', 'remote', 'nocleanup', 'identity', 'reps=', 'args=', 'client=', 'server='])
    except getopt.GetoptError:
        usage()
        sys.exit(2)

    client = os.environ.get('SAFEKEEP_TEST_CLIENT')
    server = os.environ.get('SAFEKEEP_TEST_SERVER')
    nocleanup = False
    mode = 'local'
    for o, a in opts:
        if o in ('-l', '--local'):
            mode = 'local'
        elif o in ('-r', '--remote'):
            mode = 'remote'
        elif o in ('-n', '--nocleanup'):
            nocleanup = True
        elif o in ('-i', '--identity'):
            ssh_id_file_keys = (a, readfile(a))
        elif o in ('-c', '--client'):
            client = a
        elif o in ('-s', '--server'):
            server = a
        elif o in ('--reps', ):
            global test_reps
            test_reps = int(a)
        elif o in ('--args', ):
            global safekeep_args
            safekeep_args = a
        elif o in ('-h', '--help'):
            usage()
            sys.exit()

    if not client: client = 'safekeep-test-client'
    if not server: server = 'safekeep-test-server'

    exitcode = 0
    tmproot = tempfile.mkdtemp()
    try:
        try:
            if mode == 'remote':
                if not ssh_id_file_keys:
                    ssh_id_file_keys = guess_ssh_id_file_keys()
                assert ssh_id_file_keys, "Can not determine the SSH keys"
                remoteTest(tmproot, client, server)
            else:
                localTest(tmproot)
        except TestFailure as tf:
            print(tf)
            exitcode = 1
    finally:
        if nocleanup:
            print(tmproot)
        else:
            os.system("chmod -R +w " +tmproot)
            shutil.rmtree(tmproot)

    sys.exit(exitcode)

if __name__ == '__main__':
    main()

# vim: et ts=8 sw=4 sts=4
