You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
179 lines
6.4 KiB
Python
179 lines
6.4 KiB
Python
import os
|
|
import cv2
|
|
from ImageComparator import ImageComparator
|
|
from datetime import datetime
|
|
|
|
|
|
class ImageDB:
|
|
|
|
supported_file_extensions = [".png", ".jpg", ".jpeg"]
|
|
|
|
def __init__(self, root_path: str, import_folder: str = "import", db_folder: str = "images",
|
|
dump_folder: str = "duplicate",
|
|
folder_sep: str = "/", samples: int = 100, dist_thresh: float = 80,
|
|
match_thresh: float = 0.3, resize_dim: tuple = (500,500), threads: int = 2):
|
|
self.start_time = datetime.now()
|
|
|
|
self.root_path = root_path
|
|
self.import_path = self.root_path + folder_sep + import_folder
|
|
self.db_path = self.root_path + folder_sep + db_folder
|
|
self.dump_path = self.root_path + folder_sep + dump_folder
|
|
self.dir_sep = folder_sep
|
|
|
|
log_file = self.root_path + self.dir_sep + "db_log.txt"
|
|
|
|
self.imgcomp = ImageComparator(samples, dist_thresh, match_thresh, resize_dim, threads)
|
|
|
|
if not os.path.isdir(self.root_path):
|
|
os.mkdir(self.root_path)
|
|
self.log("Created {0} because it was missing ...".format(self.root_path), False)
|
|
if not os.path.isdir(self.import_path):
|
|
os.mkdir(self.import_path)
|
|
self.log("Created {0} because it was missing ...".format(self.import_path), False)
|
|
if not os.path.isdir(self.db_path):
|
|
os.mkdir(self.db_path)
|
|
self.log("Created {0} because it was missing ...".format(self.db_path), False)
|
|
if not os.path.isdir(self.dump_path):
|
|
os.mkdir(self.dump_path)
|
|
self.log("Created {0} because it was missing ...".format(self.dump_path), False)
|
|
|
|
self.log_stream = open(log_file, "w")
|
|
|
|
self.log("Calculating all features ...")
|
|
self.picture_data = self.calc_db_features()
|
|
self.log("Done!")
|
|
|
|
def log(self, msg: str, log_to_file: bool = True):
|
|
time = datetime.now() - self.start_time
|
|
msg = "[{0}] {1}".format(time, msg)
|
|
print(msg)
|
|
if log_to_file:
|
|
self.log_stream.write(msg+"\n")
|
|
self.log_stream.flush()
|
|
|
|
def process_all_images(self):
|
|
for file_name in os.listdir(self.import_path):
|
|
self.process_image(file_name)
|
|
self.update_db_features()
|
|
|
|
def process_image(self, img_name: str) -> bool:
|
|
"""
|
|
Processes a given image and decides if it is to be imported and then import or return.
|
|
|
|
Assumes image to be inside the import folder.
|
|
Indicates if it was imported via boolean
|
|
:param img_name:
|
|
:return:
|
|
"""
|
|
img_path = self.import_path + self.dir_sep + img_name
|
|
|
|
if self.is_supported_file(img_path):
|
|
result, matched, ratio = self.is_image_unique(img_path, self.picture_data)
|
|
if not result:
|
|
self.import_image(img_name)
|
|
return True
|
|
else:
|
|
self.log("Matched {0} to {1} with ratio {2}".format(img_name, matched, ratio))
|
|
self.move_to_dump(img_name)
|
|
return False
|
|
|
|
def import_image(self, img_name: str):
|
|
"""
|
|
Import the given image into the database.
|
|
|
|
Does no checks, just imports.
|
|
:param img_name: image file name, assumed to be located inside the import folder
|
|
:return:
|
|
"""
|
|
img_path = self.import_path + self.dir_sep + img_name
|
|
dest_path = self.db_path + self.dir_sep + img_name
|
|
if os.path.isfile(img_path):
|
|
os.rename(img_path, dest_path)
|
|
else:
|
|
raise Exception("Image path has not been a valid file!")
|
|
|
|
def move_to_dump(self, img_name: str):
|
|
"""
|
|
Move the given image to the dump folder.
|
|
:param img_name:
|
|
:return:
|
|
"""
|
|
img_path = self.import_path + self.dir_sep + img_name
|
|
dest_path = self.dump_path + self.dir_sep + img_name
|
|
if os.path.isfile(img_path):
|
|
os.rename(img_path, dest_path)
|
|
else:
|
|
raise Exception("Image path has not been a valid file!")
|
|
|
|
def is_image_unique(self, img_path: str, picture_data: dict) -> tuple:
|
|
"""
|
|
Check if the given image is already in the databank or not.
|
|
:param img_path:
|
|
:param picture_data:
|
|
:return:
|
|
"""
|
|
if len(picture_data) == 0:
|
|
return False, "", 0.0
|
|
result, matched, match = self.imgcomp.has_similar_match(img_path, self.db_path,
|
|
picture_data)
|
|
|
|
if result:
|
|
return True, matched, match
|
|
else:
|
|
return False, "", 0.0
|
|
|
|
def calc_db_features(self) -> dict:
|
|
"""
|
|
Calculate keypoints for every image in the database
|
|
:return:
|
|
"""
|
|
pictures = dict()
|
|
for name in os.listdir(self.db_path):
|
|
pic_path = self.db_path + self.dir_sep + name
|
|
if not self.is_supported_file(pic_path):
|
|
continue
|
|
img = cv2.imread(pic_path)
|
|
|
|
pictures[pic_path] = self.imgcomp.get_features(img)
|
|
|
|
return pictures
|
|
|
|
def update_db_features(self):
|
|
"""
|
|
Update the keypoint dictionary with only new image data
|
|
:return:
|
|
"""
|
|
for name in os.listdir(self.db_path):
|
|
pic_path = self.db_path + self.dir_sep + name
|
|
if not self.is_supported_file(pic_path):
|
|
continue
|
|
if pic_path not in self.picture_data.keys():
|
|
self.log("New image {0}! Calculating features for memory...".format(name))
|
|
img = cv2.imread(pic_path)
|
|
self.picture_data[pic_path] = self.imgcomp.get_features(img)
|
|
self.log("Done!")
|
|
|
|
def is_supported_file(self, path: str) -> bool:
|
|
"""
|
|
Returns if the given file has a valid picture file extension.
|
|
:param path:
|
|
:return:
|
|
"""
|
|
if os.path.isfile(path):
|
|
for ext in ImageDB.supported_file_extensions:
|
|
if ext in path.lower():
|
|
return True
|
|
print("{0} is not a supported file format!".format(path))
|
|
return False
|
|
|
|
def get_db_size(self) -> int:
|
|
"""
|
|
Return the number of images inside the databank
|
|
:return:
|
|
"""
|
|
size = 0
|
|
for entry in os.listdir(self.db_path):
|
|
if os.path.isfile(self.db_path+self.dir_sep+entry):
|
|
size += 1
|
|
return size
|