From 96a22b400339d01541f45b9e162748438d4f353d Mon Sep 17 00:00:00 2001 From: Christian Decker Date: Fri, 30 Aug 2019 18:59:53 +0200 Subject: [PATCH] pytest: Add db_provider and db instances for configurable backends We will soon have a postgres backend as well, so we need a way to control the postgres process and to provision DBs to the nodes. The two interfaces are the dsn that we pass to the node, and the python query interface needed to query from tests. Signed-off-by: Christian Decker --- tests/db.py | 169 +++++++++++++++++++++++++++++++++++++++++ tests/fixtures.py | 19 ++++- tests/requirements.txt | 1 + tests/test_misc.py | 5 +- tests/utils.py | 40 ++++------ 5 files changed, 205 insertions(+), 29 deletions(-) create mode 100644 tests/db.py diff --git a/tests/db.py b/tests/db.py new file mode 100644 index 000000000..0d12309e2 --- /dev/null +++ b/tests/db.py @@ -0,0 +1,169 @@ +from ephemeral_port_reserve import reserve + +import logging +import os +import psycopg2 +import random +import shutil +import signal +import sqlite3 +import string +import subprocess +import time + + +class Sqlite3Db(object): + def __init__(self, path): + self.path = path + + def get_dsn(self): + """SQLite3 doesn't provide a DSN, resulting in no CLI-option. + """ + return None + + def query(self, query): + orig = os.path.join(self.path) + copy = self.path + ".copy" + shutil.copyfile(orig, copy) + db = sqlite3.connect(copy) + + db.row_factory = sqlite3.Row + c = db.cursor() + c.execute(query) + rows = c.fetchall() + + result = [] + for row in rows: + result.append(dict(zip(row.keys(), row))) + + db.commit() + c.close() + db.close() + return result + + def execute(self, query): + db = sqlite3.connect(self.path) + c = db.cursor() + c.execute(query) + db.commit() + c.close() + db.close() + + +class PostgresDb(object): + def __init__(self, dbname, port): + self.dbname = dbname + self.port = port + + self.conn = psycopg2.connect("dbname={dbname} user=postgres host=localhost port={port}".format( + dbname=dbname, port=port + )) + cur = self.conn.cursor() + cur.execute('SELECT 1') + cur.close() + + def get_dsn(self): + return "postgres://postgres:password@localhost:{port}/{dbname}".format( + port=self.port, dbname=self.dbname + ) + + def query(self, query): + cur = self.conn.cursor() + cur.execute(query) + + # Collect the results into a list of dicts. + res = [] + for r in cur: + t = {} + # Zip the column definition with the value to get its name. + for c, v in zip(cur.description, r): + t[c.name] = v + res.append(t) + cur.close() + return res + + def execute(self, query): + with self.conn, self.conn.cursor() as cur: + cur.execute(query) + + +class SqliteDbProvider(object): + def __init__(self, directory): + self.directory = directory + + def start(self): + pass + + def get_db(self, node_directory, testname, node_id): + path = os.path.join( + node_directory, + 'lightningd.sqlite3' + ) + return Sqlite3Db(path) + + def stop(self): + pass + + +class PostgresDbProvider(object): + def __init__(self, directory): + self.directory = directory + self.port = None + self.proc = None + print("Starting PostgresDbProvider") + + def start(self): + passfile = os.path.join(self.directory, "pgpass.txt") + self.pgdir = os.path.join(self.directory, 'pgsql') + # Need to write a tiny file containing the password so `initdb` can pick it up + with open(passfile, 'w') as f: + f.write('cltest\n') + subprocess.check_call([ + '/usr/lib/postgresql/10/bin/initdb', + '--pwfile={}'.format(passfile), + '--pgdata={}'.format(self.pgdir), + '--auth=trust', + '--username=postgres', + ]) + self.port = reserve() + self.proc = subprocess.Popen([ + '/usr/lib/postgresql/10/bin/postgres', + '-k', '/tmp/', # So we don't use /var/lib/... + '-D', self.pgdir, + '-p', str(self.port), + '-F', + '-i', + ]) + # Hacky but seems to work ok (might want to make the postgres proc a TailableProc as well if too flaky). + time.sleep(1) + self.conn = psycopg2.connect("dbname=template1 user=postgres host=localhost port={}".format(self.port)) + + # Required for CREATE DATABASE to work + self.conn.set_isolation_level(psycopg2.extensions.ISOLATION_LEVEL_AUTOCOMMIT) + + def get_db(self, node_directory, testname, node_id): + # Random suffix to avoid collisions on repeated tests + nonce = ''.join(random.choice(string.ascii_lowercase + string.digits) for _ in range(8)) + dbname = "{}_{}_{}".format(testname, node_id, nonce) + + cur = self.conn.cursor() + cur.execute("CREATE DATABASE {};".format(dbname)) + cur.close() + db = PostgresDb(dbname, self.port) + return db + + def stop(self): + # Send fast shutdown signal see [1] for details: + # + # SIGINT + # + # This is the Fast Shutdown mode. The server disallows new connections + # and sends all existing server processes SIGTERM, which will cause + # them to abort their current transactions and exit promptly. It then + # waits for all server processes to exit and finally shuts down. If + # the server is in online backup mode, backup mode will be terminated, + # rendering the backup useless. + # + # [1] https://www.postgresql.org/docs/9.1/server-shutdown.html + self.proc.send_signal(signal.SIGINT) + self.proc.wait() diff --git a/tests/fixtures.py b/tests/fixtures.py index f6c4456c6..0b38b2de2 100644 --- a/tests/fixtures.py +++ b/tests/fixtures.py @@ -1,4 +1,5 @@ from concurrent import futures +from db import SqliteDbProvider, PostgresDbProvider from utils import NodeFactory, BitcoinD import logging @@ -149,12 +150,13 @@ def teardown_checks(request): @pytest.fixture -def node_factory(request, directory, test_name, bitcoind, executor, teardown_checks): +def node_factory(request, directory, test_name, bitcoind, executor, db_provider, teardown_checks): nf = NodeFactory( test_name, bitcoind, executor, directory=directory, + db_provider=db_provider, ) yield nf @@ -275,6 +277,21 @@ def checkMemleak(node): return 0 +# Mapping from TEST_DB_PROVIDER env variable to class to be used +providers = { + 'sqlite3': SqliteDbProvider, + 'postgres': PostgresDbProvider, +} + + +@pytest.fixture(scope="session") +def db_provider(test_base_dir): + provider = providers[os.getenv('TEST_DB_PROVIDER', 'sqlite3')](test_base_dir) + provider.start() + yield provider + provider.stop() + + @pytest.fixture def executor(teardown_checks): ex = futures.ThreadPoolExecutor(max_workers=20) diff --git a/tests/requirements.txt b/tests/requirements.txt index 4208596b7..bfa85f963 100644 --- a/tests/requirements.txt +++ b/tests/requirements.txt @@ -10,3 +10,4 @@ pytest-xdist==1.29.0 python-bitcoinlib==0.10.1 tqdm==4.32.2 pytest-timeout==1.3.3 +psycopg2==2.8.3 diff --git a/tests/test_misc.py b/tests/test_misc.py index 4248b2cc8..bae6fa1b2 100644 --- a/tests/test_misc.py +++ b/tests/test_misc.py @@ -969,12 +969,11 @@ def test_reserve_enforcement(node_factory, executor): l2.stop() # They should both aim for 1%. - reserves = l2.db_query('SELECT channel_reserve_satoshis FROM channel_configs') + reserves = l2.db.query('SELECT channel_reserve_satoshis FROM channel_configs') assert reserves == [{'channel_reserve_satoshis': 10**6 // 100}] * 2 # Edit db to reduce reserve to 0 so it will try to violate it. - l2.db_query('UPDATE channel_configs SET channel_reserve_satoshis=0', - use_copy=False) + l2.db.execute('UPDATE channel_configs SET channel_reserve_satoshis=0') l2.start() wait_for(lambda: only_one(l2.rpc.listpeers(l1.info['id'])['peers'])['connected']) diff --git a/tests/utils.py b/tests/utils.py index 92be43f5e..b9f946add 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -462,7 +462,9 @@ class LightningD(TailableProc): class LightningNode(object): - def __init__(self, daemon, rpc, btc, executor, may_fail=False, may_reconnect=False, allow_broken_log=False, allow_bad_gossip=False): + def __init__(self, daemon, rpc, btc, executor, may_fail=False, + may_reconnect=False, allow_broken_log=False, + allow_bad_gossip=False, db=None): self.rpc = rpc self.daemon = daemon self.bitcoin = btc @@ -471,6 +473,7 @@ class LightningNode(object): self.may_reconnect = may_reconnect self.allow_broken_log = allow_broken_log self.allow_bad_gossip = allow_bad_gossip + self.db = db def connect(self, remote_node): self.rpc.connect(remote_node.info['id'], '127.0.0.1', remote_node.daemon.port) @@ -510,28 +513,8 @@ class LightningNode(object): def getactivechannels(self): return [c for c in self.rpc.listchannels()['channels'] if c['active']] - def db_query(self, query, use_copy=True): - orig = os.path.join(self.daemon.lightning_dir, "lightningd.sqlite3") - if use_copy: - copy = os.path.join(self.daemon.lightning_dir, "lightningd-copy.sqlite3") - shutil.copyfile(orig, copy) - db = sqlite3.connect(copy) - else: - db = sqlite3.connect(orig) - - db.row_factory = sqlite3.Row - c = db.cursor() - c.execute(query) - rows = c.fetchall() - - result = [] - for row in rows: - result.append(dict(zip(row.keys(), row))) - - db.commit() - c.close() - db.close() - return result + def db_query(self, query): + return self.db.query(query) # Assumes node is stopped! def db_manip(self, query): @@ -771,7 +754,7 @@ class LightningNode(object): class NodeFactory(object): """A factory to setup and start `lightningd` daemons. """ - def __init__(self, testname, bitcoind, executor, directory): + def __init__(self, testname, bitcoind, executor, directory, db_provider): self.testname = testname self.next_id = 1 self.nodes = [] @@ -779,6 +762,7 @@ class NodeFactory(object): self.bitcoind = bitcoind self.directory = directory self.lock = threading.Lock() + self.db_provider = db_provider def split_options(self, opts): """Split node options from cli options @@ -880,11 +864,17 @@ class NodeFactory(object): if options is not None: daemon.opts.update(options) + # Get the DB backend DSN we should be using for this test and this node. + db = self.db_provider.get_db(lightning_dir, self.testname, node_id) + dsn = db.get_dsn() + if dsn is not None: + daemon.opts['wallet'] = dsn + rpc = LightningRpc(socket_path, self.executor) node = LightningNode(daemon, rpc, self.bitcoind, self.executor, may_fail=may_fail, may_reconnect=may_reconnect, allow_broken_log=allow_broken_log, - allow_bad_gossip=allow_bad_gossip) + allow_bad_gossip=allow_bad_gossip, db=db) # Regtest estimatefee are unusable, so override. node.set_feerates(feerates, False)