django-orchestra/orchestra/contrib/orchestration/methods.py
2023-07-09 07:51:51 +00:00

187 lines
6.8 KiB
Python

import inspect
import logging
import socket
import sys
import select
import textwrap
from celery.datastructures import ExceptionInfo
from orchestra.settings import ORCHESTRA_SSH_DEFAULT_USER
from orchestra.utils.sys import sshrun
from orchestra.utils.python import CaptureStdout, import_class
from . import settings
logger = logging.getLogger(__name__)
def Paramiko(backend, log, server, cmds, run_async=False, paramiko_connections={}):
"""
Executes cmds to remote server using Pramaiko
"""
import paramiko
script = '\n'.join(cmds)
script = script.replace('\r', '')
log.state = log.STARTED
log.script = script
log.save(update_fields=('script', 'state', 'updated_at'))
if not cmds:
return
channel = None
ssh = None
try:
addr = server.get_address()
# ssh connection
ssh = paramiko_connections.get(addr)
if not ssh:
ssh = paramiko.SSHClient()
ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())
key = settings.ORCHESTRATION_SSH_KEY_PATH
try:
ssh.connect(addr, username=ORCHESTRA_SSH_DEFAULT_USER, key_filename=key)
except socket.error as e:
logger.error('%s timed out on %s' % (backend, addr))
log.state = log.TIMEOUT
log.stderr = str(e)
log.save(update_fields=('state', 'stderr', 'updated_at'))
return
paramiko_connections[addr] = ssh
transport = ssh.get_transport()
channel = transport.open_session()
channel.exec_command(backend.script_executable)
channel.sendall(script)
channel.shutdown_write()
# Log results
logger.debug('%s running on %s' % (backend, server))
if run_async:
second = False
while True:
# Non-blocking is the secret ingridient in the async sauce
select.select([channel], [], [])
if channel.recv_ready():
part = channel.recv(1024).decode('utf-8')
while part:
log.stdout += part
part = channel.recv(1024).decode('utf-8')
if channel.recv_stderr_ready():
part = channel.recv_stderr(1024).decode('utf-8')
while part:
log.stderr += part
part = channel.recv_stderr(1024).decode('utf-8')
log.save(update_fields=('stdout', 'stderr', 'updated_at'))
if channel.exit_status_ready():
if second:
break
second = True
else:
log.stdout += channel.makefile('rb', -1).read().decode('utf-8')
log.stderr += channel.makefile_stderr('rb', -1).read().decode('utf-8')
log.exit_code = channel.recv_exit_status()
log.state = log.SUCCESS if log.exit_code == 0 else log.FAILURE
logger.debug('%s execution state on %s is %s' % (backend, server, log.state))
log.save()
except:
log.state = log.ERROR
log.traceback = ExceptionInfo(sys.exc_info()).traceback
logger.error('Exception while executing %s on %s' % (backend, server))
logger.debug(log.traceback)
log.save()
finally:
if log.state == log.STARTED:
log.state = log.ABORTED
log.save(update_fields=('state', 'updated_at'))
if channel is not None:
channel.close()
def OpenSSH(backend, log, server, cmds, run_async=False):
"""
Executes cmds to remote server using SSH with connection resuse for maximum performance
"""
script = '\n'.join(cmds)
script = script.replace('\r', '')
log.state = log.STARTED
log.script = '\n'.join((log.script, script))
log.save(update_fields=('script', 'state', 'updated_at'))
if not cmds:
return
try:
ssh = sshrun(server.get_address(), script, executable=backend.script_executable,
persist=True, run_async=run_async, silent=True)
logger.debug('%s running on %s' % (backend, server))
if run_async:
for state in ssh:
log.stdout += state.stdout.decode('utf8')
log.stderr += state.stderr.decode('utf8')
log.save(update_fields=('stdout', 'stderr', 'updated_at'))
exit_code = state.exit_code
else:
log.stdout += ssh.stdout.decode('utf8')
log.stderr += ssh.stderr.decode('utf8')
exit_code = ssh.exit_code
if not log.exit_code:
log.exit_code = exit_code
if exit_code == 255 and log.stderr.startswith('ssh: connect to host'):
log.state = log.TIMEOUT
else:
log.state = log.SUCCESS if exit_code == 0 else log.FAILURE
logger.debug('%s execution state on %s is %s' % (backend, server, log.state))
log.save()
except:
log.state = log.ERROR
log.traceback = ExceptionInfo(sys.exc_info()).traceback
logger.error('Exception while executing %s on %s' % (backend, server))
logger.debug(log.traceback)
log.save()
finally:
if log.state == log.STARTED:
log.state = log.ABORTED
log.save(update_fields=('state', 'updated_at'))
def SSH(*args, **kwargs):
""" facade function enabling to chose between multiple SSH backends"""
method = import_class(settings.ORCHESTRATION_SSH_METHOD_BACKEND)
return method(*args, **kwargs)
def Python(backend, log, server, cmds, run_async=False):
script = ''
functions = set()
for cmd in cmds:
if cmd.func not in functions:
functions.add(cmd.func)
script += textwrap.dedent(''.join(inspect.getsourcelines(cmd.func)[0]))
script += '\n'
for cmd in cmds:
script += '# %s %s\n' % (cmd.func.__name__, cmd.args)
log.state = log.STARTED
log.script = '\n'.join((log.script, script))
log.save(update_fields=('script', 'state', 'updated_at'))
stdout = ''
try:
for cmd in cmds:
with CaptureStdout() as stdout:
result = cmd(server)
for line in stdout:
log.stdout += line + '\n'
if result:
log.stdout += '# Result: %s\n' % result
if run_async:
log.save(update_fields=('stdout', 'updated_at'))
except:
log.exit_code = 1
log.state = log.FAILURE
log.stdout += '\n'.join(stdout)
log.traceback += ExceptionInfo(sys.exc_info()).traceback
logger.error('Exception while executing %s on %s' % (backend, server))
else:
if not log.exit_code:
log.exit_code = 0
log.state = log.SUCCESS
logger.debug('%s execution state on %s is %s' % (backend, server, log.state))
log.save()