import hashlib
import json
import logging
import os
import socket
import sys
import select

from celery.datastructures import ExceptionInfo
from django.conf import settings as djsettings

from orchestra.utils.sys import sshrun
from orchestra.utils.python import CaptureStdout, import_class

from . import settings


logger = logging.getLogger(__name__)

paramiko_connections = {}


def Paramiko(backend, log, server, cmds, async=False):
    """
    Executes cmds to remote server using Pramaiko
    """
    import paramiko
    script = '\n'.join(cmds)
    script = script.replace('\r', '')
    log.state = log.STARTED
    log.script = script
    log.save(update_fields=('script', 'state', 'updated_at'))
    if not cmds:
        return
    channel = None
    ssh = None
    try:
        addr = server.get_address()
        # ssh connection
        ssh = paramiko_connections.get(addr)
        if not ssh:
            ssh = paramiko.SSHClient()
            ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())
            key = settings.ORCHESTRATION_SSH_KEY_PATH
            try:
                ssh.connect(addr, username='root', key_filename=key)
            except socket.error as e:
                logger.error('%s timed out on %s' % (backend, addr))
                log.state = log.TIMEOUT
                log.stderr = str(e)
                log.save(update_fields=('state', 'stderr', 'updated_at'))
                return
            paramiko_connections[addr] = ssh
        transport = ssh.get_transport()
        channel = transport.open_session()
        channel.exec_command(backend.script_executable)
        channel.sendall(script)
        channel.shutdown_write()
        # Log results
        logger.debug('%s running on %s' % (backend, server))
        if async:
            second = False
            while True:
                # Non-blocking is the secret ingridient in the async sauce
                select.select([channel], [], [])
                if channel.recv_ready():
                    part = channel.recv(1024).decode('utf-8')
                    while part:
                        log.stdout += part
                        part = channel.recv(1024).decode('utf-8')
                if channel.recv_stderr_ready():
                    part = channel.recv_stderr(1024).decode('utf-8')
                    while part:
                        log.stderr += part
                        part = channel.recv_stderr(1024).decode('utf-8')
                log.save(update_fields=('stdout', 'stderr', 'updated_at'))
                if channel.exit_status_ready():
                    if second:
                        break
                    second = True
        else:
            log.stdout += channel.makefile('rb', -1).read().decode('utf-8')
            log.stderr += channel.makefile_stderr('rb', -1).read().decode('utf-8')
        
        log.exit_code = channel.recv_exit_status()
        log.state = log.SUCCESS if log.exit_code == 0 else log.FAILURE
        logger.debug('%s execution state on %s is %s' % (backend, server, log.state))
        log.save()
    except:
        log.state = log.ERROR
        log.traceback = ExceptionInfo(sys.exc_info()).traceback
        logger.error('Exception while executing %s on %s' % (backend, server))
        logger.debug(log.traceback)
        log.save()
    finally:
        if log.state == log.STARTED:
            log.state = log.ABORTED
            log.save(update_fields=('state', 'updated_at'))
        if channel is not None:
            channel.close()


def OpenSSH(backend, log, server, cmds, async=False):
    """
    Executes cmds to remote server using SSH with connection resuse for maximum performance
    """
    script = '\n'.join(cmds)
    script = script.replace('\r', '')
    log.state = log.STARTED
    log.script = script
    log.save(update_fields=('script', 'state', 'updated_at'))
    if not cmds:
        return
    channel = None
    ssh = None
    try:
        ssh = sshrun(server.get_address(), script, executable=backend.script_executable,
            persist=True, async=async, silent=True)
        logger.debug('%s running on %s' % (backend, server))
        if async:
            second = False
            for state in ssh:
                log.stdout += state.stdout.decode('utf8')
                log.stderr += state.stderr.decode('utf8')
                log.save(update_fields=('stdout', 'stderr', 'updated_at'))
            log.exit_code = state.exit_code
        else:
            log.stdout = ssh.stdout.decode('utf8')
            log.stderr = ssh.stderr.decode('utf8')
            log.exit_code = ssh.exit_code
        log.state = log.SUCCESS if log.exit_code == 0 else log.FAILURE
        logger.debug('%s execution state on %s is %s' % (backend, server, log.state))
        log.save()
    except:
        log.state = log.ERROR
        log.traceback = ExceptionInfo(sys.exc_info()).traceback
        logger.error('Exception while executing %s on %s' % (backend, server))
        logger.debug(log.traceback)
        log.save()
    finally:
        if log.state == log.STARTED:
            log.state = log.ABORTED
            log.save(update_fields=('state', 'updated_at'))


def SSH(*args, **kwargs):
    """ facade function enabling to chose between multiple SSH backends"""
    method = import_class(settings.ORCHESTRATION_SSH_METHOD_BACKEND)
    return method(*args, **kwargs)


def Python(backend, log, server, cmds, async=False):
    script = [ str(cmd.func.__name__) + str(cmd.args) for cmd in cmds ]
    script = json.dumps(script, indent=4).replace('"', '')
    log.state = log.STARTED
    log.script = '\n'.join((log.script, script))
    log.save(update_fields=('script', 'state', 'updated_at'))
    try:
        for cmd in cmds:
            with CaptureStdout() as stdout:
                result = cmd(server)
            for line in stdout:
                log.stdout += line + '\n'
            if async:
                log.save(update_fields=('stdout', 'updated_at'))
    except:
        log.exit_code = 1
        log.state = log.FAILURE
        log.traceback = ExceptionInfo(sys.exc_info()).traceback
        logger.error('Exception while executing %s on %s' % (backend, server))
    else:
        log.exit_code = 0
        log.state = log.SUCCESS
        logger.debug('%s execution state on %s is %s' % (backend, server, log.state))
    log.save()