Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
158 changes: 158 additions & 0 deletions src/unraphael/dash/align.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,12 @@

import cv2
import diplib as dip
import matplotlib.pyplot as plt
import numpy as np
import streamlit as st
from numpy.fft import fft2, ifft2
from skimage.color import rgb2gray
from skimage.feature import ORB, SIFT, match_descriptors, plot_matches

from unraphael.types import ImageType

Expand Down Expand Up @@ -466,3 +469,158 @@ def align_image_to_base(
raise ValueError(f'No such method: {align_method}')

return func(image=image, base_image=base_image, **kwargs)


def feature_based_alignment_visual(
base_image: np.ndarray,
Comment on lines +474 to +475

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suggest protecting this function by requiring keyword arguments:

Suggested change
def feature_based_alignment_visual(
base_image: np.ndarray,
def feature_based_alignment_visual(
*,
base_image: np.ndarray,

target_image: np.ndarray,
method: str,
base_image_name: str,
target_image_name: str,
display_in_grayscale: bool = True,
max_ratio: float = 0.6,
):
"""Visualize feature-based alignment between two images using a specified
feature detection method.

Parameters
----------
base_image : np.ndarray
The base image to be used as a reference for alignment
target_image : np.ndarray
The target image to be aligned with the base image
method : str
The feature detection method to use for alignment ('SIFT' or 'ORB')
base_image_name : str
Name of the base image for labeling
target_image_name : str
Name of the target image for labeling
display_in_grayscale : bool, optional
Whether to display images in grayscale or original color, by default True
max_ratio : float, optional
The maximum ratio for descriptor matching, by default 0.6

Returns
-------
None
Displays a plot showing the feature-based alignment
Comment on lines +502 to +506

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not necessary, but please update the function name/documentation that this will plot the aligment

Suggested change
Returns
-------
None
Displays a plot showing the feature-based alignment

"""

# Initialize the feature descriptor based on the method
if method == 'SIFT':
descriptor_extractor = SIFT()
elif method == 'ORB':
descriptor_extractor = ORB()
else:
st.error('Unsupported method selected.')
return
Comment on lines +514 to +516

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider raising a ValueError instead of silently failing.


# Convert images to grayscale if not already grayscale
base_image_gray = rgb2gray(base_image) if base_image.ndim == 3 else base_image
target_image_gray = rgb2gray(target_image) if target_image.ndim == 3 else target_image

# Extract keypoints and descriptors for both images
descriptor_extractor.detect_and_extract(base_image_gray)
keypoints1 = descriptor_extractor.keypoints
descriptors1 = descriptor_extractor.descriptors

descriptor_extractor.detect_and_extract(target_image_gray)
keypoints2 = descriptor_extractor.keypoints
descriptors2 = descriptor_extractor.descriptors

# Check if keypoints are detected
if keypoints1.shape[0] == 0 or keypoints2.shape[0] == 0:
st.error('No keypoints detected in one of the images.')
return

# Match descriptors
matches = match_descriptors(
descriptors1, descriptors2, max_ratio=max_ratio, cross_check=True
)

# Check if matches are found
if len(matches) == 0:
st.error('No matches found between images.')
return

# Plot the matched keypoints with image names as titles
fig, ax = plt.subplots(figsize=(10, 8))
ax.axis('off')

if display_in_grayscale:
plot_matches(ax, base_image_gray, target_image_gray, keypoints1, keypoints2, matches)
else:
plot_matches(ax, base_image, target_image, keypoints1, keypoints2, matches)

# Adding the image names above the corresponding images
ax.text(
0.25,
1.05,
f'{base_image_name}',
ha='center',
va='center',
fontsize=12,
transform=ax.transAxes,
)
ax.text(
0.75,
1.05,
f'{target_image_name}',
ha='center',
va='center',
fontsize=12,
transform=ax.transAxes,
)

st.pyplot(fig)


def feature_alignment_navigation_widget(
base_image: ImageType,
images: list[ImageType],
method: str,
display_in_grayscale: bool = True,
max_ratio: float = 0.6,
):
"""Widget to navigate and display feature-based alignment between the base
image and other images.

Parameters:
- base_image: The base image
- images: List of aligned images to compare with the base image
- method: The feature detection method used (i.e., 'SIFT' or 'ORB')
- display_in_grayscale: Whether to display images in grayscale or original color
- max_ratio: The maximum ratio for descriptor matching
"""

if 'image_index' not in st.session_state:
st.session_state.image_index = 0

def display_current_image():
current_image = images[st.session_state.image_index]
feature_based_alignment_visual(
base_image=base_image.data,
target_image=current_image.data,
method=method,
base_image_name=base_image.name,
target_image_name=current_image.name,
display_in_grayscale=display_in_grayscale,
max_ratio=max_ratio,
)
Comment on lines +599 to +609

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this need to be a function?


def next_image():
st.session_state.image_index = (st.session_state.image_index + 1) % len(images)

def previous_image():
st.session_state.image_index = (st.session_state.image_index - 1) % len(images)
Comment on lines +611 to +615

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice to have: We use this pattern now in several places. Can we refactor/generalize it so that we avoid this code duplication everywhere?


# Layout for buttons and image display
col1, col2, col3 = st.columns([1, 6, 1])

with col1:
st.button('⏮️ Previous', on_click=previous_image)
with col3:
st.button('Next ⏭️', on_click=next_image)
Comment on lines +620 to +623

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice to have: sync these buttons with those from the comparison itself.


with col2:
display_current_image()
44 changes: 40 additions & 4 deletions src/unraphael/dash/pages/4_compare.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from __future__ import annotations

from typing import Any, List, Tuple

import imageio.v3 as imageio
import numpy as np
import streamlit as st
from align import align_image_to_base
from align import align_image_to_base, feature_alignment_navigation_widget
from equalize import equalize_image_with_base
from streamlit_image_comparison import image_comparison
from styling import set_custom_css
Expand Down Expand Up @@ -39,7 +41,9 @@ def equalize_images_widget(*, base_image: np.ndarray, images: dict[str, np.ndarr
]


def align_images_widget(*, base_image: ImageType, images: list[ImageType]) -> list[ImageType]:
def align_images_widget(
*, base_image: ImageType, images: List[ImageType]
) -> Tuple[List[ImageType], Any, Any]:
"""This widget helps with aligning images."""
st.subheader('Alignment parameters')

Expand Down Expand Up @@ -117,7 +121,7 @@ def align_images_widget(*, base_image: ImageType, images: list[ImageType]) -> li
)
)

return res
return res, align_method, motion_model

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we refactor this in a way that avoids returning a tuple? This makes the code difficult to maintain.



def alignment_help_widget():
Expand Down Expand Up @@ -251,7 +255,39 @@ def main():
images = equalize_images_widget(base_image=base_image, images=images)

with col2:
images = align_images_widget(base_image=base_image, images=images)
images, align_method, motion_model = align_images_widget(
base_image=base_image, images=images
)

# scikit-image includes SIFT and ORB but not SURF
if align_method == 'Feature based alignment' and motion_model in ['SIFT', 'ORB']:
# Add a selection button to allow the user to choose whether to visualize
# the feature-based alignment
visualize = col2.checkbox(
f'Show feature-based alignment visualization using {motion_model}', value=False
)

if visualize:
# Allow user to choose grayscale or original image display
display_in_grayscale = (
col2.radio('Display images in:', ['Grayscale', 'Original color)'])
== 'Grayscale'
)
# Slider to select max_ratio
max_ratio = col2.slider('Max ratio for descriptor matching', 0.5, 0.8, 0.6, 0.01)

st.write('')
st.write('')
st.subheader(f'Feature-based alignment visualization using {motion_model}')
st.write('')
Comment on lines +272 to +282

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this should be in the widget.


feature_alignment_navigation_widget(
base_image=base_image,
images=images,
method=motion_model,
display_in_grayscale=display_in_grayscale,
max_ratio=max_ratio,
)

with st.expander('Help for parameters for aligning images', expanded=False):
alignment_help_widget()
Expand Down