Skip to content

Software

JPK AFM Image Viewer & Loader (Python)

JPK AFM Viewer is a lightweight yet powerful Python utility designed to load, process, and visualize Atomic Force Microscopy (AFM) data acquired from JPK / Bruker NanoWizard 1 systems. It provides full access to all channels (Height, Phase, Error, etc.), automatically handles the complex JPK TIFF metadata for correct physical scaling, and offers robust image leveling routines.

The tool is ideal for researchers needing quick, reproducible visualization and comparison of AFM scans without relying on proprietary vendor software.

Example Output: Multi-File Comparison of Height channels with automatic scaling and plane leveling.
  • Direct Data Access: Loads .jpk files by correctly interpreting the custom JPK TIFF tags.
  • Automated Scaling: Ensures precise conversion of raw sensor data into physical units (nm, µm, V, °).
  • Flexible Leveling: Supports standard leveling methods including Plane Fit, Line-by-Line Correction, and Robust Leveling (which minimizes the influence of particles/holes).
  • Trace/Retrace Handling: Full support for visualizing and comparing Trace and Retrace passes for all channels.
  • File Ordering & Comparison: Easily compare multiple files in a grid view. File selection supports:
    Automatic Sorting: Alphabetical or by modification date.
    Interactive Selection: Enables precise, manual ordering of files via a sequential selection dialog (using the -i flag).

Installation and Usage

Dependencies (Python 3.x)

The core functionality relies on standard scientific Python libraries and tifffile to handle the image format:

pip install numpy matplotlib scipy tifffile

Command Line Examples

# 1. Quick view of a single file (default: height channel, plane leveling)
python jpk_afm_viewer_v9.py scan_001.jpk

# 2. View all channels (Trace only) for a single scan
python jpk_afm_viewer_v9.py scan_001.jpk --all

# 3. Compare multiple scans with interactive (manual) file selection order
python jpk_afm_viewer_v9.py -i --channel height --leveling line_robust

Download

#!/usr/bin/env python3
"""
JPK AFM Image Loader & Viewer
=============================
Loads JPK TIFF-based AFM images with all channels,
correct scaling and leveling.

Tested with JPK NanoWizard files.

Dependencies:
    conda install -c conda-forge numpy matplotlib scipy tifffile

Author: Claude and Frank Balzer (ORCID: 0000-0002-6228-6839)
"""

import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import LinearSegmentedColormap
from pathlib import Path
from dataclasses import dataclass
from typing import Optional
import warnings
import sys

try:
    import tifffile
except ImportError:
    raise ImportError(
        "tifffile is required.\n"
        "Installation: conda install -c conda-forge tifffile"
    )


# =============================================================================
# DATA STRUCTURES
# =============================================================================

@dataclass
class JPKChannel:
    """Represents a single AFM data channel."""
    name: str
    data: np.ndarray
    unit: str
    is_retrace: bool = False
    multiplier: float = 1.0
    offset: float = 0.0
    
    @property
    def display_name(self) -> str:
        """Name with trace/retrace indicator."""
        suffix = " (retrace)" if self.is_retrace else " (trace)"
        return self.name + suffix


@dataclass 
class JPKImage:
    """Represents a complete JPK AFM image."""
    filename: str
    width_m: float
    height_m: float
    pixels_x: int
    pixels_y: int
    channels: dict
    properties: dict
    
    @property
    def pixel_size_x(self) -> float:
        return self.width_m / self.pixels_x
    
    @property
    def pixel_size_y(self) -> float:
        return self.height_m / self.pixels_y
    
    def get_channel(self, name: str, retrace: bool = False) -> Optional[JPKChannel]:
        """Gets a channel by name and trace/retrace."""
        for ch in self.channels.values():
            if name.lower() in ch.name.lower() and ch.is_retrace == retrace:
                return ch
        # Fallback: search by name only
        for ch in self.channels.values():
            if name.lower() in ch.name.lower():
                return ch
        return None


# =============================================================================
# JPK TIFF TAGS (Custom Tags by JPK Instruments)
# =============================================================================

class JPKTiffTags:
    """JPK-specific TIFF tag IDs."""
    CHANNEL_NAME_SHORT = 32848
    CHANNEL_INDEX = 32849
    CHANNEL_NAME_FULL = 32850
    CHANNEL_PROPERTIES = 32851
    
    # Scan size (in page 0 / thumbnail)
    SCAN_WIDTH = 32834
    SCAN_HEIGHT = 32835
    PIXELS_X = 32838
    PIXELS_Y = 32839
    
    # Calibrated height scaling (what we want)
    CALIB_NAME = 33072
    CALIB_DTYPE = 33073
    CALIB_UNIT = 33074
    CALIB_TYPE = 33075
    CALIB_MULTIPLIER = 33076
    CALIB_OFFSET = 33077
    
    # Alternative scalings
    NOMINAL_MULTIPLIER = 33028
    NOMINAL_OFFSET = 33029
    VOLTS_MULTIPLIER = 32980
    VOLTS_OFFSET = 32981


# =============================================================================
# JPK TIFF LOADER
# =============================================================================

class JPKTiffLoader:
    """Loader for JPK TIFF format AFM files."""
    
    def __init__(self, filepath: str | Path):
        self.filepath = Path(filepath)
        self.properties = {}
        self.channels = {}
        
    def load(self) -> JPKImage:
        """Loads the JPK TIFF file."""
        
        if not self.filepath.exists():
            raise FileNotFoundError(f"File not found: {self.filepath}")
        
        with tifffile.TiffFile(self.filepath) as tif:
            # Extract scan size from first data image
            scan_width, scan_height = self._extract_scan_size(tif)
            
            pixels_x = 0
            pixels_y = 0
            
            # Iterate through all pages
            for i, page in enumerate(tif.pages):
                # Skip thumbnail (typically 64x64 or smallest page)
                if page.shape[0] < 128 or page.shape[1] < 128:
                    continue
                
                # Remember image size
                if pixels_x == 0:
                    pixels_y, pixels_x = page.shape
                
                # Extract channel information
                channel = self._extract_channel(page, i)
                if channel is not None:
                    # Create unique key
                    key = f"{channel.name}_{'retrace' if channel.is_retrace else 'trace'}"
                    self.channels[key] = channel
        
        # Fallback for scan size if not found
        if scan_width == 0:
            scan_width = pixels_x * 1e-8  # 10 nm/pixel as default
        if scan_height == 0:
            scan_height = pixels_y * 1e-8
        
        return JPKImage(
            filename=self.filepath.name,
            width_m=scan_width,
            height_m=scan_height,
            pixels_x=pixels_x,
            pixels_y=pixels_y,
            channels=self.channels,
            properties=self.properties
        )
    
    def _extract_scan_size(self, tif: tifffile.TiffFile) -> tuple[float, float]:
        """Extracts scan size from TIFF metadata (page 0)."""
        
        scan_width = 0.0
        scan_height = 0.0
        
        # Scan size is stored in page 0 (thumbnail)
        if len(tif.pages) > 0:
            page0 = tif.pages[0]
            tags = page0.tags
            
            # JPK stores scan size in custom tags
            if JPKTiffTags.SCAN_WIDTH in tags:
                scan_width = float(tags[JPKTiffTags.SCAN_WIDTH].value)
            
            if JPKTiffTags.SCAN_HEIGHT in tags:
                scan_height = float(tags[JPKTiffTags.SCAN_HEIGHT].value)
        
        return scan_width, scan_height
    
    def _extract_channel(self, page: tifffile.TiffPage, page_index: int) -> Optional[JPKChannel]:
        """Extracts channel data and metadata from a TIFF page."""
        
        tags = page.tags
        
        # Extract channel name
        channel_name = f"channel_{page_index}"
        
        if JPKTiffTags.CHANNEL_NAME_SHORT in tags:
            channel_name = str(tags[JPKTiffTags.CHANNEL_NAME_SHORT].value)
        elif JPKTiffTags.CHANNEL_NAME_FULL in tags:
            channel_name = str(tags[JPKTiffTags.CHANNEL_NAME_FULL].value)
        
        # Trace/Retrace from properties
        is_retrace = False
        if JPKTiffTags.CHANNEL_PROPERTIES in tags:
            props = str(tags[JPKTiffTags.CHANNEL_PROPERTIES].value)
            is_retrace = 'retrace : true' in props.lower() or 'retrace: true' in props.lower()
        
        # Extract scaling factors
        multiplier = 1.0
        offset = 0.0
        unit = 'a.u.'
        
        # Preferred: Calibrated scaling
        if JPKTiffTags.CALIB_MULTIPLIER in tags:
            multiplier = float(tags[JPKTiffTags.CALIB_MULTIPLIER].value)
        elif JPKTiffTags.NOMINAL_MULTIPLIER in tags:
            multiplier = float(tags[JPKTiffTags.NOMINAL_MULTIPLIER].value)
        
        if JPKTiffTags.CALIB_OFFSET in tags:
            offset = float(tags[JPKTiffTags.CALIB_OFFSET].value)
        elif JPKTiffTags.NOMINAL_OFFSET in tags:
            offset = float(tags[JPKTiffTags.NOMINAL_OFFSET].value)
        
        if JPKTiffTags.CALIB_UNIT in tags:
            unit = str(tags[JPKTiffTags.CALIB_UNIT].value)
        
        # Load raw data and calibrate
        try:
            raw_data = page.asarray().astype(np.float64)
            calibrated_data = raw_data * multiplier + offset
        except Exception as e:
            warnings.warn(f"Error loading page {page_index}: {e}")
            return None
        
        return JPKChannel(
            name=channel_name,
            data=calibrated_data,
            unit=unit,
            is_retrace=is_retrace,
            multiplier=multiplier,
            offset=offset
        )


# =============================================================================
# LEVELING FUNCTIONS
# =============================================================================

def level_plane_fit(data: np.ndarray) -> np.ndarray:
    """Plane-fit leveling (removes global tilt)."""
    ny, nx = data.shape
    x = np.arange(nx)
    y = np.arange(ny)
    X, Y = np.meshgrid(x, y)
    
    A = np.column_stack([X.ravel(), Y.ravel(), np.ones(nx * ny)])
    z = data.ravel()
    
    # Exclude NaN values
    mask = np.isfinite(z)
    if not mask.all():
        coeffs, _, _, _ = np.linalg.lstsq(A[mask], z[mask], rcond=None)
    else:
        coeffs, _, _, _ = np.linalg.lstsq(A, z, rcond=None)
    
    plane = coeffs[0] * X + coeffs[1] * Y + coeffs[2]
    return data - plane


def level_line_by_line(data: np.ndarray, order: int = 1) -> np.ndarray:
    """Line-by-line leveling."""
    ny, nx = data.shape
    leveled = np.zeros_like(data)
    x = np.arange(nx)
    
    for i in range(ny):
        line = data[i, :]
        mask = np.isfinite(line)
        if mask.sum() > order + 1:
            coeffs = np.polyfit(x[mask], line[mask], order)
            baseline = np.polyval(coeffs, x)
            leveled[i, :] = line - baseline
        else:
            leveled[i, :] = line
    
    return leveled


def level_median_diff(data: np.ndarray) -> np.ndarray:
    """Median-difference leveling for scan artifacts."""
    leveled = data.copy()
    
    for i in range(1, data.shape[0]):
        diff = np.nanmedian(data[i, :]) - np.nanmedian(data[i-1, :])
        leveled[i, :] -= diff
    
    leveled -= np.nanmedian(leveled)
    return leveled


def level_robust(data: np.ndarray, percentile_low: float = 10, percentile_high: float = 90) -> np.ndarray:
    """
    Robust leveling in two steps:
    1. Global plane-fit (removes tilt)
    2. Robust line-by-line leveling (ignores holes/particles)
    
    Args:
        data: 2D height data
        percentile_low: Lower percentile for robust fit (default: 10%)
        percentile_high: Upper percentile for robust fit (default: 90%)
    
    Returns:
        Leveled array
    """
    ny, nx = data.shape
    x = np.arange(nx)
    
    # Step 1: Global plane-fit
    step1 = level_plane_fit(data)
    
    # Step 2: Robust line-by-line leveling
    # Percentile bounds from already leveled data
    p_low = np.nanpercentile(step1, percentile_low)
    p_high = np.nanpercentile(step1, percentile_high)
    
    step2 = np.zeros_like(step1)
    for i in range(ny):
        line = step1[i, :]
        # Mask: only "normal" values (no holes/particles)
        mask = (line >= p_low) & (line <= p_high) & np.isfinite(line)
        
        if mask.sum() > 2:
            coeffs = np.polyfit(x[mask], line[mask], 1)
            baseline = np.polyval(coeffs, x)
            step2[i, :] = line - baseline
        else:
            step2[i, :] = line
    
    return step2


def level_line_robust(data: np.ndarray, order: int = 1, 
                      percentile_low: float = 10, percentile_high: float = 90) -> np.ndarray:
    """
    Robust line-by-line leveling in two steps:
    1. Normal line-by-line leveling
    2. Robust line-by-line leveling (ignores holes/particles)
    
    Args:
        data: 2D height data
        order: Polynomial order for fit (default: 1 = linear)
        percentile_low: Lower percentile for robust fit (default: 10%)
        percentile_high: Upper percentile for robust fit (default: 90%)
    
    Returns:
        Leveled array
    """
    ny, nx = data.shape
    x = np.arange(nx)
    
    # Step 1: Normal line-by-line leveling
    step1 = np.zeros_like(data)
    for i in range(ny):
        line = data[i, :]
        mask = np.isfinite(line)
        if mask.sum() > order + 1:
            coeffs = np.polyfit(x[mask], line[mask], order)
            baseline = np.polyval(coeffs, x)
            step1[i, :] = line - baseline
        else:
            step1[i, :] = line
    
    # Step 2: Robust line-by-line leveling
    # Percentile bounds from already leveled data
    p_low = np.nanpercentile(step1, percentile_low)
    p_high = np.nanpercentile(step1, percentile_high)
    
    step2 = np.zeros_like(step1)
    for i in range(ny):
        line = step1[i, :]
        # Mask: only "normal" values (no holes/particles)
        mask = (line >= p_low) & (line <= p_high) & np.isfinite(line)
        
        if mask.sum() > order + 1:
            coeffs = np.polyfit(x[mask], line[mask], order)
            baseline = np.polyval(coeffs, x)
            step2[i, :] = line - baseline
        else:
            step2[i, :] = line
    
    return step2


def apply_leveling(data: np.ndarray, method: str = 'plane', 
                   percentile_low: float = 10, percentile_high: float = 90) -> np.ndarray:
    """
    Applies leveling.
    
    Args:
        data: 2D height data
        method: Leveling method
        percentile_low: Lower percentile for robust methods (default: 10%)
        percentile_high: Upper percentile for robust methods (default: 90%)
    """
    if method == 'plane':
        return level_plane_fit(data)
    elif method == 'line':
        return level_line_by_line(data)
    elif method == 'median':
        return level_median_diff(data)
    elif method == 'robust':
        return level_robust(data, percentile_low, percentile_high)
    elif method == 'line_robust':
        return level_line_robust(data, percentile_low=percentile_low, 
                                  percentile_high=percentile_high)
    else:
        raise ValueError(f"Method '{method}' unknown. Available: plane, line, median, robust, line_robust")


# =============================================================================
# VISUALIZATION
# =============================================================================

def create_afm_colormap():
    """AFM-typische Farbskala (braun-gold)."""
    colors = [
        (0.2, 0.1, 0.0),
        (0.4, 0.2, 0.1),
        (0.6, 0.4, 0.2),
        (0.8, 0.6, 0.3),
        (1.0, 0.9, 0.7),
    ]
    return LinearSegmentedColormap.from_list('afm_brown', colors)


def format_scale(value_m: float) -> tuple[float, str]:
    """Converts meters to appropriate unit."""
    if abs(value_m) >= 1e-3:
        return value_m * 1e3, 'mm'
    elif abs(value_m) >= 1e-6:
        return value_m * 1e6, 'µm'
    elif abs(value_m) >= 1e-9:
        return value_m * 1e9, 'nm'
    else:
        return value_m * 1e12, 'pm'


def plot_afm_image(image: JPKImage, 
                   channel_name: str = 'height',
                   retrace: bool = False,
                   leveling: str = 'plane',
                   colormap: str = 'afm',
                   figsize: tuple = (10, 8),
                   show_histogram: bool = False,
                   enhance_contrast: bool = True,
                   percentile_low: float = 10,
                   percentile_high: float = 90) -> plt.Figure:
    """
    Displays an AFM image.
    
    Args:
        image: JPKImage object
        channel_name: Channel name (e.g. 'height', 'error', 'adhesion')
        retrace: True for retrace, False for trace
        leveling: 'plane', 'line', 'median', 'robust', 'line_robust' or None
        colormap: 'afm', 'viridis', etc.
        figsize: Figure size
        percentile_low: Lower percentile for robust leveling
        percentile_high: Upper percentile for robust leveling
    """
    # Find channel
    channel = image.get_channel(channel_name, retrace)
    
    if channel is None:
        available = [f"{ch.name} ({'retrace' if ch.is_retrace else 'trace'})" 
                     for ch in image.channels.values()]
        raise ValueError(f"Channel '{channel_name}' not found.\nAvailable: {available}")
    
    # Data and leveling
    data = channel.data.copy()
    if leveling and channel.unit == 'm':
        data = apply_leveling(data, method=leveling, 
                              percentile_low=percentile_low, 
                              percentile_high=percentile_high)
    
    # Units for XY axes
    width_scaled, x_unit = format_scale(image.width_m)
    height_scaled, y_unit = format_scale(image.height_m)
    
    # Z unit depending on channel
    if channel.unit == 'm':
        z_range = np.nanmax(data) - np.nanmin(data)
        _, z_unit = format_scale(z_range)
        scale_factors = {'mm': 1e3, 'µm': 1e6, 'nm': 1e9, 'pm': 1e12}
        z_scale = scale_factors.get(z_unit, 1)
        data_display = data * z_scale
        # Set minimum to 0
        data_display = data_display - np.nanmin(data_display)
        # Stretch contrast (optional)
        if enhance_contrast:
            p_low = np.nanpercentile(data_display, 2)
            p_high = np.nanpercentile(data_display, 98)
            data_display = np.clip(data_display, p_low, p_high)
        colorbar_label = f'Height ({z_unit})'
        use_cmap = create_afm_colormap() if colormap == 'afm' else colormap
    elif channel.unit == 'V':
        data_display = data * 1e3  # mV
        z_unit = 'mV'
        colorbar_label = f'{channel.name} ({z_unit})'
        use_cmap = create_afm_colormap() if colormap == 'afm' else colormap
    elif channel.unit == '°' or 'phase' in channel.name.lower():
        data_display = data
        # Stretch contrast (optional)
        if enhance_contrast:
            p_low = np.nanpercentile(data_display, 2)
            p_high = np.nanpercentile(data_display, 98)
            data_display = np.clip(data_display, p_low, p_high)
        z_unit = '°'
        colorbar_label = 'Phase (°)'
        use_cmap = 'gray'
    else:
        data_display = data
        z_unit = channel.unit if channel.unit else 'a.u.'
        colorbar_label = f'{channel.name} ({z_unit})'
        use_cmap = create_afm_colormap() if colormap == 'afm' else colormap
    
    # Plot erstellen
    if show_histogram:
        fig, (ax, ax_hist) = plt.subplots(2, 1, figsize=(figsize[0], figsize[1] + 2),
                                          gridspec_kw={'height_ratios': [4, 1]})
    else:
        fig, ax = plt.subplots(figsize=figsize)
    
    extent = [0, width_scaled, 0, height_scaled]
    
    im = ax.imshow(data_display, origin='lower', extent=extent,
                   cmap=use_cmap, aspect='equal')
    
    ax.set_xlabel(f'X ({x_unit})')
    ax.set_ylabel(f'Y ({y_unit})')
    
    trace_str = "Retrace" if channel.is_retrace else "Trace"
    level_str = f", Leveling: {leveling}" if leveling else ""
    ax.set_title(f"{image.filename}\n{channel.name} ({trace_str}){level_str}")
    
    cbar = fig.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
    cbar.set_label(colorbar_label)
    
    # Statistics
    stats = (
        f"Size: {width_scaled:.2f} × {height_scaled:.2f} {x_unit}\n"
        f"Pixel: {image.pixels_x} × {image.pixels_y}\n"
        f"Z-Range: {np.nanmax(data_display) - np.nanmin(data_display):.2f} {z_unit}\n"
        f"RMS: {np.nanstd(data_display):.2f} {z_unit}"
    )
    
    ax.text(0.02, 0.98, stats, transform=ax.transAxes,
            verticalalignment='top', fontsize=9, family='monospace',
            bbox=dict(boxstyle='round', facecolor='white', alpha=0.8))
    
    # Histogram
    if show_histogram:
        hist_data = data_display.flatten()
        hist_data = hist_data[np.isfinite(hist_data)]
        ax_hist.hist(hist_data, bins=100, color='steelblue', edgecolor='none', alpha=0.7)
        ax_hist.set_xlabel(colorbar_label)
        ax_hist.set_ylabel('Count')
        ax_hist.set_title('Height Distribution')
        ax_hist.axvline(np.mean(hist_data), color='red', linestyle='--', 
                        label=f'Mean: {np.mean(hist_data):.2f}')
        ax_hist.axvline(np.median(hist_data), color='orange', linestyle='-', 
                        label=f'Median: {np.median(hist_data):.2f}')
        ax_hist.legend(fontsize=8)
    
    plt.tight_layout()
    enable_clipboard_shortcut(fig)
    return fig


def plot_all_channels(image: JPKImage, 
                      leveling: str = 'plane',
                      trace_only: bool = True,
                      figsize: tuple = (16, 12),
                      enhance_contrast: bool = True,
                      percentile_low: float = 10,
                      percentile_high: float = 90) -> plt.Figure:
    """Displays all channels in a grid."""
    
    # Filter channels
    if trace_only:
        channels = [ch for ch in image.channels.values() if not ch.is_retrace]
    else:
        channels = list(image.channels.values())
    
    n_channels = len(channels)
    if n_channels == 0:
        raise ValueError("No channels found")
    
    # Calculate grid layout
    n_cols = min(3, n_channels)
    n_rows = (n_channels + n_cols - 1) // n_cols
    
    fig, axes = plt.subplots(n_rows, n_cols, figsize=figsize)
    if n_channels == 1:
        axes = np.array([axes])
    axes = axes.flatten()
    
    # Extent for all images
    width_scaled, x_unit = format_scale(image.width_m)
    height_scaled, y_unit = format_scale(image.height_m)
    extent = [0, width_scaled, 0, height_scaled]
    
    cmap = create_afm_colormap()
    
    for i, channel in enumerate(channels):
        ax = axes[i]
        
        data = channel.data.copy()
        if leveling and channel.unit == 'm':
            data = apply_leveling(data, method=leveling,
                                  percentile_low=percentile_low,
                                  percentile_high=percentile_high)
        
        # Scaling depending on channel
        if channel.unit == 'm':
            z_range = np.nanmax(data) - np.nanmin(data)
            _, z_unit = format_scale(z_range)
            scale_factors = {'mm': 1e3, 'µm': 1e6, 'nm': 1e9, 'pm': 1e12}
            z_scale = scale_factors.get(z_unit, 1)
            data_display = data * z_scale
            # Set minimum to 0
            data_display = data_display - np.nanmin(data_display)
            # Stretch contrast (optional)
            if enhance_contrast:
                p_low = np.nanpercentile(data_display, 2)
                p_high = np.nanpercentile(data_display, 98)
                data_display = np.clip(data_display, p_low, p_high)
            colorbar_label = f'Height ({z_unit})'
            use_cmap = cmap
        elif channel.unit == 'V':
            data_display = data * 1e3
            z_unit = 'mV'
            colorbar_label = f'{channel.name} ({z_unit})'
            use_cmap = cmap
        elif channel.unit == '°' or 'phase' in channel.name.lower():
            data_display = data
            # Stretch contrast (optional)
            if enhance_contrast:
                p_low = np.nanpercentile(data_display, 2)
                p_high = np.nanpercentile(data_display, 98)
                data_display = np.clip(data_display, p_low, p_high)
            z_unit = '°'
            colorbar_label = 'Phase (°)'
            use_cmap = 'gray'
        else:
            data_display = data
            z_unit = channel.unit if channel.unit else 'a.u.'
            colorbar_label = f'{channel.name} ({z_unit})'
            use_cmap = cmap
        
        im = ax.imshow(data_display, origin='lower', extent=extent,
                       cmap=use_cmap, aspect='equal')
        
        ax.set_xlabel(f'X ({x_unit})')
        ax.set_ylabel(f'Y ({y_unit})')
        
        trace_str = "R" if channel.is_retrace else "T"
        ax.set_title(f"{channel.name} ({trace_str})\nZ: {np.nanmax(data_display) - np.nanmin(data_display):.2f} {z_unit}")
        
        cbar = fig.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
        cbar.set_label(colorbar_label)
    
    # Hide empty subplots
    for i in range(n_channels, len(axes)):
        axes[i].set_visible(False)
    
    fig.suptitle(f"{image.filename}", fontsize=14, fontweight='bold')
    plt.tight_layout()
    enable_clipboard_shortcut(fig)
    return fig


# =============================================================================
# MULTI-FILE FUNCTIONS
# =============================================================================

def copy_figure_to_clipboard(fig=None):
    """
    Copies the current figure to clipboard (Windows).
    
    Args:
        fig: matplotlib Figure (if None, current figure is used)
    """
    import io
    
    if fig is None:
        fig = plt.gcf()
    
    # Save figure as PNG to buffer
    buf = io.BytesIO()
    fig.savefig(buf, format='png', dpi=150, bbox_inches='tight', 
                facecolor='white', edgecolor='none')
    buf.seek(0)
    
    try:
        # Windows: PIL + win32clipboard
        from PIL import Image
        import win32clipboard
        
        # Convert PNG to BMP (Windows clipboard requires BMP)
        img = Image.open(buf)
        output = io.BytesIO()
        img.convert('RGB').save(output, 'BMP')
        data = output.getvalue()[14:]  # Remove BMP header
        output.close()
        
        win32clipboard.OpenClipboard()
        win32clipboard.EmptyClipboard()
        win32clipboard.SetClipboardData(win32clipboard.CF_DIB, data)
        win32clipboard.CloseClipboard()
        
        print("✓ Figure copied to clipboard!")
        
    except ImportError:
        # Fallback: Save as file
        temp_path = Path.home() / "afm_clipboard_temp.png"
        buf.seek(0)
        with open(temp_path, 'wb') as f:
            f.write(buf.read())
        print(f"win32clipboard not installed.")
        print(f"Installation: conda install pywin32")
        print(f"Image saved as: {temp_path}")
    
    buf.close()


def enable_clipboard_shortcut(fig=None):
    """
    Enables Ctrl+C to copy the figure.
    
    Args:
        fig: matplotlib Figure
    """
    if fig is None:
        fig = plt.gcf()
    
    def on_key(event):
        if event.key == 'ctrl+c':
            copy_figure_to_clipboard(fig)
    
    fig.canvas.mpl_connect('key_press_event', on_key)
    print("Tip: Ctrl+C copies the image to clipboard")

def plot_multiple_files(filepaths: list,
                        channel_name: str = 'height',
                        retrace: bool = False,
                        leveling: str = 'plane',
                        colormap: str = 'afm',
                        figsize: tuple = None,
                        titles: list = None,
                        show_histogram: bool = False,
                        enhance_contrast: bool = True,
                        percentile_low: float = 10,
                        percentile_high: float = 90) -> plt.Figure:
    """
    Displays multiple JPK files side by side.
    
    Args:
        filepaths: List of file paths
        channel_name: Channel name
        retrace: True for retrace
        leveling: Leveling method
        colormap: Color scale
        figsize: Figure size (automatic if None)
        titles: Optional titles for each image
        show_histogram: Show histogram below each image
        enhance_contrast: Contrast enhancement (2-98 percentile)
    
    Returns:
        matplotlib Figure
    """
    n_files = len(filepaths)
    if n_files == 0:
        raise ValueError("No files specified")
    
    # Calculate grid layout
    if n_files <= 3:
        n_cols = n_files
        n_rows = 1
    elif n_files <= 6:
        n_cols = 3
        n_rows = 2
    elif n_files <= 9:
        n_cols = 3
        n_rows = 3
    else:
        n_cols = 4
        n_rows = (n_files + 3) // 4
    
    # Figure size
    if figsize is None:
        if show_histogram:
            figsize = (5 * n_cols, 6 * n_rows)  # More space for histograms
        else:
            figsize = (5 * n_cols, 4.5 * n_rows)
    
    # With histogram: Double rows (image + histogram)
    if show_histogram:
        fig, axes = plt.subplots(n_rows * 2, n_cols, figsize=figsize, squeeze=False,
                                  gridspec_kw={'height_ratios': [3, 1] * n_rows})
        img_axes = axes[0::2].flatten()  # Every second row (images)
        hist_axes = axes[1::2].flatten()  # Every second row (histograms)
    else:
        fig, axes = plt.subplots(n_rows, n_cols, figsize=figsize, squeeze=False)
        img_axes = axes.flatten()
        hist_axes = [None] * len(img_axes)
    
    cmap = create_afm_colormap() if colormap == 'afm' else colormap
    
    # Load all images
    images = []
    for fp in filepaths:
        try:
            img = load_jpk_image(fp)
            images.append(img)
        except Exception as e:
            print(f"Error loading {fp}: {e}")
            images.append(None)
    
    for i, (ax, ax_hist, image) in enumerate(zip(img_axes, hist_axes, images)):
        if image is None:
            ax.set_visible(False)
            if ax_hist is not None:
                ax_hist.set_visible(False)
            continue
        
        # Get channel
        channel = image.get_channel(channel_name, retrace)
        if channel is None:
            ax.text(0.5, 0.5, f"Channel '{channel_name}'\nnot found",
                    ha='center', va='center', transform=ax.transAxes)
            ax.set_title(image.filename)
            continue
        
        # Data and leveling
        data = channel.data.copy()
        if leveling and channel.unit == 'm':
            data = apply_leveling(data, method=leveling,
                                  percentile_low=percentile_low,
                                  percentile_high=percentile_high)
        
        # Units
        width_scaled, x_unit = format_scale(image.width_m)
        height_scaled, y_unit = format_scale(image.height_m)
        
        # Z unit depending on channel
        if channel.unit == 'm':
            z_range = np.nanmax(data) - np.nanmin(data)
            _, z_unit = format_scale(z_range)
            scale_factors = {'mm': 1e3, 'µm': 1e6, 'nm': 1e9, 'pm': 1e12}
            z_scale = scale_factors.get(z_unit, 1)
            data_display = data * z_scale
            # Set minimum to 0
            data_display = data_display - np.nanmin(data_display)
            # Stretch contrast (optional)
            if enhance_contrast:
                p_low = np.nanpercentile(data_display, 2)
                p_high = np.nanpercentile(data_display, 98)
                data_display = np.clip(data_display, p_low, p_high)
            colorbar_label = f'Height ({z_unit})'
            use_cmap = cmap
        elif channel.unit == 'V':
            data_display = data * 1e3
            z_unit = 'mV'
            colorbar_label = f'{channel.name} ({z_unit})'
            use_cmap = cmap
        elif channel.unit == '°' or 'phase' in channel.name.lower():
            data_display = data
            # Stretch contrast (optional)
            if enhance_contrast:
                p_low = np.nanpercentile(data_display, 2)
                p_high = np.nanpercentile(data_display, 98)
                data_display = np.clip(data_display, p_low, p_high)
            z_unit = '°'
            colorbar_label = 'Phase (°)'
            use_cmap = 'gray'
        else:
            data_display = data
            z_unit = channel.unit if channel.unit else 'a.u.'
            colorbar_label = f'{channel.name} ({z_unit})'
            use_cmap = cmap
        
        extent = [0, width_scaled, 0, height_scaled]
        
        im = ax.imshow(data_display, origin='lower', extent=extent,
                       cmap=use_cmap, aspect='equal')
        
        ax.set_xlabel(f'X ({x_unit})')
        ax.set_ylabel(f'Y ({y_unit})')
        
        # Title
        if titles and i < len(titles):
            title = titles[i]
        else:
            title = Path(filepaths[i]).stem  # Filename without extension
        
        z_range_display = np.nanmax(data_display) - np.nanmin(data_display)
        ax.set_title(f"{title}")
        
        # Statistics box
        stats = (
            f"{width_scaled:.1f}×{height_scaled:.1f} {x_unit}\n"
            f"Z: {z_range_display:.2f} {z_unit}\n"
            f"RMS: {np.nanstd(data_display):.2f} {z_unit}"
        )
        ax.text(0.02, 0.98, stats, transform=ax.transAxes,
                verticalalignment='top', fontsize=7, family='monospace',
                bbox=dict(boxstyle='round', facecolor='white', alpha=0.8))
        
        cbar = fig.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
        cbar.set_label(colorbar_label)
        
        # Histogram
        if ax_hist is not None:
            hist_data = data_display.flatten()
            hist_data = hist_data[np.isfinite(hist_data)]
            ax_hist.hist(hist_data, bins=50, color='steelblue', edgecolor='none', alpha=0.7)
            ax_hist.set_xlabel(colorbar_label, fontsize=8)
            ax_hist.set_ylabel('Count', fontsize=8)
            ax_hist.tick_params(labelsize=7)
    
    # Hide empty subplots
    for i in range(n_files, len(img_axes)):
        img_axes[i].set_visible(False)
        if show_histogram:
            hist_axes[i].set_visible(False)
    
    plt.tight_layout()
    enable_clipboard_shortcut(fig)
    return fig


def select_files_dialog(sort_by: str = 'name') -> list:
    """
    Opens a file selection dialog for multiple JPK files.
    
    Args:
        sort_by: Sorting - 'name' (alphabetical), 'date' (modification date), 
                 'none' (as returned by dialog)
    
    Returns:
        List of selected file paths
    """
    try:
        import tkinter as tk
        from tkinter import filedialog
    except ImportError:
        raise ImportError("tkinter is required for the file dialog")
    
    root = tk.Tk()
    root.withdraw()  # Hide main window
    root.attributes('-topmost', True)  # Bring dialog to foreground
    
    filepaths = filedialog.askopenfilenames(
        title="Select JPK files",
        filetypes=[
            ("JPK Files", "*.jpk *.jpk-qi-image *.jpk-force-map"),
            ("All Files", "*.*")
        ]
    )
    
    root.destroy()
    filepaths = list(filepaths)
    
    # Apply sorting
    if sort_by == 'name':
        filepaths.sort(key=lambda x: Path(x).name.lower())
    elif sort_by == 'date':
        filepaths.sort(key=lambda x: Path(x).stat().st_mtime)
    # 'none' = no sorting
    
    return filepaths


def select_files_interactive() -> list:
    """
    Interactive file selection with order control.
    Opens the dialog multiple times and adds files in the desired order.
    
    Returns:
        List of file paths in selection order
    """
    try:
        import tkinter as tk
        from tkinter import filedialog, messagebox
    except ImportError:
        raise ImportError("tkinter is required for the file dialog")
    
    filepaths = []
    
    root = tk.Tk()
    root.withdraw()
    
    print("Interactive file selection:")
    print("  - Select files individually or in groups")
    print("  - Click 'Cancel' when done")
    print()
    
    while True:
        root.attributes('-topmost', True)
        
        selected = filedialog.askopenfilenames(
            title=f"Select file(s) (so far: {len(filepaths)}) - Cancel when done",
            filetypes=[
                ("JPK Files", "*.jpk *.jpk-qi-image *.jpk-force-map"),
                ("All Files", "*.*")
            ]
        )
        
        if not selected:
            break
        
        for fp in selected:
            if fp not in filepaths:
                filepaths.append(fp)
                print(f"  {len(filepaths)}. {Path(fp).name}")
    
    root.destroy()
    
    print(f"\n{len(filepaths)} file(s) selected")
    return filepaths


def quick_compare(filepaths: list = None, 
                  channel: str = 'height',
                  leveling: str = 'plane',
                  sort_by: str = 'name',
                  interactive: bool = False) -> plt.Figure:
    """
    Quick comparison of multiple files.
    
    Args:
        filepaths: List of paths (opens dialog if None)
        channel: Channel name
        leveling: Leveling method
        sort_by: Sorting - 'name', 'date', 'none'
        interactive: If True, opens dialog multiple times for exact order
    
    Example:
        # With file dialog (alphabetically sorted)
        quick_compare()
        
        # With interactive order
        quick_compare(interactive=True)
        
        # Sorted by date
        quick_compare(sort_by='date')
        
        # With explicit list (custom order)
        quick_compare([r"C:/Data/scan1.jpk", r"C:/Data/scan2.jpk"])
    """
    if filepaths is None:
        if interactive:
            filepaths = select_files_interactive()
        else:
            filepaths = select_files_dialog(sort_by=sort_by)
    
    if not filepaths:
        print("No files selected")
        return None
    
    print(f"Loading {len(filepaths)} file(s)...")
    for i, fp in enumerate(filepaths, 1):
        print(f"  {i}. {Path(fp).name}")
    
    fig = plot_multiple_files(filepaths, channel_name=channel, leveling=leveling)
    plt.show()
    
    return fig


# =============================================================================
# MAIN FUNCTIONS
# =============================================================================

def load_jpk_image(filepath: str | Path) -> JPKImage:
    """
    Loads a JPK AFM image file.
    
    Args:
        filepath: Path to .jpk file
    
    Returns:
        JPKImage object with all channels
    """
    loader = JPKTiffLoader(filepath)
    return loader.load()


def list_channels(image: JPKImage) -> list[str]:
    """Lists all channels."""
    result = []
    for key, ch in image.channels.items():
        trace_str = "retrace" if ch.is_retrace else "trace"
        result.append(f"{ch.name} ({trace_str}) - Unit: {ch.unit}")
    return result


def get_height_data(image: JPKImage, 
                    retrace: bool = False,
                    leveling: str = 'plane') -> np.ndarray:
    """Extracts height data."""
    channel = image.get_channel('height', retrace)
    if channel is None:
        raise ValueError("Height channel not found")
    
    data = channel.data.copy()
    if leveling:
        data = apply_leveling(data, method=leveling)
    return data


# =============================================================================
# INTERACTIVE USAGE
# =============================================================================

def quick_view(filepath: str | Path, channel: str = 'height', leveling: str = 'plane'):
    """
    Quick display of a JPK file.
    
    Example:
        quick_view(r"C:/Data/scan.jpk")
        quick_view(r"C:/Data/scan.jpk", channel='error')
    """
    image = load_jpk_image(filepath)
    
    print(f"File: {image.filename}")
    print(f"Size: {image.pixels_x} × {image.pixels_y} Pixel")
    print(f"Scan: {image.width_m*1e6:.2f} × {image.height_m*1e6:.2f} µm")
    print(f"\nChannels:")
    for ch_info in list_channels(image):
        print(f"  - {ch_info}")
    
    fig = plot_afm_image(image, channel_name=channel, leveling=leveling)
    plt.show()
    
    return image


# =============================================================================
# COMMAND LINE
# =============================================================================

if __name__ == "__main__":
    import argparse
    
    parser = argparse.ArgumentParser(description='JPK AFM Viewer')
    parser.add_argument('files', nargs='*', help='JPK file(s)')
    parser.add_argument('-c', '--channel', default='height', help='Channel')
    parser.add_argument('-l', '--leveling', default='plane', 
                        choices=['plane', 'line', 'median', 'robust', 'line_robust', 'none'])
    parser.add_argument('--plow', type=float, default=10, 
                        help='Lower percentile for robust leveling (default: 10)')
    parser.add_argument('--phigh', type=float, default=90, 
                        help='Upper percentile for robust leveling (default: 90)')
    parser.add_argument('-r', '--retrace', action='store_true', help='Show retrace')
    parser.add_argument('--all', action='store_true', help='Show all channels')
    parser.add_argument('--list', action='store_true', help='List channels')
    parser.add_argument('--histogram', action='store_true', help='Show histogram')
    parser.add_argument('--no-enhance', action='store_true', help='Disable contrast enhancement')
    parser.add_argument('-o', '--output', help='Output file')
    parser.add_argument('--select', action='store_true', help='Open file selection dialog')
    parser.add_argument('--interactive', '-i', action='store_true', 
                        help='Interactive file selection (for exact order)')
    parser.add_argument('--sort', choices=['name', 'date', 'none'], default='name',
                        help='Sorting: name (alphabetical), date (modification date), none')
    
    args = parser.parse_args()
    
    # File selection dialog
    if args.select or args.interactive or not args.files:
        if args.select or args.interactive or len(sys.argv) == 1:
            if args.interactive:
                filepaths = select_files_interactive()
            else:
                filepaths = select_files_dialog(sort_by=args.sort)
            if not filepaths:
                print("No files selected")
                sys.exit(0)
        else:
            parser.print_help()
            print("\n\nExamples:")
            print("  python jpk_afm_viewer_v2.py scan.jpk")
            print("  python jpk_afm_viewer_v2.py scan1.jpk scan2.jpk scan3.jpk")
            print("  python jpk_afm_viewer_v2.py --select")
            print("  python jpk_afm_viewer_v2.py --select --sort date")
            print("  python jpk_afm_viewer_v2.py -i  # Interactive for exact order")
            print("  python jpk_afm_viewer_v2.py *.jpk")
            sys.exit(0)
    else:
        filepaths = args.files
    
    leveling = None if args.leveling == 'none' else args.leveling
    
    # Single file
    if len(filepaths) == 1:
        image = load_jpk_image(filepaths[0])
        
        if args.list:
            print(f"Channels in {image.filename}:")
            for ch in list_channels(image):
                print(f"  {ch}")
            sys.exit(0)
        
        enhance = not args.no_enhance
        
        if args.all:
            fig = plot_all_channels(image, leveling=leveling, enhance_contrast=enhance,
                                    percentile_low=args.plow, percentile_high=args.phigh)
        else:
            fig = plot_afm_image(image, channel_name=args.channel,
                                retrace=args.retrace, leveling=leveling,
                                show_histogram=args.histogram,
                                enhance_contrast=enhance,
                                percentile_low=args.plow, percentile_high=args.phigh)
    
    # Multiple files
    else:
        print(f"Loading {len(filepaths)} files in the following order:")
        for i, fp in enumerate(filepaths, 1):
            print(f"  {i}. {Path(fp).name}")
        enhance = not args.no_enhance
        fig = plot_multiple_files(filepaths, 
                                  channel_name=args.channel,
                                  retrace=args.retrace,
                                  leveling=leveling,
                                  show_histogram=args.histogram,
                                  enhance_contrast=enhance,
                                  percentile_low=args.plow, percentile_high=args.phigh)
    
    if args.output:
        fig.savefig(args.output, dpi=150, bbox_inches='tight')
        print(f"Saved: {args.output}")
    else:
        plt.show()