mirror of
https://github.com/aljazceru/CTFd.git
synced 2025-12-17 05:54:19 +01:00
Fix Uploaders to work with imports/exports (#749)
* Refactor Uploaders to work better with imports/exports * Get S3 uploader working properly with imports/exports * cache pip in travis
This commit is contained in:
@@ -1,4 +1,5 @@
|
|||||||
language: python
|
language: python
|
||||||
|
cache: pip
|
||||||
services:
|
services:
|
||||||
- mysql
|
- mysql
|
||||||
- postgresql
|
- postgresql
|
||||||
|
|||||||
@@ -6,8 +6,6 @@ from flask import Flask
|
|||||||
from werkzeug.contrib.fixers import ProxyFix
|
from werkzeug.contrib.fixers import ProxyFix
|
||||||
from jinja2 import FileSystemLoader
|
from jinja2 import FileSystemLoader
|
||||||
from jinja2.sandbox import SandboxedEnvironment
|
from jinja2.sandbox import SandboxedEnvironment
|
||||||
from sqlalchemy.engine.url import make_url
|
|
||||||
from sqlalchemy_utils import database_exists, create_database
|
|
||||||
from six.moves import input
|
from six.moves import input
|
||||||
|
|
||||||
from CTFd import utils
|
from CTFd import utils
|
||||||
|
|||||||
@@ -158,9 +158,8 @@ class Config(object):
|
|||||||
|
|
||||||
'''
|
'''
|
||||||
UPLOAD_PROVIDER = os.getenv('UPLOAD_PROVIDER') or 'filesystem'
|
UPLOAD_PROVIDER = os.getenv('UPLOAD_PROVIDER') or 'filesystem'
|
||||||
if UPLOAD_PROVIDER == 'filesystem':
|
UPLOAD_FOLDER = os.getenv('UPLOAD_FOLDER') or os.path.join(os.path.dirname(os.path.abspath(__file__)), 'uploads')
|
||||||
UPLOAD_FOLDER = os.getenv('UPLOAD_FOLDER') or os.path.join(os.path.dirname(os.path.abspath(__file__)), 'uploads')
|
if UPLOAD_PROVIDER == 's3':
|
||||||
elif UPLOAD_PROVIDER == 's3':
|
|
||||||
AWS_ACCESS_KEY_ID = os.getenv('AWS_ACCESS_KEY_ID')
|
AWS_ACCESS_KEY_ID = os.getenv('AWS_ACCESS_KEY_ID')
|
||||||
AWS_SECRET_ACCESS_KEY = os.getenv('AWS_SECRET_ACCESS_KEY')
|
AWS_SECRET_ACCESS_KEY = os.getenv('AWS_SECRET_ACCESS_KEY')
|
||||||
AWS_S3_BUCKET = os.getenv('AWS_S3_BUCKET')
|
AWS_S3_BUCKET = os.getenv('AWS_S3_BUCKET')
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
from CTFd.utils import get_app_config, get_config, set_config
|
from CTFd.utils import get_app_config
|
||||||
from CTFd.utils.migrations import get_current_revision, create_database, drop_database, upgrade, stamp
|
from CTFd.utils.migrations import get_current_revision, create_database, drop_database, upgrade, stamp
|
||||||
from CTFd.models import db, get_class_by_tablename
|
from CTFd.utils.uploads import get_uploader
|
||||||
|
from CTFd.models import db
|
||||||
from CTFd.cache import cache
|
from CTFd.cache import cache
|
||||||
from datafreeze.format import SERIALIZERS
|
from datafreeze.format import SERIALIZERS
|
||||||
from flask import current_app as app
|
from flask import current_app as app
|
||||||
@@ -12,7 +13,6 @@ import json
|
|||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
import six
|
import six
|
||||||
import shutil
|
|
||||||
import zipfile
|
import zipfile
|
||||||
|
|
||||||
|
|
||||||
@@ -85,6 +85,9 @@ def export_ctf():
|
|||||||
backup_zip.writestr('db/alembic_version.json', result_file.read())
|
backup_zip.writestr('db/alembic_version.json', result_file.read())
|
||||||
|
|
||||||
# Backup uploads
|
# Backup uploads
|
||||||
|
uploader = get_uploader()
|
||||||
|
uploader.sync()
|
||||||
|
|
||||||
upload_folder = os.path.join(os.path.normpath(app.root_path), app.config.get('UPLOAD_FOLDER'))
|
upload_folder = os.path.join(os.path.normpath(app.root_path), app.config.get('UPLOAD_FOLDER'))
|
||||||
for root, dirs, files in os.walk(upload_folder):
|
for root, dirs, files in os.walk(upload_folder):
|
||||||
for file in files:
|
for file in files:
|
||||||
@@ -199,7 +202,7 @@ def import_ctf(backup, erase=True):
|
|||||||
|
|
||||||
# Extracting files
|
# Extracting files
|
||||||
files = [f for f in backup.namelist() if f.startswith('uploads/')]
|
files = [f for f in backup.namelist() if f.startswith('uploads/')]
|
||||||
upload_folder = app.config.get('UPLOAD_FOLDER')
|
uploader = get_uploader()
|
||||||
for f in files:
|
for f in files:
|
||||||
filename = f.split(os.sep, 1)
|
filename = f.split(os.sep, 1)
|
||||||
|
|
||||||
@@ -207,16 +210,7 @@ def import_ctf(backup, erase=True):
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
filename = filename[1] # Get the second entry in the list (the actual filename)
|
filename = filename[1] # Get the second entry in the list (the actual filename)
|
||||||
full_path = os.path.join(upload_folder, filename)
|
|
||||||
dirname = os.path.dirname(full_path)
|
|
||||||
|
|
||||||
# Create any parent directories for the file
|
|
||||||
if not os.path.exists(dirname):
|
|
||||||
os.makedirs(dirname)
|
|
||||||
|
|
||||||
source = backup.open(f)
|
source = backup.open(f)
|
||||||
target = open(full_path, "wb")
|
uploader.store(fileobj=source, filename=filename)
|
||||||
with source, target:
|
|
||||||
shutil.copyfileobj(source, target)
|
|
||||||
|
|
||||||
cache.clear()
|
cache.clear()
|
||||||
|
|||||||
@@ -13,6 +13,9 @@ class BaseUploader(object):
|
|||||||
def __init__(self):
|
def __init__(self):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def store(self, fileobj, filename):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
def upload(self, file_obj, filename):
|
def upload(self, file_obj, filename):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@@ -22,32 +25,36 @@ class BaseUploader(object):
|
|||||||
def delete(self, filename):
|
def delete(self, filename):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def sync(self):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
class FilesystemUploader(BaseUploader):
|
class FilesystemUploader(BaseUploader):
|
||||||
def __init__(self, base_path=None):
|
def __init__(self, base_path=None):
|
||||||
super(BaseUploader, self).__init__()
|
super(BaseUploader, self).__init__()
|
||||||
self.base_path = base_path or current_app.config.get('UPLOAD_FOLDER')
|
self.base_path = base_path or current_app.config.get('UPLOAD_FOLDER')
|
||||||
|
|
||||||
|
def store(self, fileobj, filename):
|
||||||
|
location = os.path.join(self.base_path, filename)
|
||||||
|
directory = os.path.dirname(location)
|
||||||
|
|
||||||
|
if not os.path.exists(directory):
|
||||||
|
os.makedirs(directory)
|
||||||
|
|
||||||
|
with open(location, 'wb') as dst:
|
||||||
|
copyfileobj(fileobj, dst, 16384)
|
||||||
|
|
||||||
|
return filename
|
||||||
|
|
||||||
def upload(self, file_obj, filename):
|
def upload(self, file_obj, filename):
|
||||||
if len(filename) == 0:
|
if len(filename) == 0:
|
||||||
raise Exception('Empty filenames cannot be used')
|
raise Exception('Empty filenames cannot be used')
|
||||||
|
|
||||||
filename = secure_filename(filename)
|
filename = secure_filename(filename)
|
||||||
md5hash = hashlib.md5(os.urandom(64)).hexdigest()
|
md5hash = hashlib.md5(os.urandom(64)).hexdigest()
|
||||||
|
file_path = os.path.join(md5hash, filename)
|
||||||
|
|
||||||
if not os.path.exists(os.path.join(self.base_path, md5hash)):
|
return self.store(file_obj, file_path)
|
||||||
os.makedirs(os.path.join(self.base_path, md5hash))
|
|
||||||
|
|
||||||
location = os.path.join(self.base_path, md5hash, filename)
|
|
||||||
|
|
||||||
dst_file = open(location, 'wb')
|
|
||||||
try:
|
|
||||||
copyfileobj(file_obj, dst_file, 16384)
|
|
||||||
finally:
|
|
||||||
dst_file.close()
|
|
||||||
|
|
||||||
key = os.path.join(md5hash, filename)
|
|
||||||
return key
|
|
||||||
|
|
||||||
def download(self, filename):
|
def download(self, filename):
|
||||||
return send_file(safe_join(self.base_path, filename))
|
return send_file(safe_join(self.base_path, filename))
|
||||||
@@ -58,6 +65,9 @@ class FilesystemUploader(BaseUploader):
|
|||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
def sync(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class S3Uploader(BaseUploader):
|
class S3Uploader(BaseUploader):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
@@ -81,6 +91,10 @@ class S3Uploader(BaseUploader):
|
|||||||
if c in string.ascii_letters + string.digits + '-' + '_' + '.':
|
if c in string.ascii_letters + string.digits + '-' + '_' + '.':
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
def store(self, fileobj, filename):
|
||||||
|
self.s3.upload_fileobj(fileobj, self.bucket, filename)
|
||||||
|
return filename
|
||||||
|
|
||||||
def upload(self, file_obj, filename):
|
def upload(self, file_obj, filename):
|
||||||
filename = filter(self._clean_filename, secure_filename(filename).replace(' ', '_'))
|
filename = filter(self._clean_filename, secure_filename(filename).replace(' ', '_'))
|
||||||
if len(filename) <= 0:
|
if len(filename) <= 0:
|
||||||
@@ -105,3 +119,17 @@ class S3Uploader(BaseUploader):
|
|||||||
def delete(self, filename):
|
def delete(self, filename):
|
||||||
self.s3.delete_object(Bucket=self.bucket, Key=filename)
|
self.s3.delete_object(Bucket=self.bucket, Key=filename)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
def sync(self):
|
||||||
|
local_folder = current_app.config.get('UPLOAD_FOLDER')
|
||||||
|
bucket_list = self.s3.list_objects(Bucket=self.bucket)['Contents']
|
||||||
|
|
||||||
|
for s3_key in bucket_list:
|
||||||
|
s3_object = s3_key['Key']
|
||||||
|
|
||||||
|
local_path = os.path.join(local_folder, s3_object)
|
||||||
|
directory = os.path.dirname(local_path)
|
||||||
|
if not os.path.exists(directory):
|
||||||
|
os.makedirs(directory)
|
||||||
|
|
||||||
|
self.s3.download_file(self.bucket, s3_object, local_path)
|
||||||
|
|||||||
Reference in New Issue
Block a user