#! /usr/bin/env python
#
# auth-client-config: update PAM and NSS for use with a particular auth
#    mechanism
#
# Copyright (C) 2007 Jamie Strandboge <jamie@ubuntu.com>
#
#    This program is free software; you can redistribute it and/or modify
#    it under the terms of the GNU General Public License as published
#    by the Free Software Foundation; either version 2 of the License,
#    or (at your option) any later version.
#
#    This program is distributed in the hope that it will be useful, but
#    WITHOUT ANY WARRANTY; without even the implied warranty of
#    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
#    General Public License for more details.
#
#    You should have received a copy of the GNU General Public License
#    along with auth-client-config; if not, write to the Free Software
#    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
#

from optparse import OptionParser
from ConfigParser import *
import os
import re
import shutil
from stat import *
import stat
import sys
from tempfile import mkstemp
from tempfile import mkdtemp
import datetime
import time

version = '0.9'
programName = "auth-client-config"
commentStr = "# pre_" + programName + " #"

if sys.version_info[0] < 2 or \
   (sys.version_info[0] == 2 and sys.version_info[1] < 3):
    print >> sys.stderr, programName + ": Need at least python 2.3\n"
    sys.exit(1)

# These are default settings
MaxFileSize = 10 * 1024 * 1024  # 10MB
files = {'nss': '#CONFIG_PREFIX#/nsswitch.conf',
    'pam-auth': '#CONFIG_PREFIX#/pam.d/common-auth',
    'pam-account': '#CONFIG_PREFIX#/pam.d/common-account',
    'pam-password': '#CONFIG_PREFIX#/pam.d/common-password',
    'pam-session': '#CONFIG_PREFIX#/pam.d/common-session'}
profilesDir = "#CONFIG_PREFIX#/auth-client-config/profile.d"

default_profile = "acc-default"
debug = False
insecure = False


def createComment(str):
    '''Generates a uniform comment at beginning of str'''
    tmp = commentStr + " " + str
    return tmp


def stripComment(str):
    '''Strips uniform comment from str'''
    pat_comment = re.compile(r"^" + commentStr + "\s*")
    tmp = pat_comment.sub('', str)
    return tmp


def hasComments(lines):
    '''Checks lines if contains uniform comment'''
    pat_comment = re.compile(r"^" + commentStr)
    n = 0
    for line in lines:
        if pat_comment.search(line):
            n += 1
    return n


def hasDebianSentinels(lines):
    '''Checks lines if contains the known Debian PAM sentinels. Checking for
       these in auth-client-config doesn't seem right, but apparently it needs
       to be done due to the current state of pam-auth-update...
    '''
    sentinels = [ '# here\'s the fallback if no module succeeds',
        '# end of pam-auth-update config',
        '# here are the per-package modules \(the "Primary" block\)',
        '# and here are more per-package modules \(the "Additional" block\)' ]
    n = 0
    for line in lines:
        for sentinel in sentinels:
            pat = re.compile(r"^" + sentinel)
            if pat.search(line):
                n += 1
    return n



def showSystem():
    '''Print current system configuration in .INI format'''
    current = {}

    types = files.keys()
    types.sort()
    for t in types:
        f = files[t]
        size = 0
        try:
            size = os.stat(f)[ST_SIZE]
        except:
            raise

        if size > MaxFileSize:
            raise accError("'" + f + "' is too big")

        try:
            orig = open(f, 'r')
        except:
            raise

        lines = readFile(orig)
        for line in lines:
            if t == "nss":
                for n in ['passwd', 'group', 'shadow', 'netgroup']:
                    pat = re.compile(r"^\s*" + n + ":")
                    if pat.search(line):
                        current["nss_" + n] = line + "\n"
            else:
                for p in ['auth', 'account', 'password', 'session']:
                    pat = re.compile(r"^\s*" + p + "\s")
                    if pat.search(line):
                        if current.has_key("pam_" + p):
                            current["pam_" + p] += "\t" + line + "\n"
                        else:
                            current["pam_" + p] = line + "\n"
        orig.close()

    types = current.keys()
    types.sort()
    print datetime.datetime.now().strftime("[%Y-%m-%d_%H:%M:%S]")
    for t in types:
        print t + "=" + current[t].strip()


def getProfiles():
    '''Get profiles found in profiles database.  Returns dictionary with
       profile name as key and tuples for fields
    '''
    if not os.path.isdir(profilesDir):
        raise accError("Error: profiles directory does not exist")

    profiles = {}

    # Sort the list and remove the default_profile (we'll add it back in
    # later)
    files = os.listdir(profilesDir)
    try:
        files.remove(default_profile)
    except:
        print >> sys.stderr, "WARNING: '" + default_profile + \
                 "' not found in " + profilesDir + "\n"

    files.sort()

    totalSize = 0
    pat = re.compile(r'^\.')
    for f in [ default_profile ] + files:
        abs = profilesDir + "/" + f
        if not os.path.isfile(abs):
            continue

        if pat.search(f):
            #print >> sys.stderr, "** WARNING: Skipping hidden file '" + f + "'"
            continue

        size = 0
        try:
            size = os.stat(abs)[ST_SIZE]
        except:
            print >> sys.stderr, "** WARNING: Skipping '" + f + \
                     "' (couldn't stat)"
            continue

        if size > MaxFileSize:
            print >> sys.stderr, "** WARNING: Skipping '" + f + "' (too big)"
            continue

        if totalSize + size > MaxFileSize:
            print >> sys.stderr, "** WARNING: Skipping '" + f + \
                     "' (too many files read so far)"
            continue
        totalSize += size

        cdict = RawConfigParser()
        try:
            cdict.read(abs)
        except:
            print >> sys.stderr, "** WARNING: Skipping '" + f + \
                     "' (couldn't process)"
            continue

        # If multiple occurences of profile name, use the last one
        for p in cdict.sections():
            skip = False
            for key, value in cdict.items(p):
                if len(p) > 64:
                    print >> sys.stderr, "WARNING: invalid profile name " + \
                             "(too long).  Skipping"
                    skip = True
                    break
                if len(key) > 64:
                    print >> sys.stderr, "WARNING: invalid field for '" + p + \
                             "' (too long).  Skipping"
                    skip = True
                    break
                if len(value) > 1024:
                    print >> sys.stderr, "WARNING: invalid value for '" + p + \
                             ":" + key + "' (too long).  Skipping"
                    skip = True
                    break
            if skip:
                continue

            if profiles.has_key(p):
                print >> sys.stderr, "WARNING: duplicate profile '" + p + \
                         "' found in '" + f + "' (will use last one found)"

            profiles[p] = cdict.items(p)

    return profiles


def getProfileNames():
    '''Returns names of profiles'''
    names = []
    try:
        names = getProfiles().keys()
    except:
        raise
    return names


def openFiles(f):
    '''Opens the specified file and a temporary file'''
    size = 0
    try:
        size = os.stat(f)[ST_SIZE]
    except:
        raise

    if size > MaxFileSize:
        raise accError("'" + f + "' is too big")

    try:
        orig = open(f, 'r')
    except:
        raise

    try:
        (tmp, tmpname) = mkstemp()
    except:
        orig.close()
        raise

    return { "orig": orig, "origname": f, "tmp": tmp, "tmpname": tmpname }


def closeFiles(fns, update = True):
    '''Closes the specified files (as returned by openFiles), and update
       original file with the temporary file
    '''
    fns['orig'].close()
    os.close(fns['tmp'])

    if update:
        try:
            shutil.copystat(fns['origname'], fns['tmpname'])
            shutil.copy(fns['tmpname'], fns['origname'])
        except:
            raise

    try:
        os.unlink(fns['tmpname'])
    except:
        raise


def readFile(fn):
    '''Read in lines for fn, and return a list of lines.  fn must already
       be open
    '''
    lines = []
    delim = '\n'

    if fn.newlines != None:
        delim = fn.newlines

    try:
        lines = fn.read(MaxFileSize).split(delim)
    except:
        raise

    # if Unix, then strip out all carriage returns
    if delim == '\n':
        pat = re.compile(r'\r')
        i = 0
        for line in lines:
            lines[i] = pat.sub('', line)
            i += 1
    return lines


def recursive_rm(dirPath):
    '''recursively remove directory'''
    names = os.listdir(dirPath)
    for name in names:
        path = os.path.join(dirPath, name)
        if not os.path.isdir(path):
            os.unlink(path)
        else:
            recursive_rm(path)
    os.rmdir(dirPath)

#
# Classes
#
class accError(Exception):
    '''Represents auth-client-config exceptions'''
    def __init__(self, value):
        self.value = value

    def __str__(self):
        return repr(self.value)


class acc_Type:
    '''Interface for various types'''
    def __init__(self, type, dry, db):
        self.config = ""
        self.dryrun = dry
        self.dbonly = db
        self.type = type
        self.updates = {}
        self.profiles = {}
        try:
            self.profiles = getProfiles()
        except:
            raise

    def inDatabase(self, field, line):
        '''Checks if line is in the profiles database for a particular field'''
        found = []
        pat = re.compile(r"\s+")
        for m in self.profiles:
            for key, value in self.profiles[m]:
                if field == key and \
                   pat.sub(' ', line.strip()) == pat.sub(' ', value.strip()):
                    found.append(m)

        if debug and len(found) > 0:
            os.write(sys.stderr.fileno(), "inDatabase: found '" + field + \
                     "', line '" + line + "' in " + str(found) + "\n")

        return found

    def determineState(self, error, doComment, doArchive):
        '''Return state of finit-state-machine'''
        if error:
            return "S_3 or S_7"
        if doComment:
            if doArchive:
                return "S_6"
            return "S_2"
        return "S_1, S_5, or S_8"

    def findActions(self, ncomments, indb, reset = False):
        '''Return actions tuple based on state machine'''
        error = False
        doComment = False
        doArchive = False
        if ncomments < 1:            # S_0
            if reset:                # S_3
                error = True
            elif indb:               # S_1
                pass
            elif self.dbonly:        # S_3
                error = True
            else:                    # S_2
                doComment = True
        else:                        # S_4
            if indb:                 # S_5
                # Reset may need to check ncomments on its own to
                # determine S_7 vs S_8
                pass
            else:
                if reset:            # S_7
                    error = True
                elif self.dbonly:    # S_7
                    error = True
                else:                # S_6
                    doComment = True
                    doArchive = True

        if debug:
            args = " when comm=" + str(ncomments) + " indb=" + str(indb) + \
                   " reset=" + str(reset) + " dbonly=" + str(self.dbonly)
            os.write(sys.stderr.fileno(), "found state:" + \
                     self.determineState(error, doComment, doArchive) + args \
                     + "\n")

        return (error, doComment, doArchive)

    def setConfig(self, f):
        '''Set the location of the configuration file'''
        if not os.path.isfile(f):
            raise accError("'" + f + "' is not a file")
        if not os.access(f, os.W_OK):
            raise accError("'" + f + "' is not writable")
        if not stat.S_ISREG(os.stat(f)[stat.ST_MODE]):
            raise accError("Not a regular file")

        try:
            self.verifyConfig(f)
        except:
            raise

        self.config = f

    def resetConfig(self):
        '''Restores commented out lines in configuration file'''
        try:
            fns = openFiles(self.config)
        except:
            raise

        # Write to stdout or to tmpfile
        if self.dryrun:
            fd = sys.stdout.fileno()
        else:
            fd = fns['tmp']

        lines = readFile(fns['orig'])

        # This is technically covered by profileIsCurrent, but let's leave
        # here until sure we want to enforce profile with reset in main
        # script
        indb = False
        if len(self.findProfiles(lines)) > 0:
            indb = True

        (error, foo, bar) = self.findActions(hasComments(lines), indb, True)
        if error:
            closeFiles(fns)
            if hasComments(lines) < 1:
                raise accError("No previous settings found in current file")
            else:
                raise accError("Current settings not in database")


        # First time through, see what we commented out, and store
        # the first token after our comment string for later use
        types = []
        pat = re.compile(r'^' + commentStr + '\s*')
        pat_comment_sentinel = re.compile(r'^' + commentStr + '\s+#')
        for line in lines:
            if pat.search(line) and not pat_comment_sentinel.search(line):
                tmp = re.split(r'\s*', pat.sub('', line))
                types.append(tmp[0])

        # i is used to keep us from adding an extra newline at
        # end every time we run
        i = 0
        for line in lines:
            i += 1

            skip = False
            # If the line starts with one of the tokens we found
            # above, skip the line, otherwise, print the line
            # whilst stripping out any comments
            for start in types:
                pat_begin = re.compile(r'^' + start + '[:\s]')
                if pat_begin.search(line):
                    skip = True

            if not skip and i < len(lines):
                os.write(fd, stripComment(line) + "\n")

        try:
            closeFiles(fns)
        except:
            raise

    def setProfile(self, m):
        if not self.profiles.has_key(m):
            raise accError("Profile '" + m + "' not found in " + profilesDir)
        for key, value in self.profiles[m]:
            self.updates[key] = value

    def profileIsCurrent(self, name):
        try:
            fns = openFiles(self.config)
        except:
            raise

        isCurrent = False
        lines = readFile(fns['orig'])
        for p in self.findProfiles(lines):
            if p == name:
                isCurrent = True

        try:
            # don't update the files when closing
            closeFiles(fns, False)
        except:
            raise

        return isCurrent

    # API overrides
    def findProfiles(self, lines):    # should be overriden
        '''Returns profiles contained in lines'''
        raise accError("acc_Type.findProfiles: Need to override findProfiles")

    def updateConfig(self):        # should be overidden
        '''Update configuration file'''
        raise accError("acc_Type.updateConfig: Need to override updateConfig")

    def verifyConfig(self, file):    # should be overidden
        '''Verify configuration'''
        raise accError("acc_Type.verifyConfig: Need to override verifyConfig")


class acc_PAM(acc_Type):
    '''Represents PAM module type'''
    def __init__(self, t, n, d):
        try:
            acc_Type.__init__(self, t, n, d)
        except:
            raise

    def findProfiles(self, lines):
        '''Returns profiles contained in lines'''
        current = {}

        # Gather the like lines, and put them in dictionary
        for line in lines:
            for t in ['auth', 'account', 'password', 'session']:
                pat = re.compile(r"^\s*" + t + "\s")
                if pat.search(line):
                    if current.has_key(t):
                        current[t] += line + "\n"
                    else:
                        current[t] = line + "\n"

        # Check to see if entries in dictionary match the database
        entries = {}
        for t in current:
            for x in self.inDatabase("pam_" + t, current[t]):
                if entries.has_key(x):
                    entries[x] += 1
                else:
                    entries[x] = 1

        return entries.keys()

    def updateConfig(self):
        '''Update configuration file'''
        if self.type == 'auth' and not self.updates.has_key('pam_auth'):
            raise accError("'pam_auth' not found")
        if self.type == 'account' and not self.updates.has_key('pam_account'):
            raise accError("'pam_account' not found")
        if self.type == 'password' and not self.updates.has_key('pam_password'):
            raise accError("'pam_password' not found")
        if self.type == 'session' and not self.updates.has_key('pam_session'):
            raise accError("'pam_session' not found")
        try:
            fns = openFiles(self.config)
        except:
            raise

        # Write to stdout or to tmpfile
        if self.dryrun:
            fd = sys.stdout.fileno()
        else:
            fd = fns['tmp']

        lines = readFile(fns['orig'])

        indb = False
        if len(self.findProfiles(lines)) > 0:
            indb = True

        (error, doComment, doArchive) = self.findActions(hasComments(lines), \
                                                         indb, False)
        if error:
            closeFiles(fns, False)
            raise accError("Current settings not in database, but " + \
                           "database-only specified.  Skipping 'pam-" + \
                           self.type + "'")

        pat = re.compile(r"^\s*" + self.type + "\s*")
        i = 0
        for line in lines:
            i += 1

            # If doArchive and this line is a comment
            if doArchive and hasComments([line]) > 0:
                os.write(fd, "# " + stripComment(line) + "\n")
                continue

            if pat.search(line) or hasDebianSentinels([line]) > 0:
                if doComment:
                    os.write(fd, createComment(line) + "\n")
                continue

            if i < len(lines):
                os.write(fd, line + "\n")

        os.write(fd, self.updates['pam_' + self.type] + "\n")

        try:
            closeFiles(fns)
        except:
            raise

    def verifyConfig(self, file):
        '''Verify configuration'''
        try:
            fns = openFiles(file)
        except:
            raise

        lines = readFile(fns['orig'])
        entries = 0

        pat = re.compile(r"^\s*" + self.type + "\s")
        for line in lines:
            if pat.search(line):
                entries += 1

        try:
            # Don't update the files when closing
            closeFiles(fns, False)
        except:
            raise

        if entries < 1:
            raise accError("'" + os.path.basename(file) + \
                           "' doesn't have any entries for '" + self.type + \
                           "'")


class acc_NSS(acc_Type):
    '''Represents Name Service switch file'''
    def __init__(self, t, n, d):
        try:
            acc_Type.__init__(self, t, n, d)
        except:
            raise

    def findProfiles(self, lines):
        '''Returns profiles contained in lines'''
        entries = {}
        for line in lines:
            for t in ['passwd', 'group', 'shadow', 'netgroup']:
                pat = re.compile(r"^\s*" + t + ":")
                if pat.search(line):
                    for x in self.inDatabase("nss_" + t, line):
                        if entries.has_key(x):
                            entries[x] += 1
                        else:
                            entries[x] = 1

        # Since an nss profile must have 4 matching fields, check our
        # dictionary for profiles with 4 matches
        found = []
        if len(entries) != 0:
            for k in entries:
                if entries[k] == 4:
                    found.append(k)

        return found

    def resetConfig(self):
        '''Restores commented out lines in configuration file'''
        try:
            fns = openFiles(self.config)
        except:
            raise

        lines = readFile(fns['orig'])

        pat = re.compile(r'^' + commentStr + '\s*')

        error = ""
        types = []
        for line in lines:
            if pat.search(line):
                tmp = re.split(r'\s*', pat.sub('', line))

                # If found another comment for this field, error out, since
                # we can't automatically recover (n > 1 in state machine).
                # This syntax requires python >= 2.3
                if tmp[0] in types:    # S_7 in state diagram
                    error = "Too many previous configurations found.  " + \
                            "Please fix manually."
                    break

                types.append(tmp[0])

        if len(self.findProfiles(lines)) < 1:
            error = "No matching profile found for existing entries.  " + \
                    "Please reset them manually."

        try:
            # Don't update the file yet
            closeFiles(fns, False)
        except:
            raise

        if error != "":
            raise accError(error)


        # If we made it here, then call our parent to actually update
        # the file (S_8 in state diagram)
        acc_Type.resetConfig(self)

    def updateConfig(self):
        '''Update configuration file'''

        # Check that our config file has good types
        for t in ['passwd', 'group', 'shadow', 'netgroup']:
            if not self.updates.has_key('nss_' + t):
                raise accError("'nss_" + t + "' not found")

        try:
            fns = openFiles(self.config)
        except:
            raise

        # Write to stdout or to tmpfile
        if self.dryrun:
            fd = sys.stdout.fileno()
        else:
            fd = fns['tmp']

        lines = readFile(fns['orig'])

        indb = False
        if len(self.findProfiles(lines)) > 0:
            indb = True

        (error, doComment, doArchive) = self.findActions(hasComments(lines), \
                                                         indb, False)
        if error:
            closeFiles(fns, False)
            raise accError("Current settings not in database, but " + \
                           "database-only specified.  Skipping 'nss'")

        i = 0
        for line in lines:
            i += 1

            # If doArchive and this line is a comment
            if doArchive and hasComments([line]) > 0:
                os.write(fd, "# " + stripComment(line) + "\n")
                continue

            wrote_line = False
            for t in ['passwd', 'group', 'shadow', 'netgroup']:
                pat = re.compile(r"^\s*" + t + ":")
                if pat.search(line):
                    if doComment:
                        os.write(fd, createComment(line) + "\n")
                    os.write(fd, self.updates['nss_' + t] + "\n")
                    wrote_line = True
                    break

            if not wrote_line and i < len(lines):
                os.write(fd, line + "\n")

        try:
            closeFiles(fns)
        except:
            raise

    def verifyConfig(self, file):
        '''Verify configuration'''
        try:
            fns = openFiles(file)
        except:
            raise

        lines = readFile(fns['orig'])
        entries = {}
        for line in lines:
            for t in ['passwd', 'group', 'shadow', 'netgroup']:
                pat = re.compile(r"^\s*" + t + ":")
                if pat.search(line):
                    if entries.has_key(t):
                        entries[t] += 1
                    else:
                        entries[t] = 1

        try:
            # Don't update the files when closing
            closeFiles(fns, False)
        except:
            raise

        for k in ['passwd', 'group', 'shadow', 'netgroup']:
            if not entries.has_key(k):
                raise accError("'" + os.path.basename(file) + \
                               "' doesn't have an entry for '" + k + "'")

            if entries[k] > 1:
                raise accError("'" + os.path.basename(file) + \
                               "' has multiple entries for '" + k + "'")


def process_args():
    '''Process comman line arguments'''
    try:
        profiles = getProfileNames()
    except:
        raise

    usage = "%prog -p PROFILE -a -t TYPE [-dn -f FILE]\n       " + \
            "%prog -p PROFILE -a -t TYPE -r [-n -f FILE]\n       " + \
            "%prog -p PROFILE -a -t TYPE -s [-f FILE]"
    description = "This program updates nsswitch.conf and pam " + \
                  "configuration files to aid in authentication configuration."
    parser = OptionParser(usage=usage, \
                          version="%prog: " + version, \
                          description=description)
    parser.add_option("-a", "--all-types", \
                      action="store_true", \
                      dest="applyall", \
                      help="apply all types for specified profile")
    parser.add_option("-d", "--database-only", \
                      action="store_true", \
                      dest="dbonly", \
                      help="update only if current entries are in database")
    parser.add_option("-f", "--file", \
                      dest="file", \
                      help="update FILE instead of default", \
                      metavar="FILE")
    parser.add_option("-l", "--list-profiles", \
                      action="store_true", \
                      dest="listprofiles", \
                      help="list available profiles")
    parser.add_option("-L", "--list-types", \
                      action="store_true", \
                      dest="listtypes", \
                      help="list available types")
    parser.add_option("-n", "--dry-run", \
                      action="store_true", \
                      dest="dryrun", \
                      help="don't modify anything, just show the changes")
    parser.add_option("-p", "--profile", \
                      dest="profile", \
                      help="set profile to PROFILE", \
                      metavar="PROFILE", \
                      choices=profiles)
    parser.add_option("-r", "--reset", \
                      action="store_true", \
                      dest="reset", \
                      help="reset to previous non-" + programName + " values")
    parser.add_option("-s", "--check-system", \
                      action="store_true", \
                      dest="system", \
                      help="determine if system files are set to PROFILE")
    parser.add_option("-S", "--show-system", \
                      action="store_true", \
                      dest="showsystem", \
                      help="show current system settings as a profile")
    parser.add_option("-t", "--type", \
                      dest="type", \
                      help="modify files for TYPE", \
                      metavar="TYPE")

    (options, args) = parser.parse_args()

    if options.listprofiles:
        print "Available profiles are:"
        profiles.sort()
        for m in profiles:
            print "  " + m
        sys.exit(0)

    if options.showsystem:
        showSystem()
        sys.exit(0)

    if options.listtypes:
        print "Available types are:"
        types = files.keys()
        types.sort()
        for t in types:
            print "  " + t
        sys.exit(0)

    if options.dbonly and options.reset:
        raise accError ("\nCannot specify 'database-only' when using 'reset'")
        sys.exit(1)

    if not options.profile:
        raise accError ("\n'profile' is required")
        sys.exit(1)

    if options.applyall:
        if options.file:
            raise accError ("\nCannot specify 'file' when using 'apply all'")
            sys.exit(1)
        if options.type:
            raise accError ("\nCannot specify 'type' when using 'apply all'")
            sys.exit(1)
    else:
        if not options.type:
            raise accError ("\neither '-t' or '-a' is required")
            sys.exit(1)

    if options.type:
        for t in options.type.split(','):
            if not files.has_key(t):
                err = "\nInvalid type in '%s'. Valid types are:" % \
                      (options.type)
                for v in files.keys():
                    err += "\n  %s" % (v)
                raise accError (err)
                sys.exit(1)

    return options

def doChecks():
    '''Perform some security checks'''

    # Does the following checks:
    #     is setuid or setgid (for non-Linux systems)
    #    checks that if run by root, then script is owned by root
    #     checks that profilesDir is a directory
    #    checks that profilesDir and script isn't in a hidden directory
    #        somewhere
    #    checks that if run by root, then every component in absolute
    #        paths are owned by root
    #    checks that if run by root, every component of absolute paths
    #        are not a symlink
    #    checks for symlinks and perms of files in profileDir (defer
    #        hidden checks to getProfiles())
    #    warn if script is group writable
    #    warn if profilesDir or part of script path are group writable
    #
    # Doing this at the beginning causes a race condition with later
    # operations that don't do these checks.  However, if the user running
    # this script is root, then need to be root to exploit the race 
    # condition (and you are hosed anyway...)

    # Not needed on Linux, but who knows the places we will go...
    if os.getuid() != os.geteuid():
        raise accError("ERROR: this script should not be SUID")
    if os.getgid() != os.getegid():
        raise accError("ERROR: this script should not be SGID")
    uid = os.getuid()

    try:
        statinfo = os.stat(os.path.abspath(sys.argv[0]))
        mode = statinfo[ST_MODE]
    except:
        raise

    if uid == 0:
        if statinfo.st_uid != 0:
            raise accError("ERROR: script not owned by root!")
        if mode & S_IWOTH:
            raise accError("ERROR: script is world writable!")
        if mode & S_IWGRP:
            os.write(sys.stderr.fileno(), \
                     "** WARNING: script is group writable **\n\n")

    if not os.path.isdir(profilesDir):
        raise accError("ERROR: profiles directory does not exist")

    pat = re.compile(r'^\.')
    for dir in [ os.path.dirname(os.path.abspath(sys.argv[0])), \
                                 os.path.abspath(profilesDir) ]:
        while True:
            if pat.search(os.path.basename(dir)):
                raise accError("ERROR: found hidden directory in path: " + dir)

            try:
                statinfo = os.stat(dir)
                mode = statinfo[ST_MODE]
            except:
                raise

            if uid == 0:
                if os.path.islink(dir):
                    raise accError("ERROR: found symbolic link in path: " + dir)

                if statinfo.st_uid != 0:
                    raise accError("ERROR: uid is " + str(uid) + " but '" + \
                                   dir + "' is owned by " + \
                                   str(statinfo.st_uid))
                # Check group writable
                if mode & S_IWOTH:
                    raise accError("ERROR: " + dir + " is world writable!")
                if mode & S_IWGRP:
                    os.write(sys.stderr.fileno(), "** WARNING: " + dir + \
                             " is group writable **\n\n")

            # Exit loop after processing '/'
            if dir == "/":
                break

            dir = os.path.dirname(dir)
            if not dir:
                raise

    # Now check the files in profilesDir
    pat = re.compile(r'^\.')
    files = os.listdir(profilesDir)
    for f in files:
        abs = profilesDir + "/" + f
        try:
            statinfo = os.stat(abs)
            mode = statinfo[ST_MODE]
        except:
            raise

        if uid == 0:
            if os.path.islink(f):
                raise accError("ERROR: found symbolic link: " + f)

            if statinfo.st_uid != 0:
                raise accError("ERROR: uid is " + str(uid) + " but '" + f + \
                               "' is owned by " + str(statinfo.st_uid))
            # Check group writable
            if mode & S_IWOTH:
                raise accError("ERROR: " + f + " is world writable!")
            if mode & S_IWGRP:
                os.write(sys.stderr.fileno(), "** WARNING: " + f + \
                         " is group writable **\n\n")


#
# MAIN SCRIPT STARTS HERE
#
try:
    if not insecure:
        doChecks()
except accError, e:
    print >> sys.stderr, e.value + "\nAborting."
    sys.exit(1)

try:
    opts = process_args()
except accError, e:
    print >> sys.stderr, e.value + "\n"
    sys.exit(1)

services = {}
if opts.applyall:
    services['nss'] = acc_NSS("nss", opts.dryrun, opts.dbonly)
    services['pam-auth'] = acc_PAM("auth", opts.dryrun, opts.dbonly)
    services['pam-account'] = acc_PAM("account", opts.dryrun, opts.dbonly)
    services['pam-password'] = acc_PAM("password", opts.dryrun, opts.dbonly)
    services['pam-session'] = acc_PAM("session", opts.dryrun, opts.dbonly)
else:
    if opts.type == "nss":
        services['nss'] = acc_NSS("nss", opts.dryrun, opts.dbonly)
    else:
        pat = re.compile(r"^pam-")
        for t in opts.type.split(','):
            services[t] = acc_PAM(pat.sub('', t), opts.dryrun, opts.dbonly)

# Work in a temporary directory first
try:
    tmpdir = mkdtemp()
except:
    raise

tmpfiles = {}
error = False
current = False

# Sort the keys for consistency
service_keys = services.keys()
service_keys.sort()
for service in service_keys:
    if not services[service]:
        print >> sys.stderr, "Problem initializing '" + service + "'\n"
        error = True
        break

    if len(services) == 1 and opts.file:
        if not os.path.isfile(opts.file):
            print >> sys.stderr, "'" + opts.file + "' does not exist\n"
            error = True
            break
        files[service] = opts.file

    # Set the config file for service (in tmpdir)
    tmp = os.path.join(tmpdir, os.path.basename(files[service]))

    try:
        shutil.copy(files[service], tmp)
        shutil.copystat(files[service], tmp)
    except:
        print >> sys.stderr, "Error in creating temporary file"
        error = True
        break

    try:
        services[service].setConfig(tmp)
        tmpfiles[service] = tmp
    except accError, e:
        print >> sys.stderr, "Error in setting the file: " + e.value
        error = True
        break

    if opts.reset:
        try:
            if not services[service].profileIsCurrent(opts.profile):
                raise accError("'" + opts.profile + \
                               "' does not match system settings")
            services[service].resetConfig()
        except accError, e:
            print >> sys.stderr, "Error in resetting '" + service + "': " + \
                     e.value
            error = True
            break
        except:
            if debug:
                recursive_rm(tmpdir)
                raise
            print >> sys.stderr, "Error in resetting '" + service + "'"
            error = True
            break
    elif opts.system:
        try:
            current = services[service].profileIsCurrent(opts.profile)
            if not current:
                break
        except accError, e:
            print >> sys.stderr, "Error in testing '" + service + "': " + \
                     e.value
            error = True
            break
        except:
            if debug:
                recursive_rm(tmpdir)
                raise
            print >> sys.stderr, "Error in testing '" + service + "'"
            error = True
            break
    else:
        try:
            services[service].setProfile(opts.profile)
            services[service].updateConfig()
        except accError, e:
            print >> sys.stderr, "Error in updating the file: " + e.value
            error = True
            break
        except:
            if debug:
                recursive_rm(tmpdir)
                raise
            print >> sys.stderr, "Error in updating the file"
            error = True
            break

if opts.system:
    recursive_rm(tmpdir)
    if error:
        print >> sys.stderr, "--\nErrors found.  Aborting"
        sys.exit(2)
    elif not current:
        sys.exit(1)
    sys.exit(0)

if error:
    print >> sys.stderr, "--\nErrors found.  Aborting (no changes made)"
    recursive_rm(tmpdir)
    sys.exit(1)

# First verify that the tmpfiles are valid
error = False
for service in services:
    if services[service].dryrun:
        continue

    try:
        services[service].verifyConfig(tmpfiles[service])
    except accError, e:
        print >> sys.stderr, "ERROR: " + e.value
        print >> sys.stderr, "--\nErrors found.  Aborting (no changes made)"
        recursive_rm(tmpdir)
        sys.exit(1)
    except:
        raise

# Now copy the files over
error = False
for service in services:
    if services[service].dryrun:
        continue

    # If not a dry run, copy the files in tmpdir to the real location
    orig = files[service]
    try:
        shutil.copystat(orig, tmpfiles[service])
        shutil.copy(tmpfiles[service], orig)
    except:
        print >> sys.stderr, "Error: '" + orig + "' not updated"
        error = True

# Clean up
recursive_rm(tmpdir)

if error:
    sys.exit(1)

sys.exit(0)

