import hashlib
import json
import logging
import os
import socket
import sys
import select

import paramiko
from celery.datastructures import ExceptionInfo
from django.conf import settings as djsettings

from . import settings


logger = logging.getLogger(__name__)

transports = {}


def BashSSH(backend, log, server, cmds):
    from .models import BackendLog
    
    script = '\n'.join(['set -e', 'set -o pipefail'] + cmds + ['exit 0'])
    script = script.replace('\r', '')
    digest = hashlib.md5(script).hexdigest()
    path = os.path.join(settings.ORCHESTRATION_TEMP_SCRIPT_PATH, digest)
    remote_path = "%s.remote" % path
    log.script = '# %s\n%s' % (remote_path, script)
    log.save(update_fields=['script'])
    if not cmds:
        return
    channel = None
    ssh = None
    try:
        logger.debug('%s is going to be executed on %s' % (backend, server))
        # Avoid "Argument list too long" on large scripts by genereting a file
        # and scping it to the remote server
        with os.fdopen(os.open(path, os.O_WRONLY | os.O_CREAT, 0600), 'w') as handle:
            handle.write(script)
        
        # ssh connection
        ssh = paramiko.SSHClient()
        ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())
        addr = server.get_address()
        try:
            # TODO timeout
            ssh.connect(addr, username='root', key_filename=settings.ORCHESTRATION_SSH_KEY_PATH, timeout=10)
        except socket.error:
            logger.error('%s timed out on %s' % (backend, server))
            log.state = BackendLog.TIMEOUT
            log.save(update_fields=['state'])
            return
        transport = ssh.get_transport()
        
        # Copy script to remote server
        sftp = paramiko.SFTPClient.from_transport(transport)
        sftp.put(path, remote_path)
        sftp.chmod(remote_path, 0600)
        sftp.close()
        os.remove(path)
        
        # Execute it
        context = {
            'remote_path': remote_path,
            'digest': digest,
            'remove': '' if djsettings.DEBUG else "rm -fr %(remote_path)s\n",
        }
        cmd = (
            "[[ $(md5sum %(remote_path)s|awk {'print $1'}) == %(digest)s ]] && bash %(remote_path)s\n"
            "RETURN_CODE=$?\n"
            "%(remove)s"
            "exit $RETURN_CODE" % context
        )
        channel = transport.open_session()
        channel.exec_command(cmd)
        
        # Log results
        logger.debug('%s running on %s' % (backend, server))
        if True: # TODO if not async
            log.stdout += channel.makefile('rb', -1).read().decode('utf-8')
            log.stderr += channel.makefile_stderr('rb', -1).read().decode('utf-8')
        else:
            while True:
                # Non-blocking is the secret ingridient in the async sauce
                select.select([channel], [], [])
                if channel.recv_ready():
                    log.stdout += channel.recv(1024)
                if channel.recv_stderr_ready():
                    log.stderr += channel.recv_stderr(1024)
                log.save(update_fields=['stdout', 'stderr'])
                if channel.exit_status_ready():
                    break
        log.exit_code = exit_code = channel.recv_exit_status()
        log.state = BackendLog.SUCCESS if exit_code == 0 else BackendLog.FAILURE
        logger.debug('%s execution state on %s is %s' % (backend, server, log.state))
        log.save()
    except:
        log.state = BackendLog.ERROR
        log.traceback = ExceptionInfo(sys.exc_info()).traceback
        logger.error('Exception while executing %s on %s' % (backend, server))
        logger.debug(log.traceback)
        log.save()
    finally:
        if channel is not None:
            channel.close()
        if ssh is not None:
            ssh.close()


def Python(backend, log, server, cmds):
    from .models import BackendLog
    script = [ str(cmd.func.func_name) + str(cmd.args) for cmd in cmds ]
    script = json.dumps(script, indent=4).replace('"', '')
    log.script = '\n'.join([log.script, script])
    log.save(update_fields=['script'])
    stdout = ''
    try:
        for cmd in cmds:
            result = cmd(server)
            stdout += str(result)
    except:
        log.exit_code = 1
        log.state = BackendLog.FAILURE
        log.traceback = ExceptionInfo(sys.exc_info()).traceback
    else:
        log.exit_code = 0
        log.state = BackendLog.SUCCESS
    log.stdout += stdout
    log.save()