-
Notifications
You must be signed in to change notification settings - Fork 2
Visualize feature point alignment between the base image and the aligned image #86
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -9,9 +9,12 @@ | |||||||||||
|
|
||||||||||||
| import cv2 | ||||||||||||
| import diplib as dip | ||||||||||||
| import matplotlib.pyplot as plt | ||||||||||||
| import numpy as np | ||||||||||||
| import streamlit as st | ||||||||||||
| from numpy.fft import fft2, ifft2 | ||||||||||||
| from skimage.color import rgb2gray | ||||||||||||
| from skimage.feature import ORB, SIFT, match_descriptors, plot_matches | ||||||||||||
|
|
||||||||||||
| from unraphael.types import ImageType | ||||||||||||
|
|
||||||||||||
|
|
@@ -466,3 +469,158 @@ def align_image_to_base( | |||||||||||
| raise ValueError(f'No such method: {align_method}') | ||||||||||||
|
|
||||||||||||
| return func(image=image, base_image=base_image, **kwargs) | ||||||||||||
|
|
||||||||||||
|
|
||||||||||||
| def feature_based_alignment_visual( | ||||||||||||
| base_image: np.ndarray, | ||||||||||||
| target_image: np.ndarray, | ||||||||||||
| method: str, | ||||||||||||
| base_image_name: str, | ||||||||||||
| target_image_name: str, | ||||||||||||
| display_in_grayscale: bool = True, | ||||||||||||
| max_ratio: float = 0.6, | ||||||||||||
| ): | ||||||||||||
| """Visualize feature-based alignment between two images using a specified | ||||||||||||
| feature detection method. | ||||||||||||
|
|
||||||||||||
| Parameters | ||||||||||||
| ---------- | ||||||||||||
| base_image : np.ndarray | ||||||||||||
| The base image to be used as a reference for alignment | ||||||||||||
| target_image : np.ndarray | ||||||||||||
| The target image to be aligned with the base image | ||||||||||||
| method : str | ||||||||||||
| The feature detection method to use for alignment ('SIFT' or 'ORB') | ||||||||||||
| base_image_name : str | ||||||||||||
| Name of the base image for labeling | ||||||||||||
| target_image_name : str | ||||||||||||
| Name of the target image for labeling | ||||||||||||
| display_in_grayscale : bool, optional | ||||||||||||
| Whether to display images in grayscale or original color, by default True | ||||||||||||
| max_ratio : float, optional | ||||||||||||
| The maximum ratio for descriptor matching, by default 0.6 | ||||||||||||
|
|
||||||||||||
| Returns | ||||||||||||
| ------- | ||||||||||||
| None | ||||||||||||
| Displays a plot showing the feature-based alignment | ||||||||||||
|
Comment on lines
+502
to
+506
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is not necessary, but please update the function name/documentation that this will plot the aligment
Suggested change
|
||||||||||||
| """ | ||||||||||||
|
|
||||||||||||
| # Initialize the feature descriptor based on the method | ||||||||||||
| if method == 'SIFT': | ||||||||||||
| descriptor_extractor = SIFT() | ||||||||||||
| elif method == 'ORB': | ||||||||||||
| descriptor_extractor = ORB() | ||||||||||||
| else: | ||||||||||||
| st.error('Unsupported method selected.') | ||||||||||||
| return | ||||||||||||
|
Comment on lines
+514
to
+516
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Consider raising a |
||||||||||||
|
|
||||||||||||
| # Convert images to grayscale if not already grayscale | ||||||||||||
| base_image_gray = rgb2gray(base_image) if base_image.ndim == 3 else base_image | ||||||||||||
| target_image_gray = rgb2gray(target_image) if target_image.ndim == 3 else target_image | ||||||||||||
|
|
||||||||||||
| # Extract keypoints and descriptors for both images | ||||||||||||
| descriptor_extractor.detect_and_extract(base_image_gray) | ||||||||||||
| keypoints1 = descriptor_extractor.keypoints | ||||||||||||
| descriptors1 = descriptor_extractor.descriptors | ||||||||||||
|
|
||||||||||||
| descriptor_extractor.detect_and_extract(target_image_gray) | ||||||||||||
| keypoints2 = descriptor_extractor.keypoints | ||||||||||||
| descriptors2 = descriptor_extractor.descriptors | ||||||||||||
|
|
||||||||||||
| # Check if keypoints are detected | ||||||||||||
| if keypoints1.shape[0] == 0 or keypoints2.shape[0] == 0: | ||||||||||||
| st.error('No keypoints detected in one of the images.') | ||||||||||||
| return | ||||||||||||
|
|
||||||||||||
| # Match descriptors | ||||||||||||
| matches = match_descriptors( | ||||||||||||
| descriptors1, descriptors2, max_ratio=max_ratio, cross_check=True | ||||||||||||
| ) | ||||||||||||
|
|
||||||||||||
| # Check if matches are found | ||||||||||||
| if len(matches) == 0: | ||||||||||||
| st.error('No matches found between images.') | ||||||||||||
| return | ||||||||||||
|
|
||||||||||||
| # Plot the matched keypoints with image names as titles | ||||||||||||
| fig, ax = plt.subplots(figsize=(10, 8)) | ||||||||||||
| ax.axis('off') | ||||||||||||
|
|
||||||||||||
| if display_in_grayscale: | ||||||||||||
| plot_matches(ax, base_image_gray, target_image_gray, keypoints1, keypoints2, matches) | ||||||||||||
| else: | ||||||||||||
| plot_matches(ax, base_image, target_image, keypoints1, keypoints2, matches) | ||||||||||||
|
|
||||||||||||
| # Adding the image names above the corresponding images | ||||||||||||
| ax.text( | ||||||||||||
| 0.25, | ||||||||||||
| 1.05, | ||||||||||||
| f'{base_image_name}', | ||||||||||||
| ha='center', | ||||||||||||
| va='center', | ||||||||||||
| fontsize=12, | ||||||||||||
| transform=ax.transAxes, | ||||||||||||
| ) | ||||||||||||
| ax.text( | ||||||||||||
| 0.75, | ||||||||||||
| 1.05, | ||||||||||||
| f'{target_image_name}', | ||||||||||||
| ha='center', | ||||||||||||
| va='center', | ||||||||||||
| fontsize=12, | ||||||||||||
| transform=ax.transAxes, | ||||||||||||
| ) | ||||||||||||
|
|
||||||||||||
| st.pyplot(fig) | ||||||||||||
|
|
||||||||||||
|
|
||||||||||||
| def feature_alignment_navigation_widget( | ||||||||||||
| base_image: ImageType, | ||||||||||||
| images: list[ImageType], | ||||||||||||
| method: str, | ||||||||||||
| display_in_grayscale: bool = True, | ||||||||||||
| max_ratio: float = 0.6, | ||||||||||||
| ): | ||||||||||||
| """Widget to navigate and display feature-based alignment between the base | ||||||||||||
| image and other images. | ||||||||||||
|
|
||||||||||||
| Parameters: | ||||||||||||
| - base_image: The base image | ||||||||||||
| - images: List of aligned images to compare with the base image | ||||||||||||
| - method: The feature detection method used (i.e., 'SIFT' or 'ORB') | ||||||||||||
| - display_in_grayscale: Whether to display images in grayscale or original color | ||||||||||||
| - max_ratio: The maximum ratio for descriptor matching | ||||||||||||
| """ | ||||||||||||
|
|
||||||||||||
| if 'image_index' not in st.session_state: | ||||||||||||
| st.session_state.image_index = 0 | ||||||||||||
|
|
||||||||||||
| def display_current_image(): | ||||||||||||
| current_image = images[st.session_state.image_index] | ||||||||||||
| feature_based_alignment_visual( | ||||||||||||
| base_image=base_image.data, | ||||||||||||
| target_image=current_image.data, | ||||||||||||
| method=method, | ||||||||||||
| base_image_name=base_image.name, | ||||||||||||
| target_image_name=current_image.name, | ||||||||||||
| display_in_grayscale=display_in_grayscale, | ||||||||||||
| max_ratio=max_ratio, | ||||||||||||
| ) | ||||||||||||
|
Comment on lines
+599
to
+609
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does this need to be a function? |
||||||||||||
|
|
||||||||||||
| def next_image(): | ||||||||||||
| st.session_state.image_index = (st.session_state.image_index + 1) % len(images) | ||||||||||||
|
|
||||||||||||
| def previous_image(): | ||||||||||||
| st.session_state.image_index = (st.session_state.image_index - 1) % len(images) | ||||||||||||
|
Comment on lines
+611
to
+615
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nice to have: We use this pattern now in several places. Can we refactor/generalize it so that we avoid this code duplication everywhere? |
||||||||||||
|
|
||||||||||||
| # Layout for buttons and image display | ||||||||||||
| col1, col2, col3 = st.columns([1, 6, 1]) | ||||||||||||
|
|
||||||||||||
| with col1: | ||||||||||||
| st.button('⏮️ Previous', on_click=previous_image) | ||||||||||||
| with col3: | ||||||||||||
| st.button('Next ⏭️', on_click=next_image) | ||||||||||||
|
Comment on lines
+620
to
+623
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nice to have: sync these buttons with those from the comparison itself. |
||||||||||||
|
|
||||||||||||
| with col2: | ||||||||||||
| display_current_image() | ||||||||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,9 +1,11 @@ | ||
| from __future__ import annotations | ||
|
|
||
| from typing import Any, List, Tuple | ||
|
|
||
| import imageio.v3 as imageio | ||
| import numpy as np | ||
| import streamlit as st | ||
| from align import align_image_to_base | ||
| from align import align_image_to_base, feature_alignment_navigation_widget | ||
| from equalize import equalize_image_with_base | ||
| from streamlit_image_comparison import image_comparison | ||
| from styling import set_custom_css | ||
|
|
@@ -39,7 +41,9 @@ def equalize_images_widget(*, base_image: np.ndarray, images: dict[str, np.ndarr | |
| ] | ||
|
|
||
|
|
||
| def align_images_widget(*, base_image: ImageType, images: list[ImageType]) -> list[ImageType]: | ||
| def align_images_widget( | ||
| *, base_image: ImageType, images: List[ImageType] | ||
| ) -> Tuple[List[ImageType], Any, Any]: | ||
| """This widget helps with aligning images.""" | ||
| st.subheader('Alignment parameters') | ||
|
|
||
|
|
@@ -117,7 +121,7 @@ def align_images_widget(*, base_image: ImageType, images: list[ImageType]) -> li | |
| ) | ||
| ) | ||
|
|
||
| return res | ||
| return res, align_method, motion_model | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we refactor this in a way that avoids returning a tuple? This makes the code difficult to maintain. |
||
|
|
||
|
|
||
| def alignment_help_widget(): | ||
|
|
@@ -251,7 +255,39 @@ def main(): | |
| images = equalize_images_widget(base_image=base_image, images=images) | ||
|
|
||
| with col2: | ||
| images = align_images_widget(base_image=base_image, images=images) | ||
| images, align_method, motion_model = align_images_widget( | ||
| base_image=base_image, images=images | ||
| ) | ||
|
|
||
| # scikit-image includes SIFT and ORB but not SURF | ||
| if align_method == 'Feature based alignment' and motion_model in ['SIFT', 'ORB']: | ||
| # Add a selection button to allow the user to choose whether to visualize | ||
| # the feature-based alignment | ||
| visualize = col2.checkbox( | ||
| f'Show feature-based alignment visualization using {motion_model}', value=False | ||
| ) | ||
|
|
||
| if visualize: | ||
| # Allow user to choose grayscale or original image display | ||
| display_in_grayscale = ( | ||
| col2.radio('Display images in:', ['Grayscale', 'Original color)']) | ||
| == 'Grayscale' | ||
| ) | ||
| # Slider to select max_ratio | ||
| max_ratio = col2.slider('Max ratio for descriptor matching', 0.5, 0.8, 0.6, 0.01) | ||
|
|
||
| st.write('') | ||
| st.write('') | ||
| st.subheader(f'Feature-based alignment visualization using {motion_model}') | ||
| st.write('') | ||
|
Comment on lines
+272
to
+282
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this should be in the widget. |
||
|
|
||
| feature_alignment_navigation_widget( | ||
| base_image=base_image, | ||
| images=images, | ||
| method=motion_model, | ||
| display_in_grayscale=display_in_grayscale, | ||
| max_ratio=max_ratio, | ||
| ) | ||
|
|
||
| with st.expander('Help for parameters for aligning images', expanded=False): | ||
| alignment_help_widget() | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I suggest protecting this function by requiring keyword arguments: