From 32fcc04c14b3d153aa2ce4c410a53ed4482887bd Mon Sep 17 00:00:00 2001 From: Peery Date: Sat, 11 May 2019 14:47:23 +0200 Subject: [PATCH] Multi-Threaded SIFT feature comparison * Comparison based on biggest SIFT features * Multi-threaded comparison of one image against a set * Basic folder structure for importing new images into the set --- CompDatabase.py | 178 ++++++++++++++++++++++++++++++++++++++++++ ImageComparator.py | 168 +++++++++++++++++++++++++++++++++++++++ ImageCompareThread.py | 121 ++++++++++++++++++++++++++++ main.py | 22 ++++++ 4 files changed, 489 insertions(+) create mode 100644 CompDatabase.py create mode 100644 ImageComparator.py create mode 100644 ImageCompareThread.py create mode 100644 main.py diff --git a/CompDatabase.py b/CompDatabase.py new file mode 100644 index 0000000..42c66ad --- /dev/null +++ b/CompDatabase.py @@ -0,0 +1,178 @@ +import os +import cv2 +from ImageComparator import ImageComparator +from datetime import datetime + + +class ImageDB: + + supported_file_extensions = [".png", ".jpg", ".jpeg"] + + def __init__(self, root_path: str, import_folder: str = "import", db_folder: str = "images", + dump_folder: str = "duplicate", + folder_sep: str = "/", samples: int = 100, dist_thresh: float = 80, + match_thresh: float = 0.3, resize_dim: tuple = (500,500), threads: int = 2): + self.start_time = datetime.now() + + self.root_path = root_path + self.import_path = self.root_path + folder_sep + import_folder + self.db_path = self.root_path + folder_sep + db_folder + self.dump_path = self.root_path + folder_sep + dump_folder + self.dir_sep = folder_sep + + log_file = self.root_path + self.dir_sep + "db_log.txt" + + self.imgcomp = ImageComparator(samples, dist_thresh, match_thresh, resize_dim, threads) + + if not os.path.isdir(self.root_path): + os.mkdir(self.root_path) + self.log("Created {0} because it was missing ...".format(self.root_path), False) + if not os.path.isdir(self.import_path): + os.mkdir(self.import_path) + self.log("Created {0} because it was missing ...".format(self.import_path), False) + if not os.path.isdir(self.db_path): + os.mkdir(self.db_path) + self.log("Created {0} because it was missing ...".format(self.db_path), False) + if not os.path.isdir(self.dump_path): + os.mkdir(self.dump_path) + self.log("Created {0} because it was missing ...".format(self.dump_path), False) + + self.log_stream = open(log_file, "w") + + self.log("Calculating all features ...") + self.picture_data = self.calc_db_features() + self.log("Done!") + + def log(self, msg: str, log_to_file: bool = True): + time = datetime.now() - self.start_time + msg = "[{0}] {1}".format(time, msg) + print(msg) + if log_to_file: + self.log_stream.write(msg+"\n") + self.log_stream.flush() + + def process_all_images(self): + for file_name in os.listdir(self.import_path): + self.process_image(file_name) + self.update_db_features() + + def process_image(self, img_name: str) -> bool: + """ + Processes a given image and decides if it is to be imported and then import or return. + + Assumes image to be inside the import folder. + Indicates if it was imported via boolean + :param img_name: + :return: + """ + img_path = self.import_path + self.dir_sep + img_name + + if self.is_supported_file(img_path): + result, matched, ratio = self.is_image_unique(img_path, self.picture_data) + if not result: + self.import_image(img_name) + return True + else: + self.log("Matched {0} to {1} with ratio {2}".format(img_name, matched, ratio)) + self.move_to_dump(img_name) + return False + + def import_image(self, img_name: str): + """ + Import the given image into the database. + + Does no checks, just imports. + :param img_name: image file name, assumed to be located inside the import folder + :return: + """ + img_path = self.import_path + self.dir_sep + img_name + dest_path = self.db_path + self.dir_sep + img_name + if os.path.isfile(img_path): + os.rename(img_path, dest_path) + else: + raise Exception("Image path has not been a valid file!") + + def move_to_dump(self, img_name: str): + """ + Move the given image to the dump folder. + :param img_name: + :return: + """ + img_path = self.import_path + self.dir_sep + img_name + dest_path = self.dump_path + self.dir_sep + img_name + if os.path.isfile(img_path): + os.rename(img_path, dest_path) + else: + raise Exception("Image path has not been a valid file!") + + def is_image_unique(self, img_path: str, picture_data: dict) -> tuple: + """ + Check if the given image is already in the databank or not. + :param img_path: + :param picture_data: + :return: + """ + if len(picture_data) == 0: + return False, "", 0.0 + result, matched, match = self.imgcomp.has_similar_match(img_path, self.db_path, + picture_data) + + if result: + return True, matched, match + else: + return False, "", 0.0 + + def calc_db_features(self) -> dict: + """ + Calculate keypoints for every image in the database + :return: + """ + pictures = dict() + for name in os.listdir(self.db_path): + pic_path = self.db_path + self.dir_sep + name + if not self.is_supported_file(pic_path): + continue + img = cv2.imread(pic_path) + + pictures[pic_path] = self.imgcomp.get_features(img) + + return pictures + + def update_db_features(self): + """ + Update the keypoint dictionary with only new image data + :return: + """ + for name in os.listdir(self.db_path): + pic_path = self.db_path + self.dir_sep + name + if not self.is_supported_file(pic_path): + continue + if pic_path not in self.picture_data.keys(): + self.log("New image {0}! Calculating features for memory...".format(name)) + img = cv2.imread(pic_path) + self.picture_data[pic_path] = self.imgcomp.get_features(img) + self.log("Done!") + + def is_supported_file(self, path: str) -> bool: + """ + Returns if the given file has a valid picture file extension. + :param path: + :return: + """ + if os.path.isfile(path): + for ext in ImageDB.supported_file_extensions: + if ext in path.lower(): + return True + print("{0} is not a supported file format!".format(path)) + return False + + def get_db_size(self) -> int: + """ + Return the number of images inside the databank + :return: + """ + size = 0 + for entry in os.listdir(self.db_path): + if os.path.isfile(self.db_path+self.dir_sep+entry): + size += 1 + return size diff --git a/ImageComparator.py b/ImageComparator.py new file mode 100644 index 0000000..372fa5c --- /dev/null +++ b/ImageComparator.py @@ -0,0 +1,168 @@ +import cv2 +from ImageCompareThread import ImageCompareManageThread + +import datetime + + +class ImageComparator: + + concurrent_threads = 10 + + def __init__(self, samples: int = 100, dist_thresh: float = 80, match_thresh: float = 0.6, + resize_dim: tuple = (500, 500), concurrent_threads: int = 10): + self.results = {} + ImageComparator.concurrent_threads = concurrent_threads + self.icmt = None + self.samples = samples + self.dist_thresh = dist_thresh + self.match_thresh = match_thresh + self.resize_dim = resize_dim + + self.sift = cv2.xfeatures2d.SIFT_create() + + @DeprecationWarning + def match_images(self, path0: str, path1: str, sample_size: int = 100, + match_thresh: float = 0.8, + dist_thresh: float = 350, + diff_min: float = 1.5) -> bool: + """ + Matches the given images using SIFT featuring and euclidian distance comparison + of a random sample of keypoints. + + True if at least match_thresh many keypoints + have been successfully matched (e.g. 0.9 -> 90%). + :param path0: + :param path1: + :param sample_size: + :param match_thresh: float from 0 to 1 in percent of keypoint matches required + :param dist_thresh: float in max distance of keypoints to match + :param diff_min: float by which the second closest match must be bigger + :return: + """ + start_time = datetime.datetime.now() + print("Creating feature lists ...") + ft0, des0 = self.get_features(cv2.imread(path0)) + ft1, des1 = self.get_features(cv2.imread(path1)) + print("Created feature lists!") + + print(datetime.datetime.now() - start_time) + + print("Looking for matches ...") + selection = self.__get_random_selection(des0, sample_size) + hits = self.find_matching_keypoints(des0, des1, sample_size, dist_thresh) + print("Looked for matches!") + + print(datetime.datetime.now() - start_time) + + match_ratio = hits / len(selection) + print("MatchRatio:{0} Hits:{1}".format(match_ratio, hits)) + + if match_ratio >= match_thresh: + return True + else: + return False + + @DeprecationWarning + def find_matching_keypoints(self, keypoints1: list, keypoints2: list, + sample_size: int, dist_thresh: float) -> int: + """ + Find nearest neighbours for each point in keypoints1 in keypoints2. + + Returns number of sufficiently matching keypoints + :param keypoints1: + :param keypoints2: + :param sample_size: + :param dist_thresh: + :return: + """ + selection = self.__get_random_selection(keypoints1, sample_size) + hits, sum = 0, 0 + for i in range(len(selection)): + hit, dist = self.has_matching_keypoint(selection[i], keypoints2, dist_thresh) + if hit: + hits += 1 + sum += dist + + return hits + + def get_features(self, img) -> tuple: + assert(img is not None) + grey = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) + + kp = self.sift.detect(grey, None) + kp = self.select_keypoints(kp, self.samples, (img.shape[1], img.shape[0])) + des = self.sift.compute(grey, kp) + + return kp, des[1] + + def select_keypoints(self, kps: list, samples: int, dimension: tuple): + """ + Select a good sample of keypoints among all keypoints. + :param kps: + :return: + """ + result = [] + kps.sort(key=lambda x: x.pt[0]) # sort by x-coord + + kps_left = kps[:len(kps)//2] + kps_left.sort(key=lambda x: x.pt[1]) + + kps_up_left = kps_left[:len(kps_left)//2] + kps_down_left = kps_left[len(kps_left)//2:] + + kps_right = kps[len(kps)//2:] + kps_right.sort(key=lambda x: x.pt[1]) + + kps_up_right = kps_right[:len(kps_right)//2] + kps_down_right = kps_right[len(kps_right)//2:] + + sample = samples // 4 + + for quad_kps in [kps_up_left, kps_up_right, kps_down_left, kps_down_right]: + quad_kps.sort(key=lambda x: x.response, reverse=True) + + result += quad_kps[:sample] + + return result + + def has_similar_match(self, imgPath: str, dbPath: str, picture_data: dict) -> tuple: + """ + Calculate similarity of imgPath and all images in dbPath. + Returns as soon as a match has been found. + :param imgPath: + :param dbPath: + :param picture_data: + :return: + """ + # Calculating features + targImg = cv2.imread(imgPath) + targFeat, targDesc = self.get_features(targImg) + + # Matching features + + self.icmt = ImageCompareManageThread(imgPath, targDesc, picture_data, self.notify_result, + self.concurrent_threads, self.samples, + self.dist_thresh) + self.icmt.start() + self.icmt.join() + print("Done managing threads") + + match = "" + value = 0.0 + for key in self.results.keys(): + if value < self.results[key]/self.samples: + value = self.results[key] / self.samples + match = key + + if value >= self.match_thresh: + return True, match, value + else: + return False, "", 0.0 + + def notify_result(self, name: str, hits: int): + self.results[name] = hits + + if hits/self.samples >= self.match_thresh: + print("[{0}] got a match! Aborting search ...".format(name)) + self.icmt.searching = False + diff --git a/ImageCompareThread.py b/ImageCompareThread.py new file mode 100644 index 0000000..50b95ea --- /dev/null +++ b/ImageCompareThread.py @@ -0,0 +1,121 @@ +from threading import Thread +from time import sleep +import numpy as np +import random +from scipy.spatial import KDTree + + +class ImageCompareManageThread(Thread): + + def __init__(self, name: str, candidate_desc: np.array, descriptors: dict, callback, + concurrent_threads: int = 10, samples: int = 100, dist_thresh: float = 80): + super().__init__(name=name) + + self.candidate_desc = candidate_desc + self.descriptors = descriptors + self.samples = samples + self.dist_thresh = dist_thresh + self.callback = callback + self.concurrent_threads = concurrent_threads + + self.todo = [] + self.threads = {} + self.searching = False + + def run(self): + print("[{0}] Starting management ...".format(self.name)) + self.searching = True + self.todo = list(self.descriptors.keys()) + + for i in range(self.concurrent_threads): + if len(self.todo) == 0: + break + key = self.todo.pop() + ict = ImageCompareThread(key, self.candidate_desc, self.descriptors[key][1], + self.samples, self.dist_thresh, self.finish_thread) + self.threads[key] = ict + ict.start() + + while self.searching: + sleep(2) + + def finish_thread(self, name: str, hits: int): + self.callback(name, hits) + print("[{0}] finished with {1}".format(name, hits)) + + print("{0} jobs left ...".format(len(self.todo))) + if len(self.todo) > 0 and self.searching: # still work to do, start another thread + key = self.todo.pop() + ict = ImageCompareThread(key, self.candidate_desc, self.descriptors[key][1], + self.samples, self.dist_thresh, self.finish_thread) + self.threads[key] = ict + ict.start() + else: + self.searching = False + + +class ImageCompareThread(Thread): + + def __init__(self, name: str, candidate_desc: np.array, db_desc: np.array, + sample_size: int, dist_thresh: float, callback): + super().__init__(name=name) + + self.candidate_desc = candidate_desc + self.db_desc = db_desc + self.samples = sample_size + self.dist_thresh = dist_thresh + self.callback = callback + + def run(self): + print("[{0}] starting ...".format(self.name)) + hits = self.find_matching_keypoints(self.candidate_desc, self.db_desc, + self.samples, self.dist_thresh) + + self.callback(self.name, hits) + + def find_matching_keypoints(self, keypoints1: list, keypoints2: list, + sample_size: int, dist_thresh: float) -> int: + """ + Find nearest neighbours for each point in keypoints1 in keypoints2. + + Returns number of sufficiently matching keypoints + :param keypoints1: + :param keypoints2: + :param sample_size: + :param dist_thresh: + :return: + """ + hits, sum = 0, 0 + for i in range(len(keypoints1)): + hit, dist = self.has_matching_keypoint(keypoints1[i], keypoints2, dist_thresh) + if hit: + hits += 1 + sum += dist + + return hits + + def __get_random_selection(self, l: list, num: int) -> list: + result = [] + for k in range(num): + result.append(random.choice(l)) + + return result + + def has_matching_keypoint(self, point: np.ndarray, points: list, max_dist: float) -> tuple: + """ + Find nearest neighbour for point in points. + :param point: + :param points: + :param max_dist: + :return: + """ + tree = KDTree(points) + + dist, ind = tree.query([point], k=2) + dist = dist[0] # just resolving nested lists + #print("Distances:{0} Indexes:{1} MaxDist:{2}".format(dist, ind[0], max_dist)) + + if dist[0] <= max_dist: # second neighbour is found, valid hit + return True, dist[0] + else: + return False, dist[0] diff --git a/main.py b/main.py new file mode 100644 index 0000000..1082d16 --- /dev/null +++ b/main.py @@ -0,0 +1,22 @@ +from CompDatabase import ImageDB +from datetime import datetime + +start_time = datetime.now() +# TODO save & load feature lists for the database, fix resized matching (injective resizing?) + +root_folder = "Pictures" +import_folder = "import" +db_folder = "images" +samples = 200 +dist_thresh = 80 +match_thresh = 0.8 +resize_dim = (500, 500) +threads = 4 + +img_db = ImageDB(root_folder, import_folder, db_folder, samples=samples, dist_thresh=dist_thresh, + match_thresh=match_thresh, resize_dim=resize_dim, threads=threads) + +img_db.log("Starting work ...") +img_db.log("DB size: {0}".format(img_db.get_db_size())) +img_db.process_all_images() +img_db.log("Work done!")