Multi-Threaded SIFT feature comparison
* Comparison based on biggest SIFT features * Multi-threaded comparison of one image against a set * Basic folder structure for importing new images into the setmaster
parent
276db4b9fb
commit
32fcc04c14
@ -0,0 +1,178 @@
|
|||||||
|
import os
|
||||||
|
import cv2
|
||||||
|
from ImageComparator import ImageComparator
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
|
||||||
|
class ImageDB:
|
||||||
|
|
||||||
|
supported_file_extensions = [".png", ".jpg", ".jpeg"]
|
||||||
|
|
||||||
|
def __init__(self, root_path: str, import_folder: str = "import", db_folder: str = "images",
|
||||||
|
dump_folder: str = "duplicate",
|
||||||
|
folder_sep: str = "/", samples: int = 100, dist_thresh: float = 80,
|
||||||
|
match_thresh: float = 0.3, resize_dim: tuple = (500,500), threads: int = 2):
|
||||||
|
self.start_time = datetime.now()
|
||||||
|
|
||||||
|
self.root_path = root_path
|
||||||
|
self.import_path = self.root_path + folder_sep + import_folder
|
||||||
|
self.db_path = self.root_path + folder_sep + db_folder
|
||||||
|
self.dump_path = self.root_path + folder_sep + dump_folder
|
||||||
|
self.dir_sep = folder_sep
|
||||||
|
|
||||||
|
log_file = self.root_path + self.dir_sep + "db_log.txt"
|
||||||
|
|
||||||
|
self.imgcomp = ImageComparator(samples, dist_thresh, match_thresh, resize_dim, threads)
|
||||||
|
|
||||||
|
if not os.path.isdir(self.root_path):
|
||||||
|
os.mkdir(self.root_path)
|
||||||
|
self.log("Created {0} because it was missing ...".format(self.root_path), False)
|
||||||
|
if not os.path.isdir(self.import_path):
|
||||||
|
os.mkdir(self.import_path)
|
||||||
|
self.log("Created {0} because it was missing ...".format(self.import_path), False)
|
||||||
|
if not os.path.isdir(self.db_path):
|
||||||
|
os.mkdir(self.db_path)
|
||||||
|
self.log("Created {0} because it was missing ...".format(self.db_path), False)
|
||||||
|
if not os.path.isdir(self.dump_path):
|
||||||
|
os.mkdir(self.dump_path)
|
||||||
|
self.log("Created {0} because it was missing ...".format(self.dump_path), False)
|
||||||
|
|
||||||
|
self.log_stream = open(log_file, "w")
|
||||||
|
|
||||||
|
self.log("Calculating all features ...")
|
||||||
|
self.picture_data = self.calc_db_features()
|
||||||
|
self.log("Done!")
|
||||||
|
|
||||||
|
def log(self, msg: str, log_to_file: bool = True):
|
||||||
|
time = datetime.now() - self.start_time
|
||||||
|
msg = "[{0}] {1}".format(time, msg)
|
||||||
|
print(msg)
|
||||||
|
if log_to_file:
|
||||||
|
self.log_stream.write(msg+"\n")
|
||||||
|
self.log_stream.flush()
|
||||||
|
|
||||||
|
def process_all_images(self):
|
||||||
|
for file_name in os.listdir(self.import_path):
|
||||||
|
self.process_image(file_name)
|
||||||
|
self.update_db_features()
|
||||||
|
|
||||||
|
def process_image(self, img_name: str) -> bool:
|
||||||
|
"""
|
||||||
|
Processes a given image and decides if it is to be imported and then import or return.
|
||||||
|
|
||||||
|
Assumes image to be inside the import folder.
|
||||||
|
Indicates if it was imported via boolean
|
||||||
|
:param img_name:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
img_path = self.import_path + self.dir_sep + img_name
|
||||||
|
|
||||||
|
if self.is_supported_file(img_path):
|
||||||
|
result, matched, ratio = self.is_image_unique(img_path, self.picture_data)
|
||||||
|
if not result:
|
||||||
|
self.import_image(img_name)
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
self.log("Matched {0} to {1} with ratio {2}".format(img_name, matched, ratio))
|
||||||
|
self.move_to_dump(img_name)
|
||||||
|
return False
|
||||||
|
|
||||||
|
def import_image(self, img_name: str):
|
||||||
|
"""
|
||||||
|
Import the given image into the database.
|
||||||
|
|
||||||
|
Does no checks, just imports.
|
||||||
|
:param img_name: image file name, assumed to be located inside the import folder
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
img_path = self.import_path + self.dir_sep + img_name
|
||||||
|
dest_path = self.db_path + self.dir_sep + img_name
|
||||||
|
if os.path.isfile(img_path):
|
||||||
|
os.rename(img_path, dest_path)
|
||||||
|
else:
|
||||||
|
raise Exception("Image path has not been a valid file!")
|
||||||
|
|
||||||
|
def move_to_dump(self, img_name: str):
|
||||||
|
"""
|
||||||
|
Move the given image to the dump folder.
|
||||||
|
:param img_name:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
img_path = self.import_path + self.dir_sep + img_name
|
||||||
|
dest_path = self.dump_path + self.dir_sep + img_name
|
||||||
|
if os.path.isfile(img_path):
|
||||||
|
os.rename(img_path, dest_path)
|
||||||
|
else:
|
||||||
|
raise Exception("Image path has not been a valid file!")
|
||||||
|
|
||||||
|
def is_image_unique(self, img_path: str, picture_data: dict) -> tuple:
|
||||||
|
"""
|
||||||
|
Check if the given image is already in the databank or not.
|
||||||
|
:param img_path:
|
||||||
|
:param picture_data:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
if len(picture_data) == 0:
|
||||||
|
return False, "", 0.0
|
||||||
|
result, matched, match = self.imgcomp.has_similar_match(img_path, self.db_path,
|
||||||
|
picture_data)
|
||||||
|
|
||||||
|
if result:
|
||||||
|
return True, matched, match
|
||||||
|
else:
|
||||||
|
return False, "", 0.0
|
||||||
|
|
||||||
|
def calc_db_features(self) -> dict:
|
||||||
|
"""
|
||||||
|
Calculate keypoints for every image in the database
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
pictures = dict()
|
||||||
|
for name in os.listdir(self.db_path):
|
||||||
|
pic_path = self.db_path + self.dir_sep + name
|
||||||
|
if not self.is_supported_file(pic_path):
|
||||||
|
continue
|
||||||
|
img = cv2.imread(pic_path)
|
||||||
|
|
||||||
|
pictures[pic_path] = self.imgcomp.get_features(img)
|
||||||
|
|
||||||
|
return pictures
|
||||||
|
|
||||||
|
def update_db_features(self):
|
||||||
|
"""
|
||||||
|
Update the keypoint dictionary with only new image data
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
for name in os.listdir(self.db_path):
|
||||||
|
pic_path = self.db_path + self.dir_sep + name
|
||||||
|
if not self.is_supported_file(pic_path):
|
||||||
|
continue
|
||||||
|
if pic_path not in self.picture_data.keys():
|
||||||
|
self.log("New image {0}! Calculating features for memory...".format(name))
|
||||||
|
img = cv2.imread(pic_path)
|
||||||
|
self.picture_data[pic_path] = self.imgcomp.get_features(img)
|
||||||
|
self.log("Done!")
|
||||||
|
|
||||||
|
def is_supported_file(self, path: str) -> bool:
|
||||||
|
"""
|
||||||
|
Returns if the given file has a valid picture file extension.
|
||||||
|
:param path:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
if os.path.isfile(path):
|
||||||
|
for ext in ImageDB.supported_file_extensions:
|
||||||
|
if ext in path.lower():
|
||||||
|
return True
|
||||||
|
print("{0} is not a supported file format!".format(path))
|
||||||
|
return False
|
||||||
|
|
||||||
|
def get_db_size(self) -> int:
|
||||||
|
"""
|
||||||
|
Return the number of images inside the databank
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
size = 0
|
||||||
|
for entry in os.listdir(self.db_path):
|
||||||
|
if os.path.isfile(self.db_path+self.dir_sep+entry):
|
||||||
|
size += 1
|
||||||
|
return size
|
@ -0,0 +1,168 @@
|
|||||||
|
import cv2
|
||||||
|
from ImageCompareThread import ImageCompareManageThread
|
||||||
|
|
||||||
|
import datetime
|
||||||
|
|
||||||
|
|
||||||
|
class ImageComparator:
|
||||||
|
|
||||||
|
concurrent_threads = 10
|
||||||
|
|
||||||
|
def __init__(self, samples: int = 100, dist_thresh: float = 80, match_thresh: float = 0.6,
|
||||||
|
resize_dim: tuple = (500, 500), concurrent_threads: int = 10):
|
||||||
|
self.results = {}
|
||||||
|
ImageComparator.concurrent_threads = concurrent_threads
|
||||||
|
self.icmt = None
|
||||||
|
self.samples = samples
|
||||||
|
self.dist_thresh = dist_thresh
|
||||||
|
self.match_thresh = match_thresh
|
||||||
|
self.resize_dim = resize_dim
|
||||||
|
|
||||||
|
self.sift = cv2.xfeatures2d.SIFT_create()
|
||||||
|
|
||||||
|
@DeprecationWarning
|
||||||
|
def match_images(self, path0: str, path1: str, sample_size: int = 100,
|
||||||
|
match_thresh: float = 0.8,
|
||||||
|
dist_thresh: float = 350,
|
||||||
|
diff_min: float = 1.5) -> bool:
|
||||||
|
"""
|
||||||
|
Matches the given images using SIFT featuring and euclidian distance comparison
|
||||||
|
of a random sample of keypoints.
|
||||||
|
|
||||||
|
True if at least match_thresh many keypoints
|
||||||
|
have been successfully matched (e.g. 0.9 -> 90%).
|
||||||
|
:param path0:
|
||||||
|
:param path1:
|
||||||
|
:param sample_size:
|
||||||
|
:param match_thresh: float from 0 to 1 in percent of keypoint matches required
|
||||||
|
:param dist_thresh: float in max distance of keypoints to match
|
||||||
|
:param diff_min: float by which the second closest match must be bigger
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
start_time = datetime.datetime.now()
|
||||||
|
print("Creating feature lists ...")
|
||||||
|
ft0, des0 = self.get_features(cv2.imread(path0))
|
||||||
|
ft1, des1 = self.get_features(cv2.imread(path1))
|
||||||
|
print("Created feature lists!")
|
||||||
|
|
||||||
|
print(datetime.datetime.now() - start_time)
|
||||||
|
|
||||||
|
print("Looking for matches ...")
|
||||||
|
selection = self.__get_random_selection(des0, sample_size)
|
||||||
|
hits = self.find_matching_keypoints(des0, des1, sample_size, dist_thresh)
|
||||||
|
print("Looked for matches!")
|
||||||
|
|
||||||
|
print(datetime.datetime.now() - start_time)
|
||||||
|
|
||||||
|
match_ratio = hits / len(selection)
|
||||||
|
print("MatchRatio:{0} Hits:{1}".format(match_ratio, hits))
|
||||||
|
|
||||||
|
if match_ratio >= match_thresh:
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
|
||||||
|
@DeprecationWarning
|
||||||
|
def find_matching_keypoints(self, keypoints1: list, keypoints2: list,
|
||||||
|
sample_size: int, dist_thresh: float) -> int:
|
||||||
|
"""
|
||||||
|
Find nearest neighbours for each point in keypoints1 in keypoints2.
|
||||||
|
|
||||||
|
Returns number of sufficiently matching keypoints
|
||||||
|
:param keypoints1:
|
||||||
|
:param keypoints2:
|
||||||
|
:param sample_size:
|
||||||
|
:param dist_thresh:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
selection = self.__get_random_selection(keypoints1, sample_size)
|
||||||
|
hits, sum = 0, 0
|
||||||
|
for i in range(len(selection)):
|
||||||
|
hit, dist = self.has_matching_keypoint(selection[i], keypoints2, dist_thresh)
|
||||||
|
if hit:
|
||||||
|
hits += 1
|
||||||
|
sum += dist
|
||||||
|
|
||||||
|
return hits
|
||||||
|
|
||||||
|
def get_features(self, img) -> tuple:
|
||||||
|
assert(img is not None)
|
||||||
|
grey = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
|
||||||
|
|
||||||
|
kp = self.sift.detect(grey, None)
|
||||||
|
kp = self.select_keypoints(kp, self.samples, (img.shape[1], img.shape[0]))
|
||||||
|
des = self.sift.compute(grey, kp)
|
||||||
|
|
||||||
|
return kp, des[1]
|
||||||
|
|
||||||
|
def select_keypoints(self, kps: list, samples: int, dimension: tuple):
|
||||||
|
"""
|
||||||
|
Select a good sample of keypoints among all keypoints.
|
||||||
|
:param kps:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
result = []
|
||||||
|
kps.sort(key=lambda x: x.pt[0]) # sort by x-coord
|
||||||
|
|
||||||
|
kps_left = kps[:len(kps)//2]
|
||||||
|
kps_left.sort(key=lambda x: x.pt[1])
|
||||||
|
|
||||||
|
kps_up_left = kps_left[:len(kps_left)//2]
|
||||||
|
kps_down_left = kps_left[len(kps_left)//2:]
|
||||||
|
|
||||||
|
kps_right = kps[len(kps)//2:]
|
||||||
|
kps_right.sort(key=lambda x: x.pt[1])
|
||||||
|
|
||||||
|
kps_up_right = kps_right[:len(kps_right)//2]
|
||||||
|
kps_down_right = kps_right[len(kps_right)//2:]
|
||||||
|
|
||||||
|
sample = samples // 4
|
||||||
|
|
||||||
|
for quad_kps in [kps_up_left, kps_up_right, kps_down_left, kps_down_right]:
|
||||||
|
quad_kps.sort(key=lambda x: x.response, reverse=True)
|
||||||
|
|
||||||
|
result += quad_kps[:sample]
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
def has_similar_match(self, imgPath: str, dbPath: str, picture_data: dict) -> tuple:
|
||||||
|
"""
|
||||||
|
Calculate similarity of imgPath and all images in dbPath.
|
||||||
|
Returns as soon as a match has been found.
|
||||||
|
:param imgPath:
|
||||||
|
:param dbPath:
|
||||||
|
:param picture_data:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
# Calculating features
|
||||||
|
targImg = cv2.imread(imgPath)
|
||||||
|
targFeat, targDesc = self.get_features(targImg)
|
||||||
|
|
||||||
|
# Matching features
|
||||||
|
|
||||||
|
self.icmt = ImageCompareManageThread(imgPath, targDesc, picture_data, self.notify_result,
|
||||||
|
self.concurrent_threads, self.samples,
|
||||||
|
self.dist_thresh)
|
||||||
|
self.icmt.start()
|
||||||
|
self.icmt.join()
|
||||||
|
print("Done managing threads")
|
||||||
|
|
||||||
|
match = ""
|
||||||
|
value = 0.0
|
||||||
|
for key in self.results.keys():
|
||||||
|
if value < self.results[key]/self.samples:
|
||||||
|
value = self.results[key] / self.samples
|
||||||
|
match = key
|
||||||
|
|
||||||
|
if value >= self.match_thresh:
|
||||||
|
return True, match, value
|
||||||
|
else:
|
||||||
|
return False, "", 0.0
|
||||||
|
|
||||||
|
def notify_result(self, name: str, hits: int):
|
||||||
|
self.results[name] = hits
|
||||||
|
|
||||||
|
if hits/self.samples >= self.match_thresh:
|
||||||
|
print("[{0}] got a match! Aborting search ...".format(name))
|
||||||
|
self.icmt.searching = False
|
||||||
|
|
@ -0,0 +1,121 @@
|
|||||||
|
from threading import Thread
|
||||||
|
from time import sleep
|
||||||
|
import numpy as np
|
||||||
|
import random
|
||||||
|
from scipy.spatial import KDTree
|
||||||
|
|
||||||
|
|
||||||
|
class ImageCompareManageThread(Thread):
|
||||||
|
|
||||||
|
def __init__(self, name: str, candidate_desc: np.array, descriptors: dict, callback,
|
||||||
|
concurrent_threads: int = 10, samples: int = 100, dist_thresh: float = 80):
|
||||||
|
super().__init__(name=name)
|
||||||
|
|
||||||
|
self.candidate_desc = candidate_desc
|
||||||
|
self.descriptors = descriptors
|
||||||
|
self.samples = samples
|
||||||
|
self.dist_thresh = dist_thresh
|
||||||
|
self.callback = callback
|
||||||
|
self.concurrent_threads = concurrent_threads
|
||||||
|
|
||||||
|
self.todo = []
|
||||||
|
self.threads = {}
|
||||||
|
self.searching = False
|
||||||
|
|
||||||
|
def run(self):
|
||||||
|
print("[{0}] Starting management ...".format(self.name))
|
||||||
|
self.searching = True
|
||||||
|
self.todo = list(self.descriptors.keys())
|
||||||
|
|
||||||
|
for i in range(self.concurrent_threads):
|
||||||
|
if len(self.todo) == 0:
|
||||||
|
break
|
||||||
|
key = self.todo.pop()
|
||||||
|
ict = ImageCompareThread(key, self.candidate_desc, self.descriptors[key][1],
|
||||||
|
self.samples, self.dist_thresh, self.finish_thread)
|
||||||
|
self.threads[key] = ict
|
||||||
|
ict.start()
|
||||||
|
|
||||||
|
while self.searching:
|
||||||
|
sleep(2)
|
||||||
|
|
||||||
|
def finish_thread(self, name: str, hits: int):
|
||||||
|
self.callback(name, hits)
|
||||||
|
print("[{0}] finished with {1}".format(name, hits))
|
||||||
|
|
||||||
|
print("{0} jobs left ...".format(len(self.todo)))
|
||||||
|
if len(self.todo) > 0 and self.searching: # still work to do, start another thread
|
||||||
|
key = self.todo.pop()
|
||||||
|
ict = ImageCompareThread(key, self.candidate_desc, self.descriptors[key][1],
|
||||||
|
self.samples, self.dist_thresh, self.finish_thread)
|
||||||
|
self.threads[key] = ict
|
||||||
|
ict.start()
|
||||||
|
else:
|
||||||
|
self.searching = False
|
||||||
|
|
||||||
|
|
||||||
|
class ImageCompareThread(Thread):
|
||||||
|
|
||||||
|
def __init__(self, name: str, candidate_desc: np.array, db_desc: np.array,
|
||||||
|
sample_size: int, dist_thresh: float, callback):
|
||||||
|
super().__init__(name=name)
|
||||||
|
|
||||||
|
self.candidate_desc = candidate_desc
|
||||||
|
self.db_desc = db_desc
|
||||||
|
self.samples = sample_size
|
||||||
|
self.dist_thresh = dist_thresh
|
||||||
|
self.callback = callback
|
||||||
|
|
||||||
|
def run(self):
|
||||||
|
print("[{0}] starting ...".format(self.name))
|
||||||
|
hits = self.find_matching_keypoints(self.candidate_desc, self.db_desc,
|
||||||
|
self.samples, self.dist_thresh)
|
||||||
|
|
||||||
|
self.callback(self.name, hits)
|
||||||
|
|
||||||
|
def find_matching_keypoints(self, keypoints1: list, keypoints2: list,
|
||||||
|
sample_size: int, dist_thresh: float) -> int:
|
||||||
|
"""
|
||||||
|
Find nearest neighbours for each point in keypoints1 in keypoints2.
|
||||||
|
|
||||||
|
Returns number of sufficiently matching keypoints
|
||||||
|
:param keypoints1:
|
||||||
|
:param keypoints2:
|
||||||
|
:param sample_size:
|
||||||
|
:param dist_thresh:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
hits, sum = 0, 0
|
||||||
|
for i in range(len(keypoints1)):
|
||||||
|
hit, dist = self.has_matching_keypoint(keypoints1[i], keypoints2, dist_thresh)
|
||||||
|
if hit:
|
||||||
|
hits += 1
|
||||||
|
sum += dist
|
||||||
|
|
||||||
|
return hits
|
||||||
|
|
||||||
|
def __get_random_selection(self, l: list, num: int) -> list:
|
||||||
|
result = []
|
||||||
|
for k in range(num):
|
||||||
|
result.append(random.choice(l))
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
def has_matching_keypoint(self, point: np.ndarray, points: list, max_dist: float) -> tuple:
|
||||||
|
"""
|
||||||
|
Find nearest neighbour for point in points.
|
||||||
|
:param point:
|
||||||
|
:param points:
|
||||||
|
:param max_dist:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
tree = KDTree(points)
|
||||||
|
|
||||||
|
dist, ind = tree.query([point], k=2)
|
||||||
|
dist = dist[0] # just resolving nested lists
|
||||||
|
#print("Distances:{0} Indexes:{1} MaxDist:{2}".format(dist, ind[0], max_dist))
|
||||||
|
|
||||||
|
if dist[0] <= max_dist: # second neighbour is found, valid hit
|
||||||
|
return True, dist[0]
|
||||||
|
else:
|
||||||
|
return False, dist[0]
|
@ -0,0 +1,22 @@
|
|||||||
|
from CompDatabase import ImageDB
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
start_time = datetime.now()
|
||||||
|
# TODO save & load feature lists for the database, fix resized matching (injective resizing?)
|
||||||
|
|
||||||
|
root_folder = "Pictures"
|
||||||
|
import_folder = "import"
|
||||||
|
db_folder = "images"
|
||||||
|
samples = 200
|
||||||
|
dist_thresh = 80
|
||||||
|
match_thresh = 0.8
|
||||||
|
resize_dim = (500, 500)
|
||||||
|
threads = 4
|
||||||
|
|
||||||
|
img_db = ImageDB(root_folder, import_folder, db_folder, samples=samples, dist_thresh=dist_thresh,
|
||||||
|
match_thresh=match_thresh, resize_dim=resize_dim, threads=threads)
|
||||||
|
|
||||||
|
img_db.log("Starting work ...")
|
||||||
|
img_db.log("DB size: {0}".format(img_db.get_db_size()))
|
||||||
|
img_db.process_all_images()
|
||||||
|
img_db.log("Work done!")
|
Loading…
Reference in New Issue