"""
High-level Drizzle3D pipeline.
Orchestrates: query → download → per-detector drizzle → save.
"""
import logging
from pathlib import Path
from typing import Dict, List, Optional
import numpy as np
from astropy.wcs import WCS
from .accumulate import DrizzleCube, drizzle_image
from .config import Drizzle3DConfig
from .grid import build_output_wcs
from .io import save_cube
from .query import download_observations, query_observations
from .spatial import compute_spatial_mapping
from .spectral import build_z_grid
logger = logging.getLogger(__name__)
[docs]
def drizzle_detector(
fits_paths: List[Path],
config: Drizzle3DConfig,
detector: int,
output_wcs: WCS,
) -> Optional[Path]:
"""Run the drizzle pipeline for one detector.
Parameters
----------
fits_paths : list of Path
Downloaded FITS files for this detector.
config : Drizzle3DConfig
Drizzle configuration.
detector : int
Detector number (1–6).
output_wcs : WCS
Output spatial WCS.
Returns
-------
Path or None
Path to the output FITS file, or None if no valid inputs.
"""
# Build Z grid
zgrid = build_z_grid(detector, config.z_oversample, config.z_lambda_edges)
pixscale = config.effective_pixscale()
# Allocate accumulation cube
cube = DrizzleCube.create(output_wcs, pixscale, zgrid, config, detector)
output_shape = (config.output_ny(), config.output_nx())
# Build exclude mask from config.exclude_flags
from ..utils.helpers import create_flag_mask
exclude_bits = create_flag_mask(config.exclude_flags)
n_rejected = 0
for fpath in fits_paths:
try:
img_data, var_data, flag_data, zodi_data, spatial_wcs, spectral_wcs = _read_input_fits(
fpath, config.subtract_zodi, static_zodi=config.static_zodi
)
except Exception as e:
logger.warning(f"Skipping {fpath.name}: {e}")
n_rejected += 1
continue
if spatial_wcs is None:
logger.warning(f"Skipping {fpath.name}: no spatial WCS")
n_rejected += 1
continue
# Compute spatial mapping
valid_mask, pix_idx, out_y, out_x, f_xy = compute_spatial_mapping(
spatial_wcs,
img_data.shape,
output_wcs,
output_shape,
config.xy_shrink,
)
if len(out_y) == 0:
logger.debug(f"Skipping {fpath.name}: no spatial overlap")
n_rejected += 1
continue
# Get per-pixel wavelength from spectral WCS
lambda_c_map, delta_lambda_map = _extract_wavelength_maps(spectral_wcs, img_data.shape)
# Build pixel exclusion mask from flags
exclude_mask = None
if exclude_bits != 0:
exclude_mask = (flag_data & exclude_bits) != 0
# Accumulate
drizzle_image(
cube=cube,
image=img_data,
variance=var_data,
flags=flag_data,
lambda_c_map=lambda_c_map,
delta_lambda_map=delta_lambda_map,
pixel_idx=pix_idx,
out_y=out_y,
out_x=out_x,
f_xy=f_xy,
exclude_mask=exclude_mask,
)
cube.n_rejected = n_rejected
cube.finalize_masks()
# Save
output_path = Path(config.output_dir) / f"drizzle_D{detector}.fits"
save_cube(cube, output_path, overwrite=config.overwrite)
return output_path
[docs]
def drizzle(config: Drizzle3DConfig) -> Dict[int, Path]:
"""Top-level entry point: query → download → drizzle → save.
Parameters
----------
config : Drizzle3DConfig
Complete drizzle configuration.
Returns
-------
dict
{detector_id: output_path} for each successfully processed detector.
Examples
--------
>>> from spxquery.drizzle3d import Drizzle3DConfig, drizzle
>>> config = Drizzle3DConfig(
... center_ra=186.4536,
... center_dec=33.5468,
... width=30.0,
... height=30.0,
... detector=3,
... )
>>> results = drizzle(config)
"""
logger.info(f"Starting Drizzle3D pipeline: center=({config.center_ra}, {config.center_dec})")
logger.info(
f" Region: {config.width}'×{config.height}', "
f"detector={'all' if config.detector == 0 else f'D{config.detector}'}"
)
# 1. Build output spatial WCS
output_wcs = build_output_wcs(config)
# 2. Query IRSA
obs_by_det = query_observations(config)
if not obs_by_det:
logger.warning("No observations found for the target region")
return {}
results: Dict[int, Path] = {}
for det, observations in sorted(obs_by_det.items()):
logger.info(f"Processing D{det}: {len(observations)} observations")
# 3. Download / resolve from mirror
fits_paths = download_observations(
observations,
output_dir=config.output_dir,
max_workers=config.download_workers,
skip_existing=config.skip_existing,
data_mirror=config.data_mirror,
)
if not fits_paths:
logger.warning(f"D{det}: no files downloaded, skipping")
continue
# 4. Drizzle
output_path = drizzle_detector(fits_paths, config, det, output_wcs)
if output_path is not None:
results[det] = output_path
logger.info(f"Drizzle3D complete: {len(results)} detector cubes produced")
for det, path in sorted(results.items()):
logger.info(f" D{det}: {path}")
return results
def _read_input_fits(filepath: Path, subtract_zodi: bool, static_zodi: bool = False):
"""Read a SPHEREx input FITS file using the shared MEF reader.
Returns
-------
tuple
(image, variance, flags, zodi, spatial_wcs, spectral_wcs)
"""
from ..utils.spherex_mef import read_spherex_mef, subtract_zodiacal_background
mef = read_spherex_mef(filepath)
image = mef.image
variance = mef.variance
flags = mef.flags.astype(np.uint32)
zodi = mef.zodi
spatial_wcs = mef.spatial_wcs
spectral_wcs = mef.spectral_wcs
if subtract_zodi:
image, _ = subtract_zodiacal_background(image, zodi, flags, variance, static_zodi=static_zodi)
return image, variance, flags, zodi, spatial_wcs, spectral_wcs
def _extract_wavelength_maps(spectral_wcs: WCS, shape) -> tuple:
"""Extract per-pixel (λ_c, Δλ) maps from the spectral WCS.
Parameters
----------
spectral_wcs : WCS
Spectral WCS (alternative 'W' key) from the input FITS.
shape : tuple
(ny, nx) image shape.
Returns
-------
lambda_c_map : np.ndarray
(ny, nx) central wavelength [μm].
delta_lambda_map : np.ndarray
(ny, nx) bandwidth [μm].
"""
import astropy.units as u
ny, nx = shape
yy, xx = np.mgrid[0:ny, 0:nx]
xx_flat = xx.ravel().astype(np.float64)
yy_flat = yy.ravel().astype(np.float64)
try:
result = spectral_wcs.pixel_to_world(xx_flat, yy_flat)
# spectral WCS returns (wavelength, bandwidth) quantities
lambda_c_flat = result[0].to(u.micron).value
delta_lambda_flat = result[1].to(u.micron).value
except Exception as e:
logger.warning(f"Failed to extract wavelength maps: {e}")
lambda_c_flat = np.full(ny * nx, np.nan)
delta_lambda_flat = np.full(ny * nx, np.nan)
lambda_c_map = lambda_c_flat.reshape(ny, nx).astype(np.float64)
delta_lambda_map = delta_lambda_flat.reshape(ny, nx).astype(np.float64)
return lambda_c_map, delta_lambda_map