# -*- coding: utf-8 -*-
#
##########################################################################
# pyeole.diagnose - Diagnostic tools for EOLE
# Copyright © 2022 Pôle de compétences EOLE <eole@ac-dijon.fr>
#
# License CeCILL:
#  * in french: http://www.cecill.info/licences/Licence_CeCILL_V2-fr.html
#  * in english http://www.cecill.info/licences/Licence_CeCILL_V2-en.html
##########################################################################

"""Diagnostic tools for EOLE

Test certificates
-----------------

The diagnose module allows you to test the validity of some certificates.

Examples::

    >>> from pyeole.diagnose import test_cert
    >>> test_cert('/etc/ssl/certs/eole.crt')
    {'chain': (True, 'Certificat valide'), 'expired': ('OK', 'Fin de validité dans plus de 30 jours'),
     'dns': (False, 'scribe.domscribe.ac-test.fr scribe.ac-test.fr', '', 'scribe.ac-test.fr scribe.domscribe.ac-test.fr'),
     'issuer': (True, 'CA-scribe.domscribe.ac-test.fr'), 'expiration': (True, 'lun. 15 sept. 2025 10:21:58 CEST'),
     'valid': (False, '')}
    >>> test_cert('/etc/ssl/certs/eole.crt', expected_dns=['scribe.domscribe.ac-test.fr', 'scribe.ac-test.fr'])
    {'chain': (True, 'Certificat valide'), 'expired': ('OK', 'Fin de validité dans plus de 30 jours'),
     'dns': (True, 'scribe.domscribe.ac-test.fr scribe.ac-test.fr', '', ''), 'issuer': (True, 'CA-scribe.domscribe.ac-test.fr'),
     'expiration': (True, 'lun. 15 sept. 2025 10:21:58 CEST'),
     'valid': (True, '')}
    >>> test_cert('/etc/ssl/certs/expired.crt')
    {'valid': (True, 'Certificat valide'), 'dns': (True, 'thot.ac-test.fr'),
     'expired': ('ERROR', 'Fin de validité dans moins de 15 jours')}
    >>> test_cert('/etc/ssl/certs/eole.crt', ca='/etc/ssl/certs/ca_local.crt')
    True
    >>> from pyeole.diagnose import CertValidator
    >>> cert = CertValidator('/var/lib/lxc/addc/rootfs/var/lib/samba/private/tls/cert.pem', expected_dns='addc.domscribe.ac-test.fr')
    >>> cert.test_cert(strict_dns=False)
    {'chain': (True, 'Certificat valide'), 'expired': ('OK', 'Fin de validité dans plus de 30 jours'), 'dns': (False, 'scribe.domscribe.ac-test.fr scribe.ac-test.fr', 'addc.domscribe.ac-test.fr', ''), 'issuer': (True, 'CA-scribe.domscribe.ac-test.fr'), 'expiration': (True, 'lun. 15 sept. 2025 10:21:58 CEST'), 'valid': (False, '')}

"""

import logging
import os
from pyeole.process import system_out
import subprocess as subp
import tempfile
import re
from pathlib import Path
from creole.config import CERT_DB

log = logging.getLogger(__name__)

DELAY_ERROR_DAY = 15
DELAY_WARNING_DAY = 30


DEFAULT_CA = '/etc/ssl/certs/ca-certificates.crt'


def list_certificats(server_pem, server_cert, cert_type, le_certificates):
    cert_db = {}
    try:
        with open(CERT_DB, 'r') as cert_db_fh:
            cert_lines = [cert_line.strip().split(':') for cert_line in cert_db_fh.readlines()]
        for cert_line in cert_lines:
            cert = cert_line[1]
            typ = cert_type if cert == server_cert else None
            if typ == 'autosigné':
                if cert.startswith('/etc/ipsec.d/'):
                    ca = '/etc/ipsec.d/cacerts/CertifCa.pem'
                else:
                    ca = '/etc/ssl/certs/ca.crt'
            else:
                ca = DEFAULT_CA
            cert_db.setdefault(cert, {'usage': [],
                                      'chain': server_pem if cert == server_cert else None,
                                      'ca': ca
                                      })
            cert_db[cert_line[1]]['usage'].append(cert_line[0])
    except:
        pass
    cert_db.update({le: {'chain': str(Path(le).parent / 'fullchain.pem'), 'ca': DEFAULT_CA, 'type': 'letsencrypt'} for le in le_certificates})
    return cert_db

 
def _test_validity_cert(cert, chain, ca):
    """
    Check if a certificate is valid
    :param cert: The certificate to test
    :param chain: The validation chain. Can be None
    :param ca: The CA to check with. Can be a directory or None
    :return: A tuple (Error, error message)
    :rtype: (bool, string)
    """
    tmp = tempfile.NamedTemporaryFile()

    # Concat all CA if ca is a directory
    if ca and os.path.isdir(ca):
        for root, dirs, files in os.walk(ca):
            for name in files:
                with open(name, 'rb') as infile:
                    tmp.write(infile.read())
        tmp.flush()
        ca = tmp.name

    if ca and not os.path.isfile(ca):
        return False, f'CA {ca} inexistant'

    if ca and os.stat(ca).st_size == 0:
        return False, f'Fichier CA {ca} vide'

    if os.stat(cert).st_size == 0:
        return False, f'Fichier {cert} vide'

    # Verify certificate
    ssl_err = None
    cmd = 'openssl verify -purpose any '
    if ca:
        cmd += ' -CAfile ' + ca + ' -no-CApath -no-CAstore'
    if chain:
        cmd += ' -untrusted ' + chain  # Useful to test entire chain
    cmd += ' ' + cert + ' 2>&1 | grep "^error [0-9]* at " | head -n 1 | cut -d: -f2'
    try:
        ssl_err = subp.check_output(cmd, shell=True).decode('utf-8').strip()
    except subp.CalledProcessError:
        pass  # grep didn't find any error

    if ssl_err:
        return False, 'Erreur : ' + ssl_err

    return True, 'Certificat valide'


def _test_expiration_cert(cert):
    """
    Check if a certificate is about to expire
    :param cert: The certificate to test
    :return: A tuple (Error, error message)
    :rtype: (bool, string)
    """
    # Test si le certificat expire bientôt ou non
    err_delay = DELAY_ERROR_DAY * 24 * 3600
    warn_delay = DELAY_WARNING_DAY * 24 * 3600
    expire = False

    try:
        subp.check_output('openssl x509 -checkend ' + str(warn_delay)
                          + ' -noout -enddate -in ' + cert + ' 2>&1',
                          shell=True).decode('utf-8').strip()
    except subp.CalledProcessError:
        expire = True

    if expire:
        expire = False
        try:
            subp.check_output('openssl x509 -checkend ' + str(err_delay)
                              + ' -noout -enddate -in ' + cert + ' 2>&1',
                              shell=True).decode('utf-8').strip()
        except subp.CalledProcessError:
            expire = True
 
        if expire:
            return 'ERROR', f"Fin de validité dans moins de {DELAY_ERROR_DAY} jours"
        else:
            return 'WARNING', f"Fin de validité dans moins de {DELAY_WARNING_DAY} jours"
    else:
        return 'OK', f"Fin de validité dans plus de {DELAY_WARNING_DAY} jours"


def _get_dns_cert(cert):
    """
    Get DNS of a certificate
    :param cert: The certificate
    :return: A tuple (info found or not, DNS or empty string)
    :rtype: (bool, string)
    """
    # Test si le DNS est reconnus
    dns = ''
    try:
        dns = subp.check_output('openssl x509 -in ' + cert + ' -noout -text'
                                + ' | sed -n -e \'/X509v3 Subject Alternative Name/{n;p;}\''
                                + ' | sed -e \'s/^ *//\' -e \'s/DNS://g\' -e \'s/,//g\''
                                + ' -e \'s/IP Address:[0-9]\+\.[0-9]\+\.[0-9]\+\.[0-9]\+//g\'',
                                shell=True).decode('utf-8').strip()
    except subp.CalledProcessError:
        pass

    return dns != '', dns


def _get_issuer_cert(cert):
    """
    Get issuer of a certificate
    :param cert: The certificate
    :return: A tuple (info found or not, issuer or empty string)
    :rtype: (bool, string)
    """
    try:
        ret, stdout, stderr = system_out(['openssl', 'x509', '-noout', '-issuer', '-in', cert])
    except subp.CalledProcessError:
        ret = 1

    if ret == 0:
        return True, stdout.split('=')[-1].strip()
    else:
        return False, ''


def _get_expiration_cert(cert):
    """
    Get expiration date of a certificate
    :param cert: The certificate
    :return: A tuple (info found or not, expiration date or empty string)
    :rtype: (bool, string)
    """
    try:
        ret, stdout, stderr = system_out(['openssl', 'x509', '-noout', '-enddate', '-in', cert]
                                         , env={'LC_ALL': 'fr_FR.UTF-8'})
    except subp.CalledProcessError:
        ret = 1

    if ret == 0:
        enddate = system_out(['date', '-d', stdout.split('=')[1].strip()])[1].strip()
        return True, enddate
    else:
        return False, ''


def test_cert(cert, chain=None, ca=None):
    """
    Check if a certificate about its validity, expiration, DNS entry, etc...
    :param cert: The certificate to test
    :param chain: The validation chain. Can be None for autosigned certificates
    :param ca: The CA to check with. Can be a directory or None
    :return: A dictionary of tuple (Error, error message), one per test
    :rtype: dict
    """
    # TODO: change 'ERROR' to a common enum at diagnose level
    status = {'valid': (False, 'Non testé'), 'expired': ('ERROR', 'Non testé'),
              'dns': (False, None), 'issuer': (False, None), 'expiration': (False, None)}

    if not os.path.isfile(cert):
        msg = f'Impossible de trouver le certificat {cert}'
        status['valid'] = (False, msg)
        return status
 
    # Check validity
    status['valid'] = _test_validity_cert(cert, chain, ca)

    if not status['valid'][0]:
        return status  # Certificate invalid, don't go further

    status['expired'] = _test_expiration_cert(cert)
    status['dns'] = _get_dns_cert(cert)
    status['issuer'] = _get_issuer_cert(cert)
    status['expiration'] = _get_expiration_cert(cert)

    return status
 
 
def test_certs(certs, chain=None, ca=None):
    """
    :param certs: The certificates to test
    :param chain: The validation chain. Can be None for autosigned certificates
    :param ca: The CA to check with. Can be a directory or None
    :return: A list of tuple (cert, dictionary of tuple (Error, error message)), one per test
    :rtype: [str, dict]
    """
    res = []
    for cert in certs:
        res.append((cert, test_cert(cert, chain, ca)))
 
    return res

class CertValidator:
    def __init__(self, cert, cert_id=None, chain=None, ca=None, expected_dns=None, usage=[]):
        self.cert = cert
        self.cert_id = cert_id if cert_id else os.path.basename(self.cert)
        self.chain = chain
        self.ca = ca
        self.expected_dns = expected_dns if expected_dns is not None else []
        self.usage = usage
        if isinstance(self.expected_dns, str):
            self.expected_dns = [self.expected_dns]


    def _test_validity_cert(self):
        """
        Check if a certificate is valid
        :param cert: The certificate to test
        :param chain: The validation chain. Can be None
        :param ca: The CA to check with. Can be a directory or None
        :return: A tuple (Error, error message)
        :rtype: (bool, string)
        """
        tmp = tempfile.NamedTemporaryFile()

        # Concat all CA if ca is a directory
        if self.ca and os.path.isdir(self.ca):
            for root, dirs, files in os.walk(self.ca):
                for name in files:
                    with open(name, 'rb') as infile:
                        tmp.write(infile.read())
            tmp.flush()
            self.ca = tmp.name

        if self.ca and not os.path.isfile(self.ca):
            return False, f'CA {self.ca} inexistant'

        if self.ca and os.stat(self.ca).st_size == 0:
            return False, f'Fichier CA {self.ca} vide'

        if os.stat(self.cert).st_size == 0:
            return False, f'Fichier {self.cert} vide'

        # Verify certificate
        ssl_err = None
        cmd = 'openssl verify -purpose any '
        if self.ca:
            cmd += ' -CAfile ' + self.ca + ' -no-CApath -no-CAstore'
        if self.chain:
            cmd += ' -untrusted ' + self.chain  # Useful to test entire chain
        cmd += ' ' + self.cert + ' 2>&1 | grep "^error [0-9]* at " | head -n 1 | cut -d: -f2'
        try:
            ssl_err = subp.check_output(cmd, shell=True).decode('utf-8').strip()
        except subp.CalledProcessError:
            pass  # grep didn't find any error

        if ssl_err:
            return False, 'Erreur : ' + ssl_err

        return True, 'Certificat valide'


    def _test_expiration_cert(self):
        """
        Check if a certificate is about to expire
        :param cert: The certificate to test
        :return: A tuple (Error, error message)
        :rtype: (bool, string)
        """
        # Test si le certificat expire bientôt ou non
        err_delay = DELAY_ERROR_DAY * 24 * 3600
        warn_delay = DELAY_WARNING_DAY * 24 * 3600
        expire = False

        try:
            subp.check_output('openssl x509 -checkend ' + str(warn_delay)
                              + ' -noout -enddate -in ' + self.cert + ' 2>&1',
                              shell=True).decode('utf-8').strip()
        except subp.CalledProcessError:
            expire = True

        if expire:
            expire = False
            try:
                subp.check_output('openssl x509 -checkend ' + str(err_delay)
                                  + ' -noout -enddate -in ' + self.cert + ' 2>&1',
                                  shell=True).decode('utf-8').strip()
            except subp.CalledProcessError:
                expire = True

            if expire:
                return 'ERROR', f"Fin de validité dans moins de {DELAY_ERROR_DAY} jours"
            else:
                return 'WARNING', f"Fin de validité dans moins de {DELAY_WARNING_DAY} jours"
        else:
            return 'OK', f"Fin de validité dans plus de {DELAY_WARNING_DAY} jours"


    def _get_dns_cert(self, strict_dns = True):
        """
        Get DNS of a certificate
        :param cert: The certificate
        :return: A tuple (info found or not, DNS or empty string)
        :rtype: (bool, string)
        """
        # Test si le DNS est reconnus
        subject_alt_name_re = re.compile(r'(DNS:(?P<subject>[a-z0-9.-]+))')
        # Symbolic group "subject" de commonName à la première ',' ou la fin de la ligne
        subject_name_re = re.compile(r'subject=.*?commonName=(?P<subject>[^,\n]*)')
        dns = []
        try:
            x509_res = subp.check_output(['/usr/bin/openssl', 'x509',
                                          '-in', self.cert,
                                          '-noout',
                                          '-nameopt', 'lname',
                                          '-subject',
                                          '-ext', 'subjectAltName']).decode('utf-8', errors='ignore').strip()
            subject_match = subject_name_re.search(x509_res)
            if subject_match is not None:
                dns.append(subject_match.group('subject'))
            for alt_name in subject_alt_name_re.findall(x509_res):
                if not alt_name[1] in dns:
                    dns.append(alt_name[1])
        except subp.CalledProcessError:
            pass
        if strict_dns:
            missing_dns = ' '.join(set(self.expected_dns).difference(set(dns)))
            unwanted_dns = ' '.join(set(dns).difference(set(self.expected_dns)))
            matching_dns = not set(self.expected_dns).symmetric_difference(set(dns))
        else:
            matching_dns = set(self.expected_dns).issubset(set(dns))
            if matching_dns:
                missing_dns = ''
                unwanted_dns = ''
            else:
                missing_dns = ' '.join(set(self.expected_dns).difference(set(dns)))
                unwanted_dns = ''
        return dns != '' and matching_dns, ' '.join(dns), missing_dns, unwanted_dns


    def _get_issuer_cert(self):
        """
        Get issuer of a certificate
        :param cert: The certificate
        :return: A tuple (info found or not, issuer or empty string)
        :rtype: (bool, string)
        """
        try:
            ret, stdout, stderr = system_out(['openssl', 'x509', '-noout', '-issuer', '-in', self.cert])
        except subp.CalledProcessError:
            ret = 1

        if ret == 0:
            return True, stdout.split('=')[-1].strip()
        else:
            return False, ''


    def _get_expiration_cert(self):
        """
        Get expiration date of a certificate
        :param cert: The certificate
        :return: A tuple (info found or not, expiration date or empty string)
        :rtype: (bool, string)
        """
        try:
            ret, stdout, stderr = system_out(['openssl', 'x509', '-noout', '-enddate', '-in', self.cert]
                                             , env={'LC_ALL': 'fr_FR.UTF-8'})
        except subp.CalledProcessError:
            ret = 1

        if ret == 0:
            enddate = system_out(['date', '-d', stdout.split('=')[1].strip()])[1].strip()
            return True, enddate
        else:
            return False, ''


    def test_cert(self, strict_dns=True):
        """
        Check if a certificate about its validity, expiration, DNS entry, etc...
        :param cert: The certificate to test
        :param chain: The validation chain. Can be None for autosigned certificates
        :param ca: The CA to check with. Can be a directory or None
        :return: A dictionary of tuple (Error, error message), one per test
        :rtype: dict
        """
        # If no expected dns provided, enforce strict_dns value
        if not self.expected_dns:
            strict_dns = False

        # TODO: change 'ERROR' to a common enum at diagnose level
        status = {}

        if not os.path.isfile(self.cert):
            msg = f'Impossible de trouver le certificat {self.cert}'
            status['valid'] = (False, msg)
            return status

        # Check validity
        status['chain'] = self._test_validity_cert()

        status['expired'] = self._test_expiration_cert()
        status['dns'] = self._get_dns_cert(strict_dns)
        status['issuer'] = self._get_issuer_cert()
        status['expiration'] = self._get_expiration_cert()
        msg = ''

        status['valid'] = all([status[st][0] for st in ['chain', 'expired', 'dns', 'issuer', 'expiration'] ]), msg

        return status

    def format_diagnostic(self, strict_dns=True):
        """
        Build a list of command strings from the result of a certificate tests
        :param status: a dictionary with information about a certificate
        :return: a list of command to pretty print the results
        :rtype: list of str
        """
        status = self.test_cert(strict_dns)
        cmd = ['echo "Validité du certificat '+ self.cert_id + '"']

        # Usage
        if self.usage:
            usage = ', '.join(self.usage)
            cmd.append('printf ".  %${len_pf}s => %s" "Impact" "$(EchoGras "' + usage + '")"')
        if status['valid'][0]:
            cmd.append('msg=$(EchoVert "OK")')
            cmd.append('printf ".  %${len_pf}s => %s" "Certificat" "${msg}"')
        else:
            msg = status['valid'][1] if status['valid'][1] else 'Invalide'
            cmd.append('msg=$(EchoRouge "{}")'.format(msg))
            cmd.append('printf ".  %${len_pf}s => %s" "Certificat" "${msg}"')

        # Chaîne
        if 'chain' in status:
            if status['chain'][0]:
                cmd.append('msg=$(EchoVert "OK")')
                cmd.append('printf ".  %${len_pf_accent}s => %s" "Chaîne" "${msg}"')
            else:
                cmd.append('msg=$(EchoRouge "' + status['chain'][1] + '")')
                cmd.append('printf ".  %${len_pf_accent}s => %s" "Chaîne" "${msg}"')

        # Expiration
        if 'expired' in status:
            if status['expired'][0] == 'OK':
                cmd.append('msg=$(EchoVert "' + status['expired'][1] + '")')
                cmd.append('printf ".  %${len_pf}s => %s" "Expiration" "${msg}"')
            elif status['expired'][0] == 'WARNING':
                cmd.append('msg=$(EchoOrange "' + status['expired'][1] + '")')
                cmd.append('printf ".  %${len_pf}s => %s" "Expiration" "${msg}"')
            else:  # Error
                cmd.append('msg=$(EchoRouge "' + status['expired'][1] + '")')
                cmd.append('printf ".  %${len_pf}s => %s" "Expiration" "${msg}"')

        # DNS
        if 'dns' in status:
            cmd.append('printf ".  %${len_pf}s => %s" "DNS reconnus" "' + status['dns'][1] + '"')
            cmd.append('echo')
            if status['dns'][2]:
                cmd.append('msg=$(EchoRouge "' + status['dns'][2] + '")')
                cmd.append('fix=$(EchoOrange "Le certificat doit être généré à nouveau pour intégrer les DNS manquants.")')
                cmd.append('printf ".  %${len_pf}s => %s" "DNS manquants" "${msg}"')
                cmd.append('printf "   %s" "${fix}"')
            if status['dns'][3]:
                cmd.append('msg=$(EchoRouge "' + status['dns'][3] + '")')
                cmd.append('fix=$(EchoOrange "Les DNS légitimes doivent ajoutés à la variable « Nom DNS alternatif du serveur ».")')
                cmd.append('printf ".  %${len_pf}s => %s" "DNS superflus" "${msg}"')
                cmd.append('printf "   %s" "${fix}"')

        # Expiration date
        if 'expiration' in status:
            if status['expiration'][0]:
                cmd.append('printf ".  %${len_pf}s => %s" "Date de fin" "$(EchoGras "' + status['expiration'][1] + '")"')

            # Issuer
            if status['issuer'][0]:
                cmd.append('printf ".  %${len_pf}s => %s" "CA" "$(EchoGras "' + status['issuer'][1] + '")"')

        return cmd
