You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

169 lines
5.6 KiB
Python

import cv2
from ImageCompareThread import ImageCompareManageThread
import datetime
class ImageComparator:
concurrent_threads = 10
def __init__(self, samples: int = 100, dist_thresh: float = 80, match_thresh: float = 0.6,
resize_dim: tuple = (500, 500), concurrent_threads: int = 10):
self.results = {}
ImageComparator.concurrent_threads = concurrent_threads
self.icmt = None
self.samples = samples
self.dist_thresh = dist_thresh
self.match_thresh = match_thresh
self.resize_dim = resize_dim
self.sift = cv2.xfeatures2d.SIFT_create()
@DeprecationWarning
def match_images(self, path0: str, path1: str, sample_size: int = 100,
match_thresh: float = 0.8,
dist_thresh: float = 350,
diff_min: float = 1.5) -> bool:
"""
Matches the given images using SIFT featuring and euclidian distance comparison
of a random sample of keypoints.
True if at least match_thresh many keypoints
have been successfully matched (e.g. 0.9 -> 90%).
:param path0:
:param path1:
:param sample_size:
:param match_thresh: float from 0 to 1 in percent of keypoint matches required
:param dist_thresh: float in max distance of keypoints to match
:param diff_min: float by which the second closest match must be bigger
:return:
"""
start_time = datetime.datetime.now()
print("Creating feature lists ...")
ft0, des0 = self.get_features(cv2.imread(path0))
ft1, des1 = self.get_features(cv2.imread(path1))
print("Created feature lists!")
print(datetime.datetime.now() - start_time)
print("Looking for matches ...")
selection = self.__get_random_selection(des0, sample_size)
hits = self.find_matching_keypoints(des0, des1, sample_size, dist_thresh)
print("Looked for matches!")
print(datetime.datetime.now() - start_time)
match_ratio = hits / len(selection)
print("MatchRatio:{0} Hits:{1}".format(match_ratio, hits))
if match_ratio >= match_thresh:
return True
else:
return False
@DeprecationWarning
def find_matching_keypoints(self, keypoints1: list, keypoints2: list,
sample_size: int, dist_thresh: float) -> int:
"""
Find nearest neighbours for each point in keypoints1 in keypoints2.
Returns number of sufficiently matching keypoints
:param keypoints1:
:param keypoints2:
:param sample_size:
:param dist_thresh:
:return:
"""
selection = self.__get_random_selection(keypoints1, sample_size)
hits, sum = 0, 0
for i in range(len(selection)):
hit, dist = self.has_matching_keypoint(selection[i], keypoints2, dist_thresh)
if hit:
hits += 1
sum += dist
return hits
def get_features(self, img) -> tuple:
assert(img is not None)
grey = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
kp = self.sift.detect(grey, None)
kp = self.select_keypoints(kp, self.samples, (img.shape[1], img.shape[0]))
des = self.sift.compute(grey, kp)
return kp, des[1]
def select_keypoints(self, kps: list, samples: int, dimension: tuple):
"""
Select a good sample of keypoints among all keypoints.
:param kps:
:return:
"""
result = []
kps.sort(key=lambda x: x.pt[0]) # sort by x-coord
kps_left = kps[:len(kps)//2]
kps_left.sort(key=lambda x: x.pt[1])
kps_up_left = kps_left[:len(kps_left)//2]
kps_down_left = kps_left[len(kps_left)//2:]
kps_right = kps[len(kps)//2:]
kps_right.sort(key=lambda x: x.pt[1])
kps_up_right = kps_right[:len(kps_right)//2]
kps_down_right = kps_right[len(kps_right)//2:]
sample = samples // 4
for quad_kps in [kps_up_left, kps_up_right, kps_down_left, kps_down_right]:
quad_kps.sort(key=lambda x: x.response, reverse=True)
result += quad_kps[:sample]
return result
def has_similar_match(self, imgPath: str, dbPath: str, picture_data: dict) -> tuple:
"""
Calculate similarity of imgPath and all images in dbPath.
Returns as soon as a match has been found.
:param imgPath:
:param dbPath:
:param picture_data:
:return:
"""
# Calculating features
targImg = cv2.imread(imgPath)
targFeat, targDesc = self.get_features(targImg)
# Matching features
self.icmt = ImageCompareManageThread(imgPath, targDesc, picture_data, self.notify_result,
self.concurrent_threads, self.samples,
self.dist_thresh)
self.icmt.start()
self.icmt.join()
print("Done managing threads")
match = ""
value = 0.0
for key in self.results.keys():
if value < self.results[key]/self.samples:
value = self.results[key] / self.samples
match = key
if value >= self.match_thresh:
return True, match, value
else:
return False, "", 0.0
def notify_result(self, name: str, hits: int):
self.results[name] = hits
if hits/self.samples >= self.match_thresh:
print("[{0}] got a match! Aborting search ...".format(name))
self.icmt.searching = False