import hashlib import json import logging import os import socket import sys import select import paramiko from celery.datastructures import ExceptionInfo from django.conf import settings as djsettings from orchestra.utils.sys import sshrun from orchestra.utils.python import CaptureStdout, import_class from . import settings logger = logging.getLogger(__name__) paramiko_connections = {} def Paramiko(backend, log, server, cmds, async=False): """ Executes cmds to remote server using Pramaiko """ script = '\n'.join(cmds) script = script.replace('\r', '') log.state = log.STARTED log.script = script log.save(update_fields=('script', 'state')) if not cmds: return channel = None ssh = None try: addr = server.get_address() # ssh connection ssh = paramiko_connections.get(addr) if not ssh: ssh = paramiko.SSHClient() ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy()) key = settings.ORCHESTRATION_SSH_KEY_PATH try: ssh.connect(addr, username='root', key_filename=key) except socket.error as e: logger.error('%s timed out on %s' % (backend, addr)) log.state = log.TIMEOUT log.stderr = str(e) log.save(update_fields=['state', 'stderr']) return paramiko_connections[addr] = ssh transport = ssh.get_transport() channel = transport.open_session() channel.exec_command(backend.script_executable) channel.sendall(script) channel.shutdown_write() # Log results logger.debug('%s running on %s' % (backend, server)) if async: second = False while True: # Non-blocking is the secret ingridient in the async sauce select.select([channel], [], []) if channel.recv_ready(): part = channel.recv(1024).decode('utf-8') while part: log.stdout += part part = channel.recv(1024).decode('utf-8') if channel.recv_stderr_ready(): part = channel.recv_stderr(1024).decode('utf-8') while part: log.stderr += part part = channel.recv_stderr(1024).decode('utf-8') log.save(update_fields=['stdout', 'stderr']) if channel.exit_status_ready(): if second: break second = True else: log.stdout += channel.makefile('rb', -1).read().decode('utf-8') log.stderr += channel.makefile_stderr('rb', -1).read().decode('utf-8') log.exit_code = channel.recv_exit_status() log.state = log.SUCCESS if log.exit_code == 0 else log.FAILURE logger.debug('%s execution state on %s is %s' % (backend, server, log.state)) log.save() except: log.state = log.ERROR log.traceback = ExceptionInfo(sys.exc_info()).traceback logger.error('Exception while executing %s on %s' % (backend, server)) logger.debug(log.traceback) log.save() finally: if log.state == log.STARTED: log.state = log.ABORTED log.save(update_fields=['state']) if channel is not None: channel.close() def OpenSSH(backend, log, server, cmds, async=False): """ Executes cmds to remote server using SSH with connection resuse for maximum performance """ script = '\n'.join(cmds) script = script.replace('\r', '') log.state = log.STARTED log.script = script log.save(update_fields=('script', 'state')) if not cmds: return channel = None ssh = None try: ssh = sshrun(server.get_address(), script, executable=backend.script_executable, persist=True, async=async) logger.debug('%s running on %s' % (backend, server)) if async: second = False for state in ssh: log.stdout += state.stdout.decode('utf8') log.stderr += state.stderr.decode('utf8') log.save() log.exit_code = state.exit_code else: log.stdout = ssh.stdout log.stderr = ssh.stderr log.exit_code = ssh.exit_code log.state = log.SUCCESS if log.exit_code == 0 else log.FAILURE logger.debug('%s execution state on %s is %s' % (backend, server, log.state)) log.save() except: log.state = log.ERROR log.traceback = ExceptionInfo(sys.exc_info()).traceback logger.error('Exception while executing %s on %s' % (backend, server)) logger.debug(log.traceback) log.save() finally: if log.state == log.STARTED: log.state = log.ABORTED log.save(update_fields=['state']) def SSH(*args, **kwargs): """ facade function enabling to chose between multiple SSH backends""" method = import_class(settings.ORCHESTRATION_SSH_METHOD_BACKEND) return method(*args, **kwargs) def Python(backend, log, server, cmds, async=False): # TODO collect stdout? script = [ str(cmd.func.__name__) + str(cmd.args) for cmd in cmds ] script = json.dumps(script, indent=4).replace('"', '') log.state = log.STARTED log.script = '\n'.join([log.script, script]) log.save(update_fields=('script', 'state')) try: for cmd in cmds: with CaptureStdout() as stdout: result = cmd(server) for line in stdout: log.stdout += line + '\n' if async: log.save(update_fields=['stdout']) except: log.exit_code = 1 log.state = log.FAILURE log.traceback = ExceptionInfo(sys.exc_info()).traceback logger.error('Exception while executing %s on %s' % (backend, server)) else: log.exit_code = 0 log.state = log.SUCCESS logger.debug('%s execution state on %s is %s' % (backend, server, log.state)) log.save()