diff --git a/backup/backup.py b/backup/backup.py index 4ef0c88..1e44849 100755 --- a/backup/backup.py +++ b/backup/backup.py @@ -80,26 +80,44 @@ class Backend(object): """ raise NotImplementedError - def restore(self, dest: str, remove_existing: bool = False) -> bool: - """Restore the backup in this backend to its former glory. - """ + def _db_open(self, dest: str) -> sqlite3.Connection: + db = sqlite3.connect(dest) + db.execute("PRAGMA foreign_keys = 1") + return db + + def _restore_snapshot(self, snapshot: bytes, dest: str): if os.path.exists(dest): os.unlink(dest) - db = sqlite3.connect(dest) - for c in tqdm(self.stream_changes(), total=self.version): - if c.snapshot is not None: - if os.path.exists(dest): - os.unlink(dest) - with open(dest, 'wb') as f: - f.write(c.snapshot) - db = sqlite3.connect(dest) - if c.transaction is not None: - cur = db.cursor() - for q in c.transaction: - cur.execute(q.decode('UTF-8')) - db.commit() + with open(dest, 'wb') as f: + f.write(snapshot) + self.db = self._db_open(dest) - return True + def _restore_transaction(self, tx: Iterator[bytes]): + assert(self.db) + cur = self.db.cursor() + for q in tx: + cur.execute(q.decode('UTF-8')) + self.db.commit() + + def restore(self, dest: str, remove_existing: bool = False): + """Restore the backup in this backend to its former glory. + """ + + if os.path.exists(dest): + if not remove_existing: + raise ValueError( + "Destination for backup restore exists: {dest}".format( + dest=dest + ) + ) + os.unlink(dest) + + self.db = self._db_open(dest) + for c in tqdm(self.stream_changes()): + if c.snapshot is not None: + self._restore_snapshot(c.snapshot, dest) + if c.transaction is not None: + self._restore_transaction(c.transaction) class FileBackend(Backend):