#!/usr/bin/python3
#                                                         -*- coding: utf-8 -*-
#
# A simple utility to tranfer volume_key escrow packets to a server
#
# Copyright (C) 2010 Marko Myllynen <myllynen@redhat.com>
# Copyright (C) 2018 Jiří Kučera <jkucera@redhat.com>
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
# as published by the Free Software Foundation; either version 2
# of the License, or (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, see <http://www.gnu.org/licenses/>.
#
"""A simple utility to tranfer volume_key escrow packets to a server"""

import subprocess
import argparse
import smtplib
import hashlib
import string
import random
import base64
import sys
import os
import io

from email.encoders import encode_base64
from email.mime.base import MIMEBase
from email.mime.text import MIMEText
from email.mime.multipart import MIMEMultipart

# Script version
VERSION = "0.1"

# Configuration
MAIL_RECIP = "escrow@localhost.localdomain"
SMTP_SERVER = "localhost.com"
ESCROW_DIR = "/root"

# Rudimentary intranet connectivity check against a host, None to disable
TEST_HOST_CONN = "intra.example.com"
#TEST_HOST_CONN = None

me = os.path.basename(sys.argv[0])
here = os.path.dirname(os.path.realpath(__file__))
bye = sys.exit

def tostr(s):
    if isinstance(s, str):
        return s
    elif isinstance(s, bytes):
        return str(s, encoding='latin')
    return str(s)

def fwrite(f, nl, fmt, *args):
    f.write(fmt.format(*args))
    if nl:
        f.write('\n')
    else:
        f.flush()

def wout(fmt, *args):
    fwrite(sys.stdout, False, fmt, *args)

def pout(fmt, *args):
    fwrite(sys.stdout, True, fmt, *args)

def perr(fmt, *args):
    fwrite(sys.stderr, True, fmt, *args)

def error(ec, fmt, *args):
    perr(fmt, *args)
    bye(ec)

# Command line parser
def parse_cmd_line():
    """Parse command line arguments"""

    parser = argparse.ArgumentParser(description=__doc__, allow_abbrev=False)
    parser.add_argument(
        "-V", "--version", dest="version",
        action="store_const", const=1, default=0,
        help="show program's version number and exit"
    )
    parser.add_argument(
        "-v", "--verbose", dest="verbose",
        action="store_const", const=1, default=0,
        help="verbose execution"
    )
    parser.add_argument(
        "-b", "--batch", dest="batch",
        action="store_const", const=1, default=0,
        help="run as batch (including underlying programs), invert -y"
    )
    parser.add_argument(
        "-c", "--public-cert", metavar="FILE", dest="cert",
        default="",
        help="public certificate file to create a new packet"
    )
    parser.add_argument(
        "-f", "--use-file", metavar="FILE", dest="file",
        default="",
        help="use an existing escrow packet file"
    )
    parser.add_argument(
        "-r", "--mail-recipient", metavar="ADDRESS", dest="recip",
        default=MAIL_RECIP,
        help="mail recipient, default: {}".format(MAIL_RECIP)
    )
    parser.add_argument(
        "-s", "--mail-sender", metavar="ADDRESS", dest="sender",
        default="",
        help="e-mail sender"
    )
    parser.add_argument(
        "-x", "--remove-files", dest="remove",
        action="store_const", const=1, default=0,
        help="remove all used files after successful execution"
    )
    parser.add_argument(
        "-y", "--assume-yes", dest="all_yes",
        action="store_const", const=1, default=0,
        help="non-interactive mode, no confirmations asked"
    )
    parser.add_argument(
        "-p", "--fs-prefix", metavar="PATH_PREFIX", dest="fs_prefix",
        default="",
        help="prefix of /proc and /dev (allow testing)"
    )
    parser.add_argument(
        "-H", "--smtp-host", metavar="HOST[:PORT]", dest="smtp_host",
        default=SMTP_SERVER,
        help="SMTP server address, default: {}".format(SMTP_SERVER)
    )
    parser.add_argument(
        "-I", "--intranet-host", metavar="HOST[:PORT]", dest="intranet_host",
        default=TEST_HOST_CONN,
        help="host to which the connectivity is checked against, default: {}" \
             .format(repr(TEST_HOST_CONN))
    )
    parser.add_argument(
        "-d", "--escrow-dir", metavar="PATH", dest="escrow_dir",
        default=ESCROW_DIR,
        help="directory the escrow packets are written to, default: {}" \
             .format(ESCROW_DIR)
    )
    parser.add_argument(
        "--only-stdout", dest="only_stdout",
        action="store_const", const=1, default=0,
        help="all messages goes to stdout"
    )
    parser.add_argument(
        'args', nargs=argparse.REMAINDER, help=argparse.SUPPRESS
    )
    return parser.parse_known_args()

def get_host_and_port(host, port):
    i = host.rfind(':')
    if i < 0:
        return host, port
    return host[:i], int(host[i + 1:])

# Check connectivity to a host
def host_conn_check(host):
    import socket
    try:
        host, port = get_host_and_port(host, 80)
        socket.getaddrinfo(host, port, 0, 0, socket.SOL_TCP)
        return True
    except:
        return False

def yes_or_no_or_cancel(prompt):
    info = ""
    while True:
        answer = input("{}{} ([y]es/[n]o/[c]ancel) ".format(info, prompt))
        info = ""
        if not answer:
            continue
        if answer[0] in ('y', 'n', 'c'):
            return answer[0]
        info = "Incorrect answer {}. ".format(repr(answer))

def get_and_verify_email_address(
    default, host, suggestion=None, side=None, all_yes=False, verbose=False
):
    side = " {} ".format(side) if side else " "

    email = None
    if suggestion:
        email = suggestion
    if email is None and default:
        email = default

    all_yes_ = all_yes
    while True:
        while not email:
            email = input("Please enter{}e-mail address: ".format(side))
        answer = 'y' if all_yes_ else yes_or_no_or_cancel(
            "Using{}address '{}' - is this ok?".format(side, email)
        )
        all_yes_ = all_yes
        if answer == "n":
            email = None
            continue
        elif answer == "c":
            return None
        # Answer is yes.
        with smtplib.SMTP(host) as server:
            reply = server.verify(email)
        if reply[0] not in (250, 251, 252):
            perr("{} reply: {} {}", host, reply[0], tostr(reply[1]))
            perr("{}: verification failed - address unknown.", email)
            # e-mail address was invalid, prompt the user for another one;
            # we clear all_yes temporarily to get the user chance to quit the
            # endless loop in the case the SMTP server has "deny everyone"
            # policy
            all_yes_ = False
            email = None
            continue
        if verbose:
            pout("Tentatively verified{}address {}.", side, email)
        break
    return email

# Get the LUKS partition
def get_luks_partition(p, verbose=False):
    luks_parts = []
    with open('{}/proc/partitions'.format(p), 'r') as f:
        lines = f.readlines()
    for part in lines:
        if not part.strip() or "#blocks" in part:
            continue
        dev = "{}/dev/{}".format(p, part.strip().split()[-1])
        if not os.access(dev, os.R_OK | os.W_OK):
            error(2, "Only root can add new backup passphrases.")
        if subprocess.call(
            ['cryptsetup', 'isLuks', dev],
            stdout=subprocess.DEVNULL, stderr=subprocess.STDOUT
        ) == 0:
            luks_parts.append(dev)
    if len(luks_parts) == 0:
        error(2, "No LUKS partitions detected.")
    if len(luks_parts) > 1:
        perr("Multiple LUKS partitions detected.")
        if verbose:
            pout(
                "Create the escrow packet(s) manually " \
                "with volume_key(1) and use the -f option."
            )
        bye(2)
    return luks_parts[0]

# Main
def main():
    """A simple utility to tranfer volume_key escrow packets to a server"""
    opts, args = parse_cmd_line()

    if opts.only_stdout:
        sys.stderr = sys.stdout
    if opts.batch:
        opts.all_yes = 1 - opts.all_yes

    if opts.version:
        pout("{} {}", me, VERSION)
        bye(0)

    if len(args) > 0 or ((not opts.file) is (not opts.cert)):
        error(1, "{}: use -h for help", me)

    # Test intranet connectivity against the defined host
    if opts.intranet_host:
        if opts.verbose:
            wout("Testing intranet connectivity against {}... ",
                opts.intranet_host
            )
        if not host_conn_check(opts.intranet_host):
            if opts.verbose:
                pout("failed.")
            pout(
                "You must be connected to intranet in order to use this tool."
            )
            bye(1)
        if opts.verbose:
            pout("ok.")

    # Set sender email address
    sender = get_and_verify_email_address(
        None, opts.smtp_host, opts.sender, "sender",
        opts.all_yes, opts.verbose
    )
    if not sender:
        error(1, "{}: missing sender e-mail address (canceled by user)", me)

    # Set recipient address
    recipient = get_and_verify_email_address(
        MAIL_RECIP, opts.smtp_host, opts.recip, "recipient",
        opts.all_yes, opts.verbose
    )
    if not recipient:
        error(1, "{}: missing recipient e-mail address (canceled by user)", me)

    escrow = None
    # Use the specified escrow file
    if opts.file:
        with open(opts.file, 'rb'):
            pass
        escrow = opts.file
    else:
    # Create a new backup passphrase escrow file
        part = get_luks_partition(opts.fs_prefix, opts.verbose)
        escrow = "{}/escrow{}-backup-passphrase-{}".format(
            opts.escrow_dir, part.replace("/", "-"),
            "".join(random.sample(string.ascii_letters + string.digits, 6))
        )
        pout(
            "About to add a random backup passphrase to LUKS partition {}.",
            part
        )
        with open(opts.cert, 'rb'):
            pass
        accept = "y"
        if not opts.all_yes:
            accept = input("Is this ok? (y/n) ")
        if accept != "y":
            if opts.verbose:
                pout("Operation aborted.")
            bye(3)

        # Add a random backup passphrase to the LUKS device
        pout("Current passphrase for {} needed to unlock the device.", part)
        command = ['volume_key']
        kwargs = {}
        if opts.batch:
            command.append('--batch')
            # jkucera: extract passphrase from stdin manually; constructions
            #          like `subprocess.call(command)` or
            #          `subprocess.call(command, stdin=...)` makes volume_key
            #          complaining about missing passphrase
            idata = ""
            while True:
                idata += sys.stdin.read(1)
                if not idata or idata[-1] == '\0':
                    break
            kwargs = dict(input=bytes(idata, encoding=sys.stdin.encoding))
        command.extend([
            '--save', part,
            '-c', opts.cert,
            '--create-random-passphrase', escrow,
            '--output-format', 'asymmetric'
        ])
        if subprocess.run(command, **kwargs).returncode != 0:
            error(4, "Operation failed.")
        if opts.verbose:
            pout("Random passphrase has been successfully added to {}.", part)
        pout("Encrypted packet stored at {}.", escrow)

    # Read in the escrow packet as a base64 encoded binary attachment
    if opts.verbose:
        pout("Using escrow packet {}.", escrow)
    with open(escrow, 'rb') as f:
        packet = MIMEBase('application', 'octet-stream')
        packet.set_payload(f.read())
    encode_base64(packet)
    packet.add_header(
        'Content-Disposition', 'attachment; filename="{}"'.format(escrow)
    )
    with open(escrow, 'rb') as f:
        md5_sum = hashlib.md5(f.read()).hexdigest()

    # Prepare the transfer
    msg = MIMEMultipart()
    msg['Subject'] = "Backup Passphrase Escrow Packet from {}".format(sender)
    msg['From'] = sender
    msg['To'] = recipient
    msg.attach(MIMEText(
        "Backup passphrase escrow packet from: {}\n" \
        "\n" \
        "Packet MD5 checksum: {}\n".format(sender, md5_sum),
        'plain'
    ))
    msg.attach(packet)

    # Transfer the packet
    rc = 0
    s = smtplib.SMTP(opts.smtp_host)
    # NOTE(jkucera): Always raise smtplib.SMTPRecipientsRefused for just one
    #                recipient which is refused.
    try:
        s.sendmail(sender, recipient, msg.as_string())
    except smtplib.SMTPRecipientsRefused:
        rc = 5
    if rc == 0:
        pout("Escrow packet successfully sent to {}.", recipient)
    s.quit()

    # Remove the packet if requested
    removed = False
    if opts.remove and rc == 0:
        accept = "y"
        if not opts.all_yes:
            accept = input(
                "Removing escrow packet {} - is this ok? (y/n) ".format(escrow)
            )
            if accept != "y":
                if opts.verbose:
                    pout("Escrow packet not removed.")
        if accept == "y":
            os.unlink(escrow)
            removed = True
            pout("Escrow packet {} removed.", escrow)
    if rc == 0 and not removed and opts.verbose:
        pout("Remove the escrow packet after receiving confirmation e-mail.")

    # All done.
    sys.exit(rc)

if __name__ == "__main__":
    main()
