# -*- coding: utf-8 -*-

##########################################################################
# Copyright © 2019 Pôle de compétences EOLE <eole@ac-dijon.fr>
#
# License CeCILL:
#  * in french: http://www.cecill.info/licences/Licence_CeCILL_V2-fr.html
#  * in english http://www.cecill.info/licences/Licence_CeCILL_V2-en.html
##########################################################################

import ldap
from scribe.ldapconf import AD_USER, AD_BASE, AD_PWDFILE, AD_LDAPS, \
        AD_ADDRESS, LDAP_MODE

AD_USERDN = "CN={0},CN=Users,{1}".format(AD_USER, AD_BASE)

if LDAP_MODE == 'openldap':
    def get_ad_pwd():
        """
        Read AD password from special file
        and return it
        """
        return open(AD_PWDFILE).read().strip()

    def connect():
        """
        Initialize authenticated LDAP AD connection
        """
        # OPT_REFERRALS=0 is mandatory for AD
        ldap.set_option(ldap.OPT_REFERRALS, 0)
        if AD_LDAPS:
            ldap.set_option(ldap.OPT_X_TLS_REQUIRE_CERT, ldap.OPT_X_TLS_NEVER)
            connection = ldap.initialize('ldaps://{}'.format(AD_ADDRESS))
        else:
            connection = ldap.initialize('ldap://{}'.format(AD_ADDRESS))
        connection.simple_bind_s(AD_USERDN, get_ad_pwd())
        return connection

    def group_test(group):
        """
        Test if a group exists in LDAP AD
        """
        ldb_filter = '(&(objectclass=group)(cn={}))'
        connexion = connect()
        res = connexion.search(AD_BASE, ldap.SCOPE_SUBTREE,
                               ldb_filter.format(group), ['cn'])
        return connexion.result(res)[1][0][0] is not None

    def member_test(member, group):
        """
        Test if a user is member of an LDAP AD group
        """
        ldb_filter = '(&(objectclass=user)(cn={}))'.format(member)
        connexion = connect()
        res = connexion.search(AD_BASE, ldap.SCOPE_SUBTREE,
                               ldb_filter.format(group), ['memberOf'])
        res = connexion.result(res)[1][0][1]
        if isinstance(res, dict):
            for grp in res.get('memberOf', []):
                if grp.startswith('CN={},'.format(group)):
                    return True
        return False
else:
    def get_ad_pwd():
        return None

    def connect():
        return None

    def group_test(group):
        return True

    def member_test(member, group):
        return True
