diff --git a/src/unraphael/dash/align.py b/src/unraphael/dash/align.py index 1d8bc2e..d345585 100644 --- a/src/unraphael/dash/align.py +++ b/src/unraphael/dash/align.py @@ -9,9 +9,12 @@ import cv2 import diplib as dip +import matplotlib.pyplot as plt import numpy as np +import streamlit as st from numpy.fft import fft2, ifft2 from skimage.color import rgb2gray +from skimage.feature import ORB, SIFT, match_descriptors, plot_matches from unraphael.types import ImageType @@ -466,3 +469,158 @@ def align_image_to_base( raise ValueError(f'No such method: {align_method}') return func(image=image, base_image=base_image, **kwargs) + + +def feature_based_alignment_visual( + base_image: np.ndarray, + target_image: np.ndarray, + method: str, + base_image_name: str, + target_image_name: str, + display_in_grayscale: bool = True, + max_ratio: float = 0.6, +): + """Visualize feature-based alignment between two images using a specified + feature detection method. + + Parameters + ---------- + base_image : np.ndarray + The base image to be used as a reference for alignment + target_image : np.ndarray + The target image to be aligned with the base image + method : str + The feature detection method to use for alignment ('SIFT' or 'ORB') + base_image_name : str + Name of the base image for labeling + target_image_name : str + Name of the target image for labeling + display_in_grayscale : bool, optional + Whether to display images in grayscale or original color, by default True + max_ratio : float, optional + The maximum ratio for descriptor matching, by default 0.6 + + Returns + ------- + None + Displays a plot showing the feature-based alignment + """ + + # Initialize the feature descriptor based on the method + if method == 'SIFT': + descriptor_extractor = SIFT() + elif method == 'ORB': + descriptor_extractor = ORB() + else: + st.error('Unsupported method selected.') + return + + # Convert images to grayscale if not already grayscale + base_image_gray = rgb2gray(base_image) if base_image.ndim == 3 else base_image + target_image_gray = rgb2gray(target_image) if target_image.ndim == 3 else target_image + + # Extract keypoints and descriptors for both images + descriptor_extractor.detect_and_extract(base_image_gray) + keypoints1 = descriptor_extractor.keypoints + descriptors1 = descriptor_extractor.descriptors + + descriptor_extractor.detect_and_extract(target_image_gray) + keypoints2 = descriptor_extractor.keypoints + descriptors2 = descriptor_extractor.descriptors + + # Check if keypoints are detected + if keypoints1.shape[0] == 0 or keypoints2.shape[0] == 0: + st.error('No keypoints detected in one of the images.') + return + + # Match descriptors + matches = match_descriptors( + descriptors1, descriptors2, max_ratio=max_ratio, cross_check=True + ) + + # Check if matches are found + if len(matches) == 0: + st.error('No matches found between images.') + return + + # Plot the matched keypoints with image names as titles + fig, ax = plt.subplots(figsize=(10, 8)) + ax.axis('off') + + if display_in_grayscale: + plot_matches(ax, base_image_gray, target_image_gray, keypoints1, keypoints2, matches) + else: + plot_matches(ax, base_image, target_image, keypoints1, keypoints2, matches) + + # Adding the image names above the corresponding images + ax.text( + 0.25, + 1.05, + f'{base_image_name}', + ha='center', + va='center', + fontsize=12, + transform=ax.transAxes, + ) + ax.text( + 0.75, + 1.05, + f'{target_image_name}', + ha='center', + va='center', + fontsize=12, + transform=ax.transAxes, + ) + + st.pyplot(fig) + + +def feature_alignment_navigation_widget( + base_image: ImageType, + images: list[ImageType], + method: str, + display_in_grayscale: bool = True, + max_ratio: float = 0.6, +): + """Widget to navigate and display feature-based alignment between the base + image and other images. + + Parameters: + - base_image: The base image + - images: List of aligned images to compare with the base image + - method: The feature detection method used (i.e., 'SIFT' or 'ORB') + - display_in_grayscale: Whether to display images in grayscale or original color + - max_ratio: The maximum ratio for descriptor matching + """ + + if 'image_index' not in st.session_state: + st.session_state.image_index = 0 + + def display_current_image(): + current_image = images[st.session_state.image_index] + feature_based_alignment_visual( + base_image=base_image.data, + target_image=current_image.data, + method=method, + base_image_name=base_image.name, + target_image_name=current_image.name, + display_in_grayscale=display_in_grayscale, + max_ratio=max_ratio, + ) + + def next_image(): + st.session_state.image_index = (st.session_state.image_index + 1) % len(images) + + def previous_image(): + st.session_state.image_index = (st.session_state.image_index - 1) % len(images) + + # Layout for buttons and image display + col1, col2, col3 = st.columns([1, 6, 1]) + + with col1: + st.button('⏮️ Previous', on_click=previous_image) + with col3: + st.button('Next ⏭️', on_click=next_image) + + with col2: + display_current_image() diff --git a/src/unraphael/dash/pages/4_compare.py b/src/unraphael/dash/pages/4_compare.py index e72d6db..a28b65f 100644 --- a/src/unraphael/dash/pages/4_compare.py +++ b/src/unraphael/dash/pages/4_compare.py @@ -1,9 +1,11 @@ from __future__ import annotations +from typing import Any, List, Tuple + import imageio.v3 as imageio import numpy as np import streamlit as st -from align import align_image_to_base +from align import align_image_to_base, feature_alignment_navigation_widget from equalize import equalize_image_with_base from streamlit_image_comparison import image_comparison from styling import set_custom_css @@ -39,7 +41,9 @@ def equalize_images_widget(*, base_image: np.ndarray, images: dict[str, np.ndarr ] -def align_images_widget(*, base_image: ImageType, images: list[ImageType]) -> list[ImageType]: +def align_images_widget( + *, base_image: ImageType, images: List[ImageType] +) -> Tuple[List[ImageType], Any, Any]: """This widget helps with aligning images.""" st.subheader('Alignment parameters') @@ -117,7 +121,7 @@ def align_images_widget(*, base_image: ImageType, images: list[ImageType]) -> li ) ) - return res + return res, align_method, motion_model def alignment_help_widget(): @@ -251,7 +255,39 @@ def main(): images = equalize_images_widget(base_image=base_image, images=images) with col2: - images = align_images_widget(base_image=base_image, images=images) + images, align_method, motion_model = align_images_widget( + base_image=base_image, images=images + ) + + # scikit-image includes SIFT and ORB but not SURF + if align_method == 'Feature based alignment' and motion_model in ['SIFT', 'ORB']: + # Add a selection button to allow the user to choose whether to visualize + # the feature-based alignment + visualize = col2.checkbox( + f'Show feature-based alignment visualization using {motion_model}', value=False + ) + + if visualize: + # Allow user to choose grayscale or original image display + display_in_grayscale = ( + col2.radio('Display images in:', ['Grayscale', 'Original color)']) + == 'Grayscale' + ) + # Slider to select max_ratio + max_ratio = col2.slider('Max ratio for descriptor matching', 0.5, 0.8, 0.6, 0.01) + + st.write('') + st.write('') + st.subheader(f'Feature-based alignment visualization using {motion_model}') + st.write('') + + feature_alignment_navigation_widget( + base_image=base_image, + images=images, + method=motion_model, + display_in_grayscale=display_in_grayscale, + max_ratio=max_ratio, + ) with st.expander('Help for parameters for aligning images', expanded=False): alignment_help_widget()