Source code for spxquery.drizzle3d.spatial

"""
Spatial (XY) drizzle kernel for Drizzle3D.

Computes the mapping from input image pixels to output grid pixels,
including pixfrac shrinkage and fractional overlap areas.

Uses vectorised numpy operations for efficiency on full-frame images.
"""

import logging
from typing import Tuple

import numpy as np
from astropy.wcs import WCS

logger = logging.getLogger(__name__)


[docs] def compute_spatial_mapping( input_wcs: WCS, input_shape: Tuple[int, int], output_wcs: WCS, output_shape: Tuple[int, int], xy_shrink: float, ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]: """Compute spatial overlap mapping from input pixels to output grid. For each input pixel, computes the output pixel(s) it contributes to and the fractional overlap area using bilinear weight distribution. Parameters ---------- input_wcs : WCS Input image spatial WCS. input_shape : tuple of int (ny_in, nx_in) shape of the input image. output_wcs : WCS Output grid spatial WCS (2-D tangent plane). output_shape : tuple of int (ny_out, nx_out) shape of the output grid. xy_shrink : float Droplet shrink factor (0, 1]. 1.0 = full pixel footprint. Returns ------- valid_mask : np.ndarray (bool) (ny_in, nx_in) — True for pixels that land inside the output grid. pixel_idx : np.ndarray (int64) Flat index into the input image for each output contribution. Use ``rows, cols = np.unravel_index(pixel_idx, input_shape)`` to recover (y, x) coordinates. out_y : np.ndarray (int32) Output Y pixel indices. out_x : np.ndarray (int32) Output X pixel indices. f_xy : np.ndarray (float64) Spatial overlap fractions. """ ny_in, nx_in = input_shape ny_out, nx_out = output_shape n_pix = ny_in * nx_in # Build pixel coordinate grids xx, yy = np.meshgrid(np.arange(nx_in, dtype=np.float64), np.arange(ny_in, dtype=np.float64)) xx_flat = xx.ravel() yy_flat = yy.ravel() # Transform all input pixel centers to sky coordinates sky = input_wcs.pixel_to_world(xx_flat, yy_flat) ra_flat = sky.ra.deg dec_flat = sky.dec.deg # Project onto the output tangent plane out_xf, out_yf = output_wcs.world_to_pixel_values(ra_flat, dec_flat) # Validity mask: pixels that land inside the output grid (with margin) margin = 2.0 valid = ( (out_xf >= -margin) & (out_xf < nx_out + margin) & (out_yf >= -margin) & (out_yf < ny_out + margin) & np.isfinite(out_xf) & np.isfinite(out_yf) ) valid_mask = valid.reshape(ny_in, nx_in) if not np.any(valid): logger.warning("No input pixels overlap the output grid") empty = np.array([], dtype=np.int64) empty_i = np.array([], dtype=np.int32) empty_f = np.array([], dtype=np.float64) return valid_mask, empty, empty_i, empty_i, empty_f # Work with the valid subset valid_idx = np.where(valid)[0] # flat indices into input image vx = out_xf[valid] vy = out_yf[valid] n_valid = len(valid_idx) # Bilinear weight distribution: each input pixel contributes to up to 4 output pixels. # For output pixel (ox, oy), weight = (1 - |vx - (ox+0.5)|) * (1 - |vy - (oy+0.5)|) * shrink² # We only need to check the 4 pixels surrounding the continuous position. # Lower-left corner of the 2×2 neighborhood ox0 = np.floor(vx).astype(np.int32) oy0 = np.floor(vy).astype(np.int32) pix_idx_list = [] out_y_list = [] out_x_list = [] f_xy_list = [] for dy in (0, 1): for dx in (0, 1): ox = ox0 + dx oy = oy0 + dy # Distance from input position to center of this output pixel dist_x = np.abs(vx - (ox + 0.5)) dist_y = np.abs(vy - (oy + 0.5)) # Bilinear weight: nonzero only when dist < 1 weight = np.maximum(1.0 - dist_x, 0.0) * np.maximum(1.0 - dist_y, 0.0) weight *= xy_shrink**2 # shrink factor # Select contributions with nonzero weight and within bounds mask = (weight > 1e-10) & (ox >= 0) & (ox < nx_out) & (oy >= 0) & (oy < ny_out) if not np.any(mask): continue pix_idx_list.append(valid_idx[mask]) out_y_list.append(oy[mask]) out_x_list.append(ox[mask]) f_xy_list.append(weight[mask]) if not pix_idx_list: logger.warning("No valid spatial overlaps after bilinear distribution") empty = np.array([], dtype=np.int64) empty_i = np.array([], dtype=np.int32) empty_f = np.array([], dtype=np.float64) return valid_mask, empty, empty_i, empty_i, empty_f pixel_idx_arr = np.concatenate(pix_idx_list) out_y_arr = np.concatenate(out_y_list) out_x_arr = np.concatenate(out_x_list) f_xy_arr = np.concatenate(f_xy_list) logger.debug(f"Spatial mapping: {n_valid} valid input pixels → {len(out_y_arr)} output contributions") return valid_mask, pixel_idx_arr, out_y_arr, out_x_arr, f_xy_arr
[docs] def compute_spectral_overlaps( lambda_c: float, delta_lambda: float, z_shrink: float, z_edges: np.ndarray, z_centers: np.ndarray, z_widths: np.ndarray, ) -> Tuple[np.ndarray, np.ndarray]: """Compute spectral overlap fractions for a single input pixel. Uses a top-hat (boxcar) kernel centered at lambda_c with width delta_lambda * z_shrink. Parameters ---------- lambda_c : float Input pixel central wavelength [μm]. delta_lambda : float Input pixel bandwidth [μm]. z_shrink : float Spectral droplet shrink factor. z_edges : np.ndarray (n_z+1,) output Z bin edges [μm]. z_centers : np.ndarray (n_z,) output Z bin centers [μm]. z_widths : np.ndarray (n_z,) output Z bin widths [μm]. Returns ------- z_idx : np.ndarray (int32) Indices of overlapping Z bins. f_z : np.ndarray (float64) Fractional overlap for each Z bin, summing to ~1. """ half_width = 0.5 * delta_lambda * z_shrink lo = lambda_c - half_width hi = lambda_c + half_width dlam_shrunk = delta_lambda * z_shrink if dlam_shrunk <= 0: return np.array([], dtype=np.int32), np.array([], dtype=np.float64) # Find Z bins that overlap [lo, hi] overlap_lo = np.maximum(lo, z_edges[:-1]) overlap_hi = np.minimum(hi, z_edges[1:]) overlap_len = overlap_hi - overlap_lo mask = overlap_len > 0 if not np.any(mask): return np.array([], dtype=np.int32), np.array([], dtype=np.float64) z_idx = np.where(mask)[0].astype(np.int32) f_z = overlap_len[mask] / dlam_shrunk # Normalize to ensure sum ≈ 1 (edge effects near grid boundaries) f_z_sum = f_z.sum() if f_z_sum > 0: f_z /= f_z_sum return z_idx, f_z