Skip to content

Latest commit

 

History

History
661 lines (513 loc) · 22 KB

File metadata and controls

661 lines (513 loc) · 22 KB

Interactive Image Segmentation: State-of-the-Art Methods and Implementations

Last Updated: 2025-06-19

Table of Contents

Overview

Interactive image segmentation allows users to segment objects with minimal interaction, typically through clicks, scribbles, or bounding boxes. This field bridges the gap between fully automatic and manual segmentation, providing accurate results with minimal user effort.

Key Advantages

  • Precision: User guidance ensures accurate segmentation
  • Efficiency: Reduces manual annotation time by 90%+
  • Flexibility: Adapts to various object types and scenarios
  • Real-time feedback: Iterative refinement based on user input

Interactive Segmentation Paradigms

1. Click-based Methods

Users provide positive/negative clicks to indicate object/background regions.

import torch
import numpy as np
from interactive_seg import ClickBasedSegmenter

class InteractiveClickSegmenter:
    def __init__(self, model_path):
        self.model = self.load_model(model_path)
        self.click_history = []
        
    def add_click(self, x, y, is_positive=True):
        """Add user click to the interaction history"""
        click = {
            'coords': (x, y),
            'type': 'positive' if is_positive else 'negative',
            'timestamp': time.time()
        }
        self.click_history.append(click)
        
    def segment(self, image, clicks):
        """Perform segmentation based on clicks"""
        # Encode clicks as distance maps
        click_maps = self.encode_clicks(image.shape[:2], clicks)
        
        # Concatenate image and click maps
        inputs = torch.cat([image, click_maps], dim=0)
        
        # Forward pass
        with torch.no_grad():
            mask = self.model(inputs.unsqueeze(0))
            
        return torch.sigmoid(mask).squeeze().numpy()

2. Scribble-based Methods

Users draw rough scribbles to indicate regions.

class ScribbleSegmenter:
    def __init__(self, backbone='resnet50'):
        self.encoder = self.build_encoder(backbone)
        self.decoder = self.build_decoder()
        
    def process_scribbles(self, scribbles):
        """Convert scribbles to distance transforms"""
        fg_scribble, bg_scribble = scribbles['foreground'], scribbles['background']
        
        # Distance transform
        fg_dist = cv2.distanceTransform(fg_scribble, cv2.DIST_L2, 5)
        bg_dist = cv2.distanceTransform(bg_scribble, cv2.DIST_L2, 5)
        
        # Normalize distances
        fg_dist = (fg_dist - fg_dist.min()) / (fg_dist.max() - fg_dist.min() + 1e-8)
        bg_dist = (bg_dist - bg_dist.min()) / (bg_dist.max() - bg_dist.min() + 1e-8)
        
        return np.stack([fg_dist, bg_dist], axis=-1)

3. Bounding Box Methods

Users provide bounding boxes around objects.

class BBoxSegmenter:
    def __init__(self):
        self.grabcut = cv2.createBackgroundSubtractorMOG2()
        
    def segment_with_bbox(self, image, bbox):
        """Segment using bounding box initialization"""
        x, y, w, h = bbox
        mask = np.zeros(image.shape[:2], np.uint8)
        
        # Initialize GrabCut
        bgd_model = np.zeros((1, 65), np.float64)
        fgd_model = np.zeros((1, 65), np.float64)
        
        # Apply GrabCut
        cv2.grabCut(image, mask, (x, y, w, h), 
                    bgd_model, fgd_model, 5, cv2.GC_INIT_WITH_RECT)
        
        # Extract foreground
        mask2 = np.where((mask == 2) | (mask == 0), 0, 1).astype('uint8')
        return mask2

State-of-the-Art Methods

1. Segment Anything Model (SAM)

Meta's foundation model for interactive segmentation.

from segment_anything import SamPredictor, sam_model_registry

class SAMInteractive:
    def __init__(self, checkpoint="sam_vit_h_4b8939.pth"):
        self.sam = sam_model_registry["vit_h"](checkpoint=checkpoint)
        self.sam.to(device='cuda')
        self.predictor = SamPredictor(self.sam)
        
    def set_image(self, image):
        """Precompute image embeddings"""
        self.predictor.set_image(image)
        
    def predict_with_clicks(self, point_coords, point_labels):
        """Segment with click coordinates"""
        masks, scores, logits = self.predictor.predict(
            point_coords=np.array(point_coords),
            point_labels=np.array(point_labels),
            multimask_output=True,
        )
        
        # Select best mask
        best_idx = np.argmax(scores)
        return masks[best_idx], scores[best_idx]

2. Interactive Deep Learning (f-BRS)

Feature Backpropagating Refinement Scheme for improved accuracy.

class FBRS:
    def __init__(self, model, device='cuda'):
        self.model = model.to(device)
        self.device = device
        self.optimization_steps = 10
        
    def refine_with_backprop(self, image, initial_mask, clicks):
        """Refine segmentation using feature backpropagation"""
        image_tensor = self.preprocess(image)
        
        # Initialize mask
        mask = initial_mask.clone().requires_grad_(True)
        optimizer = torch.optim.Adam([mask], lr=0.1)
        
        for step in range(self.optimization_steps):
            # Forward pass
            features = self.model.encoder(image_tensor)
            refined_mask = self.model.decoder(features, mask)
            
            # Compute loss based on clicks
            click_loss = self.compute_click_loss(refined_mask, clicks)
            smoothness_loss = self.compute_smoothness_loss(refined_mask)
            
            total_loss = click_loss + 0.1 * smoothness_loss
            
            # Backward pass
            optimizer.zero_grad()
            total_loss.backward()
            optimizer.step()
            
        return refined_mask.detach()

3. RITM (Reviving Iterative Training with Mask Guidance)

Efficient iterative refinement approach.

class RITM:
    def __init__(self, backbone='hrnet32'):
        self.backbone = self.build_backbone(backbone)
        self.head = self.build_segmentation_head()
        
    def iterative_training(self, image, gt_mask, max_iters=20):
        """Train with simulated interactions"""
        current_mask = torch.zeros_like(gt_mask)
        
        for iteration in range(max_iters):
            # Simulate user click
            click = self.simulate_click(current_mask, gt_mask)
            
            # Update click encoding
            click_map = self.encode_click(click, image.shape)
            
            # Predict new mask
            inputs = torch.cat([image, current_mask, click_map], dim=1)
            current_mask = self.forward(inputs)
            
            # Check convergence
            iou = self.compute_iou(current_mask, gt_mask)
            if iou > 0.95:
                break
                
        return current_mask

Implementation Guide

Complete Interactive Segmentation Pipeline

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
import cv2
import numpy as np

class InteractiveSegmentationPipeline:
    def __init__(self, model_type='sam'):
        self.model = self.load_model(model_type)
        self.transform = self.get_transforms()
        self.interaction_history = []
        
    def load_model(self, model_type):
        """Load pre-trained model"""
        if model_type == 'sam':
            from segment_anything import sam_model_registry
            model = sam_model_registry["vit_b"](checkpoint="sam_vit_b.pth")
        elif model_type == 'ritm':
            model = RITM(backbone='hrnet32')
            model.load_state_dict(torch.load('ritm_hrnet32.pth'))
        else:
            raise ValueError(f"Unknown model type: {model_type}")
            
        return model.eval().cuda()
    
    def get_transforms(self):
        """Define image preprocessing"""
        return transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                               std=[0.229, 0.224, 0.225])
        ])
    
    def encode_interactions(self, interactions, image_shape):
        """Encode user interactions as input channels"""
        h, w = image_shape[:2]
        pos_clicks = np.zeros((h, w), dtype=np.float32)
        neg_clicks = np.zeros((h, w), dtype=np.float32)
        
        for interaction in interactions:
            x, y = interaction['coords']
            sigma = 10  # Gaussian sigma
            
            # Create Gaussian centered at click
            y_grid, x_grid = np.ogrid[:h, :w]
            gaussian = np.exp(-((x_grid - x)**2 + (y_grid - y)**2) / (2 * sigma**2))
            
            if interaction['type'] == 'positive':
                pos_clicks = np.maximum(pos_clicks, gaussian)
            else:
                neg_clicks = np.maximum(neg_clicks, gaussian)
                
        return np.stack([pos_clicks, neg_clicks], axis=-1)
    
    def segment_interactive(self, image, interactions):
        """Main segmentation function"""
        # Preprocess image
        image_tensor = self.transform(image).unsqueeze(0).cuda()
        
        # Encode interactions
        interaction_maps = self.encode_interactions(interactions, image.shape)
        interaction_tensor = torch.from_numpy(interaction_maps).permute(2, 0, 1)
        interaction_tensor = interaction_tensor.unsqueeze(0).cuda()
        
        # Concatenate inputs
        inputs = torch.cat([image_tensor, interaction_tensor], dim=1)
        
        # Forward pass
        with torch.no_grad():
            logits = self.model(inputs)
            mask = torch.sigmoid(logits) > 0.5
            
        return mask.squeeze().cpu().numpy()
    
    def refine_segmentation(self, image, mask, new_interaction):
        """Refine existing segmentation with new interaction"""
        self.interaction_history.append(new_interaction)
        
        # Re-segment with all interactions
        refined_mask = self.segment_interactive(image, self.interaction_history)
        
        # Post-processing
        refined_mask = self.post_process(refined_mask)
        
        return refined_mask
    
    def post_process(self, mask):
        """Apply morphological operations"""
        # Remove small components
        kernel = np.ones((5, 5), np.uint8)
        mask = cv2.morphologyEx(mask.astype(np.uint8), cv2.MORPH_CLOSE, kernel)
        mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel)
        
        # Find largest connected component
        num_labels, labels = cv2.connectedComponents(mask)
        if num_labels > 1:
            largest_component = 1 + np.argmax(
                [np.sum(labels == i) for i in range(1, num_labels)]
            )
            mask = (labels == largest_component).astype(np.uint8)
            
        return mask

Advanced Features Implementation

class AdvancedInteractiveSegmentation:
    def __init__(self):
        self.model = self.build_model()
        self.feature_extractor = self.build_feature_extractor()
        
    def build_model(self):
        """Build segmentation model with attention mechanism"""
        return nn.Sequential(
            # Encoder
            nn.Conv2d(5, 64, 3, padding=1),  # RGB + 2 interaction channels
            nn.ReLU(),
            nn.Conv2d(64, 128, 3, padding=1),
            nn.ReLU(),
            
            # Attention module
            SpatialAttention(128),
            
            # Decoder
            nn.Conv2d(128, 64, 3, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 1, 1),
        )
    
    def multi_scale_inference(self, image, interactions):
        """Perform multi-scale segmentation"""
        scales = [0.5, 1.0, 1.5]
        predictions = []
        
        for scale in scales:
            # Resize image
            h, w = image.shape[:2]
            new_h, new_w = int(h * scale), int(w * scale)
            scaled_image = cv2.resize(image, (new_w, new_h))
            
            # Scale interactions
            scaled_interactions = self.scale_interactions(interactions, scale)
            
            # Segment
            mask = self.segment_interactive(scaled_image, scaled_interactions)
            
            # Resize back
            mask = cv2.resize(mask.astype(np.float32), (w, h))
            predictions.append(mask)
            
        # Ensemble predictions
        final_mask = np.mean(predictions, axis=0) > 0.5
        return final_mask
    
    def uncertainty_estimation(self, image, interactions, n_samples=10):
        """Estimate segmentation uncertainty using Monte Carlo dropout"""
        self.model.train()  # Enable dropout
        
        predictions = []
        for _ in range(n_samples):
            mask = self.segment_interactive(image, interactions)
            predictions.append(mask)
            
        predictions = np.stack(predictions)
        
        # Compute uncertainty metrics
        mean_pred = np.mean(predictions, axis=0)
        uncertainty = np.var(predictions, axis=0)
        
        return mean_pred, uncertainty

Best Practices

1. User Interaction Design

class InteractionDesign:
    @staticmethod
    def suggest_next_click(mask, uncertainty_map):
        """Suggest optimal location for next user click"""
        # Find high uncertainty regions near boundaries
        edges = cv2.Canny((mask * 255).astype(np.uint8), 100, 200)
        uncertain_edges = edges * uncertainty_map
        
        # Find peak uncertainty
        max_loc = np.unravel_index(uncertain_edges.argmax(), uncertain_edges.shape)
        
        return max_loc
    
    @staticmethod
    def validate_interaction(interaction, image_shape):
        """Validate user interaction"""
        x, y = interaction['coords']
        h, w = image_shape[:2]
        
        if not (0 <= x < w and 0 <= y < h):
            raise ValueError("Click coordinates out of bounds")
            
        return True

2. Performance Optimization

class PerformanceOptimizer:
    def __init__(self):
        self.cache = {}
        
    def cache_embeddings(self, image_id, embeddings):
        """Cache image embeddings for faster inference"""
        self.cache[image_id] = {
            'embeddings': embeddings,
            'timestamp': time.time()
        }
        
    def optimize_model(self, model):
        """Apply optimization techniques"""
        # Quantization
        quantized_model = torch.quantization.quantize_dynamic(
            model, {nn.Linear, nn.Conv2d}, dtype=torch.qint8
        )
        
        # TorchScript compilation
        scripted_model = torch.jit.script(quantized_model)
        
        return scripted_model

3. Quality Assurance

class QualityAssurance:
    @staticmethod
    def validate_segmentation(mask, min_area=100):
        """Validate segmentation quality"""
        # Check minimum area
        if np.sum(mask) < min_area:
            return False, "Segmented area too small"
            
        # Check connectivity
        num_components = cv2.connectedComponents(mask.astype(np.uint8))[0]
        if num_components > 2:  # Background + 1 object
            return False, "Multiple disconnected components"
            
        # Check boundary smoothness
        perimeter = cv2.arcLength(
            cv2.findContours(mask.astype(np.uint8), 
                           cv2.RETR_EXTERNAL, 
                           cv2.CHAIN_APPROX_SIMPLE)[0][0], 
            True
        )
        area = np.sum(mask)
        compactness = (perimeter ** 2) / (4 * np.pi * area)
        
        if compactness > 2.0:
            return False, "Boundary too irregular"
            
        return True, "Valid segmentation"

Evaluation Metrics

Implementation of Common Metrics

class SegmentationMetrics:
    @staticmethod
    def iou(pred_mask, gt_mask):
        """Intersection over Union"""
        intersection = np.logical_and(pred_mask, gt_mask).sum()
        union = np.logical_or(pred_mask, gt_mask).sum()
        return intersection / (union + 1e-8)
    
    @staticmethod
    def dice_coefficient(pred_mask, gt_mask):
        """Dice coefficient"""
        intersection = 2.0 * np.logical_and(pred_mask, gt_mask).sum()
        return intersection / (pred_mask.sum() + gt_mask.sum() + 1e-8)
    
    @staticmethod
    def boundary_iou(pred_mask, gt_mask, dilation_ratio=0.02):
        """Boundary IoU metric"""
        h, w = pred_mask.shape
        dilation = int(round(dilation_ratio * np.sqrt(h * w)))
        
        # Get boundaries
        pred_boundary = self.get_boundary(pred_mask, dilation)
        gt_boundary = self.get_boundary(gt_mask, dilation)
        
        return self.iou(pred_boundary, gt_boundary)
    
    @staticmethod
    def number_of_clicks(clicks_history, target_iou=0.85):
        """Number of clicks to reach target IoU"""
        for i, (mask, iou) in enumerate(clicks_history):
            if iou >= target_iou:
                return i + 1
        return len(clicks_history)

Applications

1. Medical Image Annotation

class MedicalSegmentation:
    def __init__(self):
        self.model = self.load_medical_model()
        
    def segment_tumor(self, mri_scan, radiologist_clicks):
        """Segment tumor with radiologist guidance"""
        # Preprocess medical image
        normalized_scan = self.normalize_hu_values(mri_scan)
        
        # Multi-slice processing
        segmented_slices = []
        for slice_idx, slice_2d in enumerate(normalized_scan):
            slice_clicks = radiologist_clicks.get(slice_idx, [])
            mask = self.segment_interactive(slice_2d, slice_clicks)
            segmented_slices.append(mask)
            
        # 3D reconstruction
        tumor_volume = np.stack(segmented_slices)
        return tumor_volume

2. Video Object Segmentation

class VideoInteractiveSegmentation:
    def __init__(self):
        self.propagation_model = self.build_propagation_model()
        
    def segment_video_object(self, video_frames, first_frame_clicks):
        """Segment object across video frames"""
        masks = []
        
        # Segment first frame
        first_mask = self.segment_interactive(video_frames[0], first_frame_clicks)
        masks.append(first_mask)
        
        # Propagate to subsequent frames
        for i in range(1, len(video_frames)):
            propagated_mask = self.propagate_mask(
                video_frames[i-1], video_frames[i], masks[-1]
            )
            masks.append(propagated_mask)
            
        return masks

3. Augmented Reality Applications

class ARSegmentation:
    def __init__(self):
        self.real_time_model = self.load_lightweight_model()
        
    def segment_for_ar(self, frame, touch_points):
        """Real-time segmentation for AR applications"""
        # Convert touch points to clicks
        clicks = self.touch_to_clicks(touch_points)
        
        # Fast segmentation
        mask = self.real_time_model.segment(frame, clicks)
        
        # Apply AR effects
        ar_frame = self.apply_ar_effect(frame, mask)
        
        return ar_frame, mask

Future Directions

1. Few-Shot Learning Integration

class FewShotInteractiveSegmentation:
    def __init__(self):
        self.meta_learner = self.build_meta_learner()
        
    def adapt_to_new_class(self, support_images, support_masks, support_clicks):
        """Adapt model to new object class with few examples"""
        # Extract class prototype
        class_prototype = self.extract_prototype(
            support_images, support_masks, support_clicks
        )
        
        # Fine-tune model
        adapted_model = self.meta_learner.adapt(class_prototype)
        
        return adapted_model

2. Natural Language Guidance

class LanguageGuidedSegmentation:
    def __init__(self):
        self.vision_language_model = self.load_clip_model()
        
    def segment_with_text(self, image, text_description, optional_clicks=None):
        """Segment objects using natural language description"""
        # Encode text
        text_features = self.encode_text(text_description)
        
        # Generate attention map
        attention_map = self.compute_attention(image, text_features)
        
        # Combine with clicks if provided
        if optional_clicks:
            click_map = self.encode_clicks(optional_clicks)
            attention_map = self.fuse_guidance(attention_map, click_map)
            
        # Segment
        mask = self.segment_with_attention(image, attention_map)
        
        return mask

3. Self-Supervised Pre-training

class SelfSupervisedPretraining:
    def __init__(self):
        self.pretext_model = self.build_pretext_model()
        
    def pretrain_with_synthetic_clicks(self, unlabeled_images):
        """Pre-train using synthetic interactions"""
        for image in unlabeled_images:
            # Generate pseudo-mask
            pseudo_mask = self.generate_pseudo_mask(image)
            
            # Simulate user interactions
            synthetic_clicks = self.simulate_clicks(pseudo_mask)
            
            # Self-supervised training
            predicted_mask = self.pretext_model(image, synthetic_clicks)
            loss = self.compute_consistency_loss(predicted_mask, pseudo_mask)
            
            # Update model
            loss.backward()
            self.optimizer.step()

Conclusion

Interactive segmentation represents a crucial bridge between fully automatic and manual annotation methods. By leveraging state-of-the-art deep learning techniques with efficient user interaction paradigms, we can achieve high-quality segmentation results with minimal user effort. The field continues to evolve with advances in foundation models, few-shot learning, and multi-modal interaction methods.

Originally from umitkacar/Interactive-Image-Segmentation repository