#!/usr/bin/env python

# ZFS snapshot client

import socket, os, sys

ZSERVER = ('gohan10.freebsd.org', 8888)
ZFSLOCAL = '/tmp/.zserver'

def connect():
    """ Connects to service, returns (socket, islocal) """

    if os.path.exists(ZFSLOCAL):
        s = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
        try:
            s.connect(ZFSLOCAL)
            return (s, True)
        except:
            s.close()
    
    s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
    try:
        s.connect(ZSERVER)
        return (s, False)
    except:
        s.close()

    return (None, None)

def send(sock, cmd):
    """ Send a command, retrieve single line of reply """

    sock.write(cmd)
    sock.flush()

    res= sock.readline()
    #print "res = %s" % res
    return res

def error(msg):
    print >>sys.stderr, "%s: %s" % (sys.argv[0], msg.rstrip())
    sys.exit(1)

def do_list(sockfile, islocal, args):
    res = send(sockfile, "LIST\n")
    if res[0] == "2":
        for i in sockfile:
            print i.rstrip()
    else:
        error(res[4:])

def do_get(sockfile, islocal, args):
    res = send(sockfile, "GET %s %s\n" % (args[0], args[1]))
    if res[0] == "2":
        while True:
            block = sockfile.read(32*1024)
            if not block:
                break
            sys.stdout.write(block)
    else:
        error(res[4:])

def do_diff(sockfile, islocal, args):
    res = send(sockfile, "DIFF %s %s %s\n" % (args[0], args[1], args[2]))
    if res[0] == "2":
        while True:
            block = sockfile.read(32*1024)
            if not block:
                break
            sys.stdout.write(block)
    else:
        error(res[4:])

def do_reg(sockfile, islocal, args):
    if not sock[1]:
        error("must register on local machine")
    res = send(sockfile, "REGISTER %s\n" % args[0])
    if res[0] == "2":
        print res[4:]
    else:
        error(res[4:])

def do_unreg(sockfile, islocal, args):
    if not sock[1]:
        error("must register on local machine")
    res = send(sockfile, "UNREGISTER %s\n" % args[0])

    if res[0] == "2":
        print res[4:]
    else:
        error(res[4:])

def do_help(sockfile, islocal, args):
    for (i, val) in sorted(cmddict.iteritems()):
        print "%15s - %s" % (i, val[1])

cmddict = {'list':(do_list, 'List available filesystem/snapshot pairs'),
           'get':(do_get, 'Get a snapshot'),
           'diff':(do_diff, 'Get the diffs between two snapshots'),
           'register':(do_reg, 'Register a new filesystem (privileged)'),
           'reg':(do_reg, 'Alias for register'),
           'unregister':(do_unreg, 'Register a new filesystem (privileged)'),
           'unreg':(do_unreg, 'Alias for register'),
           'help':(do_help, 'Display this help')}

if __name__ == "__main__":

    try:
        sock = connect()
    except:
        raise
        sys.exit(1)

    args = sys.argv

    try:
        cmd = args[1]
        arg = args[2:]
#        print "cmd = %s, arg = %s" % (cmd, arg)
        cmddict[cmd][0](sock[0].makefile(), sock[1], arg)
    except (KeyError, IndexError):
        raise
        error("No such command\n")
    
