You cannot select more than 25 topics
			Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
		
		
		
		
		
			
		
			
				
	
	
		
			169 lines
		
	
	
		
			5.6 KiB
		
	
	
	
		
			Python
		
	
			
		
		
	
	
			169 lines
		
	
	
		
			5.6 KiB
		
	
	
	
		
			Python
		
	
| import cv2
 | |
| from ImageCompareThread import ImageCompareManageThread
 | |
| 
 | |
| import datetime
 | |
| 
 | |
| 
 | |
| class ImageComparator:
 | |
| 
 | |
|     concurrent_threads = 10
 | |
| 
 | |
|     def __init__(self, samples: int = 100, dist_thresh: float = 80, match_thresh: float = 0.6,
 | |
|                  resize_dim: tuple = (500, 500), concurrent_threads: int = 10):
 | |
|         self.results = {}
 | |
|         ImageComparator.concurrent_threads = concurrent_threads
 | |
|         self.icmt = None
 | |
|         self.samples = samples
 | |
|         self.dist_thresh = dist_thresh
 | |
|         self.match_thresh = match_thresh
 | |
|         self.resize_dim = resize_dim
 | |
| 
 | |
|         self.sift = cv2.xfeatures2d.SIFT_create()
 | |
| 
 | |
|     @DeprecationWarning
 | |
|     def match_images(self, path0: str, path1: str, sample_size: int = 100,
 | |
|                      match_thresh: float = 0.8,
 | |
|                      dist_thresh: float = 350,
 | |
|                      diff_min: float = 1.5) -> bool:
 | |
|         """
 | |
|         Matches the given images using SIFT featuring and euclidian distance comparison
 | |
|         of a random sample of keypoints.
 | |
| 
 | |
|         True if at least match_thresh many keypoints
 | |
|         have been successfully matched (e.g. 0.9 -> 90%).
 | |
|         :param path0:
 | |
|         :param path1:
 | |
|         :param sample_size:
 | |
|         :param match_thresh: float from 0 to 1 in percent of keypoint matches required
 | |
|         :param dist_thresh: float in max distance of keypoints to match
 | |
|         :param diff_min: float by which the second closest match must be bigger
 | |
|         :return:
 | |
|         """
 | |
|         start_time = datetime.datetime.now()
 | |
|         print("Creating feature lists ...")
 | |
|         ft0, des0 = self.get_features(cv2.imread(path0))
 | |
|         ft1, des1 = self.get_features(cv2.imread(path1))
 | |
|         print("Created feature lists!")
 | |
| 
 | |
|         print(datetime.datetime.now() - start_time)
 | |
| 
 | |
|         print("Looking for matches ...")
 | |
|         selection = self.__get_random_selection(des0, sample_size)
 | |
|         hits = self.find_matching_keypoints(des0, des1, sample_size, dist_thresh)
 | |
|         print("Looked for matches!")
 | |
| 
 | |
|         print(datetime.datetime.now() - start_time)
 | |
| 
 | |
|         match_ratio = hits / len(selection)
 | |
|         print("MatchRatio:{0} Hits:{1}".format(match_ratio, hits))
 | |
| 
 | |
|         if match_ratio >= match_thresh:
 | |
|             return True
 | |
|         else:
 | |
|             return False
 | |
| 
 | |
|     @DeprecationWarning
 | |
|     def find_matching_keypoints(self, keypoints1: list, keypoints2: list,
 | |
|                                 sample_size: int, dist_thresh: float) -> int:
 | |
|         """
 | |
|         Find nearest neighbours for each point in keypoints1 in keypoints2.
 | |
| 
 | |
|         Returns number of sufficiently matching keypoints
 | |
|         :param keypoints1:
 | |
|         :param keypoints2:
 | |
|         :param sample_size:
 | |
|         :param dist_thresh:
 | |
|         :return:
 | |
|         """
 | |
|         selection = self.__get_random_selection(keypoints1, sample_size)
 | |
|         hits, sum = 0, 0
 | |
|         for i in range(len(selection)):
 | |
|             hit, dist = self.has_matching_keypoint(selection[i], keypoints2, dist_thresh)
 | |
|             if hit:
 | |
|                 hits += 1
 | |
|                 sum += dist
 | |
| 
 | |
|         return hits
 | |
| 
 | |
|     def get_features(self, img) -> tuple:
 | |
|         assert(img is not None)
 | |
|         grey = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
 | |
| 
 | |
|         kp = self.sift.detect(grey, None)
 | |
|         kp = self.select_keypoints(kp, self.samples, (img.shape[1], img.shape[0]))
 | |
|         des = self.sift.compute(grey, kp)
 | |
| 
 | |
|         return kp, des[1]
 | |
| 
 | |
|     def select_keypoints(self, kps: list, samples: int, dimension: tuple):
 | |
|         """
 | |
|         Select a good sample of keypoints among all keypoints.
 | |
|         :param kps:
 | |
|         :return:
 | |
|         """
 | |
|         result = []
 | |
|         kps.sort(key=lambda x: x.pt[0])  # sort by x-coord
 | |
| 
 | |
|         kps_left = kps[:len(kps)//2]
 | |
|         kps_left.sort(key=lambda x: x.pt[1])
 | |
| 
 | |
|         kps_up_left = kps_left[:len(kps_left)//2]
 | |
|         kps_down_left = kps_left[len(kps_left)//2:]
 | |
| 
 | |
|         kps_right = kps[len(kps)//2:]
 | |
|         kps_right.sort(key=lambda x: x.pt[1])
 | |
| 
 | |
|         kps_up_right = kps_right[:len(kps_right)//2]
 | |
|         kps_down_right = kps_right[len(kps_right)//2:]
 | |
| 
 | |
|         sample = samples // 4
 | |
| 
 | |
|         for quad_kps in [kps_up_left, kps_up_right, kps_down_left, kps_down_right]:
 | |
|             quad_kps.sort(key=lambda x: x.response, reverse=True)
 | |
| 
 | |
|             result += quad_kps[:sample]
 | |
| 
 | |
|         return result
 | |
| 
 | |
|     def has_similar_match(self, imgPath: str, dbPath: str, picture_data: dict) -> tuple:
 | |
|         """
 | |
|         Calculate similarity of imgPath and all images in dbPath.
 | |
|         Returns as soon as a match has been found.
 | |
|         :param imgPath:
 | |
|         :param dbPath:
 | |
|         :param picture_data:
 | |
|         :return:
 | |
|         """
 | |
|         # Calculating features
 | |
|         targImg = cv2.imread(imgPath)
 | |
|         targFeat, targDesc = self.get_features(targImg)
 | |
| 
 | |
|         # Matching features
 | |
| 
 | |
|         self.icmt = ImageCompareManageThread(imgPath, targDesc, picture_data, self.notify_result,
 | |
|                                              self.concurrent_threads, self.samples,
 | |
|                                              self.dist_thresh)
 | |
|         self.icmt.start()
 | |
|         self.icmt.join()
 | |
|         print("Done managing threads")
 | |
| 
 | |
|         match = ""
 | |
|         value = 0.0
 | |
|         for key in self.results.keys():
 | |
|             if value < self.results[key]/self.samples:
 | |
|                 value = self.results[key] / self.samples
 | |
|                 match = key
 | |
| 
 | |
|         if value >= self.match_thresh:
 | |
|             return True, match, value
 | |
|         else:
 | |
|             return False, "", 0.0
 | |
| 
 | |
|     def notify_result(self, name: str, hits: int):
 | |
|         self.results[name] = hits
 | |
| 
 | |
|         if hits/self.samples >= self.match_thresh:
 | |
|             print("[{0}] got a match! Aborting search ...".format(name))
 | |
|             self.icmt.searching = False
 | |
| 
 |