Source code for twindb_backup.ssh.client

"""
Module that implements SSH client.
"""
import socket
from contextlib import contextmanager

from paramiko import AuthenticationException, AutoAddPolicy, SSHClient, SSHException

from twindb_backup import LOG
from twindb_backup.ssh.exceptions import SshClientException


[docs]class SshClient(object): """ SSH client class. Allows to connect to a remote SSH server and execute commands on it. :param host: Destination host to connect to. Defaults to '127.0.0.1'. :type host: str :param port: Destination port to connect to. Default is 22. :type port: int :param key: SSH client key for password-less authentication. Default is '/root/.id_rsa'. :type key: str :param user: SSH client username. Default is 'root'. :type user: str """ def __init__(self, host="127.0.0.1", port=22, key="/root/.id_rsa", user="root"): self._host = host self._port = port self._key = key self._user = user
[docs] @contextmanager def session(self): """ Get SSH session :rtype: generator :return: SSH session """ with self._shell() as client: transport = client.get_transport() session = transport.open_session() yield session
@contextmanager def _shell(self): """ Create SSHClient instance and connect to the destination host. :return: Connected to the remote destination host shell. :rtype: generator(SSHClient) :raise SshDestinationError: if the ssh client fails to connect. """ shell = SSHClient() shell.set_missing_host_key_policy(AutoAddPolicy()) try: LOG.debug( "Connecting to %s:%d as %s with key %s", self._host, self._port, self._user, self._key, ) shell.connect( hostname=self._host, key_filename=self._key, port=self._port, username=self._user, ) yield shell except FileNotFoundError: raise except (AuthenticationException, SSHException, socket.error) as err: raise SshClientException(err) finally: shell.close() @property def host(self): """Remote SSH host""" return self._host @property def user(self): """User for SSH connection""" return self._user @property def port(self): """TCP port for SSH connection""" return self._port
[docs] def execute(self, cmd, quiet=False, background=False): """Execute a command on a remote SSH server. :param cmd: Command for execution. :type cmd: str :param quiet: if quiet is True don't print error messages :param background: Don't wait until the command exits. :type background: bool :return: Strings with stdout and stderr. If command is executed in background the method will return None. :rtype: tuple :raise SshClientException: if any error or non-zero exit code """ max_chunk_size = 1024 * 1024 try: with self._shell() as shell: if not background: LOG.debug("%s: %s", self.host, cmd) stdin_, stdout_, _ = shell.exec_command(cmd) channel = stdout_.channel stdin_.close() channel.shutdown_write() stdout_chunks = [] stderr_chunks = [] while not channel.closed or channel.recv_ready() or channel.recv_stderr_ready(): if channel.recv_ready(): stdout_chunks.append(channel.recv(max_chunk_size).decode("utf-8")) if channel.recv_stderr_ready(): stderr_chunks.append(channel.recv_stderr(max_chunk_size).decode("utf-8")) exit_code = channel.recv_exit_status() if exit_code != 0: if not quiet: LOG.error("Failed to execute command %s", cmd) LOG.error("stderr:") LOG.error("".join(stderr_chunks)) LOG.error("eof stderr.") raise SshClientException(f"{cmd} exited with code {exit_code}") return "".join(stdout_chunks), "".join(stderr_chunks) else: LOG.debug("Executing in background (%s): %s", self.host, cmd) transport = shell.get_transport() channel = transport.open_session() channel.exec_command(cmd) LOG.debug("Ran %s in background", cmd) except (SSHException, IOError) as err: if not quiet: LOG.error("Failed to execute %s: %s", cmd, err) raise SshClientException("Failed to execute %s: %s" % (cmd, err))
[docs] @contextmanager def get_remote_handlers(self, cmd): """ Get remote stdin, stdout and stderr handler :param cmd: Command for execution :type cmd: str :return: Remote stdin, stdout and stderr handler :rtype: tuple(generator, generator, generator) :raise SshDestinationError: if any error """ try: with self._shell() as shell: LOG.debug("Try to get remote handlers: %s", cmd) stdin_, stdout_, stderr_ = shell.exec_command(cmd) yield stdin_, stdout_, stderr_ except SSHException as err: LOG.error("Failed to execute %s", cmd) raise SshClientException(err)
[docs] def list_files(self, path, recursive=False, files_only=False): """ Get list of file by prefix :param path: Path :param recursive: Recursive return list of files :type path: str :type recursive: bool :param files_only: Don't list directories if True. Default is False. :type files_only: bool :return: List of files :rtype: list """ find_cmd = ["find", f'"{path}"'] if not recursive: find_cmd.append("-maxdepth 1") if files_only: find_cmd.append("-type f") cmd = f'bash -c \'if test -d "{path}" ; ' f"then {' '.join(find_cmd)}; fi'" cout, cerr = self.execute(cmd) LOG.debug("stdout:\n%s\neof stdout.", cout) LOG.debug("stderr:\n%s\neof stderr.", cerr) if files_only: return cout.split() else: return cout.split()[1:]
[docs] def get_text_content(self, path): """ Get text content of file by path :param path: File path :type path: str :return: File content :rtype: str """ LOG.debug("Reading remote file %s", path) with self._shell() as ssh_client: sftp_client = ssh_client.open_sftp() with sftp_client.open(path) as remote_file: return remote_file.read().decode("utf-8")
[docs] def write_content(self, path, content): """ Write content to path :param path: Path to file :param content: Content """ with self._shell() as ssh_client: sftp_client = ssh_client.open_sftp() with sftp_client.open(path, "w") as remote_file: remote_file.write(content)
[docs] def write_config(self, path, cfg): """ Write config to file :param path: Path to file :param cfg: Instance of ConfigParser """ with self.get_remote_handlers("cat - > %s" % path) as (cin, _, _): cfg.write(cin)