"""
Enhanced SPHEREx Multi-Extension FITS (MEF) file handling with utility methods.
This module provides the SPHERExMEF class and related utilities for working with
SPHEREx spectral imaging data, including:
- Unit conversion (MJy/sr to μJy/arcsec² or other units)
- PSF extraction from 121-zone (11×11) oversampled grid
- WCS coordinate transformation wrappers
- Image cutout extraction
- Zodiacal background subtraction
"""
import logging
import warnings
from contextlib import contextmanager
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, Optional, Tuple, Union
import astropy.units as u
import numpy as np
from astropy import log as astropy_log
from astropy.coordinates import SkyCoord
from astropy.io import fits
from astropy.wcs import WCS
logger = logging.getLogger(__name__)
[docs]
@contextmanager
def suppress_astropy_info():
"""
Context manager to temporarily suppress astropy INFO messages and FITS warnings.
This is needed because SPHEREx provides non-standard WCS headers that trigger
harmless but annoying INFO messages about SIP distortion coefficients and
warnings about redundant SCAMP distortion parameters.
"""
original_level = astropy_log.level
astropy_log.setLevel("WARNING")
try:
# Suppress specific FITSFixedWarning about redundant SCAMP distortion parameters
with warnings.catch_warnings():
# Catch the warning by message pattern, regardless of category
warnings.filterwarnings("ignore", message=".*Removed redundant SCAMP distortion parameters.*")
warnings.filterwarnings("ignore", message=".*because SIP parameters are also present.*")
yield
finally:
astropy_log.setLevel(original_level)
[docs]
@dataclass
class SPHERExMEF:
"""
Container for SPHEREx Multi-Extension FITS data with enhanced methods.
Attributes
----------
filepath : Path
Path to original FITS file
image : np.ndarray
Calibrated flux (units specified by image_unit)
flags : np.ndarray
Bitmap flags for pixel quality
variance : np.ndarray
Variance array (units: image_unit²)
zodi : np.ndarray
Zodiacal light model (same units as image)
psf : np.ndarray
PSF cube (101×101×121) - 11×11 zones, each oversampled 10× (101×101 pixels)
spatial_wcs : WCS
Primary astrometric WCS (RA/Dec)
spectral_wcs : WCS
Alternative spectral WCS (wavelength/bandwidth)
header : fits.Header
Primary IMAGE extension header
psf_header : fits.Header
PSF extension header (contains XCTR_*, YCTR_*, OVERSAMP)
obs_id : str
Observation ID
detector : int
Detector number
mjd : float
Modified Julian Date (average of MJD-BEG and MJD-END)
image_unit : str
Units of image data (e.g., 'MJy/sr', 'uJy/arcsec2')
native_unit : str
Original unit from FITS file (always 'MJy/sr')
_psf_zone_centers : Optional[Dict[int, Tuple[float, float]]]
Cached PSF zone center coordinates (zone_id -> (x, y))
_psf_oversamp : Optional[int]
Cached PSF oversampling factor from header
"""
filepath: Path
image: np.ndarray
flags: np.ndarray
variance: np.ndarray
zodi: np.ndarray
psf: np.ndarray
spatial_wcs: WCS
spectral_wcs: WCS
header: fits.Header
psf_header: fits.Header
obs_id: str
detector: int
mjd: float
psf_fwhm: float = 6.0
image_unit: str = "MJy/sr"
native_unit: str = "MJy/sr"
_psf_zone_centers: Optional[Dict[int, Tuple[float, float]]] = None
_psf_oversamp: Optional[int] = None
@property
def image_zodi_subtracted(self) -> np.ndarray:
"""Return zodiacal light subtracted image with amplitude scaling."""
corrected_image, _ = subtract_zodiacal_background(self.image, self.zodi, self.flags, self.variance)
return corrected_image
@property
def error(self) -> np.ndarray:
"""Return error array (sqrt of variance)."""
return np.sqrt(self.variance)
@property
def psf_oversamp(self) -> int:
"""
Get PSF oversampling factor from header.
Returns
-------
int
Oversampling factor (typically 10 for SPHEREx)
"""
if self._psf_oversamp is None:
self._psf_oversamp = self.psf_header.get("OVERSAMP", 10)
return self._psf_oversamp
@property
def psf_pixel_scale(self) -> float:
"""
Get PSF pixel scale in arcsec/pixel using oversampling factor from header.
The PSF pixel scale is the detector pixel scale divided by the oversampling factor,
which is read from the OVERSAMP keyword in the PSF header.
Returns
-------
float
PSF pixel scale in arcsec/pixel
"""
# Get detector pixel scale from WCS at image center
ny, nx = self.image.shape
detector_pixel_scale = self.get_pixel_scale(nx / 2.0, ny / 2.0, fallback=6.2)
# PSF is oversampled by factor from header (typically 10)
oversamp = self.psf_oversamp
psf_scale = detector_pixel_scale / oversamp
logger.debug(
f"PSF pixel scale: {psf_scale:.4f} arcsec/pixel (detector: {detector_pixel_scale:.4f}, oversamp: {oversamp})"
)
return psf_scale
def _load_psf_zone_centers(self) -> Dict[int, Tuple[float, float]]:
"""
Load PSF zone center coordinates from PSF header.
Reads XCTR_i and YCTR_i keywords (i=1...121) to get the center
coordinates of each PSF zone on the parent detector image.
Returns
-------
Dict[int, Tuple[float, float]]
Dictionary mapping zone_id (1-121) to (x, y) coordinates in FITS convention (1-based)
"""
if self._psf_zone_centers is not None:
return self._psf_zone_centers
import re
xctr = {}
yctr = {}
for key, val in self.psf_header.items():
# Look for keys like XCTR_1, XCTR_2, ..., XCTR_121
xm = re.match(r"XCTR_(\d+)", key)
if xm:
zone_id = int(xm.group(1))
xctr[zone_id] = float(val)
# Look for keys like YCTR_1, YCTR_2, ..., YCTR_121
ym = re.match(r"YCTR_(\d+)", key)
if ym:
zone_id = int(ym.group(1))
yctr[zone_id] = float(val)
# Build dictionary of zone centers
zone_centers = {}
for zone_id in xctr.keys():
if zone_id in yctr:
zone_centers[zone_id] = (xctr[zone_id], yctr[zone_id])
# Log appropriate message based on zone count
n_zones = len(zone_centers)
if n_zones == 1:
logger.info(f"Single-zone PSF detected (zone {list(zone_centers.keys())[0]})")
elif len(xctr) != len(yctr):
logger.warning(f"Mismatch in PSF zone counts: {len(xctr)} X centers vs {len(yctr)} Y centers")
else:
logger.debug(f"Loaded {n_zones} PSF zone centers from header")
# Cache for future use
self._psf_zone_centers = zone_centers
return zone_centers
# ==================== WCS Wrapper Methods ====================
[docs]
def world_to_pixel(self, ra: float, dec: float) -> Tuple[float, float]:
"""
Convert world coordinates (RA/Dec) to pixel coordinates.
Parameters
----------
ra : float
Right ascension in degrees
dec : float
Declination in degrees
Returns
-------
x, y : float
Pixel coordinates (0-based)
"""
coord = SkyCoord(ra=ra * u.deg, dec=dec * u.deg)
x, y = self.spatial_wcs.world_to_pixel(coord)
# Check if coordinates are within image
ny, nx = self.image.shape
if not (0 <= x < nx and 0 <= y < ny):
logger.warning(f"Coordinates ({x:.1f}, {y:.1f}) outside image bounds ({nx}, {ny})")
return float(x), float(y)
[docs]
def pixel_to_world(self, x: float, y: float) -> SkyCoord:
"""
Convert pixel coordinates to world coordinates (RA/Dec).
Parameters
----------
x, y : float
Pixel coordinates (0-based)
Returns
-------
SkyCoord
Sky coordinates with RA/Dec
"""
return self.spatial_wcs.pixel_to_world(x, y)
[docs]
def pixel_to_wavelength(self, x: float, y: float) -> Tuple[float, float]:
"""
Get wavelength and bandwidth at pixel position.
Uses the spectral WCS (alternative 'W') which provides wavelength
information via lookup table for SPHEREx's 2D spectral mapping.
Parameters
----------
x, y : float
Pixel coordinates (0-based)
Returns
-------
wavelength : float
Central wavelength in microns
bandwidth : float
Bandwidth in microns
"""
# Use spectral WCS to get wavelength info
spectral_coords = self.spectral_wcs.pixel_to_world(x, y)
# spectral_coords is a tuple of (wavelength, bandpass)
wavelength = spectral_coords[0].to(u.micron).value
bandwidth = spectral_coords[1].to(u.micron).value
return wavelength, bandwidth
[docs]
def get_pixel_scale(self, x: float, y: float, fallback: float = 6.2) -> float:
"""
Calculate the pixel scale in arcsec/pixel at a given position.
This accounts for any WCS distortions and provides the actual pixel scale
at the specified position.
Parameters
----------
x, y : float
Pixel coordinates (0-based)
fallback : float
Fallback pixel scale in arcsec/pixel if WCS fails (default: 6.2 for SPHEREx)
Returns
-------
float
Pixel scale in arcseconds per pixel
"""
try:
# Get the pixel scale from the WCS at the specified position
pixel_scales = self.spatial_wcs.proj_plane_pixel_scales() # Returns scales in degrees/pixel
# For SPHEREx, pixels should be roughly square, so take the geometric mean
pixel_scale_deg = float(np.sqrt(pixel_scales[0] * pixel_scales[1]).value)
# Convert from degrees to arcseconds
pixel_scale_arcsec = pixel_scale_deg * 3600.0
logger.debug(f"WCS pixel scale at ({x:.1f}, {y:.1f}): {pixel_scale_arcsec:.3f} arcsec/pixel")
return pixel_scale_arcsec
except Exception as e:
logger.warning(f"Failed to calculate pixel scale from WCS: {e}. Using fallback {fallback} arcsec/pixel")
return fallback
# ==================== PSF Extraction Methods ====================
[docs]
def get_psf_fwhm_estimate(self, x: float, y: float) -> float:
"""
Estimate PSF FWHM at a given position using radial profile analysis.
Uses photutils.profiles.RadialProfile to convert the 2D PSF to a 1D
radial profile, then interpolates to find the half-maximum radius.
Parameters
----------
x, y : float
Pixel coordinates (0-based)
Returns
-------
float
Estimated FWHM in arcseconds
Raises
------
Exception
If FWHM estimation fails. Caller should handle the error and
decide on fallback behavior.
Notes
-----
PSF is 101×101 pixels with 10× oversampling. The PSF center is determined
from the peak pixel position (not geometric center), and radial sampling
is adapted to the PSF size. Uses constant extrapolation at radial profile
boundaries for robust interpolation.
"""
from photutils.profiles import RadialProfile
psf = self.extract_psf_at_position(x, y)
# Find PSF center from peak pixel position (not geometric center)
# np.unravel_index returns (row, col) = (y, x) for 2D array
peak_idx = np.unravel_index(psf.argmax(), psf.shape)
center_y, center_x = peak_idx
center_xy = (float(center_x), float(center_y)) # (x, y) order for photutils
logger.debug(f"PSF peak at pixel ({center_x}, {center_y}), value = {psf[center_y, center_x]:.6e}")
# Find peak value for half-maximum calculation
peak_value = psf.max()
half_max = peak_value / 2.0
# Get PSF dimensions for max radius calculation
psf_height, psf_width = psf.shape
# Determine max radius from PSF size (use half the smaller dimension)
max_radius = min(psf_height, psf_width) / 2.0
# Create radial bins using linspace for even sampling
# Use 500 points for smooth interpolation
n_radial_bins = 500
radii = np.linspace(0, max_radius, n_radial_bins)
# Create 1D radial profile by averaging flux in concentric rings
rp = RadialProfile(psf, center_xy, radii)
radial_profile = rp.profile
radial_bins = rp.radius
# Find radius where profile = half_max using root-finding
# This is more robust than np.interp as it doesn't require monotonicity
from scipy.interpolate import interp1d
from scipy.optimize import fsolve
# Create interpolation function: f(r) -> intensity
# Use edge values for constant extrapolation beyond radial_bins range
fill_value = (radial_profile[0], radial_profile[-1])
f_interp = interp1d(radial_bins, radial_profile, kind="quadratic", fill_value=fill_value, bounds_error=False)
# Find root of: f(r) - half_max = 0
# Initial guess: 10 PSF pixels (reasonable for typical PSF core)
fwhm_radius_psf_pixels = fsolve(lambda r: f_interp(r) - half_max, x0=5.0)[0]
# Convert from PSF pixels to arcseconds
# PSF pixel scale is 1/10 of detector pixel scale (10× oversampling)
fwhm_diameter_arcsec = fwhm_radius_psf_pixels * 2.0 * self.psf_pixel_scale
logger.debug(
f"Estimated PSF FWHM at ({x:.1f}, {y:.1f}): "
f"{fwhm_diameter_arcsec:.3f} arcsec "
f"(radius: {fwhm_radius_psf_pixels:.1f} PSF pixels)"
)
return fwhm_diameter_arcsec
# ==================== Cutout Methods ====================
[docs]
def get_cutout(
self,
position: Union[Tuple[float, float], SkyCoord],
size: Union[int, Tuple[int, int]],
include_extensions: Optional[list] = None,
mode: str = "trim",
) -> Dict[str, np.ndarray]:
"""
Extract a cutout from the MEF data using astropy.nddata.Cutout2D.
Returns a dictionary with cutout arrays and updated WCS. Uses Cutout2D
for proper WCS handling, which automatically updates WCS keywords for
the cutout region.
Parameters
----------
position : tuple of float or SkyCoord
Position of cutout center. Can be either:
- Tuple (x, y) of pixel coordinates (0-based)
- SkyCoord object with RA/Dec
size : int or tuple of int
Size of cutout in pixels. If int, creates square cutout.
If tuple (height, width), creates rectangular cutout.
Note: Order is (height, width) to match numpy array shape convention.
include_extensions : list of str, optional
List of extensions to include in cutout.
Default: ['image', 'flags', 'variance', 'zodi']
Available: ['image', 'flags', 'variance', 'zodi', 'psf']
mode : str, optional
Mode for handling cutouts that extend beyond image boundaries.
Options: 'trim' (default), 'partial', 'strict'
See astropy.nddata.Cutout2D documentation for details.
Returns
-------
dict
Dictionary containing:
- Requested extension cutouts (np.ndarray)
- 'wcs': Updated WCS for the cutout region
- 'position_original': Original position in parent image (x, y)
- 'position_cutout': Position in cutout coordinates (x, y)
- 'bbox_original': Bounding box in original image (BoundingBox object)
- 'shape': Shape of cutout (height, width)
Examples
--------
>>> # Square cutout using pixel coordinates
>>> cutout = mef.get_cutout(position=(1020, 1020), size=50)
>>> cutout_image = cutout['image']
>>> cutout_wcs = cutout['wcs']
>>>
>>> # Rectangular cutout using SkyCoord
>>> from astropy.coordinates import SkyCoord
>>> import astropy.units as u
>>> coord = SkyCoord(ra=304.5*u.deg, dec=42.3*u.deg)
>>> cutout = mef.get_cutout(position=coord, size=(100, 50),
... include_extensions=['image', 'flags'])
>>>
>>> # Access cutout center in cutout pixel coordinates
>>> center_x_cutout, center_y_cutout = cutout['position_cutout']
"""
from astropy.nddata import Cutout2D
if include_extensions is None:
include_extensions = ["image", "flags", "variance", "zodi"]
# Create cutout from image using Cutout2D (handles WCS automatically)
try:
cutout_obj = Cutout2D(self.image, position=position, size=size, wcs=self.spatial_wcs, mode=mode, copy=True)
except Exception as e:
logger.error(f"Failed to create cutout: {e}")
raise
# Build result dictionary with metadata and image cutout
cutout_dict = {
"wcs": cutout_obj.wcs, # Updated WCS for cutout
"position_original": cutout_obj.position_original, # (x, y) in original image
"position_cutout": cutout_obj.position_cutout, # (x, y) in cutout
"bbox_original": cutout_obj.bbox_original, # BoundingBox in original image
"shape": cutout_obj.shape, # (height, width)
}
# Add image cutout if requested (reuse cutout_obj to avoid recalculating)
if "image" in include_extensions:
cutout_dict["image"] = cutout_obj.data
# Extract cutouts for other extensions using the same bounding box
# Use slices from cutout_obj to avoid recalculating Cutout2D
extension_map = {
"flags": self.flags,
"variance": self.variance,
"zodi": self.zodi,
}
for ext_name in include_extensions:
if ext_name == "image":
continue # Already handled above
elif ext_name in extension_map:
# Use the bounding box from cutout_obj to slice directly
cutout_ext = Cutout2D(
extension_map[ext_name], position=position, size=size, wcs=self.spatial_wcs, mode=mode, copy=True
)
cutout_dict[ext_name] = cutout_ext.data
elif ext_name == "psf":
# PSF is 3D and position-dependent across the detector
# Include full PSF cube for cutouts
cutout_dict["psf"] = self.psf.copy()
else:
logger.warning(f"Unknown extension '{ext_name}' requested for cutout, skipping")
# Log cutout info
logger.info(
f"Extracted cutout: position={position}, size={size}, "
f"shape={cutout_obj.shape}, "
f"bbox=[{cutout_obj.xmin_original}:{cutout_obj.xmax_original}, {cutout_obj.ymin_original}:{cutout_obj.ymax_original}]"
)
return cutout_dict
# ==================== File I/O Functions ====================
[docs]
def read_spherex_mef(filepath: Path, target_unit: Optional[str] = None) -> SPHERExMEF:
"""
Read SPHEREx Multi-Extension FITS file with optional unit conversion.
Parameters
----------
filepath : Path
Path to SPHEREx MEF file
target_unit : str, optional
Target unit for image data. Options:
- None: Keep native MJy/sr (default)
- 'uJy/arcsec2' or 'microJy/arcsec2': Convert to μJy/arcsec²
- 'Jy/arcsec2': Convert to Jy/arcsec²
- 'MJy/sr': No conversion (native)
Returns
-------
SPHERExMEF
Container with all MEF data. If unit conversion was applied,
image, variance, and zodi arrays are in target units.
Notes
-----
Conversion from MJy/sr to μJy/arcsec²:
1 MJy/sr = 1e6 Jy/sr = 1e6 Jy / (206265 arcsec)² = 0.02350443 μJy/arcsec²
Conversion factor: 1 MJy/sr × 0.02350443 = μJy/arcsec²
"""
logger.info(f"Reading SPHEREx MEF: {filepath}")
with fits.open(filepath) as hdulist:
# Verify expected structure
if len(hdulist) < 7:
raise ValueError(f"Expected at least 7 extensions, got {len(hdulist)}")
# Read IMAGE extension
image_hdu = hdulist["IMAGE"]
image_data = image_hdu.data.astype(np.float32)
image_header = image_hdu.header
# Verify units are as expected (MJy/sr)
bunit = image_header.get("BUNIT", "").strip().upper()
if bunit and bunit not in ["MJY/SR", "MJY / SR", "MJY SR-1", "MJY/STERADIAN"]:
logger.warning(f"Unexpected BUNIT '{bunit}' in {filepath}. Expected 'MJy/sr'")
elif bunit:
logger.debug(f"Verified BUNIT: {bunit}")
else:
logger.warning(f"Missing BUNIT header in {filepath}. Assuming MJy/sr")
# Read other extensions
flags_data = hdulist["FLAGS"].data.astype(np.int32)
variance_data = hdulist["VARIANCE"].data.astype(np.float32)
zodi_data = hdulist["ZODI"].data.astype(np.float32)
# Read PSF extension (both data and header for zone information)
psf_data = hdulist["PSF"].data.astype(np.float32)
psf_header = hdulist["PSF"].header
# Load WCS with suppressed warnings about SCAMP/SIP distortion parameters
with suppress_astropy_info():
# Load spatial WCS (primary)
spatial_wcs = WCS(image_header)
# Load spectral WCS (alternative 'W')
# Need to pass HDUList for lookup table access
spectral_wcs = WCS(header=image_header, fobj=hdulist, key="W")
# Disable SIP distortion for spectral WCS
spectral_wcs.sip = None
# Extract metadata
obs_id = image_header.get("OBSID", filepath.stem)
detector = image_header.get("DETECTOR", 0)
# Calculate MJD
t_min = image_header.get("MJD-BEG", 0)
t_max = image_header.get("MJD-END", 0)
mjd = (t_min + t_max) / 2.0
# Read PSF FWHM from header (arcseconds)
psf_fwhm = image_header.get("PSF_FWHM", None)
if psf_fwhm is None:
psf_fwhm = 6.0
logger.warning(f"PSF_FWHM not found in header of {filepath.name}, using fallback {psf_fwhm} arcsec")
# Apply unit conversion if requested
native_unit = "MJy/sr"
if target_unit is not None and target_unit.lower() not in ["mjy/sr", "mjy / sr"]:
conversion_factor = _get_unit_conversion_factor(native_unit, target_unit)
image_data = image_data * conversion_factor
variance_data = variance_data * (conversion_factor**2) # Variance scales as square
zodi_data = zodi_data * conversion_factor
final_unit = _normalize_unit_string(target_unit)
logger.info(f"Converted units: {native_unit} → {final_unit} (factor: {conversion_factor:.6e})")
else:
final_unit = native_unit
mef = SPHERExMEF(
filepath=filepath,
image=image_data,
flags=flags_data,
variance=variance_data,
zodi=zodi_data,
psf=psf_data,
psf_header=psf_header,
spatial_wcs=spatial_wcs,
spectral_wcs=spectral_wcs,
header=image_header,
obs_id=obs_id,
detector=detector,
mjd=mjd,
psf_fwhm=psf_fwhm,
image_unit=final_unit,
native_unit=native_unit,
)
logger.info(f"Loaded {obs_id}: detector {detector}, shape {image_data.shape}, units {final_unit}")
return mef
def _get_unit_conversion_factor(from_unit: str, to_unit: str) -> float:
"""
Get conversion factor between surface brightness units using astropy.units.
Parameters
----------
from_unit : str
Source unit (expected: 'MJy/sr')
to_unit : str
Target unit
Returns
-------
float
Multiplication factor to convert from_unit to to_unit
Raises
------
ValueError
If unit conversion is not supported or units are incompatible
"""
# Normalize unit strings to astropy-compatible format
from_unit_norm = from_unit.lower().replace(" ", "").replace("_", "")
to_unit_norm = to_unit.lower().replace(" ", "").replace("_", "")
# Map common variations to astropy unit strings
unit_map = {
"mjy/sr": "MJy/sr",
"mjy/steradian": "MJy/sr",
"ujy/arcsec2": "uJy/arcsec2",
"microjy/arcsec2": "uJy/arcsec2",
"jy/arcsec2": "Jy/arcsec2",
"mjy/arcsec2": "MJy/arcsec2",
}
# Get standardized unit strings
from_unit_std = unit_map.get(from_unit_norm)
to_unit_std = unit_map.get(to_unit_norm)
if from_unit_std is None:
raise ValueError(f"Unrecognized source unit: '{from_unit}'. Expected 'MJy/sr' or similar.")
if to_unit_std is None:
raise ValueError(
f"Unrecognized target unit: '{to_unit}'. Supported: uJy/arcsec2, Jy/arcsec2, MJy/arcsec2, MJy/sr"
)
# Use astropy.units for conversion
try:
from_quantity = 1.0 * u.Unit(from_unit_std)
to_quantity = from_quantity.to(u.Unit(to_unit_std))
conversion_factor = to_quantity.value
logger.debug(f"Unit conversion factor: {from_unit_std} → {to_unit_std} = {conversion_factor:.6e}")
return conversion_factor
except Exception as e:
raise ValueError(
f"Failed to convert {from_unit} → {to_unit}: {e}. "
f"Units may be incompatible (not both surface brightness units)."
)
def _normalize_unit_string(unit: str) -> str:
"""Normalize unit string for display."""
unit_lower = unit.lower().replace(" ", "")
if unit_lower in ["ujy/arcsec2", "microjy/arcsec2"]:
return "μJy/arcsec²"
elif unit_lower == "jy/arcsec2":
return "Jy/arcsec²"
elif unit_lower == "mjy/arcsec2":
return "MJy/arcsec²"
elif unit_lower in ["mjy/sr", "mjy/sr"]:
return "MJy/sr"
else:
return unit
# ==================== Legacy Compatibility Functions ====================
# These functions are kept for backward compatibility with existing code
[docs]
def get_pixel_coordinates(mef: SPHERExMEF, ra: float, dec: float) -> Tuple[float, float]:
"""
Convert RA/Dec to pixel coordinates (legacy function).
Use mef.world_to_pixel() instead for new code.
"""
return mef.world_to_pixel(ra, dec)
[docs]
def get_wavelength_at_position(mef: SPHERExMEF, x: float, y: float) -> Tuple[float, float]:
"""
Get wavelength and bandwidth at pixel position (legacy function).
Use mef.pixel_to_wavelength() instead for new code.
"""
return mef.pixel_to_wavelength(x, y)
[docs]
def get_pixel_scale_at_position(wcs: WCS, x: float, y: float, pixel_scale_fallback: float = 6.2) -> float:
"""
Calculate pixel scale at position (legacy function).
Use mef.get_pixel_scale() instead for new code.
"""
# This function is used by photometry.py, so keep it functional
try:
pixel_scales = wcs.proj_plane_pixel_scales()
pixel_scale_deg = float(np.sqrt(pixel_scales[0] * pixel_scales[1]).value)
pixel_scale_arcsec = pixel_scale_deg * 3600.0
return pixel_scale_arcsec
except Exception as e:
logger.warning(f"Failed to calculate pixel scale from WCS: {e}. Using fallback {pixel_scale_fallback}")
return pixel_scale_fallback
# ==================== Flag Handling Functions ====================
[docs]
def get_flag_info(flag_value: int) -> Dict[str, bool]:
"""
Decode flag bitmap into individual flags.
Parameters
----------
flag_value : int
Combined flag bitmap value
Returns
-------
Dict[str, bool]
Dictionary of flag names and their states
"""
# Flag definitions from SPHEREx
flags = {
"TRANSIENT": 0,
"OVERFLOW": 1,
"SUR_ERROR": 2,
"NONFUNC": 6,
"DICHROIC": 7,
"MISSING_DATA": 9,
"HOT": 10,
"COLD": 11,
"FULLSAMPLE": 12,
"PHANMISS": 14,
"NONLINEAR": 15,
"PERSIST": 17,
"OUTLIER": 19,
"SOURCE": 21,
}
flag_states = {}
for name, bit in flags.items():
flag_states[name] = bool(flag_value & (1 << bit))
return flag_states
# Pre-computed combined bitmasks for create_background_mask.
# Single bitwise AND replaces per-flag loops (11+ iterations → 1 operation).
_BAD_FLAG_BITS = 0
for _bit in (0, 1, 2, 6, 7, 9, 10, 11, 14, 15, 17, 19):
_BAD_FLAG_BITS |= 1 << _bit
_SOURCE_BIT = 1 << 21
[docs]
def create_background_mask(flags: np.ndarray, exclude_source: bool = True) -> np.ndarray:
"""
Create mask for background pixels (good for zodiacal matching).
Masks out pixels with problematic flags including non-functional pixels,
outliers, etc. By default, SOURCE-flagged pixels are also excluded.
Parameters
----------
flags : np.ndarray
Flag bitmap array
exclude_source : bool, optional
If True (default), also exclude SOURCE-flagged pixels (bit 21).
If False, SOURCE pixels are kept for local background estimation.
Returns
-------
np.ndarray
Boolean mask (True = good background pixel)
Notes
-----
Bad flag bits: 0=TRANSIENT, 1=OVERFLOW, 2=SUR_ERROR, 6=NONFUNC, 7=DICHROIC,
9=MISSING_DATA, 10=HOT, 11=COLD, 14=PHANMISS, 15=NONLINEAR, 17=PERSIST,
19=OUTLIER, optionally 21=SOURCE.
"""
bad_bits = _BAD_FLAG_BITS | (_SOURCE_BIT if exclude_source else 0)
mask = (flags & bad_bits) == 0
n_good = np.sum(mask)
n_total = mask.size
logger.info(f"Background mask: {n_good}/{n_total} ({n_good / n_total * 100:.1f}%) pixels available")
return mask
# ==================== Zodiacal Background Subtraction ====================
[docs]
def estimate_zodiacal_scaling(
image: np.ndarray, zodi: np.ndarray, mask: np.ndarray, variance: Optional[np.ndarray] = None
) -> float:
"""
Estimate scaling factor to match zodiacal model to observed background.
Uses least-squares fitting on uncontaminated pixels to find the
multiplicative factor that best matches the zodi model to the data.
Parameters
----------
image : np.ndarray
Observed image data
zodi : np.ndarray
Zodiacal model
mask : np.ndarray
Boolean mask (True = good background pixels)
variance : np.ndarray, optional
Variance array for weighted fitting
Returns
-------
float
Scaling factor for zodiacal model
"""
# Extract background pixels
image_bg = image[mask]
zodi_bg = zodi[mask]
if len(image_bg) == 0:
logger.warning("No uncontaminated pixels for zodiacal scaling - using factor 1.0")
return 1.0
# Remove pixels where zodi model is zero to avoid division issues
nonzero_mask = zodi_bg != 0
if np.sum(nonzero_mask) == 0:
logger.warning("Zodiacal model is zero everywhere - using factor 1.0")
return 1.0
image_bg = image_bg[nonzero_mask]
zodi_bg = zodi_bg[nonzero_mask]
# Weighted least squares if variance is provided
if variance is not None:
var_bg = variance[mask][nonzero_mask]
# Avoid zero/negative variance
valid_var = var_bg > 0
if np.sum(valid_var) > 0:
weights = 1.0 / var_bg[valid_var]
image_bg = image_bg[valid_var]
zodi_bg = zodi_bg[valid_var]
# Weighted least squares: scale = sum(w*img*zodi) / sum(w*zodi^2)
scale_factor = np.sum(weights * image_bg * zodi_bg) / np.sum(weights * zodi_bg**2)
else:
# Fall back to unweighted
scale_factor = np.sum(image_bg * zodi_bg) / np.sum(zodi_bg**2)
else:
# Unweighted least squares: scale = sum(img*zodi) / sum(zodi^2)
scale_factor = np.sum(image_bg * zodi_bg) / np.sum(zodi_bg**2)
logger.info(f"Zodiacal scaling factor: {scale_factor:.4f}")
return scale_factor
[docs]
def subtract_zodiacal_background(
image: np.ndarray,
zodi: np.ndarray,
flags: np.ndarray,
variance: Optional[np.ndarray] = None,
zodi_scale_min: float = 0.0,
zodi_scale_max: float = 10.0,
bg_fraction_reject_level: float = 0.5,
static_zodi: bool = False,
) -> Tuple[np.ndarray, float]:
"""
Subtract zodiacal light background from image with amplitude scaling.
Uses uncontaminated background pixels to determine the optimal
scaling factor for the zodiacal model before subtraction.
Parameters
----------
image : np.ndarray
Original image
zodi : np.ndarray
Zodiacal light model
flags : np.ndarray
Flag bitmap array
variance : np.ndarray, optional
Variance array for weighted fitting
zodi_scale_min : float
Minimum allowed zodiacal scaling factor
zodi_scale_max : float
Maximum allowed zodiacal scaling factor
bg_fraction_reject_level : float, optional
Minimum fraction of background pixels required for zodiacal estimation.
If the fraction of available background pixels is below this threshold,
the function will use a fallback scale factor of 1.0. Default is 0.5.
static_zodi : bool, optional
If True, use the zodiacal model as-is (scale=1.0) without fitting.
Skips background masking and scaling factor estimation entirely.
Returns
-------
corrected_image : np.ndarray
Background-subtracted image
scale_factor : float
Applied scaling factor for the zodiacal model
"""
if static_zodi:
scale_factor = 1.0
corrected_image = image - zodi
logger.info("Subtracted zodiacal background (static, scale=1.0)")
return corrected_image, scale_factor
# Create mask for background estimation
bg_mask = create_background_mask(flags)
# Check if enough pixels are available for zodiacal estimation
n_bg_pixels = np.sum(bg_mask)
n_total_pixels = bg_mask.size
bg_fraction = n_bg_pixels / n_total_pixels if n_total_pixels > 0 else 0.0
if bg_fraction < bg_fraction_reject_level:
logger.warning(
f"Insufficient background pixels for zodiacal estimation "
f"({n_bg_pixels}/{n_total_pixels} = {bg_fraction * 100:.1f}% < {bg_fraction_reject_level * 100:.1f}%) - using fallback scale factor 1.0"
)
scale_factor = 1.0
else:
# Estimate zodiacal scaling factor
scale_factor = estimate_zodiacal_scaling(image, zodi, bg_mask, variance)
# Validate scale factor
if scale_factor <= zodi_scale_min or scale_factor > zodi_scale_max:
logger.warning(
f"Unusual scaling factor {scale_factor:.4f} (outside [{zodi_scale_min}, {zodi_scale_max}]) - using 1.0"
)
scale_factor = 1.0
# Apply scaled subtraction
corrected_image = image - (scale_factor * zodi)
logger.info(f"Subtracted zodiacal background with scaling factor {scale_factor:.4f}")
return corrected_image, scale_factor