Source code for spxquery.visualization.plots

"""
Visualization functions for SPHEREx time-domain data.
"""

import logging
from pathlib import Path
from typing import TYPE_CHECKING, List, Optional, Tuple

import matplotlib.cm as cm
import matplotlib.pyplot as plt
import numpy as np
import numpy.ma as ma
from astropy.stats import sigma_clip
from matplotlib.axes import Axes
from matplotlib.colors import Normalize
from matplotlib.figure import Figure

from ..core.config import PhotometryResult

if TYPE_CHECKING:
    from ..core.config import VisualizationConfig

logger = logging.getLogger(__name__)

# Plotting configuration
WAVELENGTH_CMAP = "rainbow"  # Colormap for wavelength coding
WAVELENGTH_RANGE = (0.75, 5.0)  # SPHEREx wavelength range in microns


[docs] def calculate_smart_ylimits( y_values: List[float], percentile_range: Tuple[float, float] = (1.0, 99.0), padding_fraction: float = 0.1, ) -> Tuple[float, float]: """ Calculate smart y-axis limits based on percentiles to exclude extreme outliers. Parameters ---------- y_values : List[float] Y-axis values to analyze percentile_range : Tuple[float, float] Lower and upper percentiles to use for limits (default: 1st to 99th) padding_fraction : float Fraction of range to add as padding (default: 0.1 = 10%) Returns ------- Tuple[float, float] (ymin, ymax) limits for y-axis """ # Filter out NaN and infinite values valid_values = [v for v in y_values if np.isfinite(v)] if not valid_values: return (0, 1) # Default range if no valid data # Calculate percentile-based limits y_min = np.percentile(valid_values, percentile_range[0]) y_max = np.percentile(valid_values, percentile_range[1]) # Add padding y_range = y_max - y_min if y_range > 0: padding = y_range * padding_fraction y_min -= padding y_max += padding else: # All values are the same, add symmetric padding if y_min != 0: y_min *= 0.9 y_max *= 1.1 else: y_min = -0.1 y_max = 0.1 return (y_min, y_max)
[docs] def apply_sigma_clipping( photometry_results: List[PhotometryResult], sigma: float = 3.0, maxiters: int = 10 ) -> List[PhotometryResult]: """ Apply sigma clipping to remove outliers based on flux values. Parameters ---------- photometry_results : List[PhotometryResult] Input photometry measurements sigma : float Number of standard deviations to use for clipping maxiters : int Maximum number of clipping iterations Returns ------- List[PhotometryResult] Filtered photometry results with outliers removed """ if not photometry_results: return photometry_results # Only clip regular measurements (not upper limits) regular_measurements = [p for p in photometry_results if not p.is_upper_limit] upper_limits = [p for p in photometry_results if p.is_upper_limit] if not regular_measurements: return photometry_results # Extract flux values for clipping fluxes = np.array([p.flux for p in regular_measurements]) # Apply sigma clipping clipped_data = sigma_clip(fluxes, sigma=sigma, maxiters=maxiters) # Keep only non-clipped measurements if ma.is_masked(clipped_data): good_indices = ~clipped_data.mask else: # If no points were clipped, all points are good good_indices = np.ones(len(fluxes), dtype=bool) # Filter regular measurements filtered_regular = [regular_measurements[i] for i in range(len(regular_measurements)) if good_indices[i]] # Combine filtered regular measurements with upper limits filtered_results = filtered_regular + upper_limits logger.info( f"Sigma clipping: {len(photometry_results)} -> {len(filtered_results)} measurements " f"({len(photometry_results) - len(filtered_results)} outliers removed)" ) return filtered_results
[docs] def create_spectrum_plot( photometry_results: List[PhotometryResult], ax: Optional[Axes] = None, apply_clipping: bool = True, sigma: float = 3.0, apply_quality_filters: bool = True, sigma_threshold: float = 5.0, bad_flags_mask: Optional[int] = None, use_magnitude: bool = False, show_errorbars: bool = True, ) -> Axes: """ Create spectrum plot (wavelength vs flux), color-coded by observation date. Parameters ---------- photometry_results : List[PhotometryResult] Photometry measurements ax : plt.Axes, optional Axes to plot on. If None, current axes are used. apply_clipping : bool Whether to apply sigma clipping to remove outliers sigma : float Number of standard deviations for sigma clipping apply_quality_filters : bool Whether to classify points as good/rejected based on QC filters sigma_threshold : float Minimum SNR (flux/flux_err) for quality control bad_flags_mask : int, optional Integer mask with bad flag bits set (created by create_flag_mask) use_magnitude : bool If True, plot AB magnitude instead of flux (default: False) show_errorbars : bool If True, show errorbars (default: True) Returns ------- plt.Axes Axes with spectrum plot """ if ax is None: ax = plt.gca() # Apply sigma clipping if requested if apply_clipping: photometry_results = apply_sigma_clipping(photometry_results, sigma=sigma) # Classify points by quality if apply_quality_filters and bad_flags_mask is not None: from ..utils.helpers import classify_photometry_by_quality classified = classify_photometry_by_quality( photometry_results, sigma_threshold=sigma_threshold, bad_flags_mask=bad_flags_mask, separate_upper_limits=True, ) good_regular = classified.good_regular rejected_regular = classified.rejected_regular good_upper_limits = classified.good_upper_limits rejected_upper_limits = classified.rejected_upper_limits else: # No quality filtering - all points are "good" good_regular = [p for p in photometry_results if not p.is_upper_limit] good_upper_limits = [p for p in photometry_results if p.is_upper_limit] rejected_regular = [] rejected_upper_limits = [] # Plot good regular measurements with error bars, color-coded by date if good_regular: wavelengths = [p.wavelength for p in good_regular] bandwidths = [p.bandwidth for p in good_regular] mjds = [p.mjd for p in good_regular] # Get y values depending on magnitude or flux mode if use_magnitude: y_values = [p.mag_ab if p.mag_ab is not None else np.nan for p in good_regular] y_errors = [p.mag_ab_error if p.mag_ab_error is not None else np.nan for p in good_regular] else: y_values = [p.flux for p in good_regular] y_errors = [p.flux_error for p in good_regular] # Convert MJD to days since first observation mjd_min = min(mjds) days_since_first = [mjd - mjd_min for mjd in mjds] # Create colormap for date coding cmap = cm.get_cmap("viridis") norm = Normalize(vmin=0, vmax=max(days_since_first) if days_since_first else 1) # Two-pass plotting: errorbars first (transparent), then markers (solid) # Pass 1: Plot errorbars only (if enabled) if show_errorbars: for wl, y_val, y_err, bw, days in zip(wavelengths, y_values, y_errors, bandwidths, days_since_first): if np.isnan(y_val): continue color = cmap(norm(days)) ax.errorbar( wl, y_val, xerr=bw, yerr=y_err, fmt="none", # No marker capsize=0, linewidth=0.5, elinewidth=0.5, color=color, alpha=0.2, # Transparent errorbars zorder=1, # Behind markers ) # Pass 2: Plot markers only (solid) for wl, y_val, days in zip(wavelengths, y_values, days_since_first): if np.isnan(y_val): continue color = cmap(norm(days)) ax.plot( wl, y_val, "o", color=color, markersize=1.5, alpha=0.9, # Solid markers zorder=2, # On top of errorbars ) # Add colorbar for date sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm) sm.set_array([]) cbar = plt.colorbar(sm, ax=ax, pad=0.02) cbar.set_label("Days since first obs", fontsize=10) # Plot rejected regular measurements as crosses if rejected_regular: wavelengths = [p.wavelength for p in rejected_regular] bandwidths = [p.bandwidth for p in rejected_regular] if use_magnitude: y_values = [p.mag_ab if p.mag_ab is not None else np.nan for p in rejected_regular] else: y_values = [p.flux for p in rejected_regular] # Filter out NaN values valid_data = [(wl, y, bw) for wl, y, bw in zip(wavelengths, y_values, bandwidths) if not np.isnan(y)] if valid_data: wavelengths, y_values, bandwidths = zip(*valid_data) ax.errorbar( wavelengths, y_values, xerr=bandwidths, yerr=None, fmt="x", markersize=2, capsize=0, linewidth=0.5, elinewidth=0.5, label="Rejected", alpha=0.5, color="gray", ) # Plot good upper limits if good_upper_limits: ul_wavelengths = [p.wavelength for p in good_upper_limits] ul_bandwidths = [p.bandwidth for p in good_upper_limits] if use_magnitude: # For magnitude, upper limit on flux becomes lower limit on magnitude # Use the stored mag_ab value (should represent the limit) ul_y_values = [p.mag_ab if p.mag_ab is not None else np.nan for p in good_upper_limits] else: ul_y_values = [p.flux + p.flux_error for p in good_upper_limits] # Upper limit value # Filter out NaN values valid_data = [(wl, y, bw) for wl, y, bw in zip(ul_wavelengths, ul_y_values, ul_bandwidths) if not np.isnan(y)] if valid_data: ul_wavelengths, ul_y_values, ul_bandwidths = zip(*valid_data) ax.errorbar( ul_wavelengths, ul_y_values, xerr=ul_bandwidths, yerr=None, fmt="v" if not use_magnitude else "^", # Flip arrow for magnitude markersize=3, capsize=0, linewidth=0.5, elinewidth=0.5, label="Upper limits" if not use_magnitude else "Lower limits (mag)", alpha=0.8, color="red", ) # Plot rejected upper limits as small crosses if rejected_upper_limits: ul_wavelengths = [p.wavelength for p in rejected_upper_limits] ul_bandwidths = [p.bandwidth for p in rejected_upper_limits] if use_magnitude: ul_y_values = [p.mag_ab if p.mag_ab is not None else np.nan for p in rejected_upper_limits] else: ul_y_values = [p.flux + p.flux_error for p in rejected_upper_limits] # Filter out NaN values valid_data = [(wl, y, bw) for wl, y, bw in zip(ul_wavelengths, ul_y_values, ul_bandwidths) if not np.isnan(y)] if valid_data: ul_wavelengths, ul_y_values, ul_bandwidths = zip(*valid_data) ax.errorbar( ul_wavelengths, ul_y_values, xerr=ul_bandwidths, yerr=None, fmt="x", markersize=2, capsize=0, linewidth=0.5, elinewidth=0.5, label="Rejected (UL)", alpha=0.5, color="lightcoral", ) # Formatting ax.set_xlabel("Wavelength (μm)", fontsize=12) if use_magnitude: ax.set_ylabel("AB Magnitude", fontsize=12) ax.invert_yaxis() # Fainter sources have higher magnitudes else: ax.set_ylabel("Flux Density (μJy)", fontsize=12) ax.set_title("SPHEREx Spectrum", fontsize=14) ax.grid(True, alpha=0.3) ax.legend() # Set x-axis limits to SPHEREx range ax.set_xlim(0.7, 5.1) # Set smart y-axis limits based on percentiles to handle outliers all_y_values = [] # Collect all y-values for limit calculation if use_magnitude: all_y_values.extend([p.mag_ab for p in good_regular if p.mag_ab is not None]) all_y_values.extend([p.mag_ab for p in rejected_regular if p.mag_ab is not None]) all_y_values.extend([p.mag_ab for p in good_upper_limits if p.mag_ab is not None]) all_y_values.extend([p.mag_ab for p in rejected_upper_limits if p.mag_ab is not None]) else: all_y_values.extend([p.flux for p in good_regular]) all_y_values.extend([p.flux for p in rejected_regular]) all_y_values.extend([p.flux + p.flux_error for p in good_upper_limits]) all_y_values.extend([p.flux + p.flux_error for p in rejected_upper_limits]) if all_y_values: y_min, y_max = calculate_smart_ylimits(all_y_values, percentile_range=(0.1, 99.9)) # For magnitude plots, axis is inverted so we need to reverse the order if use_magnitude: ax.set_ylim(y_max, y_min) # Reversed for inverted axis else: ax.set_ylim(y_min, y_max) return ax
[docs] def create_lightcurve_plot( photometry_results: List[PhotometryResult], ax: Optional[Axes] = None, apply_clipping: bool = True, sigma: float = 3.0, apply_quality_filters: bool = True, sigma_threshold: float = 5.0, bad_flags_mask: Optional[int] = None, use_magnitude: bool = False, show_errorbars: bool = True, ) -> Axes: """ Create light curve plot (time vs flux) color-coded by wavelength. Parameters ---------- photometry_results : List[PhotometryResult] Photometry measurements ax : plt.Axes, optional Axes to plot on. If None, current axes are used. apply_clipping : bool Whether to apply sigma clipping to remove outliers sigma : float Number of standard deviations for sigma clipping apply_quality_filters : bool Whether to classify points as good/rejected based on QC filters sigma_threshold : float Minimum SNR (flux/flux_err) for quality control bad_flags_mask : int, optional Integer mask with bad flag bits set (created by create_flag_mask) use_magnitude : bool If True, plot AB magnitude instead of flux (default: False) show_errorbars : bool If True, show errorbars (default: True) Returns ------- plt.Axes Axes with light curve plot """ if ax is None: ax = plt.gca() if not photometry_results: ax.text(0.5, 0.5, "No data available", ha="center", va="center", transform=ax.transAxes) return ax # Apply sigma clipping if requested if apply_clipping: photometry_results = apply_sigma_clipping(photometry_results, sigma=sigma) # Classify points by quality if apply_quality_filters and bad_flags_mask is not None: from ..utils.helpers import classify_photometry_by_quality classified = classify_photometry_by_quality( photometry_results, sigma_threshold=sigma_threshold, bad_flags_mask=bad_flags_mask, separate_upper_limits=False, # Light curve doesn't separate upper limits ) # Combine all good points (regular + upper limits if any) good_points = classified.good_regular rejected_points = classified.rejected_regular else: # No quality filtering - all points are "good" good_points = photometry_results rejected_points = [] # Get colormap for wavelength coding cmap = cm.get_cmap(WAVELENGTH_CMAP) norm = Normalize(vmin=WAVELENGTH_RANGE[0], vmax=WAVELENGTH_RANGE[1]) # Sort by MJD for proper time ordering good_sorted = sorted(good_points, key=lambda x: x.mjd) rejected_sorted = sorted(rejected_points, key=lambda x: x.mjd) # Two-pass plotting: errorbars first (transparent), then markers (solid) # Pass 1: Plot errorbars only (if enabled) if show_errorbars: for result in good_sorted: color = cmap(norm(result.wavelength)) if result.is_upper_limit: # Skip upper limits for errorbars continue else: # Plot regular measurement errorbars if use_magnitude: y_val = result.mag_ab if result.mag_ab is not None else np.nan y_err = result.mag_ab_error if result.mag_ab_error is not None else np.nan else: y_val = result.flux y_err = result.flux_error if np.isnan(y_val): continue ax.errorbar( result.mjd, y_val, yerr=y_err, fmt="none", # No marker capsize=0, linewidth=0.5, elinewidth=0.5, color=color, alpha=0.2, # Transparent errorbars zorder=1, # Behind markers ) # Pass 2: Plot markers only (solid) for result in good_sorted: color = cmap(norm(result.wavelength)) if result.is_upper_limit: # Plot upper limit if use_magnitude: y_val = result.mag_ab if result.mag_ab is not None else np.nan marker = "^" # Flip for magnitude (lower limit) else: y_val = result.flux + result.flux_error marker = "v" if np.isnan(y_val): continue ax.plot( result.mjd, y_val, marker, color=color, markersize=3, alpha=0.9, # Solid markers zorder=2, # On top of errorbars ) else: # Plot regular measurement if use_magnitude: y_val = result.mag_ab if result.mag_ab is not None else np.nan else: y_val = result.flux if np.isnan(y_val): continue ax.plot( result.mjd, y_val, "o", color=color, markersize=1.5, alpha=0.9, # Solid markers zorder=2, # On top of errorbars ) # Plot rejected points as small gray crosses for result in rejected_sorted: if use_magnitude: y_val = result.mag_ab if result.mag_ab is not None else np.nan else: y_val = result.flux # Skip if invalid magnitude if np.isnan(y_val): continue ax.plot( result.mjd, y_val, "x", color="gray", markersize=2, alpha=0.5, ) # Add colorbar for wavelength sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm) sm.set_array([]) cbar = plt.colorbar(sm, ax=ax, pad=0.02) cbar.set_label("Wavelength (μm)", fontsize=10) # Formatting ax.set_xlabel("MJD", fontsize=12) if use_magnitude: ax.set_ylabel("AB Magnitude", fontsize=12) ax.invert_yaxis() # Fainter sources have higher magnitudes else: ax.set_ylabel("Flux Density (μJy)", fontsize=12) ax.set_title("SPHEREx Light Curve", fontsize=14) ax.grid(True, alpha=0.3) # Add some padding to x-axis (use all points for range) all_sorted = sorted(photometry_results, key=lambda x: x.mjd) mjds = [p.mjd for p in all_sorted] mjd_range = max(mjds) - min(mjds) if mjd_range > 0: ax.set_xlim(min(mjds) - 0.05 * mjd_range, max(mjds) + 0.05 * mjd_range) # Set smart y-axis limits based on percentiles to handle outliers all_y_values = [] # Collect all y-values for limit calculation for result in good_points: if use_magnitude: if result.mag_ab is not None: all_y_values.append(result.mag_ab) else: if result.is_upper_limit: all_y_values.append(result.flux + result.flux_error) else: all_y_values.append(result.flux) for result in rejected_points: if use_magnitude: if result.mag_ab is not None: all_y_values.append(result.mag_ab) else: all_y_values.append(result.flux) if all_y_values: y_min, y_max = calculate_smart_ylimits(all_y_values, percentile_range=(0.1, 99.9)) # For magnitude plots, axis is inverted so we need to reverse the order if use_magnitude: ax.set_ylim(y_max, y_min) # Reversed for inverted axis else: ax.set_ylim(y_min, y_max) return ax
[docs] def create_combined_plot( photometry_results: List[PhotometryResult], output_path: Optional[Path] = None, figsize: Optional[Tuple[float, float]] = None, apply_clipping: bool = True, sigma: Optional[float] = None, apply_quality_filters: bool = True, sigma_threshold: float = 5.0, bad_flags: Optional[List[int]] = None, use_magnitude: bool = False, show_errorbars: bool = True, visualization_config: Optional["VisualizationConfig"] = None, ) -> Figure: """ Create combined plot with spectrum and light curve. Quality control filters classify points as good or rejected: - Good points: plotted normally (filled circles) - Rejected points: plotted as small gray crosses - All points appear in the plot and are saved in CSV output Parameters ---------- photometry_results : List[PhotometryResult] Photometry measurements (all points included) output_path : Path, optional Path to save figure. If None, figure is not saved. figsize : Tuple[float, float], optional Figure size in inches (overrides visualization_config if provided) apply_clipping : bool Whether to apply sigma clipping to remove outliers sigma : float, optional Number of standard deviations for sigma clipping (overrides visualization_config if provided) apply_quality_filters : bool Whether to apply quality control filters (SNR and flags) sigma_threshold : float Minimum SNR (flux/flux_err) for quality control (default: 5.0) bad_flags : List[int], optional List of bad flag bit positions to reject Default: [0, 1, 2, 6, 7, 9, 10, 11, 14, 15, 17, 19] use_magnitude : bool If True, plot AB magnitude instead of flux (default: False) show_errorbars : bool If True, show errorbars (default: True) visualization_config : VisualizationConfig, optional Advanced visualization configuration. If None, uses defaults. Returns ------- Figure Matplotlib figure with both plots Notes ----- Priority: explicit parameters > visualization_config > defaults """ from ..core.config import VisualizationConfig # Use default config if none provided if visualization_config is None: visualization_config = VisualizationConfig() # Apply parameter priority: explicit > config > defaults if figsize is None: figsize = visualization_config.figsize if sigma is None: sigma = visualization_config.sigma_clip_sigma # Create flag mask if quality filtering is requested bad_flags_mask = None if apply_quality_filters: from ..utils.helpers import check_flag_bits, create_flag_mask if bad_flags is None: bad_flags = [0, 1, 2, 6, 7, 9, 10, 11, 14, 15, 17, 19] bad_flags_mask = create_flag_mask(bad_flags) # Log filtering statistics good_count = 0 rejected_count = 0 for p in photometry_results: snr = p.flux / p.flux_error if p.flux_error > 0 else 0.0 fails_snr = snr < sigma_threshold fails_flag = check_flag_bits(p.flag, bad_flags_mask) if fails_snr or fails_flag: rejected_count += 1 else: good_count += 1 logger.info( f"Quality filtering: {len(photometry_results)} total points " f"({good_count} good, {rejected_count} rejected - shown as crosses)" ) # Create figure with two subplots using constrained layout to handle colorbars fig, (ax1, ax2) = plt.subplots( 2, 1, figsize=figsize, gridspec_kw={"height_ratios": [1, 1]}, constrained_layout=True ) # Create spectrum plot (top) with QC classification create_spectrum_plot( photometry_results, ax1, apply_clipping=apply_clipping, sigma=sigma, apply_quality_filters=apply_quality_filters, sigma_threshold=sigma_threshold, bad_flags_mask=bad_flags_mask, use_magnitude=use_magnitude, show_errorbars=show_errorbars, ) # Create light curve plot (bottom) with QC classification create_lightcurve_plot( photometry_results, ax2, apply_clipping=apply_clipping, sigma=sigma, apply_quality_filters=apply_quality_filters, sigma_threshold=sigma_threshold, bad_flags_mask=bad_flags_mask, use_magnitude=use_magnitude, show_errorbars=show_errorbars, ) # Save if requested if output_path: output_path.parent.mkdir(parents=True, exist_ok=True) fig.savefig(output_path, dpi=visualization_config.dpi, bbox_inches="tight") logger.info(f"Saved combined plot to {output_path}") return fig
[docs] def plot_summary_statistics(photometry_results: List[PhotometryResult], output_path: Optional[Path] = None) -> Figure: """ Create summary statistics plots. Parameters ---------- photometry_results : List[PhotometryResult] Photometry measurements output_path : Path, optional Path to save figure Returns ------- Figure Figure with summary plots """ fig, axes = plt.subplots(2, 2, figsize=(12, 8)) # 1. Histogram of wavelengths ax = axes[0, 0] wavelengths = [p.wavelength for p in photometry_results] ax.hist(wavelengths, bins=20, alpha=0.7, edgecolor="black") ax.set_xlabel("Wavelength (μm)") ax.set_ylabel("Count") ax.set_title("Wavelength Distribution") ax.grid(True, alpha=0.3) # 2. Band distribution ax = axes[0, 1] bands = [p.band for p in photometry_results] unique_bands, counts = np.unique(bands, return_counts=True) ax.bar(unique_bands, counts, alpha=0.7, edgecolor="black") ax.set_xlabel("Band") ax.set_ylabel("Count") ax.set_title("Observations per Band") ax.grid(True, alpha=0.3, axis="y") # 3. SNR distribution ax = axes[1, 0] snrs = [p.flux / p.flux_error for p in photometry_results if p.flux_error > 0] ax.hist(snrs, bins=20, alpha=0.7, edgecolor="black") ax.set_xlabel("Signal-to-Noise Ratio") ax.set_ylabel("Count") ax.set_title("SNR Distribution") ax.grid(True, alpha=0.3) ax.set_xlim(0, np.percentile(snrs, 95) if snrs else 10) # 4. Time coverage ax = axes[1, 1] bands_unique = sorted(set(bands)) band_colors = cm.get_cmap("rainbow")(np.linspace(0, 1, len(bands_unique))) for band, color in zip(bands_unique, band_colors): band_mjds = [p.mjd for p in photometry_results if p.band == band] ax.scatter([band] * len(band_mjds), band_mjds, alpha=0.6, s=20, color=color) ax.set_xlabel("Band") ax.set_ylabel("MJD") ax.set_title("Temporal Coverage by Band") ax.grid(True, alpha=0.3, axis="y") plt.tight_layout() if output_path: output_path.parent.mkdir(parents=True, exist_ok=True) fig.savefig(output_path, dpi=150, bbox_inches="tight") logger.info(f"Saved summary statistics to {output_path}") return fig