Fix Uploaders to work with imports/exports (#749)

* Refactor Uploaders to work better with imports/exports
* Get S3 uploader working properly with imports/exports
* cache pip in travis
This commit is contained in:
Kevin Chung
2018-11-23 06:10:33 -05:00
committed by GitHub
parent 310475d739
commit 49ed27cfd6
5 changed files with 52 additions and 32 deletions

View File

@@ -1,4 +1,5 @@
language: python language: python
cache: pip
services: services:
- mysql - mysql
- postgresql - postgresql

View File

@@ -6,8 +6,6 @@ from flask import Flask
from werkzeug.contrib.fixers import ProxyFix from werkzeug.contrib.fixers import ProxyFix
from jinja2 import FileSystemLoader from jinja2 import FileSystemLoader
from jinja2.sandbox import SandboxedEnvironment from jinja2.sandbox import SandboxedEnvironment
from sqlalchemy.engine.url import make_url
from sqlalchemy_utils import database_exists, create_database
from six.moves import input from six.moves import input
from CTFd import utils from CTFd import utils

View File

@@ -158,9 +158,8 @@ class Config(object):
''' '''
UPLOAD_PROVIDER = os.getenv('UPLOAD_PROVIDER') or 'filesystem' UPLOAD_PROVIDER = os.getenv('UPLOAD_PROVIDER') or 'filesystem'
if UPLOAD_PROVIDER == 'filesystem': UPLOAD_FOLDER = os.getenv('UPLOAD_FOLDER') or os.path.join(os.path.dirname(os.path.abspath(__file__)), 'uploads')
UPLOAD_FOLDER = os.getenv('UPLOAD_FOLDER') or os.path.join(os.path.dirname(os.path.abspath(__file__)), 'uploads') if UPLOAD_PROVIDER == 's3':
elif UPLOAD_PROVIDER == 's3':
AWS_ACCESS_KEY_ID = os.getenv('AWS_ACCESS_KEY_ID') AWS_ACCESS_KEY_ID = os.getenv('AWS_ACCESS_KEY_ID')
AWS_SECRET_ACCESS_KEY = os.getenv('AWS_SECRET_ACCESS_KEY') AWS_SECRET_ACCESS_KEY = os.getenv('AWS_SECRET_ACCESS_KEY')
AWS_S3_BUCKET = os.getenv('AWS_S3_BUCKET') AWS_S3_BUCKET = os.getenv('AWS_S3_BUCKET')

View File

@@ -1,6 +1,7 @@
from CTFd.utils import get_app_config, get_config, set_config from CTFd.utils import get_app_config
from CTFd.utils.migrations import get_current_revision, create_database, drop_database, upgrade, stamp from CTFd.utils.migrations import get_current_revision, create_database, drop_database, upgrade, stamp
from CTFd.models import db, get_class_by_tablename from CTFd.utils.uploads import get_uploader
from CTFd.models import db
from CTFd.cache import cache from CTFd.cache import cache
from datafreeze.format import SERIALIZERS from datafreeze.format import SERIALIZERS
from flask import current_app as app from flask import current_app as app
@@ -12,7 +13,6 @@ import json
import os import os
import re import re
import six import six
import shutil
import zipfile import zipfile
@@ -85,6 +85,9 @@ def export_ctf():
backup_zip.writestr('db/alembic_version.json', result_file.read()) backup_zip.writestr('db/alembic_version.json', result_file.read())
# Backup uploads # Backup uploads
uploader = get_uploader()
uploader.sync()
upload_folder = os.path.join(os.path.normpath(app.root_path), app.config.get('UPLOAD_FOLDER')) upload_folder = os.path.join(os.path.normpath(app.root_path), app.config.get('UPLOAD_FOLDER'))
for root, dirs, files in os.walk(upload_folder): for root, dirs, files in os.walk(upload_folder):
for file in files: for file in files:
@@ -199,7 +202,7 @@ def import_ctf(backup, erase=True):
# Extracting files # Extracting files
files = [f for f in backup.namelist() if f.startswith('uploads/')] files = [f for f in backup.namelist() if f.startswith('uploads/')]
upload_folder = app.config.get('UPLOAD_FOLDER') uploader = get_uploader()
for f in files: for f in files:
filename = f.split(os.sep, 1) filename = f.split(os.sep, 1)
@@ -207,16 +210,7 @@ def import_ctf(backup, erase=True):
continue continue
filename = filename[1] # Get the second entry in the list (the actual filename) filename = filename[1] # Get the second entry in the list (the actual filename)
full_path = os.path.join(upload_folder, filename)
dirname = os.path.dirname(full_path)
# Create any parent directories for the file
if not os.path.exists(dirname):
os.makedirs(dirname)
source = backup.open(f) source = backup.open(f)
target = open(full_path, "wb") uploader.store(fileobj=source, filename=filename)
with source, target:
shutil.copyfileobj(source, target)
cache.clear() cache.clear()

View File

@@ -13,6 +13,9 @@ class BaseUploader(object):
def __init__(self): def __init__(self):
raise NotImplementedError raise NotImplementedError
def store(self, fileobj, filename):
raise NotImplementedError
def upload(self, file_obj, filename): def upload(self, file_obj, filename):
raise NotImplementedError raise NotImplementedError
@@ -22,32 +25,36 @@ class BaseUploader(object):
def delete(self, filename): def delete(self, filename):
raise NotImplementedError raise NotImplementedError
def sync(self):
raise NotImplementedError
class FilesystemUploader(BaseUploader): class FilesystemUploader(BaseUploader):
def __init__(self, base_path=None): def __init__(self, base_path=None):
super(BaseUploader, self).__init__() super(BaseUploader, self).__init__()
self.base_path = base_path or current_app.config.get('UPLOAD_FOLDER') self.base_path = base_path or current_app.config.get('UPLOAD_FOLDER')
def store(self, fileobj, filename):
location = os.path.join(self.base_path, filename)
directory = os.path.dirname(location)
if not os.path.exists(directory):
os.makedirs(directory)
with open(location, 'wb') as dst:
copyfileobj(fileobj, dst, 16384)
return filename
def upload(self, file_obj, filename): def upload(self, file_obj, filename):
if len(filename) == 0: if len(filename) == 0:
raise Exception('Empty filenames cannot be used') raise Exception('Empty filenames cannot be used')
filename = secure_filename(filename) filename = secure_filename(filename)
md5hash = hashlib.md5(os.urandom(64)).hexdigest() md5hash = hashlib.md5(os.urandom(64)).hexdigest()
file_path = os.path.join(md5hash, filename)
if not os.path.exists(os.path.join(self.base_path, md5hash)): return self.store(file_obj, file_path)
os.makedirs(os.path.join(self.base_path, md5hash))
location = os.path.join(self.base_path, md5hash, filename)
dst_file = open(location, 'wb')
try:
copyfileobj(file_obj, dst_file, 16384)
finally:
dst_file.close()
key = os.path.join(md5hash, filename)
return key
def download(self, filename): def download(self, filename):
return send_file(safe_join(self.base_path, filename)) return send_file(safe_join(self.base_path, filename))
@@ -58,6 +65,9 @@ class FilesystemUploader(BaseUploader):
return True return True
return False return False
def sync(self):
pass
class S3Uploader(BaseUploader): class S3Uploader(BaseUploader):
def __init__(self): def __init__(self):
@@ -81,6 +91,10 @@ class S3Uploader(BaseUploader):
if c in string.ascii_letters + string.digits + '-' + '_' + '.': if c in string.ascii_letters + string.digits + '-' + '_' + '.':
return True return True
def store(self, fileobj, filename):
self.s3.upload_fileobj(fileobj, self.bucket, filename)
return filename
def upload(self, file_obj, filename): def upload(self, file_obj, filename):
filename = filter(self._clean_filename, secure_filename(filename).replace(' ', '_')) filename = filter(self._clean_filename, secure_filename(filename).replace(' ', '_'))
if len(filename) <= 0: if len(filename) <= 0:
@@ -105,3 +119,17 @@ class S3Uploader(BaseUploader):
def delete(self, filename): def delete(self, filename):
self.s3.delete_object(Bucket=self.bucket, Key=filename) self.s3.delete_object(Bucket=self.bucket, Key=filename)
return True return True
def sync(self):
local_folder = current_app.config.get('UPLOAD_FOLDER')
bucket_list = self.s3.list_objects(Bucket=self.bucket)['Contents']
for s3_key in bucket_list:
s3_object = s3_key['Key']
local_path = os.path.join(local_folder, s3_object)
directory = os.path.dirname(local_path)
if not os.path.exists(directory):
os.makedirs(directory)
self.s3.download_file(self.bucket, s3_object, local_path)