You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

179 lines
6.4 KiB
Python

import os
import cv2
from ImageComparator import ImageComparator
from datetime import datetime
class ImageDB:
supported_file_extensions = [".png", ".jpg", ".jpeg"]
def __init__(self, root_path: str, import_folder: str = "import", db_folder: str = "images",
dump_folder: str = "duplicate",
folder_sep: str = "/", samples: int = 100, dist_thresh: float = 80,
match_thresh: float = 0.3, resize_dim: tuple = (500,500), threads: int = 2):
self.start_time = datetime.now()
self.root_path = root_path
self.import_path = self.root_path + folder_sep + import_folder
self.db_path = self.root_path + folder_sep + db_folder
self.dump_path = self.root_path + folder_sep + dump_folder
self.dir_sep = folder_sep
log_file = self.root_path + self.dir_sep + "db_log.txt"
self.imgcomp = ImageComparator(samples, dist_thresh, match_thresh, resize_dim, threads)
if not os.path.isdir(self.root_path):
os.mkdir(self.root_path)
self.log("Created {0} because it was missing ...".format(self.root_path), False)
if not os.path.isdir(self.import_path):
os.mkdir(self.import_path)
self.log("Created {0} because it was missing ...".format(self.import_path), False)
if not os.path.isdir(self.db_path):
os.mkdir(self.db_path)
self.log("Created {0} because it was missing ...".format(self.db_path), False)
if not os.path.isdir(self.dump_path):
os.mkdir(self.dump_path)
self.log("Created {0} because it was missing ...".format(self.dump_path), False)
self.log_stream = open(log_file, "w")
self.log("Calculating all features ...")
self.picture_data = self.calc_db_features()
self.log("Done!")
def log(self, msg: str, log_to_file: bool = True):
time = datetime.now() - self.start_time
msg = "[{0}] {1}".format(time, msg)
print(msg)
if log_to_file:
self.log_stream.write(msg+"\n")
self.log_stream.flush()
def process_all_images(self):
for file_name in os.listdir(self.import_path):
self.process_image(file_name)
self.update_db_features()
def process_image(self, img_name: str) -> bool:
"""
Processes a given image and decides if it is to be imported and then import or return.
Assumes image to be inside the import folder.
Indicates if it was imported via boolean
:param img_name:
:return:
"""
img_path = self.import_path + self.dir_sep + img_name
if self.is_supported_file(img_path):
result, matched, ratio = self.is_image_unique(img_path, self.picture_data)
if not result:
self.import_image(img_name)
return True
else:
self.log("Matched {0} to {1} with ratio {2}".format(img_name, matched, ratio))
self.move_to_dump(img_name)
return False
def import_image(self, img_name: str):
"""
Import the given image into the database.
Does no checks, just imports.
:param img_name: image file name, assumed to be located inside the import folder
:return:
"""
img_path = self.import_path + self.dir_sep + img_name
dest_path = self.db_path + self.dir_sep + img_name
if os.path.isfile(img_path):
os.rename(img_path, dest_path)
else:
raise Exception("Image path has not been a valid file!")
def move_to_dump(self, img_name: str):
"""
Move the given image to the dump folder.
:param img_name:
:return:
"""
img_path = self.import_path + self.dir_sep + img_name
dest_path = self.dump_path + self.dir_sep + img_name
if os.path.isfile(img_path):
os.rename(img_path, dest_path)
else:
raise Exception("Image path has not been a valid file!")
def is_image_unique(self, img_path: str, picture_data: dict) -> tuple:
"""
Check if the given image is already in the databank or not.
:param img_path:
:param picture_data:
:return:
"""
if len(picture_data) == 0:
return False, "", 0.0
result, matched, match = self.imgcomp.has_similar_match(img_path, self.db_path,
picture_data)
if result:
return True, matched, match
else:
return False, "", 0.0
def calc_db_features(self) -> dict:
"""
Calculate keypoints for every image in the database
:return:
"""
pictures = dict()
for name in os.listdir(self.db_path):
pic_path = self.db_path + self.dir_sep + name
if not self.is_supported_file(pic_path):
continue
img = cv2.imread(pic_path)
pictures[pic_path] = self.imgcomp.get_features(img)
return pictures
def update_db_features(self):
"""
Update the keypoint dictionary with only new image data
:return:
"""
for name in os.listdir(self.db_path):
pic_path = self.db_path + self.dir_sep + name
if not self.is_supported_file(pic_path):
continue
if pic_path not in self.picture_data.keys():
self.log("New image {0}! Calculating features for memory...".format(name))
img = cv2.imread(pic_path)
self.picture_data[pic_path] = self.imgcomp.get_features(img)
self.log("Done!")
def is_supported_file(self, path: str) -> bool:
"""
Returns if the given file has a valid picture file extension.
:param path:
:return:
"""
if os.path.isfile(path):
for ext in ImageDB.supported_file_extensions:
if ext in path.lower():
return True
print("{0} is not a supported file format!".format(path))
return False
def get_db_size(self) -> int:
"""
Return the number of images inside the databank
:return:
"""
size = 0
for entry in os.listdir(self.db_path):
if os.path.isfile(self.db_path+self.dir_sep+entry):
size += 1
return size