Last Updated: 2025-06-19
- Overview
- Interactive Segmentation Paradigms
- State-of-the-Art Methods
- Implementation Guide
- Best Practices
- Evaluation Metrics
- Applications
- Future Directions
Interactive image segmentation allows users to segment objects with minimal interaction, typically through clicks, scribbles, or bounding boxes. This field bridges the gap between fully automatic and manual segmentation, providing accurate results with minimal user effort.
- Precision: User guidance ensures accurate segmentation
- Efficiency: Reduces manual annotation time by 90%+
- Flexibility: Adapts to various object types and scenarios
- Real-time feedback: Iterative refinement based on user input
Users provide positive/negative clicks to indicate object/background regions.
import torch
import numpy as np
from interactive_seg import ClickBasedSegmenter
class InteractiveClickSegmenter:
def __init__(self, model_path):
self.model = self.load_model(model_path)
self.click_history = []
def add_click(self, x, y, is_positive=True):
"""Add user click to the interaction history"""
click = {
'coords': (x, y),
'type': 'positive' if is_positive else 'negative',
'timestamp': time.time()
}
self.click_history.append(click)
def segment(self, image, clicks):
"""Perform segmentation based on clicks"""
# Encode clicks as distance maps
click_maps = self.encode_clicks(image.shape[:2], clicks)
# Concatenate image and click maps
inputs = torch.cat([image, click_maps], dim=0)
# Forward pass
with torch.no_grad():
mask = self.model(inputs.unsqueeze(0))
return torch.sigmoid(mask).squeeze().numpy()Users draw rough scribbles to indicate regions.
class ScribbleSegmenter:
def __init__(self, backbone='resnet50'):
self.encoder = self.build_encoder(backbone)
self.decoder = self.build_decoder()
def process_scribbles(self, scribbles):
"""Convert scribbles to distance transforms"""
fg_scribble, bg_scribble = scribbles['foreground'], scribbles['background']
# Distance transform
fg_dist = cv2.distanceTransform(fg_scribble, cv2.DIST_L2, 5)
bg_dist = cv2.distanceTransform(bg_scribble, cv2.DIST_L2, 5)
# Normalize distances
fg_dist = (fg_dist - fg_dist.min()) / (fg_dist.max() - fg_dist.min() + 1e-8)
bg_dist = (bg_dist - bg_dist.min()) / (bg_dist.max() - bg_dist.min() + 1e-8)
return np.stack([fg_dist, bg_dist], axis=-1)Users provide bounding boxes around objects.
class BBoxSegmenter:
def __init__(self):
self.grabcut = cv2.createBackgroundSubtractorMOG2()
def segment_with_bbox(self, image, bbox):
"""Segment using bounding box initialization"""
x, y, w, h = bbox
mask = np.zeros(image.shape[:2], np.uint8)
# Initialize GrabCut
bgd_model = np.zeros((1, 65), np.float64)
fgd_model = np.zeros((1, 65), np.float64)
# Apply GrabCut
cv2.grabCut(image, mask, (x, y, w, h),
bgd_model, fgd_model, 5, cv2.GC_INIT_WITH_RECT)
# Extract foreground
mask2 = np.where((mask == 2) | (mask == 0), 0, 1).astype('uint8')
return mask2Meta's foundation model for interactive segmentation.
from segment_anything import SamPredictor, sam_model_registry
class SAMInteractive:
def __init__(self, checkpoint="sam_vit_h_4b8939.pth"):
self.sam = sam_model_registry["vit_h"](checkpoint=checkpoint)
self.sam.to(device='cuda')
self.predictor = SamPredictor(self.sam)
def set_image(self, image):
"""Precompute image embeddings"""
self.predictor.set_image(image)
def predict_with_clicks(self, point_coords, point_labels):
"""Segment with click coordinates"""
masks, scores, logits = self.predictor.predict(
point_coords=np.array(point_coords),
point_labels=np.array(point_labels),
multimask_output=True,
)
# Select best mask
best_idx = np.argmax(scores)
return masks[best_idx], scores[best_idx]Feature Backpropagating Refinement Scheme for improved accuracy.
class FBRS:
def __init__(self, model, device='cuda'):
self.model = model.to(device)
self.device = device
self.optimization_steps = 10
def refine_with_backprop(self, image, initial_mask, clicks):
"""Refine segmentation using feature backpropagation"""
image_tensor = self.preprocess(image)
# Initialize mask
mask = initial_mask.clone().requires_grad_(True)
optimizer = torch.optim.Adam([mask], lr=0.1)
for step in range(self.optimization_steps):
# Forward pass
features = self.model.encoder(image_tensor)
refined_mask = self.model.decoder(features, mask)
# Compute loss based on clicks
click_loss = self.compute_click_loss(refined_mask, clicks)
smoothness_loss = self.compute_smoothness_loss(refined_mask)
total_loss = click_loss + 0.1 * smoothness_loss
# Backward pass
optimizer.zero_grad()
total_loss.backward()
optimizer.step()
return refined_mask.detach()Efficient iterative refinement approach.
class RITM:
def __init__(self, backbone='hrnet32'):
self.backbone = self.build_backbone(backbone)
self.head = self.build_segmentation_head()
def iterative_training(self, image, gt_mask, max_iters=20):
"""Train with simulated interactions"""
current_mask = torch.zeros_like(gt_mask)
for iteration in range(max_iters):
# Simulate user click
click = self.simulate_click(current_mask, gt_mask)
# Update click encoding
click_map = self.encode_click(click, image.shape)
# Predict new mask
inputs = torch.cat([image, current_mask, click_map], dim=1)
current_mask = self.forward(inputs)
# Check convergence
iou = self.compute_iou(current_mask, gt_mask)
if iou > 0.95:
break
return current_maskimport torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
import cv2
import numpy as np
class InteractiveSegmentationPipeline:
def __init__(self, model_type='sam'):
self.model = self.load_model(model_type)
self.transform = self.get_transforms()
self.interaction_history = []
def load_model(self, model_type):
"""Load pre-trained model"""
if model_type == 'sam':
from segment_anything import sam_model_registry
model = sam_model_registry["vit_b"](checkpoint="sam_vit_b.pth")
elif model_type == 'ritm':
model = RITM(backbone='hrnet32')
model.load_state_dict(torch.load('ritm_hrnet32.pth'))
else:
raise ValueError(f"Unknown model type: {model_type}")
return model.eval().cuda()
def get_transforms(self):
"""Define image preprocessing"""
return transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
def encode_interactions(self, interactions, image_shape):
"""Encode user interactions as input channels"""
h, w = image_shape[:2]
pos_clicks = np.zeros((h, w), dtype=np.float32)
neg_clicks = np.zeros((h, w), dtype=np.float32)
for interaction in interactions:
x, y = interaction['coords']
sigma = 10 # Gaussian sigma
# Create Gaussian centered at click
y_grid, x_grid = np.ogrid[:h, :w]
gaussian = np.exp(-((x_grid - x)**2 + (y_grid - y)**2) / (2 * sigma**2))
if interaction['type'] == 'positive':
pos_clicks = np.maximum(pos_clicks, gaussian)
else:
neg_clicks = np.maximum(neg_clicks, gaussian)
return np.stack([pos_clicks, neg_clicks], axis=-1)
def segment_interactive(self, image, interactions):
"""Main segmentation function"""
# Preprocess image
image_tensor = self.transform(image).unsqueeze(0).cuda()
# Encode interactions
interaction_maps = self.encode_interactions(interactions, image.shape)
interaction_tensor = torch.from_numpy(interaction_maps).permute(2, 0, 1)
interaction_tensor = interaction_tensor.unsqueeze(0).cuda()
# Concatenate inputs
inputs = torch.cat([image_tensor, interaction_tensor], dim=1)
# Forward pass
with torch.no_grad():
logits = self.model(inputs)
mask = torch.sigmoid(logits) > 0.5
return mask.squeeze().cpu().numpy()
def refine_segmentation(self, image, mask, new_interaction):
"""Refine existing segmentation with new interaction"""
self.interaction_history.append(new_interaction)
# Re-segment with all interactions
refined_mask = self.segment_interactive(image, self.interaction_history)
# Post-processing
refined_mask = self.post_process(refined_mask)
return refined_mask
def post_process(self, mask):
"""Apply morphological operations"""
# Remove small components
kernel = np.ones((5, 5), np.uint8)
mask = cv2.morphologyEx(mask.astype(np.uint8), cv2.MORPH_CLOSE, kernel)
mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel)
# Find largest connected component
num_labels, labels = cv2.connectedComponents(mask)
if num_labels > 1:
largest_component = 1 + np.argmax(
[np.sum(labels == i) for i in range(1, num_labels)]
)
mask = (labels == largest_component).astype(np.uint8)
return maskclass AdvancedInteractiveSegmentation:
def __init__(self):
self.model = self.build_model()
self.feature_extractor = self.build_feature_extractor()
def build_model(self):
"""Build segmentation model with attention mechanism"""
return nn.Sequential(
# Encoder
nn.Conv2d(5, 64, 3, padding=1), # RGB + 2 interaction channels
nn.ReLU(),
nn.Conv2d(64, 128, 3, padding=1),
nn.ReLU(),
# Attention module
SpatialAttention(128),
# Decoder
nn.Conv2d(128, 64, 3, padding=1),
nn.ReLU(),
nn.Conv2d(64, 1, 1),
)
def multi_scale_inference(self, image, interactions):
"""Perform multi-scale segmentation"""
scales = [0.5, 1.0, 1.5]
predictions = []
for scale in scales:
# Resize image
h, w = image.shape[:2]
new_h, new_w = int(h * scale), int(w * scale)
scaled_image = cv2.resize(image, (new_w, new_h))
# Scale interactions
scaled_interactions = self.scale_interactions(interactions, scale)
# Segment
mask = self.segment_interactive(scaled_image, scaled_interactions)
# Resize back
mask = cv2.resize(mask.astype(np.float32), (w, h))
predictions.append(mask)
# Ensemble predictions
final_mask = np.mean(predictions, axis=0) > 0.5
return final_mask
def uncertainty_estimation(self, image, interactions, n_samples=10):
"""Estimate segmentation uncertainty using Monte Carlo dropout"""
self.model.train() # Enable dropout
predictions = []
for _ in range(n_samples):
mask = self.segment_interactive(image, interactions)
predictions.append(mask)
predictions = np.stack(predictions)
# Compute uncertainty metrics
mean_pred = np.mean(predictions, axis=0)
uncertainty = np.var(predictions, axis=0)
return mean_pred, uncertaintyclass InteractionDesign:
@staticmethod
def suggest_next_click(mask, uncertainty_map):
"""Suggest optimal location for next user click"""
# Find high uncertainty regions near boundaries
edges = cv2.Canny((mask * 255).astype(np.uint8), 100, 200)
uncertain_edges = edges * uncertainty_map
# Find peak uncertainty
max_loc = np.unravel_index(uncertain_edges.argmax(), uncertain_edges.shape)
return max_loc
@staticmethod
def validate_interaction(interaction, image_shape):
"""Validate user interaction"""
x, y = interaction['coords']
h, w = image_shape[:2]
if not (0 <= x < w and 0 <= y < h):
raise ValueError("Click coordinates out of bounds")
return Trueclass PerformanceOptimizer:
def __init__(self):
self.cache = {}
def cache_embeddings(self, image_id, embeddings):
"""Cache image embeddings for faster inference"""
self.cache[image_id] = {
'embeddings': embeddings,
'timestamp': time.time()
}
def optimize_model(self, model):
"""Apply optimization techniques"""
# Quantization
quantized_model = torch.quantization.quantize_dynamic(
model, {nn.Linear, nn.Conv2d}, dtype=torch.qint8
)
# TorchScript compilation
scripted_model = torch.jit.script(quantized_model)
return scripted_modelclass QualityAssurance:
@staticmethod
def validate_segmentation(mask, min_area=100):
"""Validate segmentation quality"""
# Check minimum area
if np.sum(mask) < min_area:
return False, "Segmented area too small"
# Check connectivity
num_components = cv2.connectedComponents(mask.astype(np.uint8))[0]
if num_components > 2: # Background + 1 object
return False, "Multiple disconnected components"
# Check boundary smoothness
perimeter = cv2.arcLength(
cv2.findContours(mask.astype(np.uint8),
cv2.RETR_EXTERNAL,
cv2.CHAIN_APPROX_SIMPLE)[0][0],
True
)
area = np.sum(mask)
compactness = (perimeter ** 2) / (4 * np.pi * area)
if compactness > 2.0:
return False, "Boundary too irregular"
return True, "Valid segmentation"class SegmentationMetrics:
@staticmethod
def iou(pred_mask, gt_mask):
"""Intersection over Union"""
intersection = np.logical_and(pred_mask, gt_mask).sum()
union = np.logical_or(pred_mask, gt_mask).sum()
return intersection / (union + 1e-8)
@staticmethod
def dice_coefficient(pred_mask, gt_mask):
"""Dice coefficient"""
intersection = 2.0 * np.logical_and(pred_mask, gt_mask).sum()
return intersection / (pred_mask.sum() + gt_mask.sum() + 1e-8)
@staticmethod
def boundary_iou(pred_mask, gt_mask, dilation_ratio=0.02):
"""Boundary IoU metric"""
h, w = pred_mask.shape
dilation = int(round(dilation_ratio * np.sqrt(h * w)))
# Get boundaries
pred_boundary = self.get_boundary(pred_mask, dilation)
gt_boundary = self.get_boundary(gt_mask, dilation)
return self.iou(pred_boundary, gt_boundary)
@staticmethod
def number_of_clicks(clicks_history, target_iou=0.85):
"""Number of clicks to reach target IoU"""
for i, (mask, iou) in enumerate(clicks_history):
if iou >= target_iou:
return i + 1
return len(clicks_history)class MedicalSegmentation:
def __init__(self):
self.model = self.load_medical_model()
def segment_tumor(self, mri_scan, radiologist_clicks):
"""Segment tumor with radiologist guidance"""
# Preprocess medical image
normalized_scan = self.normalize_hu_values(mri_scan)
# Multi-slice processing
segmented_slices = []
for slice_idx, slice_2d in enumerate(normalized_scan):
slice_clicks = radiologist_clicks.get(slice_idx, [])
mask = self.segment_interactive(slice_2d, slice_clicks)
segmented_slices.append(mask)
# 3D reconstruction
tumor_volume = np.stack(segmented_slices)
return tumor_volumeclass VideoInteractiveSegmentation:
def __init__(self):
self.propagation_model = self.build_propagation_model()
def segment_video_object(self, video_frames, first_frame_clicks):
"""Segment object across video frames"""
masks = []
# Segment first frame
first_mask = self.segment_interactive(video_frames[0], first_frame_clicks)
masks.append(first_mask)
# Propagate to subsequent frames
for i in range(1, len(video_frames)):
propagated_mask = self.propagate_mask(
video_frames[i-1], video_frames[i], masks[-1]
)
masks.append(propagated_mask)
return masksclass ARSegmentation:
def __init__(self):
self.real_time_model = self.load_lightweight_model()
def segment_for_ar(self, frame, touch_points):
"""Real-time segmentation for AR applications"""
# Convert touch points to clicks
clicks = self.touch_to_clicks(touch_points)
# Fast segmentation
mask = self.real_time_model.segment(frame, clicks)
# Apply AR effects
ar_frame = self.apply_ar_effect(frame, mask)
return ar_frame, maskclass FewShotInteractiveSegmentation:
def __init__(self):
self.meta_learner = self.build_meta_learner()
def adapt_to_new_class(self, support_images, support_masks, support_clicks):
"""Adapt model to new object class with few examples"""
# Extract class prototype
class_prototype = self.extract_prototype(
support_images, support_masks, support_clicks
)
# Fine-tune model
adapted_model = self.meta_learner.adapt(class_prototype)
return adapted_modelclass LanguageGuidedSegmentation:
def __init__(self):
self.vision_language_model = self.load_clip_model()
def segment_with_text(self, image, text_description, optional_clicks=None):
"""Segment objects using natural language description"""
# Encode text
text_features = self.encode_text(text_description)
# Generate attention map
attention_map = self.compute_attention(image, text_features)
# Combine with clicks if provided
if optional_clicks:
click_map = self.encode_clicks(optional_clicks)
attention_map = self.fuse_guidance(attention_map, click_map)
# Segment
mask = self.segment_with_attention(image, attention_map)
return maskclass SelfSupervisedPretraining:
def __init__(self):
self.pretext_model = self.build_pretext_model()
def pretrain_with_synthetic_clicks(self, unlabeled_images):
"""Pre-train using synthetic interactions"""
for image in unlabeled_images:
# Generate pseudo-mask
pseudo_mask = self.generate_pseudo_mask(image)
# Simulate user interactions
synthetic_clicks = self.simulate_clicks(pseudo_mask)
# Self-supervised training
predicted_mask = self.pretext_model(image, synthetic_clicks)
loss = self.compute_consistency_loss(predicted_mask, pseudo_mask)
# Update model
loss.backward()
self.optimizer.step()Interactive segmentation represents a crucial bridge between fully automatic and manual annotation methods. By leveraging state-of-the-art deep learning techniques with efficient user interaction paradigms, we can achieve high-quality segmentation results with minimal user effort. The field continues to evolve with advances in foundation models, few-shot learning, and multi-modal interaction methods.
Originally from umitkacar/Interactive-Image-Segmentation repository