# -*- coding:utf-8 -*-
import logging
import re
import hashlib
import time
from tempfile import mkstemp
from os import unlink, system, sep, listdir
from os.path import basename, dirname, isdir, join
from commands import getoutput, getstatusoutput

from creole.cert import ca_keyfile

from formencode import validators, foreach
from sqlalchemy.exc import IntegrityError
from arv.config import key2keyid_path, id2sql_path, bin2sql_path, ssl_dir, \
                       vpn_path
from arv.lib.logger import logger

def trace(hide_args=None, hide_kwargs=None):
    """This is a decorator which can be used to trace functions calls

    It can replace some positional and/or keyword arguments with some
    'X' if they are present.

    @param hide_args: List of positional argument indexes to replace if present
    @type hide_args: C{list}
    @param hide_kwargs: List of keyword argument names to replace if present
    @type hide_kwargs: C{list}
    """
    def tracedec(func):
        def newFunc(*args, **kwargs):
            # Do nothing if debug is not enabled
            if logger.isEnabledFor(logging.DEBUG):
                # Copy arguments
                args_list = list(args)
                args_dict = kwargs.copy()
                if hide_args is not None:
                    for index in hide_args:
                        if index < len(args_list):
                            args_list[index] = 'XXXXXXXX'
                if hide_kwargs is not None:
                    for keyname in hide_kwargs:
                        if keyname in args_dict:
                            args_dict[keyname] = 'XXXXXXXX'
                logger.debug( "-> entering %s(%s, %s)" % (func.__name__, str(args_list), str(args_dict)) )
            return func(*args, **kwargs)

        newFunc.__name__ = func.__name__
        newFunc.__doc__ = func.__doc__
        newFunc.__dict__.update(func.__dict__)
        return newFunc
    return tracedec

@trace()
def normalize_unicode(string):
    if not type(string) == unicode:
        if type(string) == str:
            string = unicode(string, "utf-8")
        else:
            raise TypeError( "unsupported encoding")
    return string

@trace()
def valid(value, typ):
    """
        formencode validation
    """
    if typ == 'string':
        validator = validators.String()
    if typ == 'bool':
        validator = validators.StringBoolean()
    if typ == 'integer':
        validator = validators.Int()
    if typ == 'enum':
        #FIXME : nothing is done in this case by now
        #validator = foreach.ForEach()
        return value
    if typ == 'ip':
        validator = validators.IPAddress()
    val = validator.to_python(str(value))
    if typ == 'string' or typ == 'ip':
        val = normalize_unicode(val)
    return val
# ____________________________________________________________
@trace()
def try_unique_column(func_name, function, **args):
    try:
        return function(**args)
    except IntegrityError, e:
        if " column name is not unique " in str(e):
            raise Exception("Name should be unique")
        else:
            raise Exception("error in %s: %s" % (func_name, str(e)))
    except Exception, e:
        raise Exception("error in %s: %s" % (func_name, str(e)))

# ____________________________________________________________
#
@trace()
def get_scndline(lines):
    """parse a two lines (type, name) from a command output
    typically :

      type    encoding
      2,  X'736466736466'
    """
    logger.debug("scndline : " + lines)
    second_line = lines.split('\n')[1]
    keyid = second_line.split(',')[1]
    encoded_string = keyid.strip()
    #encoded_string =  encoded_string.replace("X'", "")
    #encoded_string = encoded_string.replace("'", "")
    if encoded_string.startswith("X'") and encoded_string.endswith("'"):
        return encoded_string[2:-1]
    else:
        raise TypeError("unexpected encoded string: {0}".format(encoded_string))

@trace()
def get_lastline(lines):
    """parse lines (type, name) from a command output
    typically :

    writing RSA key
    parsed 2048 bits RSA private key.
    subjectPublicKeyInfo keyid: b0:71:fb:0a:62:f7:8d:7b:9d:35:d7:c9:4d:12:f5:d5:51:e6:db:da
    subjectPublicKey keyid:     0e:5d:33:43:15:b0:f2:a0:0b:b6:6b:f3:24:44:6a:f5:08:91:da:0d

    """
    logger.debug("lastline : " + lines)
    lines = lines.split('\n')
    last_line = lines[-1]
    lines.pop()
    penultimate = lines[-1]
    subjkey = penultimate.replace('subjectPublicKeyInfo keyid:', '')
    # impossible : so much semicolumnsi in the keyid itself
    # keyid = last_line.split(':')[1]
    keyid = last_line.replace('subjectPublicKey keyid:', '')
    return (subjkey.strip(), keyid.strip())

@trace()
def bin_encoding(clear):
    """encode in binary for strongswan database
    """
    cmd = '%s "%s"' % (id2sql_path, clear)
    result = get_scndline(getoutput(cmd))
    #encode and decode. this functionality already exists with the encodings library (which is built-in)
    # have to decode to put the normal string in a sqlalchemy binary type
    logger.debug("#-> bin_encoding result var: " + result)
    return result.decode("hex")

@trace()
def suppress_colon(keyid):
    """suppress colons in a string
    """
    return keyid.replace(':', '').decode('hex')

@trace(hide_args=[1], hide_kwargs=['passwd'])
def get_keyid_in_certif(certif_name, passwd=None, certiftype='ca'):
    """takes the long name (ex: /var/lib/arv/CA/certs/CaCert.pem) of the certificate
    and returns the keyid
    """
    certif_path = dirname(certif_name)
    certif_name = basename(certif_name)
    certif_path = dirname(certif_path)
    certif_path = join(certif_path, 'private')
    if certiftype == 'ca':
        priv_certname = ca_keyfile
    else:
        priv_certname = join(certif_path, 'priv-'+certif_name)
    return get_keyid_from_keyid_in_certif(priv_certname, passwd, mode='rsa')

@trace(hide_args=[1], hide_kwargs=['passwd'])
def get_keyid_from_keyid_in_certif(certif_name, passwd=None, mode='rsa'):
    """constructs a keyid when a keyid is not present in the certificate
    """
    if mode == 'rsa':
        if password_OK(certif_name, passwd):
            #Extract from private key
            cmd = 'openssl {0} -in {1} -text -passin pass:"{2}" | {3}'.format(mode, certif_name, passwd, key2keyid_path)
            subjkey, keyid = get_lastline(getoutput(cmd))
        else:
            raise Exception('Invalid password')
    elif mode == 'x509':
        #Extract from credential
        cmd = 'openssl {0} -in {1} -text -noout| grep keyid'.format(mode, certif_name)
        output = getoutput(cmd)
        keyid = output.split("keyid:")[1]
        logger.debug('{0}'.format(keyid))
        subjkey = keyid
    return (subjkey, keyid)

@trace(hide_args=[1], hide_kwargs=['passwd'])
def get_keyid_from_certifstring(certif_string, passwd=None, mode='rsa'):
    """constructs a keyid when a keyid is not present in the certificate
    """
    try:
        fd, certif_name = mkstemp()
        fh = file(certif_name, 'w')
        fh.write(certif_string)
        fh.close()
        subjkey, keyid = get_keyid_from_keyid_in_certif(certif_name, passwd, mode=mode)
        unlink(certif_name)
        return (subjkey, keyid)
    except Exception, e:
        msg = 'Cannot generate keyid in certificate : %s'% str(e)
        logger.warning(msg)
        raise Exception(msg)

@trace(hide_args=[1], hide_kwargs=['password'])
def password_OK(private_key, password):
    """Test private_key password validity
    """
    retcode = system("""openssl rsa -in "{0}" -text -passin pass:"{1}" -noout &> /dev/null""".format(private_key, password))
    if retcode == 0:
        return True
    else:
        return False

@trace(hide_args=[1], hide_kwargs=['passwd'])
def decrypt_privkey(privkey_string, passwd):
    """Suppress password from private key
    """
    cmd = 'echo -e "{0}" | openssl rsa -passin pass:"{1}" 2>&1'.format(privkey_string, passwd)
    errcode, output = getstatusoutput(cmd)
    if errcode:
        raise Exception('Unable to decrypt private key, check password')
    lines = output.split('\n')
    lines = lines[1:]
    decrypted_key = "\n".join(lines)
    return decrypted_key

def get_req_archive(name):
    """Download credential request to send to CA
    """
    reqname = name +".p10"
    privkeyname = "priv-"+name+".pem"
    reqfilename = join(ssl_dir, "req",  reqname)
    privkeyfilename = join(ssl_dir, "private", privkeyname)
    system("""cd {0}
mv {1} {0}
mv {2} {0}
tar -czf {3}.tgz {4} {5}
rm {4} {5}
""".format(ssl_dir, reqfilename, privkeyfilename, name, reqname, privkeyname))
    tarfname = join(ssl_dir, name+".tgz")
    fh = file(tarfname, 'r')
    content = fh.read()
    fh.close()
    unlink(tarfname)
    return content

@trace()
def gen_archive_name(uai, name):
    """
    """
    amonpath = vpn_path + sep + str(uai)
    archivename = str(uai) + "-" + unicode.encode(name, "utf-8") + ".tar.gz"
    return amonpath, archivename

@trace()
def split_pkcs7(pkcs7_cred):
    """Split pkcs7 string in two strings ca_cred and cred
    """
    try:
        cmd = 'echo "{0}" | openssl pkcs7 -print_certs '.format(pkcs7_cred)
        two_creds = getoutput(cmd)
        # = content.slpit('\n')
        new_cred = True
        issuer = ''
        ca_cred = ''
        cred = ''
        for line in two_creds.splitlines():
            if line == '':
                subject = ''
                issuer = ''
                new_cred = True
            elif new_cred:
                subject = re.match( r'subject(.*)CN=(.*)', line)
                if not subject:
                    raise Exception('Cannot get subject from pkcs7 file')
                subject = subject.group(2)
                new_cred = False
            elif not new_cred and issuer == '' and subject:
                issuer = re.match( r'issuer(.*)CN=(.*)', line)
                if not issuer:
                    raise Exception('Cannot get issuer from pkcs7 file')
                issuer = issuer.group(2)
                if subject == issuer and subject != '':
                    if ca_cred != '':
                        raise Exception('Already a ca_cred in pkcs7 file')
                else:
                    if cred != '':
                        raise Exception('Already a cred in pkcs7 file')
            elif subject and issuer:
                if subject == issuer:
                    ca_cred += '{0}\n'.format(line)
                else:
                    cred += '{0}\n'.format(line)
            else:
                raise Exception('Unknown line in pkcs7 file {0}'.format(line))

        if ca_cred == '' or cred == '':
            raise Exception('Pkcs7 file not complete')
        return (ca_cred, cred)
    except Exception, e:
        logger.error('Cannot read pkcs7: {1}'.format(e))
        raise Exception('Cannot read pkcs7: {1}'.format(e))

@trace()
def extract_crls_from_certifstring(certifstring):
    """Extract crl from certif string
    """
    try:
        cmd = 'echo "{0}"|openssl x509 -text -noout|grep "URI"'.format(certifstring)
        errcode, output = getstatusoutput(cmd)
        output = output.split('\n')
        crl = []
        crlbegin = False
        for line in output:
            if "CRL Distribution Points:" in line:
                crlbegin = True
                continue
            if crlbegin:
                try:
                    crl.append(line.split('URI:')[1])
                except:
                    break
        return crl
    except:
        logger.warning("No crl found in credential")
        pass

@trace()
def ipsec_running():
    """
    """
    cmd = 'ipsec status > /dev/null'
    errcode, output = getstatusoutput(cmd)
    if errcode == 0:
        return True
    else:
        return False

@trace()
def ipsec_restart():
    """
    """
    ipsec_restart = "/etc/init.d/ipsec restart > /dev/null"
    errcode, output = getstatusoutput(ipsec_restart)

@trace()
def ipsec_down(connstring):
    """
    """
    cmd = 'ipsec down "connstring" >/dev/null'
    errcode, output = getstatusoutput(cmd)

@trace()
def ipsec_up(connstring):
    """
    """
    cmd = 'ipsec up "connstring" >/dev/null'
    errcode, output = getstatusoutput(cmd)

@trace()
def purge_file(filename):
    """
    """
    fh = open(filename,"w")
    fh.write('')
    fh.close()

@trace()
def fill_file(from_filename, to_filename):
    """
    Fill content of from_filename into to_filename without suppress to_filename
    """
    from_fd = open(from_filename,"rb")
    content = from_fd.read()
    purge_file(to_filename)
    to_fd = open(to_filename, "w")
    to_fd.write(content)
    from_fd.close()
    to_fd.close()

@trace()
def md5(filename, sw_database_mode='True'):
    """Compute md5 hash of the specified file"""
    if sw_database_mode == 'True':
        logger.debug('database mode True')
        cmd = '/usr/bin/sqlite3 "%s" ".dump"' % filename
        errcode, content = getstatusoutput(cmd)
        if errcode != 0:
            logger.debug("Unable to open the file in readmode:" + filename)
            return
    else:
        logger.debug('database mode False')
        fd = open(filename, "rb")
        content = fd.read()
        fd.close()
    m = hashlib.md5()
    m.update(content)
    return m.hexdigest()

@trace(hide_args=[2], hide_kwargs=['passwd'])
def valid_priv_and_cred(private_key, credential, passwd):
    """Test private key and credential compatibility
    """
    privkey_cmd = 'echo "{0}" | openssl rsa -noout -modulus -passin pass:"{1}"'.format(private_key, passwd)
    cred_cmd = 'echo "{0}" | openssl x509 -noout -modulus'.format(credential)
    errcode, privkey_modulus = getstatusoutput(privkey_cmd)
    if errcode == 0:
        errcode, cred_modulus = getstatusoutput(cred_cmd)
        if errcode == 0:
            return privkey_modulus == cred_modulus
    return False

@trace()
def cred_end_validity_date(credential):
    """Return credential end validity date
        from "Mar 22 09:32:39 2015 GMT" format
        to "22/03/2015" format
    """
    cred_cmd = 'echo "{0}" | openssl x509 -noout -dates|grep ^"notAfter="'.format(credential)
    errcode, not_after_date = getstatusoutput(cred_cmd)
    if errcode == 0:
        try:
            exp_date =  not_after_date.split("=")[1]
            conv = time.strptime(exp_date, "%b %d %H:%M:%S %Y GMT")
            return time.strftime("%Y/%m/%d", conv)
        except:
            return not_after_date
    return None

def escape_special_characters(text, characters='\\\'"'):
    """Escape special characters
    """
    for character in characters:
        text = text.replace( character, '\\' + character )
    return text

def clean_directory(d):
    """Recursively cleans directory without deleting tree
    """
    for f in listdir(d):
        full_path = join(d, f)
        if isdir(full_path):
            clean_directory(full_path)
        else:
            unlink(full_path)

def remove_special_characters(text, characters='\\\'"'):
    """Remove special characters
    """
    for character in characters:
        text = text.replace( character, '' )
    return text

