mirror of
https://github.com/aljazceru/InvSR.git
synced 2025-12-17 14:24:27 +01:00
first commit
This commit is contained in:
199
basicsr/utils/lmdb_util.py
Normal file
199
basicsr/utils/lmdb_util.py
Normal file
@@ -0,0 +1,199 @@
|
||||
import cv2
|
||||
import lmdb
|
||||
import sys
|
||||
from multiprocessing import Pool
|
||||
from os import path as osp
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
def make_lmdb_from_imgs(data_path,
|
||||
lmdb_path,
|
||||
img_path_list,
|
||||
keys,
|
||||
batch=5000,
|
||||
compress_level=1,
|
||||
multiprocessing_read=False,
|
||||
n_thread=40,
|
||||
map_size=None):
|
||||
"""Make lmdb from images.
|
||||
|
||||
Contents of lmdb. The file structure is:
|
||||
|
||||
::
|
||||
|
||||
example.lmdb
|
||||
├── data.mdb
|
||||
├── lock.mdb
|
||||
├── meta_info.txt
|
||||
|
||||
The data.mdb and lock.mdb are standard lmdb files and you can refer to
|
||||
https://lmdb.readthedocs.io/en/release/ for more details.
|
||||
|
||||
The meta_info.txt is a specified txt file to record the meta information
|
||||
of our datasets. It will be automatically created when preparing
|
||||
datasets by our provided dataset tools.
|
||||
Each line in the txt file records 1)image name (with extension),
|
||||
2)image shape, and 3)compression level, separated by a white space.
|
||||
|
||||
For example, the meta information could be:
|
||||
`000_00000000.png (720,1280,3) 1`, which means:
|
||||
1) image name (with extension): 000_00000000.png;
|
||||
2) image shape: (720,1280,3);
|
||||
3) compression level: 1
|
||||
|
||||
We use the image name without extension as the lmdb key.
|
||||
|
||||
If `multiprocessing_read` is True, it will read all the images to memory
|
||||
using multiprocessing. Thus, your server needs to have enough memory.
|
||||
|
||||
Args:
|
||||
data_path (str): Data path for reading images.
|
||||
lmdb_path (str): Lmdb save path.
|
||||
img_path_list (str): Image path list.
|
||||
keys (str): Used for lmdb keys.
|
||||
batch (int): After processing batch images, lmdb commits.
|
||||
Default: 5000.
|
||||
compress_level (int): Compress level when encoding images. Default: 1.
|
||||
multiprocessing_read (bool): Whether use multiprocessing to read all
|
||||
the images to memory. Default: False.
|
||||
n_thread (int): For multiprocessing.
|
||||
map_size (int | None): Map size for lmdb env. If None, use the
|
||||
estimated size from images. Default: None
|
||||
"""
|
||||
|
||||
assert len(img_path_list) == len(keys), ('img_path_list and keys should have the same length, '
|
||||
f'but got {len(img_path_list)} and {len(keys)}')
|
||||
print(f'Create lmdb for {data_path}, save to {lmdb_path}...')
|
||||
print(f'Totoal images: {len(img_path_list)}')
|
||||
if not lmdb_path.endswith('.lmdb'):
|
||||
raise ValueError("lmdb_path must end with '.lmdb'.")
|
||||
if osp.exists(lmdb_path):
|
||||
print(f'Folder {lmdb_path} already exists. Exit.')
|
||||
sys.exit(1)
|
||||
|
||||
if multiprocessing_read:
|
||||
# read all the images to memory (multiprocessing)
|
||||
dataset = {} # use dict to keep the order for multiprocessing
|
||||
shapes = {}
|
||||
print(f'Read images with multiprocessing, #thread: {n_thread} ...')
|
||||
pbar = tqdm(total=len(img_path_list), unit='image')
|
||||
|
||||
def callback(arg):
|
||||
"""get the image data and update pbar."""
|
||||
key, dataset[key], shapes[key] = arg
|
||||
pbar.update(1)
|
||||
pbar.set_description(f'Read {key}')
|
||||
|
||||
pool = Pool(n_thread)
|
||||
for path, key in zip(img_path_list, keys):
|
||||
pool.apply_async(read_img_worker, args=(osp.join(data_path, path), key, compress_level), callback=callback)
|
||||
pool.close()
|
||||
pool.join()
|
||||
pbar.close()
|
||||
print(f'Finish reading {len(img_path_list)} images.')
|
||||
|
||||
# create lmdb environment
|
||||
if map_size is None:
|
||||
# obtain data size for one image
|
||||
img = cv2.imread(osp.join(data_path, img_path_list[0]), cv2.IMREAD_UNCHANGED)
|
||||
_, img_byte = cv2.imencode('.png', img, [cv2.IMWRITE_PNG_COMPRESSION, compress_level])
|
||||
data_size_per_img = img_byte.nbytes
|
||||
print('Data size per image is: ', data_size_per_img)
|
||||
data_size = data_size_per_img * len(img_path_list)
|
||||
map_size = data_size * 10
|
||||
|
||||
env = lmdb.open(lmdb_path, map_size=map_size)
|
||||
|
||||
# write data to lmdb
|
||||
pbar = tqdm(total=len(img_path_list), unit='chunk')
|
||||
txn = env.begin(write=True)
|
||||
txt_file = open(osp.join(lmdb_path, 'meta_info.txt'), 'w')
|
||||
for idx, (path, key) in enumerate(zip(img_path_list, keys)):
|
||||
pbar.update(1)
|
||||
pbar.set_description(f'Write {key}')
|
||||
key_byte = key.encode('ascii')
|
||||
if multiprocessing_read:
|
||||
img_byte = dataset[key]
|
||||
h, w, c = shapes[key]
|
||||
else:
|
||||
_, img_byte, img_shape = read_img_worker(osp.join(data_path, path), key, compress_level)
|
||||
h, w, c = img_shape
|
||||
|
||||
txn.put(key_byte, img_byte)
|
||||
# write meta information
|
||||
txt_file.write(f'{key}.png ({h},{w},{c}) {compress_level}\n')
|
||||
if idx % batch == 0:
|
||||
txn.commit()
|
||||
txn = env.begin(write=True)
|
||||
pbar.close()
|
||||
txn.commit()
|
||||
env.close()
|
||||
txt_file.close()
|
||||
print('\nFinish writing lmdb.')
|
||||
|
||||
|
||||
def read_img_worker(path, key, compress_level):
|
||||
"""Read image worker.
|
||||
|
||||
Args:
|
||||
path (str): Image path.
|
||||
key (str): Image key.
|
||||
compress_level (int): Compress level when encoding images.
|
||||
|
||||
Returns:
|
||||
str: Image key.
|
||||
byte: Image byte.
|
||||
tuple[int]: Image shape.
|
||||
"""
|
||||
|
||||
img = cv2.imread(path, cv2.IMREAD_UNCHANGED)
|
||||
if img.ndim == 2:
|
||||
h, w = img.shape
|
||||
c = 1
|
||||
else:
|
||||
h, w, c = img.shape
|
||||
_, img_byte = cv2.imencode('.png', img, [cv2.IMWRITE_PNG_COMPRESSION, compress_level])
|
||||
return (key, img_byte, (h, w, c))
|
||||
|
||||
|
||||
class LmdbMaker():
|
||||
"""LMDB Maker.
|
||||
|
||||
Args:
|
||||
lmdb_path (str): Lmdb save path.
|
||||
map_size (int): Map size for lmdb env. Default: 1024 ** 4, 1TB.
|
||||
batch (int): After processing batch images, lmdb commits.
|
||||
Default: 5000.
|
||||
compress_level (int): Compress level when encoding images. Default: 1.
|
||||
"""
|
||||
|
||||
def __init__(self, lmdb_path, map_size=1024**4, batch=5000, compress_level=1):
|
||||
if not lmdb_path.endswith('.lmdb'):
|
||||
raise ValueError("lmdb_path must end with '.lmdb'.")
|
||||
if osp.exists(lmdb_path):
|
||||
print(f'Folder {lmdb_path} already exists. Exit.')
|
||||
sys.exit(1)
|
||||
|
||||
self.lmdb_path = lmdb_path
|
||||
self.batch = batch
|
||||
self.compress_level = compress_level
|
||||
self.env = lmdb.open(lmdb_path, map_size=map_size)
|
||||
self.txn = self.env.begin(write=True)
|
||||
self.txt_file = open(osp.join(lmdb_path, 'meta_info.txt'), 'w')
|
||||
self.counter = 0
|
||||
|
||||
def put(self, img_byte, key, img_shape):
|
||||
self.counter += 1
|
||||
key_byte = key.encode('ascii')
|
||||
self.txn.put(key_byte, img_byte)
|
||||
# write meta information
|
||||
h, w, c = img_shape
|
||||
self.txt_file.write(f'{key}.png ({h},{w},{c}) {self.compress_level}\n')
|
||||
if self.counter % self.batch == 0:
|
||||
self.txn.commit()
|
||||
self.txn = self.env.begin(write=True)
|
||||
|
||||
def close(self):
|
||||
self.txn.commit()
|
||||
self.env.close()
|
||||
self.txt_file.close()
|
||||
Reference in New Issue
Block a user