git » chasquid » smarthost » tree

[smarthost] / test / util / chamuyero

#!/usr/bin/env python3
"""
chamuyero is a tool to test and validate line-oriented commands and servers.

It can launch and communicate with other processes, and follow a script of
line-oriented request-response, validating the dialog as it goes along.

This can be used to test line-oriented network protocols (such as SMTP) or
interactive command-line tools.
"""

import argparse
import os
import re
import ssl
import socket
import subprocess
import sys
import threading
import time

# Command-line flags.
ap = argparse.ArgumentParser()
ap.add_argument("script", type=argparse.FileType('r', encoding='utf8'))
args = ap.parse_args()

# Make sure stdout is open in utf8 mode, as we will print our input, which is
# utf8, and want it to work regardless of the environment.
sys.stdout = open(sys.stdout.fileno(), mode='w', encoding='utf8', buffering=1)


class Process (object):
    def __init__(self, cmd, **kwargs):
        self.cmd = subprocess.Popen(cmd, **kwargs)

    def write(self, s):
        self.cmd.stdin.write(s)

    def readline(self):
        return self.cmd.stdout.readline()

    def wait(self):
        return self.cmd.wait()

    def close(self):
        return self.cmd.terminate()

class Sock (object):
    """A (generic) socket.

    This class implements the common code for socket support.
    Subclasses will implement the behaviour specific to different socket
    types.
    """
    def __init__(self, addr):
        self.addr = addr
        self.sock = NotImplemented
        self.connr = None
        self.connw = None
        self.has_conn = threading.Event()

    def listen(self):
        self.sock.bind(self.addr)
        self.sock.listen(1)
        threading.Thread(target=self._accept).start()

    def _accept(self):
        conn, _ = self.sock.accept()
        self.connr = conn.makefile(mode="r", encoding="utf8")
        self.connw = conn.makefile(mode="w", encoding="utf8")
        self.has_conn.set()

    def write(self, s):
        self.has_conn.wait()
        self.connw.write(s)
        self.connw.flush()

    def readline(self):
        self.has_conn.wait()
        return self.connr.readline()

    def close(self):
        self.connr.close()
        self.connw.close()
        self.sock.close()

class UnixSock (Sock):
    def __init__(self, addr):
        Sock.__init__(self, addr)
        self.sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)

    def listen(self):
        if os.path.exists(self.addr):
            os.remove(self.addr)
        Sock.listen(self)


class TCPSock (Sock):
    def __init__(self, addr):
        host, port = addr.rsplit(":", 1)
        Sock.__init__(self, (host, int(port)))
        self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)

    def connect(self):
        self.sock = socket.create_connection(self.addr)
        self.connr = self.sock.makefile(mode="r", encoding="utf8")
        self.connw = self.sock.makefile(mode="w", encoding="utf8")
        self.has_conn.set()


class TLSSock (Sock):
    def __init__(self, addr):
        host, port = addr.rsplit(":", 1)
        Sock.__init__(self, (host, int(port)))
        plain_sock = socket.create_connection(self.addr)
        self.sock = ssl.wrap_socket(plain_sock)

    def connect(self):
        self.connr = self.sock.makefile(mode="r", encoding="utf8")
        self.connw = self.sock.makefile(mode="w", encoding="utf8")
        self.has_conn.set()


class Interpreter (object):
    """Interpreter for chamuyero scripts."""
    def __init__(self):
        # Processes and sockets we have spawn. Indexed by the id provided by
        # the user.
        self.procs = {}

        # Line number we are processing.
        self.nline = 0

    def syntax_error(self, msg):
        raise SyntaxError("Error in line %d: %s" % (self.nline, msg))

    def runtime_error(self, msg):
        raise RuntimeError("Error in line %d: %s" % (self.nline, msg))

    def run(self, fd):
        """Main processing loop."""
        cont_l = ""
        for l in fd:
            self.nline += 1

            # Remove rightmost \n.
            l = l[:-1]

            # Continuations with \.
            if cont_l:
                l = cont_l + " " + l.lstrip()
            if l.endswith("\\"):
                cont_l = l[:-1]
                continue
            else:
                cont_l = ""

            # Comments start with a "#".
            if l.strip().startswith("#") or l.strip() == "":
                continue

            print(l)

            # Everything else is of the form:
            # <proc> <op> [params]
            sp = l.split(None, 2)
            if len(sp) == 3:
                proc, op, params = sp
            else:
                proc, op = sp
                params = ""

            # =   Launch a process.
            if op == "=":
                cmd = Process(params, shell=True, universal_newlines=True,
                        stdin=subprocess.PIPE, stdout=subprocess.PIPE,
                        stderr=subprocess.STDOUT)
                self.procs[proc] = cmd

            # |=   Launch a process, do not capture stdout.
            elif op == "|=":
                cmd = Process(params, shell=True, stdin=subprocess.PIPE)
                self.procs[proc] = cmd

            # unix_listen  Listen on an UNIX socket.
            elif op == "unix_listen":
                sock = UnixSock(params)
                sock.listen()
                self.procs[proc] = sock

            # tcp_listen  Listen on a TCP socket.
            elif op == "tcp_listen":
                sock = TCPSock(params)
                sock.listen()
                self.procs[proc] = sock

            elif op == "tcp_connect":
                sock = TCPSock(params)
                sock.connect()
                self.procs[proc] = sock

            elif op == "tls_connect":
                sock = TLSSock(params)
                sock.connect()
                self.procs[proc] = sock

            # ->   Send to a process stdin, with a \n at the end.
            # .>   Send to a process stdin, no \n at the end.
            elif op == "->":
                self.procs[proc].write(params + "\n")
            elif op == ".>":
                self.procs[proc].write(params)

            # <-   Read from the process, expect matching input.
            # <~   Read from the process, match input using regexp.
            # <... Read many lines until one matches.
            elif op == "<-":
                read = self.procs[proc].readline()
                if read != params + "\n":
                    self.runtime_error("data different that expected:\n"
                            + "  expected: %s\n" % repr(params)
                            + "       got: %s" % repr(read))
            elif op == "<~":
                read = self.procs[proc].readline()
                m = re.match(params, read)
                if m is None:
                    self.runtime_error("data did not match regexp:\n"
                            + "  regexp: %s\n" % repr(params)
                            + "     got: %s" % repr(read))
            elif op == "<...":
                while True:
                    read = self.procs[proc].readline()
                    m = re.match(params, read)
                    if m:
                        break

            # sleep  Sleep this number of seconds (process-independent).
            elif op == "sleep":
                time.sleep(float(params))

            # wait   Wait for the process to exit (with the given code).
            elif op == "wait":
                retcode = self.procs[proc].wait()
                if params and retcode != int(params):
                    self.runtime_error("return code did not match:\n"
                            + "  expected %s, got %d" % (params, retcode))

            # close  Close the process.
            elif op == "close":
                self.procs[proc].close()

            else:
                self.syntax_error("unknown syntax")


if __name__ == "__main__":
    i = Interpreter()
    i.run(args.script)