backup: Implement the restore method for all backends

Also implements the `stream_changes` function in the FileBackend which is used
by `restore` to get all the changes.
This commit is contained in:
Christian Decker
2020-04-04 20:11:29 +02:00
parent caa5c0a16b
commit a68758b1d0
2 changed files with 45 additions and 0 deletions

View File

@@ -1,6 +1,7 @@
#!/usr/bin/env python3
from collections import namedtuple
from pyln.client import Plugin
from tqdm import tqdm
from typing import Mapping, Type, Iterator
from urllib.parse import urlparse
import json
@@ -8,6 +9,7 @@ import logging
import os
import struct
import sys
import sqlite3
plugin = Plugin()
@@ -78,6 +80,27 @@ class Backend(object):
"""
raise NotImplementedError
def restore(self, dest: str, remove_existing: bool = False) -> bool:
"""Restore the backup in this backend to its former glory.
"""
if os.path.exists(dest):
os.unlink(dest)
db = sqlite3.connect(dest)
for c in tqdm(self.stream_changes(), total=self.version):
if c.snapshot is not None:
if os.path.exists(dest):
os.unlink(dest)
with open(dest, 'wb') as f:
f.write(c.snapshot)
db = sqlite3.connect(dest)
if c.transaction is not None:
cur = db.cursor()
for q in c.transaction:
cur.execute(q.decode('UTF-8'))
db.commit()
return True
class FileBackend(Backend):
def __init__(self, destination: str):
@@ -160,6 +183,27 @@ class FileBackend(Backend):
self.prev_version, self.offsets[1] = 0, 0
return True
def stream_changes(self) -> Iterator[Change]:
self.read_metadata()
version = -1
with open(self.url.path, 'rb') as f:
# Skip the header
f.seek(512)
while version < self.version:
length, version, typ = struct.unpack("!IIb", f.read(9))
payload = f.read(length)
if typ == 1:
yield Change(version=version, snapshot=None, transaction=payload.split(b'\x00'))
elif typ == 2:
yield Change(version=version, snapshot=payload, transaction=None)
else:
raise ValueError("Unknown FileBackend entry type {}".format(typ))
if version != self.version:
raise ValueError("Versions do not match up: restored version {}, backend version {}".format(version, self.version))
assert(version == self.version)
def resolve_backend_class(backend_url):
backend_map: Mapping[str, Type[Backend]] = {