From 64239218534d181877dcafdce922ee635ffa67b0 Mon Sep 17 00:00:00 2001 From: goldokpa Date: Fri, 22 May 2026 00:46:18 +0300 Subject: [PATCH] fix: resolve merge conflicts from develop consolidation into main MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Resolves 6 conflict regions across auth.py, main.py, and pipeline.py from the develop→main single-trunk consolidation. Keeps dev bypass in auth, merges governance/flooding imports, fixes duplicate org param in predict_upload, adds missing numpy import, and adopts clean GEE production path (no synthetic fallbacks). Co-Authored-By: Claude Opus 4.6 --- config.yaml | 19 +- requirements.txt | 3 + scripts/export_model.py | 51 ++- scripts/prepare_data.py | 193 ++++++---- scripts/train.py | 221 ++++++----- src/climatevision/api/auth.py | 30 +- src/climatevision/api/main.py | 133 +++---- src/climatevision/data/__init__.py | 7 +- src/climatevision/data/band_mapping.py | 26 +- src/climatevision/data/gee_downloader.py | 211 +++++++---- src/climatevision/data/preprocessing.py | 152 +++++--- src/climatevision/inference/pipeline.py | 422 ++++++++++----------- src/climatevision/inference/postprocess.py | 349 +++++++++-------- tests/test_security.py | 268 +++++++++++++ 14 files changed, 1294 insertions(+), 791 deletions(-) create mode 100644 tests/test_security.py diff --git a/config.yaml b/config.yaml index 2ce5c8a..aba8172 100644 --- a/config.yaml +++ b/config.yaml @@ -16,6 +16,7 @@ analysis_types: num_classes: 2 bands: ["B04", "B03", "B02", "B08"] # Red, Green, Blue, NIR classes: ["non_forest", "forest"] + scl_clear_labels: [4, 5, 6, 11] # vegetation, bare soil, water, snow thresholds: alert_forest_loss: 5.0 # Alert if >5% forest loss critical_forest_loss: 15.0 # Critical if >15% loss @@ -38,6 +39,7 @@ analysis_types: num_classes: 3 bands: ["B02", "B03", "B04", "B11"] # Blue, Green, Red, SWIR classes: ["open_water", "sea_ice", "land"] + scl_clear_labels: [4, 5, 6, 11] # vegetation, bare soil, water, snow thresholds: alert_ice_loss: 10.0 # Alert if >10% ice loss critical_ice_loss: 25.0 # Critical if >25% loss @@ -60,12 +62,19 @@ analysis_types: display_name: "Flood Detection" description: "Detect and monitor flooding events and affected areas" model: - architecture: "unet" + architecture: "flood_unet" weights: "models/unet_flood.pth" in_channels: 3 num_classes: 3 + model_sar: + architecture: "flood_unet" + weights: "models/unet_flood_sar.pth" + in_channels: 5 + num_classes: 3 bands: ["B03", "B08", "B11"] # Green, NIR, SWIR + sar_bands: ["VV", "VH"] classes: ["dry_land", "permanent_water", "flooded"] + scl_clear_labels: [4, 5, 6] # vegetation, bare soil, water (NO snow/ice) thresholds: alert_flood_area: 5.0 # Alert if >5% area flooded critical_flood_area: 20.0 # Critical if >20% flooded @@ -74,6 +83,7 @@ analysis_types: - "flooded_percentage" - "flooded_area_km2" - "mndwi_stats" + - "affected_road_km" # Drought Monitoring drought: @@ -153,6 +163,13 @@ satellite: cloud_coverage_max: 20 # percentage revisit_time: 5 # days + sentinel1: + bands: ["VV", "VH"] + resolution: 10 # meters + revisit_time: 6 # days + polarization: ["VV", "VH"] + orbit: "DESCENDING" + landsat8: bands: ["B4", "B3", "B2", "B5"] resolution: 30 # meters diff --git a/requirements.txt b/requirements.txt index c67ad0e..1a981a1 100644 --- a/requirements.txt +++ b/requirements.txt @@ -16,6 +16,9 @@ fiona>=1.9.0 opencv-python>=4.5.0 pillow>=9.0.0 albumentations>=1.3.0 +segmentation-models-pytorch>=0.3.3 +timm>=0.9.0 +scipy>=1.10.0 # Visualization matplotlib>=3.5.0 diff --git a/scripts/export_model.py b/scripts/export_model.py index e855a12..e4def52 100644 --- a/scripts/export_model.py +++ b/scripts/export_model.py @@ -46,30 +46,63 @@ # Model loading # --------------------------------------------------------------------------- +def _load_run_config(ckpt_path: Path) -> dict: + """Load the full training config from the run directory.""" + run_dir = ckpt_path.parent + config_path = run_dir / "config.yaml" + if config_path.exists(): + try: + import yaml + with open(config_path) as f: + return yaml.safe_load(f) or {} + except Exception: + pass + return {} + + def load_model(ckpt_path: Path) -> tuple[nn.Module, dict]: from climatevision.models.unet import get_model + from climatevision.models.flood_unet import build_flood_model ckpt = torch.load(ckpt_path, map_location="cpu") - cfg = ckpt.get("cfg", {}) + # Trainer cfg is in ckpt['cfg']; full model cfg is in config.yaml next to checkpoint + cfg = _load_run_config(ckpt_path) + if not cfg: + cfg = ckpt.get("cfg", {}) arch = cfg.get("model", {}).get("architecture", "attention_unet") state = ckpt.get("ema_state_dict") or ckpt.get("model_state_dict", ckpt) - # Infer in_channels from weight shape - in_ch = 4 + # Infer in_channels and n_classes from weight shape + in_ch = cfg.get("model", {}).get("in_channels", 4) + n_classes = cfg.get("model", {}).get("num_classes", 2) for key, val in state.items(): - if "inc" in key and "weight" in key and val.ndim == 4: - in_ch = val.shape[1] - break + if val.ndim == 4: + if val.shape[1] in (3, 4, 5) and "encoder" not in key and "down" not in key: + # first conv layer — input channels + in_ch = val.shape[1] + if val.shape[0] in (2, 3) and ("outc" in key or "segmentation_head" in key or "classifier" in key): + # final conv layer — output classes + n_classes = val.shape[0] + + # Flood models use smp-based architectures + if arch in ("flood_unet", "flood_unet_s2only"): + use_sar = in_ch == 5 + model = build_flood_model( + use_sar=use_sar, + encoder_name=cfg.get("model", {}).get("encoder", "efficientnet-b7"), + ) + else: + model = get_model(arch, n_channels=in_ch, n_classes=n_classes) - model = get_model(arch, n_channels=in_ch, n_classes=2) model.load_state_dict(state, strict=False) model.eval() logger.info( - "Loaded %s (in_channels=%d) from epoch %d val_iou=%.4f", + "Loaded %s (in_channels=%d, classes=%d) from epoch %d val_iou=%.4f", arch, in_ch, + n_classes, ckpt.get("epoch", 0), ckpt.get("val_iou", 0.0), ) @@ -233,7 +266,7 @@ def main() -> None: "checkpoint": str(ckpt_path), "architecture": cfg.get("model", {}).get("architecture", "unknown"), "in_channels": in_channels, - "num_classes": 2, + "num_classes": cfg.get("model", {}).get("num_classes", 2), "image_size": image_size, "onnx_opset": args.opset, "onnx_path": str(onnx_path), diff --git a/scripts/prepare_data.py b/scripts/prepare_data.py index 2ef1100..ff24ba6 100644 --- a/scripts/prepare_data.py +++ b/scripts/prepare_data.py @@ -1,22 +1,19 @@ """ -Data preparation script for ClimateVision forest segmentation. +Data preparation script for ClimateVision. -Two modes: - --mode synthetic Generate fractal-noise synthetic Sentinel-2 patches (no data required) +Modes: + --mode synthetic Generate synthetic patches for testing/smoke tests --mode gee Download real Sentinel-2 L2A tiles via Google Earth Engine Usage: - # Quick start — 2 000 synthetic patches, default 70/15/15 split: - python scripts/prepare_data.py --mode synthetic --n-patches 2000 --out data/processed - - # Fewer patches for a fast smoke test: - python scripts/prepare_data.py --mode synthetic --n-patches 200 --out data/processed - - # Real data via GEE (requires authenticated `earthengine-api`): - python scripts/prepare_data.py --mode gee \\ - --bbox 2.3 48.8 2.5 49.0 \\ - --start 2022-01-01 --end 2023-12-31 \\ - --out data/processed + # Synthetic flood patches for testing: + python scripts/prepare_data.py --mode synthetic --analysis-type flooding --n-patches 200 --out data/processed/flood + + # Real data via GEE: + python scripts/prepare_data.py --mode gee --analysis-type flooding \ + --bbox 36.7 -1.4 37.0 -1.1 \ + --start 2024-04-01 --end 2024-04-10 \ + --out data/processed/flood """ from __future__ import annotations @@ -41,38 +38,51 @@ # --------------------------------------------------------------------------- def generate_synthetic( + analysis_type: str, n_patches: int, out_dir: Path, patch_size: int, train_ratio: float, val_ratio: float, ) -> None: - """Delegate entirely to the built-in synthetic generator.""" - try: - from climatevision.data.synthetic import generate_synthetic_dataset - except ImportError as exc: - logger.error("Cannot import climatevision package: %s", exc) - logger.error("Run `pip install -e .` from the project root first.") - sys.exit(1) + """Generate synthetic patches for testing.""" + from climatevision.data.synthetic import ( + generate_synthetic_dataset, + generate_flood_test_patches, + ) test_ratio = max(0.0, 1.0 - train_ratio - val_ratio) n_train = int(n_patches * train_ratio) - n_val = int(n_patches * val_ratio) - n_test = max(0, n_patches - n_train - n_val) + n_val = int(n_patches * val_ratio) + n_test = max(0, n_patches - n_train - n_val) - logger.info( - "Generating %d synthetic patches " - "(train=%d / val=%d / test=%d) patch_size=%d", - n_patches, n_train, n_val, n_test, patch_size, - ) - - generate_synthetic_dataset( - output_dir=out_dir, - n_train=n_train, - n_val=n_val, - n_test=n_test, - patch_size=patch_size, - ) + if analysis_type == "flooding": + logger.info( + "Generating %d synthetic flood patches (train=%d / val=%d / test=%d)", + n_patches, n_train, n_val, n_test, + ) + for split, n in [("train", n_train), ("val", n_val), ("test", n_test)]: + if n > 0: + split_dir = out_dir / split + generate_flood_test_patches( + output_dir=split_dir, + n=n, + patch_size=patch_size, + use_sar=False, + seed=42 if split == "train" else 99 if split == "val" else 123, + ) + else: + logger.info( + "Generating %d synthetic forest patches (train=%d / val=%d / test=%d)", + n_patches, n_train, n_val, n_test, + ) + generate_synthetic_dataset( + output_dir=out_dir, + n_train=n_train, + n_val=n_val, + n_test=n_test, + patch_size=patch_size, + ) logger.info("Dataset written to %s", out_dir) @@ -82,6 +92,7 @@ def generate_synthetic( # --------------------------------------------------------------------------- def download_gee( + analysis_type: str, bbox: tuple[float, float, float, float], start: str, end: str, @@ -101,8 +112,8 @@ def download_gee( try: import os svc_account = os.getenv("GEE_SERVICE_ACCOUNT") - key_file = os.getenv("GEE_SERVICE_ACCOUNT_KEY") - project = os.getenv("GEE_PROJECT_ID") + key_file = os.getenv("GEE_SERVICE_ACCOUNT_KEY") + project = os.getenv("GEE_PROJECT_ID") if key_file and not os.path.isabs(key_file): key_file = str(PROJECT_ROOT / key_file) @@ -128,37 +139,41 @@ def download_gee( import random, urllib.request, tempfile, os + from climatevision.data.band_mapping import get_bands_for_analysis + + bands = get_bands_for_analysis(analysis_type) + gee_bands = { + "B01": "B1", "B02": "B2", "B03": "B3", "B04": "B4", + "B05": "B5", "B06": "B6", "B07": "B7", "B08": "B8", + "B8A": "B8A", "B09": "B9", "B10": "B10", "B11": "B11", "B12": "B12", + } + selected_bands = [gee_bands[b] for b in bands] + west, south, east, north = bbox - # GEE download size limit is 48 MB per request. - # At 100 m resolution, a 0.25° tile is ~278x278 px × 5 bands × 4 bytes ≈ 1.5 MB — safe. - # 100 m is standard for regional forest classification. TILE_DEG = 0.25 - SCALE_M = 100 + SCALE_M = 100 - # Build tile grid tiles = [] lat = south while lat < north: lon = west while lon < east: tiles.append(( - round(lon, 6), - round(lat, 6), - round(min(lon + TILE_DEG, east), 6), + round(lon, 6), + round(lat, 6), + round(min(lon + TILE_DEG, east), 6), round(min(lat + TILE_DEG, north), 6), )) lon += TILE_DEG lat += TILE_DEG - logger.info("Downloading %d tiles (%.2f° each, scale=%dm)…", len(tiles), TILE_DEG, SCALE_M) + logger.info("Downloading %d tiles (%.2f deg each, scale=%dm)…", len(tiles), TILE_DEG, SCALE_M) patches: list[tuple[np.ndarray, np.ndarray]] = [] - - # Minimal rasterio profile for writing plain GeoTIFF patches base_profile = { "driver": "GTiff", - "crs": "EPSG:4326", + "crs": "EPSG:4326", "transform": rasterio.transform.from_bounds(west, south, east, north, patch_size, patch_size), } @@ -173,25 +188,34 @@ def download_gee( .filterBounds(tile_region) .filterDate(start, end) .filter(ee.Filter.lt("CLOUDY_PIXEL_PERCENTAGE", cloud_threshold * 100)) - .select(["B4", "B3", "B2", "B8"]) + .select(selected_bands) ) - dw = ( - ee.ImageCollection("GOOGLE/DYNAMICWORLD/V1") - .filterBounds(tile_region) - .filterDate(start, end) - .select("label") - .mode() - ) - forest_mask = dw.eq(1).rename("forest") + # For flood detection, use JRC Global Surface Water for water mask baseline + if analysis_type == "flooding": + water = ee.Image("JRC/GSW1_4/GlobalSurfaceWater").select("occurrence").clip(tile_region) + water_mask = water.gt(50).rename("water") + # We can't easily generate "flooded" labels without event-specific data, + # so we create a 3-class mask: dry=0, permanent_water=1, flooded=2 + # Here we set flooded=water (simplified — in production use event-specific polygons) + label = water_mask + else: + dw = ( + ee.ImageCollection("GOOGLE/DYNAMICWORLD/V1") + .filterBounds(tile_region) + .filterDate(start, end) + .select("label") + .mode() + ) + label = dw.eq(1).rename("forest") try: - image = collection.median().clip(tile_region) - combined = image.addBands(forest_mask) + image = collection.median().clip(tile_region) + combined = image.addBands(label) url = combined.getDownloadURL({ "region": tile_region, - "scale": SCALE_M, + "scale": SCALE_M, "format": "GEO_TIFF", }) @@ -199,16 +223,17 @@ def download_gee( urllib.request.urlretrieve(url, tmp) with rasterio.open(tmp) as src: - full = src.read() # (5, H, W) + full = src.read() os.unlink(tmp) - if full.shape[0] < 5: + n_bands = len(selected_bands) + if full.shape[0] < n_bands + 1: continue - image_data = full[:4].astype(np.float32) - mask_data = (full[4] > 0).astype(np.uint8) - _, H, W = image_data.shape + image_data = full[:n_bands].astype(np.float32) + mask_data = full[n_bands].astype(np.uint8) + _, H, W = image_data.shape for y in range(0, H - patch_size + 1, patch_size): for x in range(0, W - patch_size + 1, patch_size): @@ -216,12 +241,12 @@ def download_gee( break patches.append(( image_data[:, y:y + patch_size, x:x + patch_size], - mask_data[ y:y + patch_size, x:x + patch_size], + mask_data[y:y + patch_size, x:x + patch_size], )) if len(patches) >= max_patches: break - logger.info(" tile %d/%d → %d patches so far", ti + 1, len(tiles), len(patches)) + logger.info(" tile %d/%d -> %d patches so far", ti + 1, len(tiles), len(patches)) except Exception as exc: logger.warning(" tile %d/%d skipped: %s", ti + 1, len(tiles), exc) @@ -233,16 +258,15 @@ def download_gee( logger.info("Extracted %d patches total", len(patches)) - # Shuffle + split random.seed(42) random.shuffle(patches) - n = len(patches) + n = len(patches) n_train = int(n * train_ratio) - n_val = int(n * val_ratio) - splits = { + n_val = int(n * val_ratio) + splits = { "train": patches[:n_train], - "val": patches[n_train:n_train + n_val], - "test": patches[n_train + n_val:], + "val": patches[n_train:n_train + n_val], + "test": patches[n_train + n_val:], } for split, split_patches in splits.items(): @@ -250,7 +274,7 @@ def download_gee( (out_dir / split / "masks").mkdir(parents=True, exist_ok=True) for idx, (img_patch, mask_patch) in enumerate(split_patches): stem = f"patch_{idx:05d}" - img_profile = {**base_profile, "count": 4, "dtype": "float32", + img_profile = {**base_profile, "count": img_patch.shape[0], "dtype": "float32", "height": patch_size, "width": patch_size} with rasterio.open(out_dir / split / "images" / f"{stem}.tif", "w", **img_profile) as dst: dst.write(img_patch) @@ -268,7 +292,6 @@ def download_gee( # --------------------------------------------------------------------------- def fit_normalizer(data_dir: Path, out_path: Path) -> None: - """Compute per-band mean/std on the training set and save to JSON.""" try: from climatevision.data.preprocessing import Sentinel2Normalizer except ImportError as exc: @@ -304,7 +327,9 @@ def parse_args() -> argparse.Namespace: formatter_class=argparse.ArgumentDefaultsHelpFormatter, ) p.add_argument("--mode", choices=["synthetic", "gee"], default="synthetic") - p.add_argument("--out", type=Path, default=Path("data/processed"), + p.add_argument("--analysis-type", default="deforestation", + help="Analysis type: deforestation, flooding, ice_melting") + p.add_argument("--out", type=Path, default=Path("data/processed"), help="Output directory (created if needed)") # Synthetic options @@ -318,16 +343,16 @@ def parse_args() -> argparse.Namespace: help="[gee] Bounding box: west south east north") p.add_argument("--start", type=str, default="2022-01-01", help="[gee] Start date YYYY-MM-DD") - p.add_argument("--end", type=str, default="2023-12-31", + p.add_argument("--end", type=str, default="2023-12-31", help="[gee] End date YYYY-MM-DD") p.add_argument("--max-patches", type=int, default=5000, help="[gee] Maximum patches to extract from download") p.add_argument("--cloud-threshold", type=float, default=0.2, - help="[gee] Max cloud fraction (0–1)") + help="[gee] Max cloud fraction (0-1)") # Split ratios p.add_argument("--train-ratio", type=float, default=0.70) - p.add_argument("--val-ratio", type=float, default=0.15) + p.add_argument("--val-ratio", type=float, default=0.15) # Normalizer p.add_argument("--fit-normalizer", action="store_true", @@ -342,11 +367,12 @@ def main() -> None: args = parse_args() if args.train_ratio + args.val_ratio > 1.0: - logger.error("--train-ratio + --val-ratio must be ≤ 1.0") + logger.error("--train-ratio + --val-ratio must be <= 1.0") sys.exit(1) if args.mode == "synthetic": generate_synthetic( + analysis_type=args.analysis_type, n_patches=args.n_patches, out_dir=args.out, patch_size=args.patch_size, @@ -358,7 +384,8 @@ def main() -> None: logger.error("--bbox W S E N is required for --mode gee") sys.exit(1) download_gee( - bbox=tuple(args.bbox), # type: ignore[arg-type] + analysis_type=args.analysis_type, + bbox=tuple(args.bbox), start=args.start, end=args.end, out_dir=args.out, diff --git a/scripts/train.py b/scripts/train.py index 87decbb..41e2714 100644 --- a/scripts/train.py +++ b/scripts/train.py @@ -1,20 +1,14 @@ """ -Production training entry-point for ClimateVision forest segmentation. +Production training entry-point for ClimateVision. Usage: - # Train with defaults (generates synthetic data if none exists): - python scripts/train.py + # Train flood detection model with real data: + python scripts/train.py --analysis-type flooding --data-dir data/processed/flood # Custom config: - python scripts/train.py --config config/train.yaml + python scripts/train.py --config config/train.yaml --analysis-type flooding - # Override specific keys: - python scripts/train.py --config config/train.yaml \\ - --data-dir data/processed \\ - --epochs 50 \\ - --batch-size 8 - - # Resume from a checkpoint: + # Resume from checkpoint: python scripts/train.py --resume models/my_run/checkpoint_epoch_0030.pth """ from __future__ import annotations @@ -33,7 +27,6 @@ ) logger = logging.getLogger(__name__) -# Project root on the Python path so `climatevision` is importable PROJECT_ROOT = Path(__file__).resolve().parent.parent sys.path.insert(0, str(PROJECT_ROOT / "src")) @@ -44,7 +37,7 @@ def _load_yaml(path: str | Path) -> dict: try: - import yaml # PyYAML + import yaml except ImportError: logger.error("PyYAML not installed. Run: pip install pyyaml") sys.exit(1) @@ -53,7 +46,6 @@ def _load_yaml(path: str | Path) -> dict: def _deep_merge(base: dict, override: dict) -> dict: - """Recursively merge override into a copy of base.""" result = base.copy() for k, v in override.items(): if isinstance(v, dict) and isinstance(result.get(k), dict): @@ -64,7 +56,6 @@ def _deep_merge(base: dict, override: dict) -> dict: def build_config(args: argparse.Namespace) -> dict: - """Load YAML config and apply CLI overrides.""" cfg: dict = {} if args.config and Path(args.config).exists(): @@ -73,7 +64,6 @@ def build_config(args: argparse.Namespace) -> dict: else: logger.info("No config file — using defaults") - # CLI overrides (only non-None values) overrides: dict = {} if args.data_dir: overrides.setdefault("data", {})["dir"] = args.data_dir @@ -95,51 +85,80 @@ def build_config(args: argparse.Namespace) -> dict: overrides.setdefault("data", {})["image_size"] = args.image_size if args.arch: overrides.setdefault("model", {})["architecture"] = args.arch + if args.analysis_type: + overrides.setdefault("analysis", {})["type"] = args.analysis_type cfg = _deep_merge(cfg, overrides) - # Defaults for any missing keys + # Defaults cfg.setdefault("data", {}) - cfg["data"].setdefault("dir", "data/processed") - cfg["data"].setdefault("image_size", 256) - cfg["data"].setdefault("batch_size", 16) - cfg["data"].setdefault("num_workers", 4) + cfg["data"].setdefault("dir", "data/processed") + cfg["data"].setdefault("image_size", 256) + cfg["data"].setdefault("batch_size", 16) + cfg["data"].setdefault("num_workers", 4) cfg["data"].setdefault("use_weighted_sampler", True) - cfg["data"].setdefault("pin_memory", True) + cfg["data"].setdefault("pin_memory", True) + + cfg.setdefault("analysis", {}) + cfg["analysis"].setdefault("type", "deforestation") + + # Auto-configure model defaults based on analysis type BEFORE generic defaults. + # When the user explicitly passes --analysis-type on CLI we override config-file + # model settings so the architecture matches the data. + analysis_type = cfg["analysis"]["type"] + explicit_analysis = args.analysis_type is not None # came from CLI + + if analysis_type == "flooding": + forced_arch, forced_ch, forced_cls, forced_enc = "flood_unet", 3, 3, "efficientnet-b7" + elif analysis_type == "ice_melting": + forced_arch, forced_ch, forced_cls, forced_enc = "attention_unet", 2, 2, None + else: # deforestation + forced_arch, forced_ch, forced_cls, forced_enc = "attention_unet", 4, 2, None cfg.setdefault("model", {}) - cfg["model"].setdefault("architecture", "attention_unet") - cfg["model"].setdefault("in_channels", 4) - cfg["model"].setdefault("num_classes", 2) - cfg["model"].setdefault("bilinear", True) + # If user explicitly requested an analysis type, force model compatibility. + # Otherwise just use setdefault so config file values are respected. + if explicit_analysis: + cfg["model"]["architecture"] = forced_arch + cfg["model"]["in_channels"] = forced_ch + cfg["model"]["num_classes"] = forced_cls + if forced_enc: + cfg["model"]["encoder"] = forced_enc + else: + cfg["model"].setdefault("architecture", forced_arch) + cfg["model"].setdefault("in_channels", forced_ch) + cfg["model"].setdefault("num_classes", forced_cls) + if forced_enc: + cfg["model"].setdefault("encoder", forced_enc) + cfg["model"].setdefault("bilinear", True) cfg.setdefault("loss", {}) - cfg["loss"].setdefault("type", "combined") - cfg["loss"].setdefault("focal_weight", 0.5) - cfg["loss"].setdefault("focal_alpha", 0.25) - cfg["loss"].setdefault("focal_gamma", 2.0) - cfg["loss"].setdefault("use_class_weights", True) + cfg["loss"].setdefault("type", "combined") + cfg["loss"].setdefault("focal_weight", 0.5) + cfg["loss"].setdefault("focal_alpha", 0.25) + cfg["loss"].setdefault("focal_gamma", 2.0) + cfg["loss"].setdefault("use_class_weights", True) cfg.setdefault("optimizer", {}) - cfg["optimizer"].setdefault("learning_rate", 1e-4) - cfg["optimizer"].setdefault("weight_decay", 1e-4) - cfg["optimizer"].setdefault("min_lr", 1e-6) + cfg["optimizer"].setdefault("learning_rate", 1e-4) + cfg["optimizer"].setdefault("weight_decay", 1e-4) + cfg["optimizer"].setdefault("min_lr", 1e-6) cfg.setdefault("schedule", {}) - cfg["schedule"].setdefault("epochs", 100) - cfg["schedule"].setdefault("warmup_epochs", 5) - cfg["schedule"].setdefault("checkpoint_interval", 10) + cfg["schedule"].setdefault("epochs", 100) + cfg["schedule"].setdefault("warmup_epochs", 5) + cfg["schedule"].setdefault("checkpoint_interval", 10) cfg.setdefault("training", {}) - cfg["training"].setdefault("mixed_precision", True) - cfg["training"].setdefault("grad_clip", 1.0) - cfg["training"].setdefault("use_ema", True) - cfg["training"].setdefault("ema_decay", 0.9999) + cfg["training"].setdefault("mixed_precision", True) + cfg["training"].setdefault("grad_clip", 1.0) + cfg["training"].setdefault("use_ema", True) + cfg["training"].setdefault("ema_decay", 0.9999) cfg["training"].setdefault("early_stopping_patience", 15) cfg.setdefault("output", {}) - cfg["output"].setdefault("save_dir", "models") - cfg["output"].setdefault("run_name", "") + cfg["output"].setdefault("save_dir", "models") + cfg["output"].setdefault("run_name", "") cfg.setdefault("normalizer_stats", "") return cfg @@ -152,12 +171,23 @@ def build_config(args: argparse.Namespace) -> dict: def build_model(cfg: dict): """Instantiate the segmentation model from config.""" from climatevision.models.unet import get_model + from climatevision.models.flood_unet import build_flood_model + mcfg = cfg["model"] arch = mcfg["architecture"] + analysis_type = cfg["analysis"]["type"] + + # Flood models use smp-based architectures + if arch in ("flood_unet", "flood_unet_s2only"): + use_sar = mcfg["in_channels"] == 5 + return build_flood_model( + use_sar=use_sar, + encoder_name=mcfg.get("encoder", "efficientnet-b7"), + ) kwargs = { "n_channels": mcfg["in_channels"], - "n_classes": mcfg["num_classes"], + "n_classes": mcfg["num_classes"], } if arch == "unet": kwargs["bilinear"] = mcfg.get("bilinear", True) @@ -226,7 +256,6 @@ def load_normalizer(cfg: dict): # --------------------------------------------------------------------------- def maybe_resume(model, optimizer, resume_path: str | None) -> int: - """Load weights from checkpoint. Returns start epoch (0 if no resume).""" if not resume_path: return 0 import torch @@ -244,29 +273,17 @@ def maybe_resume(model, optimizer, resume_path: str | None) -> int: # --------------------------------------------------------------------------- -# Auto-generate data if directory is empty +# Data validation — fail fast if no real data # --------------------------------------------------------------------------- -def maybe_generate_data(data_dir: Path, patch_size: int = 256, n_patches: int = 1000) -> None: +def validate_data_exists(data_dir: Path) -> None: + """Fail hard if training data is missing. No synthetic fallback.""" train_img = data_dir / "train" / "images" - if train_img.exists() and any(train_img.glob("*.tif")): - return - - logger.warning("No training data found in %s", data_dir) - logger.info("Auto-generating %d synthetic patches…", n_patches) - - cmd = [ - sys.executable, str(PROJECT_ROOT / "scripts" / "prepare_data.py"), - "--mode", "synthetic", - "--n-patches", str(n_patches), - "--patch-size", str(patch_size), - "--out", str(data_dir), - "--fit-normalizer", - ] - import subprocess - result = subprocess.run(cmd, check=False) - if result.returncode != 0: - logger.error("Data generation failed — check prepare_data.py output") + if not train_img.exists() or not any(train_img.glob("*.tif")): + logger.error("No training data found in %s", data_dir) + logger.error("Prepare real data before training:") + logger.error(" python scripts/prepare_data.py --mode synthetic --n-patches 1000 --out %s", data_dir) + logger.error(" # OR download real GEE data") sys.exit(1) @@ -276,14 +293,13 @@ def maybe_generate_data(data_dir: Path, patch_size: int = 256, n_patches: int = def main() -> None: args = parse_args() - cfg = build_config(args) + cfg = build_config(args) - # Run name / output directory + analysis_type = cfg["analysis"]["type"] run_name = cfg["output"]["run_name"] or datetime.now().strftime("%Y%m%d_%H%M%S") - save_dir = Path(cfg["output"]["save_dir"]) / run_name + save_dir = Path(cfg["output"]["save_dir"]) / f"{analysis_type}_{run_name}" save_dir.mkdir(parents=True, exist_ok=True) - # Persist effective config try: import yaml with open(save_dir / "config.yaml", "w") as f: @@ -291,15 +307,14 @@ def main() -> None: except ImportError: import json with open(save_dir / "config.json", "w") as f: - import json json.dump(cfg, f, indent=2) logger.info("Run: %s → %s", run_name, save_dir) # Data - data_dir = Path(cfg["data"]["dir"]) + data_dir = Path(cfg["data"]["dir"]) image_size = cfg["data"]["image_size"] - maybe_generate_data(data_dir, patch_size=image_size) + validate_data_exists(data_dir) normalizer = load_normalizer(cfg) @@ -312,6 +327,7 @@ def main() -> None: normalizer=normalizer, pin_memory=cfg["data"]["pin_memory"], use_weighted_sampler=cfg["data"]["use_weighted_sampler"], + analysis_type=analysis_type, ) if "train" not in loaders: @@ -326,28 +342,28 @@ def main() -> None: ) # Class weights + num_classes = cfg["model"]["num_classes"] class_weights = None if cfg["loss"]["use_class_weights"]: - class_weights = loaders["train"].dataset.compute_class_weights() + class_weights = loaders["train"].dataset.compute_class_weights(num_classes=num_classes) logger.info("Class weights: %s", class_weights.tolist()) # Model + loss - model = build_model(cfg) + model = build_model(cfg) criterion = build_criterion(cfg, class_weights=class_weights) - # Trainer config dict (flat, as Trainer expects) trainer_cfg = { - "learning_rate": cfg["optimizer"]["learning_rate"], - "weight_decay": cfg["optimizer"]["weight_decay"], - "min_lr": cfg["optimizer"]["min_lr"], - "epochs": cfg["schedule"]["epochs"], - "warmup_epochs": cfg["schedule"]["warmup_epochs"], - "checkpoint_interval": cfg["schedule"]["checkpoint_interval"], - "mixed_precision": cfg["training"]["mixed_precision"], - "grad_clip": cfg["training"]["grad_clip"], - "use_ema": cfg["training"]["use_ema"], - "ema_decay": cfg["training"]["ema_decay"], - "early_stopping_patience": cfg["training"]["early_stopping_patience"], + "learning_rate": cfg["optimizer"]["learning_rate"], + "weight_decay": cfg["optimizer"]["weight_decay"], + "min_lr": cfg["optimizer"]["min_lr"], + "epochs": cfg["schedule"]["epochs"], + "warmup_epochs": cfg["schedule"]["warmup_epochs"], + "checkpoint_interval": cfg["schedule"]["checkpoint_interval"], + "mixed_precision": cfg["training"]["mixed_precision"], + "grad_clip": cfg["training"]["grad_clip"], + "use_ema": cfg["training"]["use_ema"], + "ema_decay": cfg["training"]["ema_decay"], + "early_stopping_patience": cfg["training"]["early_stopping_patience"], } from climatevision.training.trainer import Trainer @@ -359,7 +375,6 @@ def main() -> None: save_dir=save_dir, ) - # Optional resume if args.resume: maybe_resume(model, trainer.optimizer, args.resume) @@ -367,8 +382,10 @@ def main() -> None: history = trainer.fit() elapsed = time.time() - t_start - best_iou = max((e.get("iou_forest", 0) for e in history["val"]), default=0) - best_f1 = max((e.get("f1", 0) for e in history["val"]), default=0) + # Generalize best metric reporting + metric_key = f"iou_{analysis_type}" + best_iou = max((e.get(metric_key, e.get("iou", 0)) for e in history["val"]), default=0) + best_f1 = max((e.get("f1", 0) for e in history["val"]), default=0) logger.info("=" * 60) logger.info("Training complete in %.1f min", elapsed / 60) @@ -388,20 +405,22 @@ def main() -> None: # --------------------------------------------------------------------------- def parse_args() -> argparse.Namespace: - p = argparse.ArgumentParser(description="Train ClimateVision forest segmentation model") - p.add_argument("--config", default=str(PROJECT_ROOT / "config" / "train.yaml"), + p = argparse.ArgumentParser(description="Train ClimateVision segmentation model") + p.add_argument("--config", default=str(PROJECT_ROOT / "config" / "train.yaml"), help="Path to YAML config file") - p.add_argument("--data-dir", default=None, help="Override data.dir") - p.add_argument("--epochs", type=int, default=None, help="Override schedule.epochs") + p.add_argument("--analysis-type", default=None, + help="Analysis type: deforestation, flooding, ice_melting") + p.add_argument("--data-dir", default=None, help="Override data.dir") + p.add_argument("--epochs", type=int, default=None, help="Override schedule.epochs") p.add_argument("--batch-size", type=int, default=None, help="Override data.batch_size") - p.add_argument("--lr", type=float, default=None, help="Override optimizer.learning_rate") - p.add_argument("--save-dir", default=None, help="Override output.save_dir") - p.add_argument("--run-name", default=None, help="Override output.run_name") - p.add_argument("--resume", default=None, help="Path to checkpoint to resume from") - p.add_argument("--arch", choices=["unet", "attention_unet"], default=None) - p.add_argument("--no-amp", action="store_true", help="Disable mixed-precision (AMP)") - p.add_argument("--num-workers", type=int, default=None, help="DataLoader worker count (0=main process)") - p.add_argument("--image-size", type=int, default=None, help="Spatial crop size in pixels") + p.add_argument("--lr", type=float, default=None, help="Override optimizer.learning_rate") + p.add_argument("--save-dir", default=None, help="Override output.save_dir") + p.add_argument("--run-name", default=None, help="Override output.run_name") + p.add_argument("--resume", default=None, help="Path to checkpoint to resume from") + p.add_argument("--arch", choices=["unet", "attention_unet", "flood_unet", "flood_unet_s2only"], default=None) + p.add_argument("--no-amp", action="store_true", help="Disable mixed-precision (AMP)") + p.add_argument("--num-workers", type=int, default=None, help="DataLoader worker count (0=main process)") + p.add_argument("--image-size", type=int, default=None, help="Spatial crop size in pixels") return p.parse_args() diff --git a/src/climatevision/api/auth.py b/src/climatevision/api/auth.py index 85a8ad7..1da5034 100644 --- a/src/climatevision/api/auth.py +++ b/src/climatevision/api/auth.py @@ -68,6 +68,9 @@ def validate_key(self, api_key: str) -> Optional[dict]: """ Validate an API key and return organization context. + Queries the SQLite database for an organization matching the + SHA-256 hash of the provided API key. + Args: api_key: The API key to validate @@ -85,15 +88,36 @@ def validate_key(self, api_key: str) -> Optional[dict]: "demo": True, } - # Check cache first key_hash = self.hash_key(api_key) + + # Check cache first if key_hash in self._key_cache: cached = self._key_cache[key_hash] if cached.get("expires_at", datetime.max) > datetime.utcnow(): return cached.get("org") - # Would query database in production - # For now, return None to indicate key not found + # Query database + try: + from climatevision.db import get_connection + with get_connection() as conn: + row = conn.execute( + "SELECT id, name, type, api_key FROM organizations WHERE api_key = ?", + (key_hash,), + ).fetchone() + if row: + org = { + "id": row["id"], + "name": row["name"], + "type": row["type"], + } + self._key_cache[key_hash] = { + "org": org, + "expires_at": datetime.max, + } + return org + except Exception as exc: + logger.warning("Database query failed during API key validation: %s", exc) + return None def revoke_key(self, api_key: str) -> bool: diff --git a/src/climatevision/api/main.py b/src/climatevision/api/main.py index 729b213..d4a41f2 100644 --- a/src/climatevision/api/main.py +++ b/src/climatevision/api/main.py @@ -14,6 +14,8 @@ import json import logging import time + +import numpy as np from datetime import datetime, timezone from pathlib import Path from typing import Any, Optional, Literal @@ -43,7 +45,10 @@ mark_alert_delivered, ) from climatevision.inference import run_inference_from_file, run_inference_from_gee +from climatevision.governance import explain_prediction, SHAPExplainer from climatevision.api.auth import require_api_key +from climatevision.inference.alert_generator import AlertGenerator +from climatevision.analysis.flooding import FloodingAnalysis logger = logging.getLogger(__name__) @@ -242,59 +247,16 @@ def _load_template_result( end_date: Optional[str], analysis_type: str = "deforestation", ) -> dict[str, Any]: - """Load or create a template result for failed inference.""" - outputs_dir = Path(__file__).resolve().parents[3] / "outputs" - template_path = outputs_dir / "inference_results.json" - - if template_path.exists(): - template: dict[str, Any] = json.loads(template_path.read_text(encoding="utf-8")) - else: - # Create analysis-specific template - if analysis_type == "ice_melting": - template = { - "region": {"bbox": bbox or None}, - "inference": { - "image_size": [256, 256], - "ice_pixels": 0, - "water_pixels": 0, - "land_pixels": 0, - "ice_percentage": 0.0, - "mean_confidence": 0.0, - }, - } - elif analysis_type == "flooding": - template = { - "region": {"bbox": bbox or None}, - "inference": { - "image_size": [256, 256], - "flooded_pixels": 0, - "dry_pixels": 0, - "water_pixels": 0, - "flooded_percentage": 0.0, - "mean_confidence": 0.0, - }, - } - else: # deforestation (default) - template = { - "region": {"bbox": bbox or None}, - "ndvi_stats": {"NDVI_min": 0.0, "NDVI_mean": 0.0, "NDVI_max": 0.0}, - "inference": { - "image_size": [256, 256], - "forest_pixels": 0, - "non_forest_pixels": 0, - "forest_percentage": 0.0, - "mean_confidence": 0.0, - }, - } - - if bbox is not None: - template.setdefault("region", {})["bbox"] = bbox - if start_date and end_date: - template.setdefault("region", {})["date_range"] = f"{start_date} to {end_date}" - - template["analysis_type"] = analysis_type - - return template + """ + DEPRECATED: Template results are no longer used in production. + + This function is retained only for backward compatibility with tests. + All production failures now raise HTTPException instead of returning stubs. + """ + raise HTTPException( + status_code=503, + detail=f"Inference failed for {analysis_type}. No synthetic fallback available in production.", + ) async def _persist_upload(*, run_id: int, file: UploadFile) -> str: @@ -364,6 +326,8 @@ def create_app() -> FastAPI: ) app.add_middleware(AuditLogMiddleware) + from climatevision.security.api_security import SecurityMiddleware + app.add_middleware(SecurityMiddleware) app.add_middleware( CORSMiddleware, allow_origins=[ @@ -572,8 +536,8 @@ async def predict_json( with get_connection() as conn: cur = conn.execute( """ - INSERT INTO runs (kind, status, analysis_type, bbox, start_date, end_date, created_at, updated_at) - VALUES (?, ?, ?, ?, ?, ?, ?, ?) + INSERT INTO runs (kind, status, analysis_type, bbox, start_date, end_date, created_at, updated_at, organization_id) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?) """, ( body.kind, @@ -584,6 +548,7 @@ async def predict_json( body.end_date, created_at, created_at, + org["id"], ), ) run_id = int(cur.lastrowid) @@ -600,14 +565,36 @@ async def predict_json( status = "completed" except Exception as exc: logger.exception("Inference failed for run %s", run_id) - result_payload = _load_template_result( - bbox=body.bbox, - start_date=body.start_date, - end_date=body.end_date, - analysis_type=body.analysis_type, - ) - result_payload["error"] = str(exc) - status = "failed" + raise HTTPException( + status_code=503, + detail=f"Inference failed: {exc}", + ) from exc + + # Generate alerts for flooding + if body.analysis_type == "flooding": + try: + analysis = FloodingAnalysis() + metrics = analysis.calculate_metrics( + prediction=np.array([]), # Will be populated from result + image_size=(256, 256), + bbox=body.bbox, + ) + # TODO: populate metrics from actual prediction mask + # For now, use the percentage from inference result + flooded_pct = result_payload.get("inference", {}).get("flooded_percentage", 0) + metrics["flooded_percentage"] = flooded_pct + alerts = analysis.generate_alerts(metrics) + for alert in alerts: + create_organization_alert( + org_id=org["id"], + alert_type=alert.alert_type, + severity=alert.severity.value, + title=alert.title, + message=alert.message, + details=alert.details, + ) + except Exception as exc: + logger.warning("Alert generation failed for run %s: %s", run_id, exc) # Persist result result_created_at = _utc_now_iso() @@ -629,12 +616,12 @@ async def predict_json( @app.post("/api/predict/upload") async def predict_upload( kind: str = Form(default="upload"), - org: dict[str, Any] = Depends(require_api_key), analysis_type: str = Form(default="deforestation"), bbox: str | None = Form(default=None), start_date: str | None = Form(default=None), end_date: str | None = Form(default=None), file: UploadFile = File(...), + org: dict[str, Any] = Depends(require_api_key), ) -> dict[str, Any]: """Run prediction on uploaded satellite imagery file.""" if start_date and end_date and start_date > end_date: @@ -652,8 +639,8 @@ async def predict_upload( with get_connection() as conn: cur = conn.execute( """ - INSERT INTO runs (kind, status, analysis_type, bbox, start_date, end_date, created_at, updated_at) - VALUES (?, ?, ?, ?, ?, ?, ?, ?) + INSERT INTO runs (kind, status, analysis_type, bbox, start_date, end_date, created_at, updated_at, organization_id) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?) """, ( kind, @@ -664,6 +651,7 @@ async def predict_upload( end_date, created_at, created_at, + org["id"], ), ) run_id = int(cur.lastrowid) @@ -683,15 +671,10 @@ async def predict_upload( status = "completed" except Exception as exc: logger.exception("Inference failed for upload run %s", run_id) - result_payload = _load_template_result( - bbox=parsed_bbox, - start_date=start_date, - end_date=end_date, - analysis_type=analysis_type, - ) - result_payload.setdefault("input", {})["file"] = dest - result_payload["error"] = str(exc) - status = "failed" + raise HTTPException( + status_code=503, + detail=f"Inference failed: {exc}", + ) from exc # Persist result result_created_at = _utc_now_iso() diff --git a/src/climatevision/data/__init__.py b/src/climatevision/data/__init__.py index 232f42d..320410f 100644 --- a/src/climatevision/data/__init__.py +++ b/src/climatevision/data/__init__.py @@ -1,7 +1,7 @@ -from .dataset import ForestDataset, create_dataloaders +from .dataset import SatelliteDataset, ForestDataset, FloodDataset, create_dataloaders from .augmentation import get_train_transforms, get_val_transforms from .preprocessing import Sentinel2Normalizer, compute_dataset_stats, apply_scl_cloud_mask -from .synthetic import generate_synthetic_dataset +from .synthetic import generate_synthetic_dataset, generate_flood_test_patches from .gee_downloader import download_tile_for_analysis from .band_mapping import ( get_bands_for_analysis, @@ -27,7 +27,9 @@ __all__ = [ # Dataset + "SatelliteDataset", "ForestDataset", + "FloodDataset", "create_dataloaders", # Augmentation "get_train_transforms", @@ -38,6 +40,7 @@ "apply_scl_cloud_mask", # Synthetic "generate_synthetic_dataset", + "generate_flood_test_patches", # GEE "download_tile_for_analysis", # Band mapping diff --git a/src/climatevision/data/band_mapping.py b/src/climatevision/data/band_mapping.py index 9f9d73b..2d55db1 100644 --- a/src/climatevision/data/band_mapping.py +++ b/src/climatevision/data/band_mapping.py @@ -1,5 +1,5 @@ """ -Analysis-specific Sentinel-2 band mapping utilities. +Analysis-specific satellite band mapping utilities. Provides a single source of truth for which spectral bands each climate analysis type requires, derived from config.yaml. @@ -22,6 +22,9 @@ "B8A", "B09", "B10", "B11", "B12", ] +# Sentinel-1 SAR bands +SENTINEL1_BAND_ORDER = ["VV", "VH"] + # Scene Classification Layer (SCL) is not part of the 13 reflectance bands # but is essential for cloud masking. SCL_BAND = "SCL" @@ -36,7 +39,7 @@ def _load_config() -> dict[str, Any]: def get_bands_for_analysis(analysis_type: str) -> list[str]: """ - Return the Sentinel-2 band names required for *analysis_type*. + Return the satellite band names required for *analysis_type*. The bands are read from ``config.yaml`` and are guaranteed to be returned in the same order they are declared there. @@ -69,9 +72,6 @@ def get_band_indices(band_names: list[str]) -> list[int]: indices = [] for b in band_names: if b == SCL_BAND: - # SCL does not belong to the 13 reflectance bands; - # callers that need an index in a multi-band array should - # append it separately and compute len(reflectance_bands). raise ValueError( f"SCL is not part of the 13-band reflectance stack. " f"Append it manually after resolving reflectance indices." @@ -82,6 +82,16 @@ def get_band_indices(band_names: list[str]) -> list[int]: return indices +def get_sar_band_indices(band_names: list[str]) -> list[int]: + """Map Sentinel-1 band names to zero-based indices.""" + indices = [] + for b in band_names: + if b not in SENTINEL1_BAND_ORDER: + raise ValueError(f"Unknown Sentinel-1 band: {b}") + indices.append(SENTINEL1_BAND_ORDER.index(b)) + return indices + + def is_analysis_enabled(analysis_type: str) -> bool: """Return True if the analysis type is enabled in config.yaml.""" cfg = _load_config() @@ -109,3 +119,9 @@ def get_model_config(analysis_type: str) -> dict[str, Any]: cfg = _load_config() analysis_cfg = cfg.get("analysis_types", {}).get(analysis_type, {}) return dict(analysis_cfg.get("model", {})) + + +def get_analysis_config(analysis_type: str) -> dict[str, Any]: + """Return the full analysis type configuration dict.""" + cfg = _load_config() + return dict(cfg.get("analysis_types", {}).get(analysis_type, {})) diff --git a/src/climatevision/data/gee_downloader.py b/src/climatevision/data/gee_downloader.py index fa65f0b..d6c45ae 100644 --- a/src/climatevision/data/gee_downloader.py +++ b/src/climatevision/data/gee_downloader.py @@ -1,9 +1,8 @@ """ Google Earth Engine tile downloader for ClimateVision. -Provides analysis-aware Sentinel-2 tile downloads with a synthetic fallback -when GEE credentials are unavailable. Downloaded tiles are saved as GeoTIFF -and include a metadata dict that labels synthetic scenes explicitly. +Provides analysis-aware Sentinel-2 and Sentinel-1 tile downloads. +GEE failures raise explicit exceptions — NO synthetic fallbacks in production. """ from __future__ import annotations @@ -12,7 +11,7 @@ import tempfile import urllib.request from pathlib import Path -from typing import Any, Optional +from typing import Any import numpy as np @@ -41,6 +40,22 @@ } +# --------------------------------------------------------------------------- +# Exceptions +# --------------------------------------------------------------------------- + +class GEEAuthenticationError(RuntimeError): + """Raised when GEE credentials are missing or invalid.""" + + +class TileNotFoundError(RuntimeError): + """Raised when no satellite images are found for the given criteria.""" + + +# --------------------------------------------------------------------------- +# GEE initialisation +# --------------------------------------------------------------------------- + def _initialize_ee() -> Any: """Lazy import and initialise Google Earth Engine.""" import ee # noqa @@ -62,16 +77,9 @@ def _initialize_ee() -> Any: return ee -def _get_default_tile_size() -> int: - """Read the default tile size from config.yaml.""" - import yaml - - config_path = _PROJECT_ROOT / "config.yaml" - with open(config_path, "r") as f: - cfg = yaml.safe_load(f) - image_size = cfg.get("data", {}).get("image_size", [256, 256]) - return int(image_size[0]) - +# --------------------------------------------------------------------------- +# Sentinel-2 download +# --------------------------------------------------------------------------- def download_tile_for_analysis( bbox: list[float], @@ -95,8 +103,11 @@ def download_tile_for_analysis( include_scl: Whether to append the SCL band for cloud masking. Returns: - (file_path, metadata_dict). If GEE is unavailable, the synthetic - fallback is used and ``metadata["is_synthetic"]`` is ``True``. + (file_path, metadata_dict). + + Raises: + GEEAuthenticationError: If GEE cannot be initialised. + TileNotFoundError: If no images are found for the date range. """ if output_dir is None: output_dir = _SATELLITE_DIR @@ -112,14 +123,10 @@ def download_tile_for_analysis( ee = _initialize_ee() rasterio = __import__("rasterio") except Exception as exc: - logger.warning("GEE unavailable (%s). Using synthetic fallback.", exc) - return _generate_synthetic_tile( - bbox=bbox, - start_date=start_date, - end_date=end_date, - analysis_type=analysis_type, - out_path=out_path, - ) + raise GEEAuthenticationError( + f"Google Earth Engine unavailable: {exc}. " + f"Set GEE_PROJECT_ID or GEE_SERVICE_ACCOUNT + GEE_SERVICE_ACCOUNT_KEY." + ) from exc bands = get_bands_for_analysis(analysis_type) gee_bands = [_BAND_NAME_TO_GEE[b] for b in bands] @@ -137,16 +144,9 @@ def download_tile_for_analysis( count = collection.size().getInfo() if count == 0: - logger.warning( - "No GEE images found for %s %s to %s. Using synthetic fallback.", - analysis_type, start_date, end_date, - ) - return _generate_synthetic_tile( - bbox=bbox, - start_date=start_date, - end_date=end_date, - analysis_type=analysis_type, - out_path=out_path, + raise TileNotFoundError( + f"No Sentinel-2 images found for {analysis_type} " + f"from {start_date} to {end_date} in bbox {bbox}." ) image = collection.median().clip(region) @@ -166,8 +166,6 @@ def download_tile_for_analysis( os.unlink(tmp) - # Re-order bands to match project convention if needed - # (GEE returns in selection order) profile.update( driver="GTiff", dtype="float32", @@ -179,6 +177,7 @@ def download_tile_for_analysis( metadata: dict[str, Any] = { "source": "gee", + "satellite": "sentinel2", "analysis_type": analysis_type, "bbox": bbox, "start_date": start_date, @@ -186,75 +185,123 @@ def download_tile_for_analysis( "bands": bands, "scale_m": scale_m, "images_available": count, - "is_synthetic": False, "shape": list(data.shape), } - logger.info("Downloaded real tile to %s (%d images available)", out_path, count) + logger.info("Downloaded S2 tile to %s (%d images available)", out_path, count) return out_path, metadata -def _generate_synthetic_tile( +# --------------------------------------------------------------------------- +# Sentinel-1 SAR download +# --------------------------------------------------------------------------- + +def download_s1_tile( bbox: list[float], start_date: str, end_date: str, - analysis_type: str, - out_path: Path, + polarization: list[str] | None = None, + output_dir: str | Path | None = None, + scale_m: int = 100, + orbit: str = "DESCENDING", ) -> tuple[Path, dict[str, Any]]: """ - Generate a physically plausible synthetic Sentinel-2 tile when GEE fails. - The output is explicitly tagged ``is_synthetic: True``. + Download a median Sentinel-1 SAR composite for the given bbox and date range. + + Args: + bbox: [west, south, east, north] in WGS84. + start_date: Start date (YYYY-MM-DD). + end_date: End date (YYYY-MM-DD). + polarization: List of polarizations, e.g., ["VV", "VH"]. Defaults to both. + output_dir: Where to save the GeoTIFF. + scale_m: GEE export resolution in metres. + orbit: "ASCENDING", "DESCENDING", or "BOTH". + + Returns: + (file_path, metadata_dict). + + Raises: + GEEAuthenticationError: If GEE cannot be initialised. + TileNotFoundError: If no SAR images are found. """ - rasterio = __import__("rasterio") + if output_dir is None: + output_dir = _SATELLITE_DIR + output_dir = Path(output_dir) + output_dir.mkdir(parents=True, exist_ok=True) - bands = get_bands_for_analysis(analysis_type) - n_bands = len(bands) - tile_size = _get_default_tile_size() - h, w = tile_size, tile_size - - # Seed RNG from bbox so the same region is deterministic - seed = int(abs(sum(v * 1000 * (i + 1) for i, v in enumerate(bbox)))) % (2 ** 31) - rng = np.random.default_rng(seed) - - # Build a synthetic stack: draw reflectance values typical for mixed forest - data = np.zeros((n_bands, h, w), dtype=np.float32) - for b in range(n_bands): - mean = rng.uniform(500.0, 3000.0) - std = rng.uniform(200.0, 800.0) - data[b] = rng.normal(mean, std, (h, w)).clip(0.0, 10000.0) - - # Append an SCL band (all clear = 4) - scl = np.full((1, h, w), 4, dtype=np.float32) - data = np.concatenate([data, scl], axis=0) - - transform = rasterio.transform.from_bounds( - bbox[0], bbox[1], bbox[2], bbox[3], w, h + if polarization is None: + polarization = ["VV", "VH"] + + safe_start = start_date.replace("-", "") + safe_end = end_date.replace("-", "") + stem = f"s1_{safe_start}_{safe_end}_{'_'.join(str(round(c, 4)) for c in bbox)}" + out_path = output_dir / f"{stem}.tif" + + try: + ee = _initialize_ee() + rasterio = __import__("rasterio") + except Exception as exc: + raise GEEAuthenticationError( + f"Google Earth Engine unavailable: {exc}." + ) from exc + + region = ee.Geometry.Rectangle(bbox) + collection = ( + ee.ImageCollection("COPERNICUS/S1_GRD") + .filterBounds(region) + .filterDate(start_date, end_date) + .filter(ee.Filter.listContains("transmitterReceiverPolarisation", "VV")) + .select(polarization) + ) + + if orbit != "BOTH": + collection = collection.filter(ee.Filter.eq("orbitProperties_pass", orbit)) + + count = collection.size().getInfo() + if count == 0: + raise TileNotFoundError( + f"No Sentinel-1 images found from {start_date} to {end_date} in bbox {bbox}." + ) + + # Convert to linear intensity from dB scale + image = collection.median().clip(region) + image = image.pow(2) # GEE S1 is in dB; convert to linear intensity for ML + + url = image.getDownloadURL({ + "region": region, + "scale": scale_m, + "format": "GEO_TIFF", + }) + + tmp = tempfile.mktemp(suffix=".tif") + urllib.request.urlretrieve(url, tmp) + + with rasterio.open(tmp) as src: + data = src.read().astype(np.float32) + profile = src.profile + + os.unlink(tmp) + + profile.update( + driver="GTiff", + dtype="float32", + count=data.shape[0], ) - profile = { - "driver": "GTiff", - "dtype": "float32", - "count": data.shape[0], - "height": h, - "width": w, - "crs": "EPSG:4326", - "transform": transform, - } with rasterio.open(out_path, "w", **profile) as dst: dst.write(data) metadata: dict[str, Any] = { - "source": "synthetic_fallback", - "analysis_type": analysis_type, + "source": "gee", + "satellite": "sentinel1", "bbox": bbox, "start_date": start_date, "end_date": end_date, - "bands": bands, - "scale_m": 100, - "images_available": 0, - "is_synthetic": True, + "polarization": polarization, + "scale_m": scale_m, + "images_available": count, "shape": list(data.shape), } - logger.info("Generated synthetic fallback tile to %s", out_path) + logger.info("Downloaded S1 tile to %s (%d images available)", out_path, count) return out_path, metadata diff --git a/src/climatevision/data/preprocessing.py b/src/climatevision/data/preprocessing.py index fd62b17..6d95cca 100644 --- a/src/climatevision/data/preprocessing.py +++ b/src/climatevision/data/preprocessing.py @@ -2,14 +2,7 @@ Sentinel-2 band normalization and preprocessing. Sentinel-2 L2A surface reflectance is stored as uint16 in range [0, 10000]. -We normalize each band to float32 using robust per-channel statistics derived -from a large sample of Amazon/Congo forest and non-forest pixels. - -Reference band order expected throughout this project: - index 0 → B04 Red (~665 nm) - index 1 → B03 Green (~560 nm) - index 2 → B02 Blue (~490 nm) - index 3 → B08 NIR (~842 nm) +Normalizers are band-agnostic and adapt to any number of input channels. """ from __future__ import annotations @@ -23,97 +16,121 @@ logger = logging.getLogger(__name__) # --------------------------------------------------------------------------- -# Sentinel-2 L2A statistics computed from 50 k Amazon/Congo patches -# Values are surface reflectance ×10000, band order [R, G, B, NIR] +# Default Sentinel-2 statistics — used as fallback when no dataset stats exist +# Band order MUST match the actual input bands (not hardcoded to RGB+NIR) # --------------------------------------------------------------------------- -_S2_MEAN = np.array([943.0, 1069.0, 981.0, 2734.0], dtype=np.float32) -_S2_STD = np.array([590.0, 547.0, 498.0, 1246.0], dtype=np.float32) -# Robust (2nd–98th percentile) clip bounds to suppress sensor artefacts -_S2_P2 = np.array([ 0.0, 10.0, 0.0, 100.0], dtype=np.float32) -_S2_P98 = np.array([2500.0, 2500.0, 2200.0, 8000.0], dtype=np.float32) +def _default_stats(num_bands: int) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: + """Return sensible default mean/std/p2/p98 for `num_bands`.""" + mean = np.full(num_bands, 1500.0, dtype=np.float32) + std = np.full(num_bands, 800.0, dtype=np.float32) + p2 = np.full(num_bands, 0.0, dtype=np.float32) + p98 = np.full(num_bands, 4000.0, dtype=np.float32) + return mean, std, p2, p98 + + +# --------------------------------------------------------------------------- +# Sentinel-2 Normalizer (band-agnostic) +# --------------------------------------------------------------------------- class Sentinel2Normalizer: """ - Normalize a 4-band Sentinel-2 image to zero-mean / unit-variance float32. + Normalize a multi-band Sentinel-2 image to zero-mean / unit-variance float32. Two modes: - - 'standard': use pre-computed global statistics (default, fast). + - 'standard': use pre-computed or default global statistics (fast). - 'dataset': use statistics supplied via `fit()` (accurate per dataset). """ - def __init__(self, mode: str = "standard"): + def __init__(self, mode: str = "standard", num_bands: int = 4): assert mode in ("standard", "dataset") self.mode = mode - self.mean: np.ndarray = _S2_MEAN.copy() - self.std: np.ndarray = _S2_STD.copy() - self.p2: np.ndarray = _S2_P2.copy() - self.p98: np.ndarray = _S2_P98.copy() + self.num_bands = num_bands + self.mean, self.std, self.p2, self.p98 = _default_stats(num_bands) self._fitted = (mode == "standard") - # ------------------------------------------------------------------ def fit(self, images: list[np.ndarray]) -> "Sentinel2Normalizer": - """Compute statistics from a list of (4, H, W) arrays.""" + """Compute statistics from a list of (C, H, W) arrays.""" + if not images: + raise ValueError("Cannot fit on empty image list") + all_pixels: list[np.ndarray] = [] for img in images: c, h, w = img.shape all_pixels.append(img.reshape(c, -1)) - stacked = np.concatenate(all_pixels, axis=1) # (4, N) + stacked = np.concatenate(all_pixels, axis=1) # (C, N) + self.num_bands = stacked.shape[0] self.mean = stacked.mean(axis=1).astype(np.float32) - self.std = stacked.std(axis=1).astype(np.float32) + 1e-6 - self.p2 = np.percentile(stacked, 2, axis=1).astype(np.float32) - self.p98 = np.percentile(stacked, 98, axis=1).astype(np.float32) + self.std = stacked.std(axis=1).astype(np.float32) + 1e-6 + self.p2 = np.percentile(stacked, 2, axis=1).astype(np.float32) + self.p98 = np.percentile(stacked, 98, axis=1).astype(np.float32) self._fitted = True return self - # ------------------------------------------------------------------ def __call__(self, image: np.ndarray) -> np.ndarray: """ - Normalize a (4, H, W) uint16 or float32 array to float32. + Normalize a (C, H, W) uint16 or float32 array to float32. Returns values roughly in [-3, 3]. """ if not self._fitted: raise RuntimeError("Call fit() before normalizing in 'dataset' mode.") img = image.astype(np.float32) + c = img.shape[0] + + # Resize stats if band count mismatch (e.g., 3-band vs 5-band) + if c != self.num_bands: + if c > self.num_bands: + # Pad stats with defaults for extra bands + extra = c - self.num_bands + self.mean = np.concatenate([self.mean, np.full(extra, 1500.0, dtype=np.float32)]) + self.std = np.concatenate([self.std, np.full(extra, 800.0, dtype=np.float32)]) + self.p2 = np.concatenate([self.p2, np.full(extra, 0.0, dtype=np.float32)]) + self.p98 = np.concatenate([self.p98, np.full(extra, 4000.0, dtype=np.float32)]) + else: + self.mean = self.mean[:c] + self.std = self.std[:c] + self.p2 = self.p2[:c] + self.p98 = self.p98[:c] + self.num_bands = c # 1. Clip outliers band-wise - for b in range(min(4, img.shape[0])): + for b in range(c): img[b] = np.clip(img[b], self.p2[b], self.p98[b]) # 2. Standardize - for b in range(min(4, img.shape[0])): + for b in range(c): img[b] = (img[b] - self.mean[b]) / self.std[b] return img - # ------------------------------------------------------------------ def save(self, path: str | Path) -> None: data = { "mean": self.mean.tolist(), - "std": self.std.tolist(), - "p2": self.p2.tolist(), - "p98": self.p98.tolist(), + "std": self.std.tolist(), + "p2": self.p2.tolist(), + "p98": self.p98.tolist(), "mode": self.mode, + "num_bands": self.num_bands, } Path(path).write_text(json.dumps(data, indent=2)) @classmethod def load(cls, path: str | Path) -> "Sentinel2Normalizer": data = json.loads(Path(path).read_text()) - obj = cls(mode=data["mode"]) + obj = cls(mode=data["mode"], num_bands=data.get("num_bands", 4)) obj.mean = np.array(data["mean"], dtype=np.float32) - obj.std = np.array(data["std"], dtype=np.float32) - obj.p2 = np.array(data["p2"], dtype=np.float32) - obj.p98 = np.array(data["p98"], dtype=np.float32) + obj.std = np.array(data["std"], dtype=np.float32) + obj.p2 = np.array(data["p2"], dtype=np.float32) + obj.p98 = np.array(data["p98"], dtype=np.float32) obj._fitted = True return obj # --------------------------------------------------------------------------- -# Dataset statistics helper +# Cloud masking # --------------------------------------------------------------------------- def apply_scl_cloud_mask( @@ -128,15 +145,15 @@ def apply_scl_cloud_mask( Args: image: Array of shape (C, H, W). scl_band: Array of shape (H, W) containing Scene Classification Layer values. - clear_labels: SCL codes considered clear. Defaults to vegetation, bare soil, - water, and snow (``[4, 5, 6, 11]``). + clear_labels: SCL codes considered clear. If None, uses a safe default + of [4, 5, 6] (vegetation, bare soil, water). fill_value: Value to replace cloudy pixels with. Returns: Cloud-masked image with the same shape as *image*. """ if clear_labels is None: - clear_labels = [4, 5, 6, 11] + clear_labels = [4, 5, 6] if image.ndim != 3: raise ValueError(f"image must be 3-D (C, H, W), got shape {image.shape}") @@ -152,6 +169,47 @@ def apply_scl_cloud_mask( return masked +# --------------------------------------------------------------------------- +# Band resampling +# --------------------------------------------------------------------------- + +def resample_20m_to_10m( + band_20m: np.ndarray, + target_shape: tuple[int, int], + order: int = 1, +) -> np.ndarray: + """ + Resample a 20m band to 10m resolution using bilinear interpolation. + + Args: + band_20m: Array of shape (H, W) or (C, H, W). + target_shape: Desired (H, W) at 10m resolution (2× the 20m dimensions). + order: 0=nearest, 1=bilinear, 3=bicubic. Default 1 (bilinear). + + Returns: + Resampled array with the same number of dimensions as input. + """ + try: + from scipy.ndimage import zoom + except ImportError: + raise ImportError("scipy is required for band resampling. Install: pip install scipy") + + if band_20m.ndim == 2: + h, w = band_20m.shape + zoom_factors = (target_shape[0] / h, target_shape[1] / w) + return zoom(band_20m, zoom_factors, order=order, mode="reflect") + elif band_20m.ndim == 3: + c, h, w = band_20m.shape + zoom_factors = (1.0, target_shape[0] / h, target_shape[1] / w) + return zoom(band_20m, zoom_factors, order=order, mode="reflect") + else: + raise ValueError(f"band_20m must be 2-D or 3-D, got shape {band_20m.shape}") + + +# --------------------------------------------------------------------------- +# Dataset statistics helper +# --------------------------------------------------------------------------- + def compute_dataset_stats( image_dir: str | Path, max_samples: int = 500, @@ -176,7 +234,7 @@ def compute_dataset_stats( stacked = np.concatenate(all_pixels, axis=1).astype(np.float32) # (C, N) return { "mean": stacked.mean(axis=1).tolist(), - "std": stacked.std(axis=1).tolist(), - "min": stacked.min(axis=1).tolist(), - "max": stacked.max(axis=1).tolist(), + "std": stacked.std(axis=1).tolist(), + "min": stacked.min(axis=1).tolist(), + "max": stacked.max(axis=1).tolist(), } diff --git a/src/climatevision/inference/pipeline.py b/src/climatevision/inference/pipeline.py index 7af17ab..06cb3d0 100644 --- a/src/climatevision/inference/pipeline.py +++ b/src/climatevision/inference/pipeline.py @@ -2,11 +2,13 @@ Inference pipeline for ClimateVision. Provides: -- run_inference(image_array, bbox, start_date, end_date, analysis_type) — core inference on a numpy array -- run_inference_from_file(path, bbox, start_date, end_date, analysis_type) — load file then infer -- run_inference_from_gee(bbox, start_date, end_date, analysis_type) — GEE NDVI + real tile inference -""" +- run_inference(image_array, bbox, start_date, end_date, analysis_type) — core inference +- run_inference_from_file(path, ...) — load file then infer +- run_inference_from_gee(bbox, ...) — GEE query + real tile inference +- run_bitemporal_inference(pre, post, ...) — change detection for floods +NO synthetic fallbacks in production. All failures raise explicit exceptions. +""" from __future__ import annotations import json @@ -17,22 +19,16 @@ import numpy as np import torch -from climatevision.data.band_mapping import get_bands_for_analysis, get_model_config +from climatevision.data.band_mapping import get_bands_for_analysis, get_model_config, get_analysis_config from climatevision.models.unet import UNet logger = logging.getLogger(__name__) -# --------------------------------------------------------------------------- -# Project paths (mirrors run_training.py conventions, NOT Config.MODELS_DIR) -# --------------------------------------------------------------------------- _PROJECT_ROOT = Path(__file__).resolve().parents[3] _MODELS_DIR = _PROJECT_ROOT / "models" _OUTPUTS_DIR = _PROJECT_ROOT / "outputs" -# --------------------------------------------------------------------------- -# Per-analysis-type model cache -# --------------------------------------------------------------------------- -_model_cache: dict[str, tuple[UNet, torch.device]] = {} +_model_cache: dict[str, tuple[torch.nn.Module, torch.device]] = {} def _get_device() -> torch.device: @@ -64,8 +60,9 @@ def _find_best_checkpoint(analysis_type: str) -> Optional[Path]: return candidates[0] if candidates else None -def _load_model(analysis_type: str = "deforestation") -> tuple[UNet, torch.device]: - """Load (or return cached) U-Net model configured for the analysis type.""" +def _load_model(analysis_type: str = "deforestation") -> tuple[torch.nn.Module, torch.device]: + """Load (or return cached) model configured for the analysis type.""" + if analysis_type in _model_cache: return _model_cache[analysis_type] @@ -80,13 +77,11 @@ def _load_model(analysis_type: str = "deforestation") -> tuple[UNet, torch.devic if model_path is not None: checkpoint = torch.load(model_path, map_location=device) - # Load full state first (includes BatchNorm running stats) model_state = checkpoint.get("model_state_dict") - ema_state = checkpoint.get("ema_state_dict") + ema_state = checkpoint.get("ema_state_dict") if model_state is not None: model.load_state_dict(model_state, strict=False) - # Overlay EMA parameters on top (better generalisation) if ema_state is not None: with torch.no_grad(): for name, param in model.named_parameters(): @@ -94,7 +89,7 @@ def _load_model(analysis_type: str = "deforestation") -> tuple[UNet, torch.devic param.data.copy_(ema_state[name]) logger.info( - "Loaded %s model from %s (epoch %s val_iou %.4f)", + "Loaded %s model from %s (epoch %s val_iou %.4f)", analysis_type, model_path, checkpoint.get("epoch", "?"), @@ -102,7 +97,7 @@ def _load_model(analysis_type: str = "deforestation") -> tuple[UNet, torch.devic ) else: logger.warning( - "No trained model found for %s under %s — using untrained weights (demo).", + "No trained model found for %s under %s — using untrained weights.", analysis_type, _MODELS_DIR, ) @@ -114,88 +109,82 @@ def _load_model(analysis_type: str = "deforestation") -> tuple[UNet, torch.devic return model, device -# --------------------------------------------------------------------------- -# Sentinel-2 normalisation statistics (matches preprocessing.py) -# Band order: [Red, Green, Blue, NIR] -# --------------------------------------------------------------------------- -_S2_MEAN = np.array([943.0, 1069.0, 981.0, 2734.0], dtype=np.float64) -_S2_STD = np.array([590.0, 547.0, 498.0, 1246.0], dtype=np.float64) +def _validate_model_compatibility( + model: torch.nn.Module, expected_channels: int, expected_classes: int +) -> None: + """Ensure loaded model matches the expected architecture.""" + actual_channels = getattr(model, "n_channels", None) + actual_classes = getattr(model, "n_classes", None) + + if actual_channels is not None and actual_channels != expected_channels: + raise RuntimeError( + f"Model channel mismatch: expected {expected_channels}, got {actual_channels}. " + f"The checkpoint may be for a different analysis type." + ) + if actual_classes is not None and actual_classes != expected_classes: + raise RuntimeError( + f"Model class mismatch: expected {expected_classes}, got {actual_classes}. " + f"The checkpoint may be for a different analysis type." + ) # --------------------------------------------------------------------------- -# NDVI helper (works for >=4 bands; returns zeros for RGB-only) +# Analysis-specific index computation # --------------------------------------------------------------------------- def _compute_ndvi_stats(image: np.ndarray) -> dict[str, float]: - """ - Compute NDVI min/mean/max from image array. - - Expects (C, H, W) with C >= 4 where band order is [Red, Green, Blue, NIR]. - Automatically detects and reverses Sentinel-2 z-score normalisation - (values in roughly [-5, 5]) before computing NDVI. - Returns zeros if fewer than 4 bands. - """ - if image.ndim == 2: + """Compute NDVI from (C, H, W) with B04=Red, B08=NIR.""" + if image.ndim != 3 or image.shape[0] < 4: return {"NDVI_min": 0.0, "NDVI_mean": 0.0, "NDVI_max": 0.0} - # Normalise to (C, H, W) - if image.ndim == 3 and image.shape[2] < image.shape[0]: - image = np.transpose(image, (2, 0, 1)) - - n_bands = image.shape[0] - if n_bands < 4: - return {"NDVI_min": 0.0, "NDVI_mean": 0.0, "NDVI_max": 0.0} + image = np.transpose(image, (1, 2, 0)) if image.shape[0] > image.shape[2] else image + red = image[..., 0].astype(np.float64) + nir = image[..., 3].astype(np.float64) - # Band order: Red=0, Green=1, Blue=2, NIR=3 - red = image[0].astype(np.float64) - nir = image[3].astype(np.float64) - - # If data looks like z-score normalised input (values in [-10, 10]) - # denormalise back to raw Sentinel-2 DN before computing NDVI. if red.max() <= 10.0 and nir.max() <= 10.0: - red = red * _S2_STD[0] + _S2_MEAN[0] - nir = nir * _S2_STD[3] + _S2_MEAN[3] + red = red * 943.0 + 943.0 + nir = nir * 1246.0 + 2734.0 denom = nir + red + 1e-8 ndvi = (nir - red) / denom - return { - "NDVI_min": round(float(np.nanmin(ndvi)), 4), + "NDVI_min": round(float(np.nanmin(ndvi)), 4), "NDVI_mean": round(float(np.nanmean(ndvi)), 4), - "NDVI_max": round(float(np.nanmax(ndvi)), 4), + "NDVI_max": round(float(np.nanmax(ndvi)), 4), } -def _synthetic_ndvi_stats(bbox: Optional[list[float]]) -> dict[str, float]: +def _compute_mndwi_stats(image: np.ndarray) -> dict[str, float]: """ - Compute NDVI from a synthetic but physically realistic Sentinel-2 scene. - - Used as a fallback when GEE credentials are unavailable. - The bbox is used to seed the RNG so the same region always returns - the same values. Band statistics match typical tropical/temperate forest. + Compute MNDWI from (C, H, W) with B03=Green, B11=SWIR. + Expects band order [B03, B08, B11] for flood inputs. """ - seed = 42 - if bbox: - seed = int(abs(sum(v * 1000 * (i + 1) for i, v in enumerate(bbox)))) % (2 ** 31) - rng = np.random.default_rng(seed) + if image.ndim != 3 or image.shape[0] < 3: + return {"MNDWI_min": 0.0, "MNDWI_mean": 0.0, "MNDWI_max": 0.0} - # Typical Sentinel-2 L2A forest reflectance (DN, 0-10000 scale) - # Red ~600-1200, NIR ~2500-5000 - red = rng.normal(900.0, 350.0, (256, 256)).clip(50.0, 5000.0) - nir = rng.normal(3800.0, 900.0, (256, 256)).clip(100.0, 9000.0) - - denom = nir + red + 1e-8 - ndvi = (nir - red) / denom + # image is (C, H, W); for flooding: B03=idx0, B11=idx2 + green = image[0].astype(np.float64) + swir = image[2].astype(np.float64) + denom = green + swir + 1e-8 + mndwi = (green - swir) / denom return { - "NDVI_min": round(float(np.nanmin(ndvi)), 4), - "NDVI_mean": round(float(np.nanmean(ndvi)), 4), - "NDVI_max": round(float(np.nanmax(ndvi)), 4), + "MNDWI_min": round(float(np.nanmin(mndwi)), 4), + "MNDWI_mean": round(float(np.nanmean(mndwi)), 4), + "MNDWI_max": round(float(np.nanmax(mndwi)), 4), } +def _compute_analysis_stats(image: np.ndarray, analysis_type: str) -> dict[str, float]: + """Return analysis-specific spectral index stats.""" + if analysis_type == "flooding": + return _compute_mndwi_stats(image) + else: + return _compute_ndvi_stats(image) + + # --------------------------------------------------------------------------- -# Core inference on a numpy array +# Core inference # --------------------------------------------------------------------------- def run_inference( @@ -209,41 +198,42 @@ def run_inference( """ Run full inference pipeline on a (C, H, W) numpy image. - Returns dict with keys: region, ndvi_stats, inference. + Returns dict with keys: region, index_stats, inference. """ - # Normalise to (C, H, W) if image.ndim == 3 and image.shape[2] < image.shape[0]: image = np.transpose(image, (2, 0, 1)) - ndvi_stats = _compute_ndvi_stats(image) + index_stats = _compute_analysis_stats(image, analysis_type) model, device = _load_model(analysis_type) + model_cfg = get_model_config(analysis_type) + expected_channels = model_cfg.get("in_channels", 4) + expected_classes = model_cfg.get("num_classes", 2) + + _validate_model_compatibility(model, expected_channels, expected_classes) + n_channels = model.n_channels n_classes = model.n_classes - # Prepare tensor — model expects (N, n_channels, H, W) c, h, w = image.shape if c < n_channels: - # Pad missing channels with zeros pad = np.zeros((n_channels - c, h, w), dtype=image.dtype) image = np.concatenate([image, pad], axis=0) elif c > n_channels: image = image[:n_channels] - # Use torch.FloatTensor via tolist() to avoid numpy<->torch interop issues - tensor = torch.FloatTensor(image.astype(np.float32).tolist()).unsqueeze(0) # (1, C, H, W) + tensor = torch.FloatTensor(image.astype(np.float32).tolist()).unsqueeze(0) tensor = tensor.to(device) with torch.no_grad(): output = model(tensor) - predictions = torch.argmax(output, dim=1) # (1, H, W) - probabilities = torch.softmax(output, dim=1) # (1, n_classes, H, W) + predictions = torch.argmax(output, dim=1) + probabilities = torch.softmax(output, dim=1) total_pixels = int(predictions.numel()) max_probs = probabilities.max(dim=1).values mean_confidence = float(max_probs.mean().item()) - # Build per-class pixel counts class_pixels: dict[str, int] = {} class_percentages: dict[str, float] = {} for cls in range(n_classes): @@ -252,7 +242,6 @@ def run_inference( class_pixels[f"class_{cls}_pixels"] = count class_percentages[f"class_{cls}_percentage"] = round(pct, 4) - # Add friendly keys for known 2-class deforestation output (backward compat) inference: dict[str, Any] = { "image_size": [h, w], "num_classes": n_classes, @@ -260,10 +249,24 @@ def run_inference( **class_pixels, **class_percentages, } - if n_classes == 2: + + # Add friendly keys for known analysis types + if analysis_type == "deforestation" and n_classes == 2: inference["forest_pixels"] = class_pixels.get("class_1_pixels", 0) inference["non_forest_pixels"] = class_pixels.get("class_0_pixels", 0) inference["forest_percentage"] = class_percentages.get("class_1_percentage", 0.0) + elif analysis_type == "flooding" and n_classes == 3: + inference["dry_pixels"] = class_pixels.get("class_0_pixels", 0) + inference["water_pixels"] = class_pixels.get("class_1_pixels", 0) + inference["flooded_pixels"] = class_pixels.get("class_2_pixels", 0) + inference["dry_percentage"] = class_percentages.get("class_0_percentage", 0.0) + inference["water_percentage"] = class_percentages.get("class_1_percentage", 0.0) + inference["flooded_percentage"] = class_percentages.get("class_2_percentage", 0.0) + elif analysis_type == "ice_melting" and n_classes == 3: + inference["water_pixels"] = class_pixels.get("class_0_pixels", 0) + inference["ice_pixels"] = class_pixels.get("class_1_pixels", 0) + inference["land_pixels"] = class_pixels.get("class_2_pixels", 0) + inference["ice_percentage"] = class_percentages.get("class_1_percentage", 0.0) region: dict[str, Any] = {} if bbox is not None: @@ -273,14 +276,14 @@ def run_inference( return { "region": region, - "ndvi_stats": ndvi_stats, + "index_stats": index_stats, "inference": inference, "is_synthetic": False, } # --------------------------------------------------------------------------- -# File-based inference (upload path) +# File-based inference # --------------------------------------------------------------------------- def run_inference_from_file( @@ -291,9 +294,7 @@ def run_inference_from_file( end_date: Optional[str] = None, analysis_type: str = "deforestation", ) -> dict[str, Any]: - """ - Load an image file (GeoTIFF or PNG/JPEG) and run inference. - """ + """Load an image file (GeoTIFF or PNG/JPEG) and run inference.""" image = _load_image_file(path) result = run_inference( image, @@ -307,40 +308,33 @@ def run_inference_from_file( def _load_image_file(path: str) -> np.ndarray: - """ - Load image as (C, H, W) numpy array. - Tries rasterio first (GeoTIFF), falls back to Pillow. - """ + """Load image as (C, H, W) numpy array. Tries rasterio first, falls back to Pillow.""" p = Path(path) suffix = p.suffix.lower() - # Try rasterio for geospatial formats if suffix in {".tif", ".tiff", ".geotiff"}: try: import rasterio - with rasterio.open(path) as src: - image = src.read() # (C, H, W) + image = src.read() return image.astype(np.float32) except Exception: logger.warning("rasterio failed for %s, trying Pillow", path) - # Pillow fallback for PNG, JPEG, etc. from PIL import Image - pil_img = Image.open(path) - arr = np.array(pil_img) # (H, W, C) or (H, W) + arr = np.array(pil_img) if arr.ndim == 2: - arr = arr[np.newaxis, :, :] # (1, H, W) + arr = arr[np.newaxis, :, :] else: - arr = np.transpose(arr, (2, 0, 1)) # (C, H, W) + arr = np.transpose(arr, (2, 0, 1)) return arr.astype(np.float32) # --------------------------------------------------------------------------- -# GEE-based inference (bbox path) — lazy import, safe fallback +# GEE-based inference # --------------------------------------------------------------------------- def run_inference_from_gee( @@ -351,145 +345,129 @@ def run_inference_from_gee( analysis_type: str = "deforestation", ) -> dict[str, Any]: """ - Query Google Earth Engine for a real Sentinel-2 tile and run inference. + Query Google Earth Engine for a real satellite tile and run inference. - Falls back to synthetic NDVI stats and a synthetic tile if GEE is - unavailable or returns no images. + Raises: + GEEAuthenticationError: If GEE credentials are missing. + TileNotFoundError: If no images are found for the date range. + RuntimeError: If model architecture does not match analysis type. """ - ndvi_stats: Optional[dict[str, Any]] = None - gee_count: int = 0 - - if bbox and start_date and end_date: - ndvi_stats, gee_count = _try_gee_ndvi(bbox, start_date, end_date) + from climatevision.data import download_tile_for_analysis, apply_scl_cloud_mask + from climatevision.data.gee_downloader import GEEAuthenticationError, TileNotFoundError - # --- Attempt to download a real tile from GEE --- - try: - from climatevision.data import download_tile_for_analysis, apply_scl_cloud_mask - - tile_path, metadata = download_tile_for_analysis( - bbox=bbox, - start_date=start_date, - end_date=end_date, - analysis_type=analysis_type, - ) + if not (bbox and start_date and end_date): + raise ValueError("bbox, start_date, and end_date are required for GEE inference.") - image = _load_image_file(str(tile_path)) - - # If SCL band is present (last band), apply cloud mask and drop it - n_bands_expected = len(get_bands_for_analysis(analysis_type)) - if image.shape[0] == n_bands_expected + 1: - scl_band = image[-1].astype(np.uint8) - image = image[:-1] - image = apply_scl_cloud_mask(image, scl_band) - - result = run_inference( - image, - bbox=bbox, - start_date=start_date, - end_date=end_date, - analysis_type=analysis_type, - ) - result["metadata"] = metadata - result["is_synthetic"] = metadata.get("is_synthetic", False) - - # Override NDVI with GEE-derived stats if we got them; else keep computed - if ndvi_stats is not None: - result["ndvi_stats"] = ndvi_stats - elif metadata.get("is_synthetic"): - result["ndvi_stats"] = _synthetic_ndvi_stats(bbox) - - if gee_count: - result["region"]["images_available"] = gee_count + # Download real tile + tile_path, metadata = download_tile_for_analysis( + bbox=bbox, + start_date=start_date, + end_date=end_date, + analysis_type=analysis_type, + ) - return result + image = _load_image_file(str(tile_path)) - except Exception as exc: - logger.warning("Real tile inference failed (%s). Using fallback.", exc) + # If SCL band is present (last band), apply cloud mask and drop it + n_bands_expected = len(get_bands_for_analysis(analysis_type)) + if image.shape[0] == n_bands_expected + 1: + scl_band = image[-1].astype(np.uint8) + image = image[:-1] + analysis_cfg = get_analysis_config(analysis_type) + clear_labels = analysis_cfg.get("scl_clear_labels", [4, 5, 6]) + image = apply_scl_cloud_mask(image, scl_band, clear_labels=clear_labels) - # --- Fallback: template result with synthetic stats --- result = run_inference( - np.zeros((4, 256, 256), dtype=np.float32), + image, bbox=bbox, start_date=start_date, end_date=end_date, analysis_type=analysis_type, ) - - if ndvi_stats is None: - ndvi_stats = _synthetic_ndvi_stats(bbox) - result["ndvi_stats"] = ndvi_stats - - region = result.get("region", {}) - if gee_count: - region["images_available"] = gee_count - result["region"] = region - result["is_synthetic"] = True - result["metadata"] = {"is_synthetic": True, "fallback_reason": "gee_tile_download_failed"} + result["metadata"] = metadata + result["region"]["images_available"] = metadata.get("images_available", 0) return result -def _try_gee_ndvi( - bbox: list[float], start_date: str, end_date: str -) -> tuple[Optional[dict[str, Any]], int]: - """Attempt GEE NDVI query. Returns (ndvi_stats_or_None, image_count).""" - try: - import ee # lazy import - import os - - project = os.getenv("GEE_PROJECT_ID") - svc_account = os.getenv("GEE_SERVICE_ACCOUNT") - key_file = os.getenv("GEE_SERVICE_ACCOUNT_KEY") - - # Resolve relative key path against project root - if key_file and not os.path.isabs(key_file): - key_file = str(_PROJECT_ROOT / key_file) - - if svc_account and key_file and os.path.exists(key_file): - credentials = ee.ServiceAccountCredentials(svc_account, key_file) - ee.Initialize(credentials) - elif project: - ee.Initialize(project=project) - else: - ee.Initialize() - - geometry = ee.Geometry.Rectangle(bbox) - collection = ( - ee.ImageCollection("COPERNICUS/S2_SR_HARMONIZED") - .filterBounds(geometry) - .filterDate(start_date, end_date) - .filter(ee.Filter.lt("CLOUDY_PIXEL_PERCENTAGE", 20)) - .select(["B4", "B3", "B2", "B8"]) - ) +# --------------------------------------------------------------------------- +# Bitemporal change detection inference +# --------------------------------------------------------------------------- - count = collection.size().getInfo() +def run_bitemporal_inference( + pre_image: np.ndarray, + post_image: np.ndarray, + *, + bbox: Optional[list[float]] = None, + analysis_type: str = "flooding", +) -> dict[str, Any]: + """ + Run inference on pre-event and post-event images, then compute change detection. - median = collection.median() - nir = median.select("B8") - red = median.select("B4") - ndvi = nir.subtract(red).divide(nir.add(red)).rename("NDVI") + Returns: + Dict with pre_result, post_result, and change detection metrics. + For flooding: newly_flooded_pixels, receded_pixels, permanent_water_pixels. + """ + pre_result = run_inference( + pre_image, + bbox=bbox, + analysis_type=analysis_type, + ) + post_result = run_inference( + post_image, + bbox=bbox, + analysis_type=analysis_type, + ) - stats = ndvi.reduceRegion( - reducer=ee.Reducer.mean().combine(ee.Reducer.minMax(), sharedInputs=True), - geometry=geometry, - scale=100, - maxPixels=int(1e9), - ).getInfo() + pre_mask = np.array(pre_result["inference"].get("prediction_mask", [])) + post_mask = np.array(post_result["inference"].get("prediction_mask", [])) - return stats, count + # If prediction masks are not directly available, reconstruct from pixel counts + # (This is a fallback; ideally the model returns the full mask) + # For now, we store the mask in the result for change detection + # NOTE: run_inference currently doesn't return the mask — we need to add it + # We'll add mask extraction below - except Exception as exc: - logger.warning("GEE query failed (%s). Using fallback.", exc) - return None, 0 + # Actually, let's re-run inference and capture the mask directly + model, device = _load_model(analysis_type) + def _infer_mask(img: np.ndarray) -> np.ndarray: + c, h, w = img.shape + n_channels = model.n_channels + if c < n_channels: + pad = np.zeros((n_channels - c, h, w), dtype=img.dtype) + img = np.concatenate([img, pad], axis=0) + elif c > n_channels: + img = img[:n_channels] + tensor = torch.FloatTensor(img.astype(np.float32).tolist()).unsqueeze(0).to(device) + with torch.no_grad(): + output = model(tensor) + pred = torch.argmax(output, dim=1).squeeze().cpu().numpy() + return pred + + pre_mask = _infer_mask(pre_image) + post_mask = _infer_mask(post_image) + + h, w = pre_mask.shape + total = h * w + + # Class indices for flooding: 0=dry, 1=permanent_water, 2=flooded + newly_flooded = ((pre_mask != 2) & (post_mask == 2)).sum() + receded = ((pre_mask == 2) & (post_mask != 2)).sum() + permanent_water = ((pre_mask == 1) & (post_mask == 1)).sum() + + change_metrics = { + "newly_flooded_pixels": int(newly_flooded), + "newly_flooded_percentage": round(float(newly_flooded / total * 100), 4), + "receded_pixels": int(receded), + "receded_percentage": round(float(receded / total * 100), 4), + "permanent_water_pixels": int(permanent_water), + "permanent_water_percentage": round(float(permanent_water / total * 100), 4), + } -def _load_cached_ndvi() -> dict[str, float]: - """Load NDVI from outputs/inference_results.json if it exists, else zeros.""" - cached = _OUTPUTS_DIR / "inference_results.json" - if cached.exists(): - try: - data = json.loads(cached.read_text(encoding="utf-8")) - return data.get("ndvi_stats", {"NDVI_min": 0.0, "NDVI_mean": 0.0, "NDVI_max": 0.0}) - except Exception: - pass - return {"NDVI_min": 0.0, "NDVI_mean": 0.0, "NDVI_max": 0.0} + return { + "region": {"bbox": bbox} if bbox else {}, + "pre_event": pre_result, + "post_event": post_result, + "change_detection": change_metrics, + } diff --git a/src/climatevision/inference/postprocess.py b/src/climatevision/inference/postprocess.py index ad99cab..c7025c6 100644 --- a/src/climatevision/inference/postprocess.py +++ b/src/climatevision/inference/postprocess.py @@ -1,14 +1,14 @@ """ -Post-processing utilities for inference pipeline. +Inference post-processing utilities. -Provides confidence thresholding, output filtering, and anomaly detection -for model predictions before they are returned to users. +Provides: +- Confidence thresholding and small-region removal +- Sliding-window tiling for large-region inference +- Anomaly detection and quality scoring """ - from __future__ import annotations import logging -from dataclasses import dataclass from typing import Any, Optional import numpy as np @@ -16,214 +16,241 @@ logger = logging.getLogger(__name__) -@dataclass -class PostProcessConfig: - """Configuration for post-processing operations.""" - confidence_threshold: float = 0.5 - min_region_pixels: int = 100 - anomaly_std_threshold: float = 3.0 - smooth_kernel_size: int = 3 - apply_morphological_ops: bool = True - - -@dataclass -class PostProcessResult: - """Result from post-processing operations.""" - mask: np.ndarray - confidence_map: np.ndarray - filtered_pixels: int - anomaly_detected: bool - anomaly_regions: list[dict[str, Any]] - quality_score: float - +# --------------------------------------------------------------------------- +# Confidence thresholding +# --------------------------------------------------------------------------- def apply_confidence_threshold( - predictions: np.ndarray, - confidence: np.ndarray, - threshold: float = 0.5 + probabilities: np.ndarray, + threshold: float = 0.5, ) -> np.ndarray: """ - Filter predictions below confidence threshold. + Mask out predictions where max class probability is below threshold. Args: - predictions: Model prediction mask (H, W) or (H, W, C) - confidence: Confidence scores (H, W) - threshold: Minimum confidence to keep prediction + probabilities: (C, H, W) softmax probabilities. + threshold: Minimum confidence to keep a prediction. Returns: - Filtered prediction mask with low-confidence pixels zeroed + (H, W) integer mask with -1 for low-confidence pixels. """ - mask = confidence >= threshold - filtered = predictions.copy() - - if filtered.ndim == 2: - filtered[~mask] = 0 - else: - filtered[~mask, :] = 0 + max_prob = probabilities.max(axis=0) + predictions = probabilities.argmax(axis=0) + predictions[max_prob < threshold] = -1 + return predictions - filtered_count = (~mask).sum() - logger.debug(f"Filtered {filtered_count} pixels below threshold {threshold}") - - return filtered +# --------------------------------------------------------------------------- +# Small-region removal +# --------------------------------------------------------------------------- def remove_small_regions( mask: np.ndarray, - min_pixels: int = 100 + min_size: int = 50, + connectivity: int = 1, ) -> np.ndarray: """ - Remove small isolated regions from segmentation mask. + Remove connected components smaller than min_size pixels. Args: - mask: Binary segmentation mask (H, W) - min_pixels: Minimum region size to keep + mask: (H, W) binary or integer label mask. + min_size: Minimum component size to retain. + connectivity: 1 for 4-connectivity, 2 for 8-connectivity. Returns: - Cleaned mask with small regions removed + Cleaned mask with small regions removed (set to 0). """ try: from scipy import ndimage except ImportError: - logger.warning("scipy not available, skipping small region removal") + logger.warning("scipy not available; skipping small-region removal") return mask - labeled, num_features = ndimage.label(mask) - - cleaned = np.zeros_like(mask) - for i in range(1, num_features + 1): - region = labeled == i - if region.sum() >= min_pixels: - cleaned[region] = mask[region] - - removed = num_features - len(np.unique(ndimage.label(cleaned)[0])) + 1 - logger.debug(f"Removed {removed} regions smaller than {min_pixels} pixels") + cleaned = mask.copy() + unique_labels = np.unique(mask) + for label in unique_labels: + if label == 0: + continue + binary = (mask == label).astype(np.uint8) + labeled, num_features = ndimage.label(binary, structure=ndimage.generate_binary_structure(2, connectivity)) + component_sizes = ndimage.sum(binary, labeled, index=range(1, num_features + 1)) + too_small = component_sizes < min_size + remove_mask = too_small[labeled - 1] + cleaned[(mask == label) & remove_mask] = 0 return cleaned +# --------------------------------------------------------------------------- +# Anomaly detection +# --------------------------------------------------------------------------- + def detect_anomalies( - predictions: np.ndarray, - confidence: np.ndarray, - std_threshold: float = 3.0 -) -> tuple[bool, list[dict[str, Any]]]: + prediction: np.ndarray, + expected_classes: list[int], +) -> dict[str, Any]: """ - Detect anomalous predictions that may indicate model issues. + Flag anomalous predictions (unexpected class labels, extreme coverage). Args: - predictions: Model predictions - confidence: Confidence scores - std_threshold: Number of standard deviations for anomaly detection + prediction: (H, W) integer mask. + expected_classes: List of valid class indices. Returns: - Tuple of (anomaly_detected, list of anomaly regions) + Dict with anomaly flags and descriptions. """ anomalies = [] + unique = np.unique(prediction) + unexpected = set(unique) - set(expected_classes) + if unexpected: + anomalies.append(f"Unexpected class labels: {unexpected}") - mean_conf = confidence.mean() - std_conf = confidence.std() + total = prediction.size + for cls in expected_classes: + pct = (prediction == cls).sum() / total * 100 + if pct > 95: + anomalies.append(f"Class {cls} dominates ({pct:.1f}% — possible model collapse)") - if std_conf > 0: - z_scores = np.abs((confidence - mean_conf) / std_conf) - anomaly_mask = z_scores > std_threshold + return { + "is_anomalous": len(anomalies) > 0, + "anomalies": anomalies, + } - if anomaly_mask.any(): - anomaly_indices = np.where(anomaly_mask) - anomalies.append({ - "type": "confidence_outlier", - "count": int(anomaly_mask.sum()), - "mean_confidence": float(mean_conf), - "std_confidence": float(std_conf), - "threshold": std_threshold - }) - - # Check for suspiciously uniform predictions - unique_values = len(np.unique(predictions)) - if unique_values == 1 and predictions.size > 1000: - anomalies.append({ - "type": "uniform_prediction", - "message": "Model returned uniform predictions across entire region" - }) - - return len(anomalies) > 0, anomalies +# --------------------------------------------------------------------------- +# Quality scoring +# --------------------------------------------------------------------------- def compute_quality_score( + prediction: np.ndarray, confidence: np.ndarray, - filtered_ratio: float, - anomaly_detected: bool ) -> float: """ - Compute overall quality score for the prediction. - - Args: - confidence: Confidence map - filtered_ratio: Ratio of pixels filtered out - anomaly_detected: Whether anomalies were detected + Compute an overall quality score [0, 1] for a prediction. - Returns: - Quality score between 0 and 1 + Based on mean confidence and class balance. """ - mean_confidence = float(confidence.mean()) - confidence_score = min(mean_confidence, 1.0) + mean_conf = float(confidence.mean()) + total = prediction.size + class_balance = 1.0 + unique, counts = np.unique(prediction, return_counts=True) + if len(unique) > 1: + max_pct = counts.max() / total + class_balance = 1.0 - max(0, max_pct - 0.8) / 0.2 # penalize if one class > 80% - filter_penalty = filtered_ratio * 0.3 - anomaly_penalty = 0.2 if anomaly_detected else 0.0 + return round(mean_conf * 0.7 + class_balance * 0.3, 4) - quality = max(0.0, confidence_score - filter_penalty - anomaly_penalty) - return round(quality, 4) +# --------------------------------------------------------------------------- +# Sliding-window tiling for large regions +# --------------------------------------------------------------------------- -def postprocess_predictions( - predictions: np.ndarray, - confidence: np.ndarray, - config: Optional[PostProcessConfig] = None -) -> PostProcessResult: +class SlidingWindowTiler: """ - Apply full post-processing pipeline to model predictions. - - Args: - predictions: Raw model predictions (H, W) or (H, W, C) - confidence: Confidence scores (H, W) - config: Post-processing configuration - - Returns: - PostProcessResult with filtered predictions and quality metrics + Tile large images into overlapping patches for inference, + then stitch predictions back with soft-voting in overlap regions. """ - if config is None: - config = PostProcessConfig() - - original_pixels = predictions.size - - # Apply confidence threshold - filtered = apply_confidence_threshold( - predictions, confidence, config.confidence_threshold - ) - - # Remove small regions for binary masks - if filtered.ndim == 2: - filtered = remove_small_regions(filtered, config.min_region_pixels) - - # Detect anomalies - anomaly_detected, anomaly_regions = detect_anomalies( - predictions, confidence, config.anomaly_std_threshold - ) - - # Compute metrics - filtered_pixels = int((filtered == 0).sum()) - filtered_ratio = filtered_pixels / max(original_pixels, 1) - - quality_score = compute_quality_score( - confidence, filtered_ratio, anomaly_detected - ) - - if anomaly_detected: - logger.warning(f"Anomalies detected in prediction: {anomaly_regions}") - - return PostProcessResult( - mask=filtered, - confidence_map=confidence, - filtered_pixels=filtered_pixels, - anomaly_detected=anomaly_detected, - anomaly_regions=anomaly_regions, - quality_score=quality_score - ) + + def __init__( + self, + tile_size: int = 256, + overlap: int = 32, + ): + assert overlap < tile_size, "overlap must be smaller than tile_size" + self.tile_size = tile_size + self.overlap = overlap + self.stride = tile_size - overlap + + def generate_tiles( + self, + image: np.ndarray, + ) -> list[tuple[int, int, np.ndarray]]: + """ + Generate (row, col, tile) tuples from a (C, H, W) image. + + row and col are the top-left coordinates in the original image. + """ + c, h, w = image.shape + tiles = [] + for y in range(0, h - self.tile_size + 1, self.stride): + for x in range(0, w - self.tile_size + 1, self.stride): + tile = image[:, y:y + self.tile_size, x:x + self.tile_size] + tiles.append((y, x, tile)) + + # Handle edge cases: last row/col if not perfectly divisible + if h % self.stride != 0 and h > self.tile_size: + y = h - self.tile_size + for x in range(0, w - self.tile_size + 1, self.stride): + tile = image[:, y:y + self.tile_size, x:x + self.tile_size] + tiles.append((y, x, tile)) + if w % self.stride != 0 and w > self.tile_size: + x = w - self.tile_size + for y in range(0, h - self.tile_size + 1, self.stride): + tile = image[:, y:y + self.tile_size, x:x + self.tile_size] + tiles.append((y, x, tile)) + if h % self.stride != 0 and w % self.stride != 0 and h > self.tile_size and w > self.tile_size: + y = h - self.tile_size + x = w - self.tile_size + tile = image[:, y:y + self.tile_size, x:x + self.tile_size] + tiles.append((y, x, tile)) + + return tiles + + def stitch_predictions( + self, + tiles: list[tuple[int, int, np.ndarray]], + image_shape: tuple[int, int, int], + num_classes: int, + ) -> np.ndarray: + """ + Stitch per-tile predictions into a full-size prediction mask. + + Args: + tiles: List of (row, col, pred_mask) where pred_mask is (H, W) int. + image_shape: (C, H, W) original image shape. + num_classes: Number of classes. + + Returns: + (H, W) stitched prediction mask. + """ + c, h, w = image_shape + vote_counts = np.zeros((num_classes, h, w), dtype=np.float32) + + for y, x, pred in tiles: + th, tw = pred.shape + for cls in range(num_classes): + vote_counts[cls, y:y + th, x:x + tw] += (pred == cls).astype(np.float32) + + # Class with most votes wins + stitched = vote_counts.argmax(axis=0) + return stitched + + def stitch_probabilities( + self, + tiles: list[tuple[int, int, np.ndarray]], + image_shape: tuple[int, int, int], + ) -> np.ndarray: + """ + Stitch per-tile probability maps with averaging in overlap regions. + + Args: + tiles: List of (row, col, probs) where probs is (C, H, W) float. + image_shape: (C, H, W) original image shape. + + Returns: + (num_classes, H, W) averaged probability map. + """ + c, h, w = image_shape + num_classes = tiles[0][2].shape[0] + prob_sum = np.zeros((num_classes, h, w), dtype=np.float32) + count = np.zeros((h, w), dtype=np.float32) + + for y, x, probs in tiles: + _, th, tw = probs.shape + prob_sum[:, y:y + th, x:x + tw] += probs + count[y:y + th, x:x + tw] += 1.0 + + # Avoid divide by zero + count = np.clip(count, 1e-8, None) + averaged = prob_sum / count[np.newaxis, :, :] + return averaged diff --git a/tests/test_security.py b/tests/test_security.py new file mode 100644 index 0000000..e9e4fd5 --- /dev/null +++ b/tests/test_security.py @@ -0,0 +1,268 @@ +""" +Security tests for ClimateVision API. + +Tests input validation, sanitization, and security controls. +""" + +import pytest +import numpy as np +from pathlib import Path +import sys + +sys.path.insert(0, str(Path(__file__).parent.parent / "src")) + +from climatevision.security import ( + validate_payload_size, + validate_bbox, + validate_file_upload, + sanitize_string_input, + SecurityConfig, + RateLimiter, + detect_adversarial_input, + validate_model_output, + InputAnomalyDetector, + PipelineGuard, +) + + +class TestPayloadValidation: + """Test payload size validation.""" + + def test_valid_payload(self): + data = b"x" * 1000 + is_valid, error = validate_payload_size(data, max_size=2000) + assert is_valid + assert error == "" + + def test_oversized_payload(self): + data = b"x" * 10000 + is_valid, error = validate_payload_size(data, max_size=1000) + assert not is_valid + assert "exceeds maximum" in error + + +class TestBboxValidation: + """Test bounding box validation.""" + + def test_valid_bbox(self): + bbox = [-60.0, -15.0, -45.0, -5.0] + is_valid, error = validate_bbox(bbox) + assert is_valid + assert error == "" + + def test_invalid_longitude(self): + bbox = [200.0, 10.0, 30.0, 40.0] + is_valid, error = validate_bbox(bbox) + assert not is_valid + assert "longitude" in error.lower() + + def test_invalid_latitude(self): + bbox = [10.0, 100.0, 30.0, 40.0] + is_valid, error = validate_bbox(bbox) + assert not is_valid + assert "latitude" in error.lower() + + def test_wrong_order_longitude(self): + bbox = [30.0, 10.0, 20.0, 40.0] + is_valid, error = validate_bbox(bbox) + assert not is_valid + assert "West" in error + + def test_wrong_order_latitude(self): + bbox = [10.0, 50.0, 30.0, 40.0] + is_valid, error = validate_bbox(bbox) + assert not is_valid + assert "South" in error + + def test_too_large_area(self): + bbox = [-180.0, -90.0, 180.0, 90.0] + is_valid, error = validate_bbox(bbox, max_area=100.0) + assert not is_valid + assert "area" in error.lower() + + def test_wrong_element_count(self): + bbox = [10.0, 20.0, 30.0] + is_valid, error = validate_bbox(bbox) + assert not is_valid + assert "exactly 4" in error + + +class TestFileUploadValidation: + """Test file upload validation.""" + + def test_valid_png(self): + content = b"\x89PNG\r\n\x1a\n" + b"x" * 100 + is_valid, error = validate_file_upload(content, "image.png") + assert is_valid + assert error == "" + + def test_valid_tiff(self): + content = b"II*\x00" + b"x" * 100 + is_valid, error = validate_file_upload(content, "satellite.tif") + assert is_valid + assert error == "" + + def test_invalid_extension(self): + content = b"\x89PNG\r\n\x1a\n" + b"x" * 100 + is_valid, error = validate_file_upload(content, "malware.exe") + assert not is_valid + assert "not allowed" in error + + def test_path_traversal(self): + content = b"\x89PNG\r\n\x1a\n" + b"x" * 100 + is_valid, error = validate_file_upload(content, "../../../etc/passwd") + assert not is_valid + assert "path traversal" in error.lower() + + def test_extension_mismatch(self): + content = b"\x89PNG\r\n\x1a\n" + b"x" * 100 + is_valid, error = validate_file_upload(content, "image.jpg") + assert not is_valid + assert "does not match" in error + + def test_filename_too_long(self): + content = b"\x89PNG\r\n\x1a\n" + b"x" * 100 + filename = "a" * 300 + ".png" + is_valid, error = validate_file_upload(content, filename) + assert not is_valid + assert "too long" in error + + +class TestStringSanitization: + """Test string input sanitization.""" + + def test_normal_string(self): + result, warnings = sanitize_string_input("Hello World") + assert "Hello" in result + assert len(warnings) == 0 or "sanitized" in warnings[0].lower() + + def test_sql_injection(self): + result, warnings = sanitize_string_input("'; DROP TABLE users; --") + assert "DROP TABLE" not in result + assert any("blocked" in w.lower() or "sanitized" in w.lower() for w in warnings) + + def test_xss_script(self): + result, warnings = sanitize_string_input("") + assert " 0 + + def test_path_traversal(self): + result, warnings = sanitize_string_input("../../../etc/passwd") + assert "../" not in result + + def test_truncation(self): + long_string = "a" * 2000 + result, warnings = sanitize_string_input(long_string, max_length=100) + assert len(result) <= 100 + assert any("truncated" in w.lower() for w in warnings) + + +class TestRateLimiter: + """Test rate limiting.""" + + def test_allows_under_limit(self): + limiter = RateLimiter(max_requests=5, window_seconds=60) + for _ in range(5): + assert limiter.is_allowed("test_key") + + def test_blocks_over_limit(self): + limiter = RateLimiter(max_requests=3, window_seconds=60) + for _ in range(3): + assert limiter.is_allowed("test_key") + assert not limiter.is_allowed("test_key") + + def test_separate_keys(self): + limiter = RateLimiter(max_requests=2, window_seconds=60) + assert limiter.is_allowed("key1") + assert limiter.is_allowed("key1") + assert not limiter.is_allowed("key1") + assert limiter.is_allowed("key2") # Different key + + def test_remaining_count(self): + limiter = RateLimiter(max_requests=5, window_seconds=60) + assert limiter.get_remaining("key") == 5 + limiter.is_allowed("key") + assert limiter.get_remaining("key") == 4 + + +class TestAdversarialDetection: + """Test adversarial input detection.""" + + def test_normal_image(self): + image = np.random.randn(4, 256, 256).astype(np.float32) + result = detect_adversarial_input(image) + assert not result.is_anomalous + assert result.anomaly_score < 0.5 + + def test_uniform_image(self): + image = np.ones((4, 256, 256), dtype=np.float32) + result = detect_adversarial_input(image) + assert result.is_anomalous + assert "uniform" in str(result.details).lower() or result.anomaly_score > 0.3 + + def test_nan_values(self): + image = np.random.randn(4, 256, 256).astype(np.float32) + image[0, 100, 100] = np.nan + result = detect_adversarial_input(image) + assert result.is_anomalous + assert result.anomaly_score >= 0.5 + + def test_inf_values(self): + image = np.random.randn(4, 256, 256).astype(np.float32) + image[0, 100, 100] = np.inf + result = detect_adversarial_input(image) + assert result.is_anomalous + + def test_out_of_range(self): + image = np.random.randn(4, 256, 256).astype(np.float32) * 100 + result = detect_adversarial_input(image) + assert result.anomaly_score > 0 + + +class TestOutputValidation: + """Test model output validation.""" + + def test_valid_output(self): + predictions = np.random.randint(0, 2, (256, 256)) + result = validate_model_output(predictions, n_classes=2) + assert result.is_valid + assert result.confidence > 0.5 + + def test_invalid_class_values(self): + predictions = np.array([[0, 1, 5, 10]]) + result = validate_model_output(predictions, n_classes=2) + assert not result.is_valid or len(result.issues) > 0 + + def test_single_class_domination(self): + predictions = np.ones((256, 256), dtype=np.int32) + result = validate_model_output(predictions, n_classes=2) + assert len(result.issues) > 0 + assert any("dominates" in issue.lower() for issue in result.issues) + + def test_nan_in_predictions(self): + predictions = np.array([[0.0, 1.0, np.nan]]) + result = validate_model_output(predictions, n_classes=2) + assert not result.is_valid + + +class TestPipelineGuard: + """Test complete pipeline guard.""" + + def test_blocks_adversarial(self): + guard = PipelineGuard() + adversarial_image = np.ones((4, 256, 256), dtype=np.float32) * 0.5 + + result = guard.check_input(adversarial_image) + # Uniform image should be flagged + assert result.anomaly_score > 0 + + def test_passes_normal_image(self): + guard = PipelineGuard() + normal_image = np.random.randn(4, 256, 256).astype(np.float32) + + result = guard.check_input(normal_image) + assert not result.is_anomalous + + +if __name__ == "__main__": + pytest.main([__file__, "-v"])