"""Batch photometry pipeline — orchestrates query, download, extract, aggregate."""
import logging
import multiprocessing as mp
from pathlib import Path
from typing import List, Optional
import yaml
from ..core.config import DownloadConfig, DownloadResult, PhotometryConfig, QueryResults
from ..core.download import parallel_download, print_download_summary
from .aggregate import aggregate_lightcurves
from .config import BatchConfig, load_catalog
from .extract import run_extraction
from .query import query_region_observations
logger = logging.getLogger(__name__)
_QUERY_SUMMARY_FILE = "query_summary.yaml"
def _py(obj):
"""Convert numpy scalars to plain Python types for YAML serialization."""
if hasattr(obj, "item"):
return obj.item()
return obj
def _save_query_summary(results: QueryResults, config: BatchConfig) -> Path:
"""Serialize query results to a YAML summary file."""
summary = {
"query_time": results.query_time.isoformat(),
"region": {
"center_ra": _py(config.center_ra),
"center_dec": _py(config.center_dec),
"radius_deg": _py(config.radius),
"coverage_mode": config.coverage_mode,
},
"filters": {
"bands": config.bands,
"mjd_range": [float(x) for x in config.mjd_range] if config.mjd_range else None,
},
"n_observations": len(results),
"band_counts": {k: int(v) for k, v in results.band_counts.items()},
"time_span_days": float(results.time_span_days),
"observations": [
{
"obs_id": obs.obs_id,
"band": obs.band,
"mjd": round(float(obs.mjd), 6),
"wavelength_um": round(float(obs.wavelength_center), 4),
"download_url": obs.download_url,
}
for obs in results.observations
],
}
path = config.output_dir / _QUERY_SUMMARY_FILE
path.parent.mkdir(parents=True, exist_ok=True)
with open(path, "w") as f:
yaml.dump(summary, f, default_flow_style=False, sort_keys=False)
logger.info(f"Saved query summary to {path}")
return path
[docs]
def load_query_summary(output_dir: Path) -> dict:
"""Load a previously saved query_summary.yaml.
Parameters
----------
output_dir : Path
Root batch output directory containing the YAML file.
"""
path = Path(output_dir) / _QUERY_SUMMARY_FILE
if not path.exists():
raise FileNotFoundError(f"No query summary found at {path}")
with open(path) as f:
return yaml.safe_load(f)
[docs]
class BatchPipeline:
"""Multi-source batch photometry pipeline.
Four stages: query -> download -> extract -> aggregate.
Each stage can be run independently for resumable execution.
Parameters
----------
config : BatchConfig
Region, catalog, and processing configuration.
"""
[docs]
def __init__(self, config: BatchConfig):
self.config = config
self.config.output_dir.mkdir(parents=True, exist_ok=True)
self.config.image_dir.mkdir(parents=True, exist_ok=True)
self.config.per_image_dir.mkdir(parents=True, exist_ok=True)
self._query_results = None
# Ensure spawn method for clean worker environments
try:
mp.set_start_method("spawn", force=True)
except RuntimeError:
pass
[docs]
def run_query(self):
"""Stage 1: Query IRSA for full-frame images covering the region."""
logger.info("Stage 1/4: Querying region observations")
self._query_results = query_region_observations(self.config)
_save_query_summary(self._query_results, self.config)
logger.info(f"Found {len(self._query_results)} observations")
return self._query_results
[docs]
def run_download(self, skip_existing: bool = True) -> List[DownloadResult]:
"""Stage 2: Download full-frame FITS images (no cutouts)."""
if self._query_results is None:
raise RuntimeError("Run run_query() first")
logger.info("Stage 2/4: Downloading full-frame images")
# No cutout parameters — download full images
download_info = [(obs, obs.download_url) for obs in self._query_results.observations]
if not download_info:
logger.warning("No observations to download")
return []
total_files = len(download_info)
logger.info(f"Downloading {total_files} full-frame images (~{total_files * 0.07:.0f} GB)...")
download_config = DownloadConfig(max_download_workers=self.config.max_download_workers)
results = parallel_download(
download_info,
self.config.image_dir,
max_workers=self.config.max_download_workers,
skip_existing=skip_existing,
download_config=download_config,
)
print_download_summary(results)
n_success = sum(1 for r in results if r.success)
logger.info(f"Downloaded {n_success}/{total_files} images")
return results
[docs]
def run_aggregate(self, clean: bool = False) -> int:
"""Stage 4: Aggregate per-image CSVs into per-source light curves."""
logger.info("Stage 4/4: Aggregating light curves")
n_sources = aggregate_lightcurves(
image_csv_dir=self.config.per_image_dir,
lightcurve_dir=self.config.lightcurve_dir,
bucket_dir=self.config.bucket_dir,
num_buckets=self.config.num_buckets,
clean=clean,
keep_bucket_files=self.config.keep_bucket_files,
)
logger.info(f"Created {n_sources} light curve files")
return n_sources
[docs]
def run_all(self, skip_existing: bool = True):
"""Run all four stages sequentially."""
logger.info("=" * 60)
logger.info("Batch Photometry Pipeline")
logger.info(f"Region: RA={self.config.center_ra}, Dec={self.config.center_dec}, "
f"radius={self.config.radius} deg")
logger.info("=" * 60)
self.run_query()
self.run_download(skip_existing=skip_existing)
self.run_extract(skip_existing=skip_existing)
self.run_aggregate()
logger.info("Pipeline complete.")
[docs]
def run_batch(
catalog: str,
center_ra: float,
center_dec: float,
radius: float,
output_dir: str = "batch_output",
bands: Optional[List[str]] = None,
coverage_mode: str = "any",
max_images: int = 500,
max_download_workers: int = 4,
max_extract_workers: int = 12,
skip_existing: bool = True,
photometry_config: Optional[PhotometryConfig] = None,
) -> BatchPipeline:
"""Run the full batch photometry pipeline with one function call.
Parameters
----------
catalog : str
Path to CSV with columns targetid, ra, dec.
center_ra, center_dec : float
Region center in degrees.
radius : float
Region radius in degrees.
output_dir : str
Root output directory.
bands : list of str or None
Bands to query. None = all.
coverage_mode : str
``"any"`` (INTERSECTS) or ``"full"`` (CONTAINS).
max_images : int
Safety gate — raise if exceeded.
max_download_workers : int
Parallel download threads.
max_extract_workers : int
Parallel extraction processes.
skip_existing : bool
Resume mode — skip already-processed images.
photometry_config : PhotometryConfig or None
Override default photometry parameters.
Returns
-------
BatchPipeline
The pipeline instance (for inspecting results).
"""
config = BatchConfig(
catalog_path=Path(catalog),
center_ra=center_ra,
center_dec=center_dec,
radius=radius,
output_dir=Path(output_dir),
bands=bands,
coverage_mode=coverage_mode,
max_images=max_images,
max_download_workers=max_download_workers,
max_extract_workers=max_extract_workers,
photometry=photometry_config or PhotometryConfig(),
)
pipeline = BatchPipeline(config)
pipeline.run_all(skip_existing=skip_existing)
return pipeline