diff --git a/config.yaml b/config.yaml
index 2ce5c8a..aba8172 100644
--- a/config.yaml
+++ b/config.yaml
@@ -16,6 +16,7 @@ analysis_types:
num_classes: 2
bands: ["B04", "B03", "B02", "B08"] # Red, Green, Blue, NIR
classes: ["non_forest", "forest"]
+ scl_clear_labels: [4, 5, 6, 11] # vegetation, bare soil, water, snow
thresholds:
alert_forest_loss: 5.0 # Alert if >5% forest loss
critical_forest_loss: 15.0 # Critical if >15% loss
@@ -38,6 +39,7 @@ analysis_types:
num_classes: 3
bands: ["B02", "B03", "B04", "B11"] # Blue, Green, Red, SWIR
classes: ["open_water", "sea_ice", "land"]
+ scl_clear_labels: [4, 5, 6, 11] # vegetation, bare soil, water, snow
thresholds:
alert_ice_loss: 10.0 # Alert if >10% ice loss
critical_ice_loss: 25.0 # Critical if >25% loss
@@ -60,12 +62,19 @@ analysis_types:
display_name: "Flood Detection"
description: "Detect and monitor flooding events and affected areas"
model:
- architecture: "unet"
+ architecture: "flood_unet"
weights: "models/unet_flood.pth"
in_channels: 3
num_classes: 3
+ model_sar:
+ architecture: "flood_unet"
+ weights: "models/unet_flood_sar.pth"
+ in_channels: 5
+ num_classes: 3
bands: ["B03", "B08", "B11"] # Green, NIR, SWIR
+ sar_bands: ["VV", "VH"]
classes: ["dry_land", "permanent_water", "flooded"]
+ scl_clear_labels: [4, 5, 6] # vegetation, bare soil, water (NO snow/ice)
thresholds:
alert_flood_area: 5.0 # Alert if >5% area flooded
critical_flood_area: 20.0 # Critical if >20% flooded
@@ -74,6 +83,7 @@ analysis_types:
- "flooded_percentage"
- "flooded_area_km2"
- "mndwi_stats"
+ - "affected_road_km"
# Drought Monitoring
drought:
@@ -153,6 +163,13 @@ satellite:
cloud_coverage_max: 20 # percentage
revisit_time: 5 # days
+ sentinel1:
+ bands: ["VV", "VH"]
+ resolution: 10 # meters
+ revisit_time: 6 # days
+ polarization: ["VV", "VH"]
+ orbit: "DESCENDING"
+
landsat8:
bands: ["B4", "B3", "B2", "B5"]
resolution: 30 # meters
diff --git a/requirements.txt b/requirements.txt
index c67ad0e..1a981a1 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -16,6 +16,9 @@ fiona>=1.9.0
opencv-python>=4.5.0
pillow>=9.0.0
albumentations>=1.3.0
+segmentation-models-pytorch>=0.3.3
+timm>=0.9.0
+scipy>=1.10.0
# Visualization
matplotlib>=3.5.0
diff --git a/scripts/export_model.py b/scripts/export_model.py
index e855a12..e4def52 100644
--- a/scripts/export_model.py
+++ b/scripts/export_model.py
@@ -46,30 +46,63 @@
# Model loading
# ---------------------------------------------------------------------------
+def _load_run_config(ckpt_path: Path) -> dict:
+ """Load the full training config from the run directory."""
+ run_dir = ckpt_path.parent
+ config_path = run_dir / "config.yaml"
+ if config_path.exists():
+ try:
+ import yaml
+ with open(config_path) as f:
+ return yaml.safe_load(f) or {}
+ except Exception:
+ pass
+ return {}
+
+
def load_model(ckpt_path: Path) -> tuple[nn.Module, dict]:
from climatevision.models.unet import get_model
+ from climatevision.models.flood_unet import build_flood_model
ckpt = torch.load(ckpt_path, map_location="cpu")
- cfg = ckpt.get("cfg", {})
+ # Trainer cfg is in ckpt['cfg']; full model cfg is in config.yaml next to checkpoint
+ cfg = _load_run_config(ckpt_path)
+ if not cfg:
+ cfg = ckpt.get("cfg", {})
arch = cfg.get("model", {}).get("architecture", "attention_unet")
state = ckpt.get("ema_state_dict") or ckpt.get("model_state_dict", ckpt)
- # Infer in_channels from weight shape
- in_ch = 4
+ # Infer in_channels and n_classes from weight shape
+ in_ch = cfg.get("model", {}).get("in_channels", 4)
+ n_classes = cfg.get("model", {}).get("num_classes", 2)
for key, val in state.items():
- if "inc" in key and "weight" in key and val.ndim == 4:
- in_ch = val.shape[1]
- break
+ if val.ndim == 4:
+ if val.shape[1] in (3, 4, 5) and "encoder" not in key and "down" not in key:
+ # first conv layer — input channels
+ in_ch = val.shape[1]
+ if val.shape[0] in (2, 3) and ("outc" in key or "segmentation_head" in key or "classifier" in key):
+ # final conv layer — output classes
+ n_classes = val.shape[0]
+
+ # Flood models use smp-based architectures
+ if arch in ("flood_unet", "flood_unet_s2only"):
+ use_sar = in_ch == 5
+ model = build_flood_model(
+ use_sar=use_sar,
+ encoder_name=cfg.get("model", {}).get("encoder", "efficientnet-b7"),
+ )
+ else:
+ model = get_model(arch, n_channels=in_ch, n_classes=n_classes)
- model = get_model(arch, n_channels=in_ch, n_classes=2)
model.load_state_dict(state, strict=False)
model.eval()
logger.info(
- "Loaded %s (in_channels=%d) from epoch %d val_iou=%.4f",
+ "Loaded %s (in_channels=%d, classes=%d) from epoch %d val_iou=%.4f",
arch,
in_ch,
+ n_classes,
ckpt.get("epoch", 0),
ckpt.get("val_iou", 0.0),
)
@@ -233,7 +266,7 @@ def main() -> None:
"checkpoint": str(ckpt_path),
"architecture": cfg.get("model", {}).get("architecture", "unknown"),
"in_channels": in_channels,
- "num_classes": 2,
+ "num_classes": cfg.get("model", {}).get("num_classes", 2),
"image_size": image_size,
"onnx_opset": args.opset,
"onnx_path": str(onnx_path),
diff --git a/scripts/prepare_data.py b/scripts/prepare_data.py
index 2ef1100..ff24ba6 100644
--- a/scripts/prepare_data.py
+++ b/scripts/prepare_data.py
@@ -1,22 +1,19 @@
"""
-Data preparation script for ClimateVision forest segmentation.
+Data preparation script for ClimateVision.
-Two modes:
- --mode synthetic Generate fractal-noise synthetic Sentinel-2 patches (no data required)
+Modes:
+ --mode synthetic Generate synthetic patches for testing/smoke tests
--mode gee Download real Sentinel-2 L2A tiles via Google Earth Engine
Usage:
- # Quick start — 2 000 synthetic patches, default 70/15/15 split:
- python scripts/prepare_data.py --mode synthetic --n-patches 2000 --out data/processed
-
- # Fewer patches for a fast smoke test:
- python scripts/prepare_data.py --mode synthetic --n-patches 200 --out data/processed
-
- # Real data via GEE (requires authenticated `earthengine-api`):
- python scripts/prepare_data.py --mode gee \\
- --bbox 2.3 48.8 2.5 49.0 \\
- --start 2022-01-01 --end 2023-12-31 \\
- --out data/processed
+ # Synthetic flood patches for testing:
+ python scripts/prepare_data.py --mode synthetic --analysis-type flooding --n-patches 200 --out data/processed/flood
+
+ # Real data via GEE:
+ python scripts/prepare_data.py --mode gee --analysis-type flooding \
+ --bbox 36.7 -1.4 37.0 -1.1 \
+ --start 2024-04-01 --end 2024-04-10 \
+ --out data/processed/flood
"""
from __future__ import annotations
@@ -41,38 +38,51 @@
# ---------------------------------------------------------------------------
def generate_synthetic(
+ analysis_type: str,
n_patches: int,
out_dir: Path,
patch_size: int,
train_ratio: float,
val_ratio: float,
) -> None:
- """Delegate entirely to the built-in synthetic generator."""
- try:
- from climatevision.data.synthetic import generate_synthetic_dataset
- except ImportError as exc:
- logger.error("Cannot import climatevision package: %s", exc)
- logger.error("Run `pip install -e .` from the project root first.")
- sys.exit(1)
+ """Generate synthetic patches for testing."""
+ from climatevision.data.synthetic import (
+ generate_synthetic_dataset,
+ generate_flood_test_patches,
+ )
test_ratio = max(0.0, 1.0 - train_ratio - val_ratio)
n_train = int(n_patches * train_ratio)
- n_val = int(n_patches * val_ratio)
- n_test = max(0, n_patches - n_train - n_val)
+ n_val = int(n_patches * val_ratio)
+ n_test = max(0, n_patches - n_train - n_val)
- logger.info(
- "Generating %d synthetic patches "
- "(train=%d / val=%d / test=%d) patch_size=%d",
- n_patches, n_train, n_val, n_test, patch_size,
- )
-
- generate_synthetic_dataset(
- output_dir=out_dir,
- n_train=n_train,
- n_val=n_val,
- n_test=n_test,
- patch_size=patch_size,
- )
+ if analysis_type == "flooding":
+ logger.info(
+ "Generating %d synthetic flood patches (train=%d / val=%d / test=%d)",
+ n_patches, n_train, n_val, n_test,
+ )
+ for split, n in [("train", n_train), ("val", n_val), ("test", n_test)]:
+ if n > 0:
+ split_dir = out_dir / split
+ generate_flood_test_patches(
+ output_dir=split_dir,
+ n=n,
+ patch_size=patch_size,
+ use_sar=False,
+ seed=42 if split == "train" else 99 if split == "val" else 123,
+ )
+ else:
+ logger.info(
+ "Generating %d synthetic forest patches (train=%d / val=%d / test=%d)",
+ n_patches, n_train, n_val, n_test,
+ )
+ generate_synthetic_dataset(
+ output_dir=out_dir,
+ n_train=n_train,
+ n_val=n_val,
+ n_test=n_test,
+ patch_size=patch_size,
+ )
logger.info("Dataset written to %s", out_dir)
@@ -82,6 +92,7 @@ def generate_synthetic(
# ---------------------------------------------------------------------------
def download_gee(
+ analysis_type: str,
bbox: tuple[float, float, float, float],
start: str,
end: str,
@@ -101,8 +112,8 @@ def download_gee(
try:
import os
svc_account = os.getenv("GEE_SERVICE_ACCOUNT")
- key_file = os.getenv("GEE_SERVICE_ACCOUNT_KEY")
- project = os.getenv("GEE_PROJECT_ID")
+ key_file = os.getenv("GEE_SERVICE_ACCOUNT_KEY")
+ project = os.getenv("GEE_PROJECT_ID")
if key_file and not os.path.isabs(key_file):
key_file = str(PROJECT_ROOT / key_file)
@@ -128,37 +139,41 @@ def download_gee(
import random, urllib.request, tempfile, os
+ from climatevision.data.band_mapping import get_bands_for_analysis
+
+ bands = get_bands_for_analysis(analysis_type)
+ gee_bands = {
+ "B01": "B1", "B02": "B2", "B03": "B3", "B04": "B4",
+ "B05": "B5", "B06": "B6", "B07": "B7", "B08": "B8",
+ "B8A": "B8A", "B09": "B9", "B10": "B10", "B11": "B11", "B12": "B12",
+ }
+ selected_bands = [gee_bands[b] for b in bands]
+
west, south, east, north = bbox
- # GEE download size limit is 48 MB per request.
- # At 100 m resolution, a 0.25° tile is ~278x278 px × 5 bands × 4 bytes ≈ 1.5 MB — safe.
- # 100 m is standard for regional forest classification.
TILE_DEG = 0.25
- SCALE_M = 100
+ SCALE_M = 100
- # Build tile grid
tiles = []
lat = south
while lat < north:
lon = west
while lon < east:
tiles.append((
- round(lon, 6),
- round(lat, 6),
- round(min(lon + TILE_DEG, east), 6),
+ round(lon, 6),
+ round(lat, 6),
+ round(min(lon + TILE_DEG, east), 6),
round(min(lat + TILE_DEG, north), 6),
))
lon += TILE_DEG
lat += TILE_DEG
- logger.info("Downloading %d tiles (%.2f° each, scale=%dm)…", len(tiles), TILE_DEG, SCALE_M)
+ logger.info("Downloading %d tiles (%.2f deg each, scale=%dm)…", len(tiles), TILE_DEG, SCALE_M)
patches: list[tuple[np.ndarray, np.ndarray]] = []
-
- # Minimal rasterio profile for writing plain GeoTIFF patches
base_profile = {
"driver": "GTiff",
- "crs": "EPSG:4326",
+ "crs": "EPSG:4326",
"transform": rasterio.transform.from_bounds(west, south, east, north, patch_size, patch_size),
}
@@ -173,25 +188,34 @@ def download_gee(
.filterBounds(tile_region)
.filterDate(start, end)
.filter(ee.Filter.lt("CLOUDY_PIXEL_PERCENTAGE", cloud_threshold * 100))
- .select(["B4", "B3", "B2", "B8"])
+ .select(selected_bands)
)
- dw = (
- ee.ImageCollection("GOOGLE/DYNAMICWORLD/V1")
- .filterBounds(tile_region)
- .filterDate(start, end)
- .select("label")
- .mode()
- )
- forest_mask = dw.eq(1).rename("forest")
+ # For flood detection, use JRC Global Surface Water for water mask baseline
+ if analysis_type == "flooding":
+ water = ee.Image("JRC/GSW1_4/GlobalSurfaceWater").select("occurrence").clip(tile_region)
+ water_mask = water.gt(50).rename("water")
+ # We can't easily generate "flooded" labels without event-specific data,
+ # so we create a 3-class mask: dry=0, permanent_water=1, flooded=2
+ # Here we set flooded=water (simplified — in production use event-specific polygons)
+ label = water_mask
+ else:
+ dw = (
+ ee.ImageCollection("GOOGLE/DYNAMICWORLD/V1")
+ .filterBounds(tile_region)
+ .filterDate(start, end)
+ .select("label")
+ .mode()
+ )
+ label = dw.eq(1).rename("forest")
try:
- image = collection.median().clip(tile_region)
- combined = image.addBands(forest_mask)
+ image = collection.median().clip(tile_region)
+ combined = image.addBands(label)
url = combined.getDownloadURL({
"region": tile_region,
- "scale": SCALE_M,
+ "scale": SCALE_M,
"format": "GEO_TIFF",
})
@@ -199,16 +223,17 @@ def download_gee(
urllib.request.urlretrieve(url, tmp)
with rasterio.open(tmp) as src:
- full = src.read() # (5, H, W)
+ full = src.read()
os.unlink(tmp)
- if full.shape[0] < 5:
+ n_bands = len(selected_bands)
+ if full.shape[0] < n_bands + 1:
continue
- image_data = full[:4].astype(np.float32)
- mask_data = (full[4] > 0).astype(np.uint8)
- _, H, W = image_data.shape
+ image_data = full[:n_bands].astype(np.float32)
+ mask_data = full[n_bands].astype(np.uint8)
+ _, H, W = image_data.shape
for y in range(0, H - patch_size + 1, patch_size):
for x in range(0, W - patch_size + 1, patch_size):
@@ -216,12 +241,12 @@ def download_gee(
break
patches.append((
image_data[:, y:y + patch_size, x:x + patch_size],
- mask_data[ y:y + patch_size, x:x + patch_size],
+ mask_data[y:y + patch_size, x:x + patch_size],
))
if len(patches) >= max_patches:
break
- logger.info(" tile %d/%d → %d patches so far", ti + 1, len(tiles), len(patches))
+ logger.info(" tile %d/%d -> %d patches so far", ti + 1, len(tiles), len(patches))
except Exception as exc:
logger.warning(" tile %d/%d skipped: %s", ti + 1, len(tiles), exc)
@@ -233,16 +258,15 @@ def download_gee(
logger.info("Extracted %d patches total", len(patches))
- # Shuffle + split
random.seed(42)
random.shuffle(patches)
- n = len(patches)
+ n = len(patches)
n_train = int(n * train_ratio)
- n_val = int(n * val_ratio)
- splits = {
+ n_val = int(n * val_ratio)
+ splits = {
"train": patches[:n_train],
- "val": patches[n_train:n_train + n_val],
- "test": patches[n_train + n_val:],
+ "val": patches[n_train:n_train + n_val],
+ "test": patches[n_train + n_val:],
}
for split, split_patches in splits.items():
@@ -250,7 +274,7 @@ def download_gee(
(out_dir / split / "masks").mkdir(parents=True, exist_ok=True)
for idx, (img_patch, mask_patch) in enumerate(split_patches):
stem = f"patch_{idx:05d}"
- img_profile = {**base_profile, "count": 4, "dtype": "float32",
+ img_profile = {**base_profile, "count": img_patch.shape[0], "dtype": "float32",
"height": patch_size, "width": patch_size}
with rasterio.open(out_dir / split / "images" / f"{stem}.tif", "w", **img_profile) as dst:
dst.write(img_patch)
@@ -268,7 +292,6 @@ def download_gee(
# ---------------------------------------------------------------------------
def fit_normalizer(data_dir: Path, out_path: Path) -> None:
- """Compute per-band mean/std on the training set and save to JSON."""
try:
from climatevision.data.preprocessing import Sentinel2Normalizer
except ImportError as exc:
@@ -304,7 +327,9 @@ def parse_args() -> argparse.Namespace:
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
p.add_argument("--mode", choices=["synthetic", "gee"], default="synthetic")
- p.add_argument("--out", type=Path, default=Path("data/processed"),
+ p.add_argument("--analysis-type", default="deforestation",
+ help="Analysis type: deforestation, flooding, ice_melting")
+ p.add_argument("--out", type=Path, default=Path("data/processed"),
help="Output directory (created if needed)")
# Synthetic options
@@ -318,16 +343,16 @@ def parse_args() -> argparse.Namespace:
help="[gee] Bounding box: west south east north")
p.add_argument("--start", type=str, default="2022-01-01",
help="[gee] Start date YYYY-MM-DD")
- p.add_argument("--end", type=str, default="2023-12-31",
+ p.add_argument("--end", type=str, default="2023-12-31",
help="[gee] End date YYYY-MM-DD")
p.add_argument("--max-patches", type=int, default=5000,
help="[gee] Maximum patches to extract from download")
p.add_argument("--cloud-threshold", type=float, default=0.2,
- help="[gee] Max cloud fraction (0–1)")
+ help="[gee] Max cloud fraction (0-1)")
# Split ratios
p.add_argument("--train-ratio", type=float, default=0.70)
- p.add_argument("--val-ratio", type=float, default=0.15)
+ p.add_argument("--val-ratio", type=float, default=0.15)
# Normalizer
p.add_argument("--fit-normalizer", action="store_true",
@@ -342,11 +367,12 @@ def main() -> None:
args = parse_args()
if args.train_ratio + args.val_ratio > 1.0:
- logger.error("--train-ratio + --val-ratio must be ≤ 1.0")
+ logger.error("--train-ratio + --val-ratio must be <= 1.0")
sys.exit(1)
if args.mode == "synthetic":
generate_synthetic(
+ analysis_type=args.analysis_type,
n_patches=args.n_patches,
out_dir=args.out,
patch_size=args.patch_size,
@@ -358,7 +384,8 @@ def main() -> None:
logger.error("--bbox W S E N is required for --mode gee")
sys.exit(1)
download_gee(
- bbox=tuple(args.bbox), # type: ignore[arg-type]
+ analysis_type=args.analysis_type,
+ bbox=tuple(args.bbox),
start=args.start,
end=args.end,
out_dir=args.out,
diff --git a/scripts/train.py b/scripts/train.py
index 87decbb..41e2714 100644
--- a/scripts/train.py
+++ b/scripts/train.py
@@ -1,20 +1,14 @@
"""
-Production training entry-point for ClimateVision forest segmentation.
+Production training entry-point for ClimateVision.
Usage:
- # Train with defaults (generates synthetic data if none exists):
- python scripts/train.py
+ # Train flood detection model with real data:
+ python scripts/train.py --analysis-type flooding --data-dir data/processed/flood
# Custom config:
- python scripts/train.py --config config/train.yaml
+ python scripts/train.py --config config/train.yaml --analysis-type flooding
- # Override specific keys:
- python scripts/train.py --config config/train.yaml \\
- --data-dir data/processed \\
- --epochs 50 \\
- --batch-size 8
-
- # Resume from a checkpoint:
+ # Resume from checkpoint:
python scripts/train.py --resume models/my_run/checkpoint_epoch_0030.pth
"""
from __future__ import annotations
@@ -33,7 +27,6 @@
)
logger = logging.getLogger(__name__)
-# Project root on the Python path so `climatevision` is importable
PROJECT_ROOT = Path(__file__).resolve().parent.parent
sys.path.insert(0, str(PROJECT_ROOT / "src"))
@@ -44,7 +37,7 @@
def _load_yaml(path: str | Path) -> dict:
try:
- import yaml # PyYAML
+ import yaml
except ImportError:
logger.error("PyYAML not installed. Run: pip install pyyaml")
sys.exit(1)
@@ -53,7 +46,6 @@ def _load_yaml(path: str | Path) -> dict:
def _deep_merge(base: dict, override: dict) -> dict:
- """Recursively merge override into a copy of base."""
result = base.copy()
for k, v in override.items():
if isinstance(v, dict) and isinstance(result.get(k), dict):
@@ -64,7 +56,6 @@ def _deep_merge(base: dict, override: dict) -> dict:
def build_config(args: argparse.Namespace) -> dict:
- """Load YAML config and apply CLI overrides."""
cfg: dict = {}
if args.config and Path(args.config).exists():
@@ -73,7 +64,6 @@ def build_config(args: argparse.Namespace) -> dict:
else:
logger.info("No config file — using defaults")
- # CLI overrides (only non-None values)
overrides: dict = {}
if args.data_dir:
overrides.setdefault("data", {})["dir"] = args.data_dir
@@ -95,51 +85,80 @@ def build_config(args: argparse.Namespace) -> dict:
overrides.setdefault("data", {})["image_size"] = args.image_size
if args.arch:
overrides.setdefault("model", {})["architecture"] = args.arch
+ if args.analysis_type:
+ overrides.setdefault("analysis", {})["type"] = args.analysis_type
cfg = _deep_merge(cfg, overrides)
- # Defaults for any missing keys
+ # Defaults
cfg.setdefault("data", {})
- cfg["data"].setdefault("dir", "data/processed")
- cfg["data"].setdefault("image_size", 256)
- cfg["data"].setdefault("batch_size", 16)
- cfg["data"].setdefault("num_workers", 4)
+ cfg["data"].setdefault("dir", "data/processed")
+ cfg["data"].setdefault("image_size", 256)
+ cfg["data"].setdefault("batch_size", 16)
+ cfg["data"].setdefault("num_workers", 4)
cfg["data"].setdefault("use_weighted_sampler", True)
- cfg["data"].setdefault("pin_memory", True)
+ cfg["data"].setdefault("pin_memory", True)
+
+ cfg.setdefault("analysis", {})
+ cfg["analysis"].setdefault("type", "deforestation")
+
+ # Auto-configure model defaults based on analysis type BEFORE generic defaults.
+ # When the user explicitly passes --analysis-type on CLI we override config-file
+ # model settings so the architecture matches the data.
+ analysis_type = cfg["analysis"]["type"]
+ explicit_analysis = args.analysis_type is not None # came from CLI
+
+ if analysis_type == "flooding":
+ forced_arch, forced_ch, forced_cls, forced_enc = "flood_unet", 3, 3, "efficientnet-b7"
+ elif analysis_type == "ice_melting":
+ forced_arch, forced_ch, forced_cls, forced_enc = "attention_unet", 2, 2, None
+ else: # deforestation
+ forced_arch, forced_ch, forced_cls, forced_enc = "attention_unet", 4, 2, None
cfg.setdefault("model", {})
- cfg["model"].setdefault("architecture", "attention_unet")
- cfg["model"].setdefault("in_channels", 4)
- cfg["model"].setdefault("num_classes", 2)
- cfg["model"].setdefault("bilinear", True)
+ # If user explicitly requested an analysis type, force model compatibility.
+ # Otherwise just use setdefault so config file values are respected.
+ if explicit_analysis:
+ cfg["model"]["architecture"] = forced_arch
+ cfg["model"]["in_channels"] = forced_ch
+ cfg["model"]["num_classes"] = forced_cls
+ if forced_enc:
+ cfg["model"]["encoder"] = forced_enc
+ else:
+ cfg["model"].setdefault("architecture", forced_arch)
+ cfg["model"].setdefault("in_channels", forced_ch)
+ cfg["model"].setdefault("num_classes", forced_cls)
+ if forced_enc:
+ cfg["model"].setdefault("encoder", forced_enc)
+ cfg["model"].setdefault("bilinear", True)
cfg.setdefault("loss", {})
- cfg["loss"].setdefault("type", "combined")
- cfg["loss"].setdefault("focal_weight", 0.5)
- cfg["loss"].setdefault("focal_alpha", 0.25)
- cfg["loss"].setdefault("focal_gamma", 2.0)
- cfg["loss"].setdefault("use_class_weights", True)
+ cfg["loss"].setdefault("type", "combined")
+ cfg["loss"].setdefault("focal_weight", 0.5)
+ cfg["loss"].setdefault("focal_alpha", 0.25)
+ cfg["loss"].setdefault("focal_gamma", 2.0)
+ cfg["loss"].setdefault("use_class_weights", True)
cfg.setdefault("optimizer", {})
- cfg["optimizer"].setdefault("learning_rate", 1e-4)
- cfg["optimizer"].setdefault("weight_decay", 1e-4)
- cfg["optimizer"].setdefault("min_lr", 1e-6)
+ cfg["optimizer"].setdefault("learning_rate", 1e-4)
+ cfg["optimizer"].setdefault("weight_decay", 1e-4)
+ cfg["optimizer"].setdefault("min_lr", 1e-6)
cfg.setdefault("schedule", {})
- cfg["schedule"].setdefault("epochs", 100)
- cfg["schedule"].setdefault("warmup_epochs", 5)
- cfg["schedule"].setdefault("checkpoint_interval", 10)
+ cfg["schedule"].setdefault("epochs", 100)
+ cfg["schedule"].setdefault("warmup_epochs", 5)
+ cfg["schedule"].setdefault("checkpoint_interval", 10)
cfg.setdefault("training", {})
- cfg["training"].setdefault("mixed_precision", True)
- cfg["training"].setdefault("grad_clip", 1.0)
- cfg["training"].setdefault("use_ema", True)
- cfg["training"].setdefault("ema_decay", 0.9999)
+ cfg["training"].setdefault("mixed_precision", True)
+ cfg["training"].setdefault("grad_clip", 1.0)
+ cfg["training"].setdefault("use_ema", True)
+ cfg["training"].setdefault("ema_decay", 0.9999)
cfg["training"].setdefault("early_stopping_patience", 15)
cfg.setdefault("output", {})
- cfg["output"].setdefault("save_dir", "models")
- cfg["output"].setdefault("run_name", "")
+ cfg["output"].setdefault("save_dir", "models")
+ cfg["output"].setdefault("run_name", "")
cfg.setdefault("normalizer_stats", "")
return cfg
@@ -152,12 +171,23 @@ def build_config(args: argparse.Namespace) -> dict:
def build_model(cfg: dict):
"""Instantiate the segmentation model from config."""
from climatevision.models.unet import get_model
+ from climatevision.models.flood_unet import build_flood_model
+
mcfg = cfg["model"]
arch = mcfg["architecture"]
+ analysis_type = cfg["analysis"]["type"]
+
+ # Flood models use smp-based architectures
+ if arch in ("flood_unet", "flood_unet_s2only"):
+ use_sar = mcfg["in_channels"] == 5
+ return build_flood_model(
+ use_sar=use_sar,
+ encoder_name=mcfg.get("encoder", "efficientnet-b7"),
+ )
kwargs = {
"n_channels": mcfg["in_channels"],
- "n_classes": mcfg["num_classes"],
+ "n_classes": mcfg["num_classes"],
}
if arch == "unet":
kwargs["bilinear"] = mcfg.get("bilinear", True)
@@ -226,7 +256,6 @@ def load_normalizer(cfg: dict):
# ---------------------------------------------------------------------------
def maybe_resume(model, optimizer, resume_path: str | None) -> int:
- """Load weights from checkpoint. Returns start epoch (0 if no resume)."""
if not resume_path:
return 0
import torch
@@ -244,29 +273,17 @@ def maybe_resume(model, optimizer, resume_path: str | None) -> int:
# ---------------------------------------------------------------------------
-# Auto-generate data if directory is empty
+# Data validation — fail fast if no real data
# ---------------------------------------------------------------------------
-def maybe_generate_data(data_dir: Path, patch_size: int = 256, n_patches: int = 1000) -> None:
+def validate_data_exists(data_dir: Path) -> None:
+ """Fail hard if training data is missing. No synthetic fallback."""
train_img = data_dir / "train" / "images"
- if train_img.exists() and any(train_img.glob("*.tif")):
- return
-
- logger.warning("No training data found in %s", data_dir)
- logger.info("Auto-generating %d synthetic patches…", n_patches)
-
- cmd = [
- sys.executable, str(PROJECT_ROOT / "scripts" / "prepare_data.py"),
- "--mode", "synthetic",
- "--n-patches", str(n_patches),
- "--patch-size", str(patch_size),
- "--out", str(data_dir),
- "--fit-normalizer",
- ]
- import subprocess
- result = subprocess.run(cmd, check=False)
- if result.returncode != 0:
- logger.error("Data generation failed — check prepare_data.py output")
+ if not train_img.exists() or not any(train_img.glob("*.tif")):
+ logger.error("No training data found in %s", data_dir)
+ logger.error("Prepare real data before training:")
+ logger.error(" python scripts/prepare_data.py --mode synthetic --n-patches 1000 --out %s", data_dir)
+ logger.error(" # OR download real GEE data")
sys.exit(1)
@@ -276,14 +293,13 @@ def maybe_generate_data(data_dir: Path, patch_size: int = 256, n_patches: int =
def main() -> None:
args = parse_args()
- cfg = build_config(args)
+ cfg = build_config(args)
- # Run name / output directory
+ analysis_type = cfg["analysis"]["type"]
run_name = cfg["output"]["run_name"] or datetime.now().strftime("%Y%m%d_%H%M%S")
- save_dir = Path(cfg["output"]["save_dir"]) / run_name
+ save_dir = Path(cfg["output"]["save_dir"]) / f"{analysis_type}_{run_name}"
save_dir.mkdir(parents=True, exist_ok=True)
- # Persist effective config
try:
import yaml
with open(save_dir / "config.yaml", "w") as f:
@@ -291,15 +307,14 @@ def main() -> None:
except ImportError:
import json
with open(save_dir / "config.json", "w") as f:
- import json
json.dump(cfg, f, indent=2)
logger.info("Run: %s → %s", run_name, save_dir)
# Data
- data_dir = Path(cfg["data"]["dir"])
+ data_dir = Path(cfg["data"]["dir"])
image_size = cfg["data"]["image_size"]
- maybe_generate_data(data_dir, patch_size=image_size)
+ validate_data_exists(data_dir)
normalizer = load_normalizer(cfg)
@@ -312,6 +327,7 @@ def main() -> None:
normalizer=normalizer,
pin_memory=cfg["data"]["pin_memory"],
use_weighted_sampler=cfg["data"]["use_weighted_sampler"],
+ analysis_type=analysis_type,
)
if "train" not in loaders:
@@ -326,28 +342,28 @@ def main() -> None:
)
# Class weights
+ num_classes = cfg["model"]["num_classes"]
class_weights = None
if cfg["loss"]["use_class_weights"]:
- class_weights = loaders["train"].dataset.compute_class_weights()
+ class_weights = loaders["train"].dataset.compute_class_weights(num_classes=num_classes)
logger.info("Class weights: %s", class_weights.tolist())
# Model + loss
- model = build_model(cfg)
+ model = build_model(cfg)
criterion = build_criterion(cfg, class_weights=class_weights)
- # Trainer config dict (flat, as Trainer expects)
trainer_cfg = {
- "learning_rate": cfg["optimizer"]["learning_rate"],
- "weight_decay": cfg["optimizer"]["weight_decay"],
- "min_lr": cfg["optimizer"]["min_lr"],
- "epochs": cfg["schedule"]["epochs"],
- "warmup_epochs": cfg["schedule"]["warmup_epochs"],
- "checkpoint_interval": cfg["schedule"]["checkpoint_interval"],
- "mixed_precision": cfg["training"]["mixed_precision"],
- "grad_clip": cfg["training"]["grad_clip"],
- "use_ema": cfg["training"]["use_ema"],
- "ema_decay": cfg["training"]["ema_decay"],
- "early_stopping_patience": cfg["training"]["early_stopping_patience"],
+ "learning_rate": cfg["optimizer"]["learning_rate"],
+ "weight_decay": cfg["optimizer"]["weight_decay"],
+ "min_lr": cfg["optimizer"]["min_lr"],
+ "epochs": cfg["schedule"]["epochs"],
+ "warmup_epochs": cfg["schedule"]["warmup_epochs"],
+ "checkpoint_interval": cfg["schedule"]["checkpoint_interval"],
+ "mixed_precision": cfg["training"]["mixed_precision"],
+ "grad_clip": cfg["training"]["grad_clip"],
+ "use_ema": cfg["training"]["use_ema"],
+ "ema_decay": cfg["training"]["ema_decay"],
+ "early_stopping_patience": cfg["training"]["early_stopping_patience"],
}
from climatevision.training.trainer import Trainer
@@ -359,7 +375,6 @@ def main() -> None:
save_dir=save_dir,
)
- # Optional resume
if args.resume:
maybe_resume(model, trainer.optimizer, args.resume)
@@ -367,8 +382,10 @@ def main() -> None:
history = trainer.fit()
elapsed = time.time() - t_start
- best_iou = max((e.get("iou_forest", 0) for e in history["val"]), default=0)
- best_f1 = max((e.get("f1", 0) for e in history["val"]), default=0)
+ # Generalize best metric reporting
+ metric_key = f"iou_{analysis_type}"
+ best_iou = max((e.get(metric_key, e.get("iou", 0)) for e in history["val"]), default=0)
+ best_f1 = max((e.get("f1", 0) for e in history["val"]), default=0)
logger.info("=" * 60)
logger.info("Training complete in %.1f min", elapsed / 60)
@@ -388,20 +405,22 @@ def main() -> None:
# ---------------------------------------------------------------------------
def parse_args() -> argparse.Namespace:
- p = argparse.ArgumentParser(description="Train ClimateVision forest segmentation model")
- p.add_argument("--config", default=str(PROJECT_ROOT / "config" / "train.yaml"),
+ p = argparse.ArgumentParser(description="Train ClimateVision segmentation model")
+ p.add_argument("--config", default=str(PROJECT_ROOT / "config" / "train.yaml"),
help="Path to YAML config file")
- p.add_argument("--data-dir", default=None, help="Override data.dir")
- p.add_argument("--epochs", type=int, default=None, help="Override schedule.epochs")
+ p.add_argument("--analysis-type", default=None,
+ help="Analysis type: deforestation, flooding, ice_melting")
+ p.add_argument("--data-dir", default=None, help="Override data.dir")
+ p.add_argument("--epochs", type=int, default=None, help="Override schedule.epochs")
p.add_argument("--batch-size", type=int, default=None, help="Override data.batch_size")
- p.add_argument("--lr", type=float, default=None, help="Override optimizer.learning_rate")
- p.add_argument("--save-dir", default=None, help="Override output.save_dir")
- p.add_argument("--run-name", default=None, help="Override output.run_name")
- p.add_argument("--resume", default=None, help="Path to checkpoint to resume from")
- p.add_argument("--arch", choices=["unet", "attention_unet"], default=None)
- p.add_argument("--no-amp", action="store_true", help="Disable mixed-precision (AMP)")
- p.add_argument("--num-workers", type=int, default=None, help="DataLoader worker count (0=main process)")
- p.add_argument("--image-size", type=int, default=None, help="Spatial crop size in pixels")
+ p.add_argument("--lr", type=float, default=None, help="Override optimizer.learning_rate")
+ p.add_argument("--save-dir", default=None, help="Override output.save_dir")
+ p.add_argument("--run-name", default=None, help="Override output.run_name")
+ p.add_argument("--resume", default=None, help="Path to checkpoint to resume from")
+ p.add_argument("--arch", choices=["unet", "attention_unet", "flood_unet", "flood_unet_s2only"], default=None)
+ p.add_argument("--no-amp", action="store_true", help="Disable mixed-precision (AMP)")
+ p.add_argument("--num-workers", type=int, default=None, help="DataLoader worker count (0=main process)")
+ p.add_argument("--image-size", type=int, default=None, help="Spatial crop size in pixels")
return p.parse_args()
diff --git a/src/climatevision/api/auth.py b/src/climatevision/api/auth.py
index 85a8ad7..1da5034 100644
--- a/src/climatevision/api/auth.py
+++ b/src/climatevision/api/auth.py
@@ -68,6 +68,9 @@ def validate_key(self, api_key: str) -> Optional[dict]:
"""
Validate an API key and return organization context.
+ Queries the SQLite database for an organization matching the
+ SHA-256 hash of the provided API key.
+
Args:
api_key: The API key to validate
@@ -85,15 +88,36 @@ def validate_key(self, api_key: str) -> Optional[dict]:
"demo": True,
}
- # Check cache first
key_hash = self.hash_key(api_key)
+
+ # Check cache first
if key_hash in self._key_cache:
cached = self._key_cache[key_hash]
if cached.get("expires_at", datetime.max) > datetime.utcnow():
return cached.get("org")
- # Would query database in production
- # For now, return None to indicate key not found
+ # Query database
+ try:
+ from climatevision.db import get_connection
+ with get_connection() as conn:
+ row = conn.execute(
+ "SELECT id, name, type, api_key FROM organizations WHERE api_key = ?",
+ (key_hash,),
+ ).fetchone()
+ if row:
+ org = {
+ "id": row["id"],
+ "name": row["name"],
+ "type": row["type"],
+ }
+ self._key_cache[key_hash] = {
+ "org": org,
+ "expires_at": datetime.max,
+ }
+ return org
+ except Exception as exc:
+ logger.warning("Database query failed during API key validation: %s", exc)
+
return None
def revoke_key(self, api_key: str) -> bool:
diff --git a/src/climatevision/api/main.py b/src/climatevision/api/main.py
index 729b213..d4a41f2 100644
--- a/src/climatevision/api/main.py
+++ b/src/climatevision/api/main.py
@@ -14,6 +14,8 @@
import json
import logging
import time
+
+import numpy as np
from datetime import datetime, timezone
from pathlib import Path
from typing import Any, Optional, Literal
@@ -43,7 +45,10 @@
mark_alert_delivered,
)
from climatevision.inference import run_inference_from_file, run_inference_from_gee
+from climatevision.governance import explain_prediction, SHAPExplainer
from climatevision.api.auth import require_api_key
+from climatevision.inference.alert_generator import AlertGenerator
+from climatevision.analysis.flooding import FloodingAnalysis
logger = logging.getLogger(__name__)
@@ -242,59 +247,16 @@ def _load_template_result(
end_date: Optional[str],
analysis_type: str = "deforestation",
) -> dict[str, Any]:
- """Load or create a template result for failed inference."""
- outputs_dir = Path(__file__).resolve().parents[3] / "outputs"
- template_path = outputs_dir / "inference_results.json"
-
- if template_path.exists():
- template: dict[str, Any] = json.loads(template_path.read_text(encoding="utf-8"))
- else:
- # Create analysis-specific template
- if analysis_type == "ice_melting":
- template = {
- "region": {"bbox": bbox or None},
- "inference": {
- "image_size": [256, 256],
- "ice_pixels": 0,
- "water_pixels": 0,
- "land_pixels": 0,
- "ice_percentage": 0.0,
- "mean_confidence": 0.0,
- },
- }
- elif analysis_type == "flooding":
- template = {
- "region": {"bbox": bbox or None},
- "inference": {
- "image_size": [256, 256],
- "flooded_pixels": 0,
- "dry_pixels": 0,
- "water_pixels": 0,
- "flooded_percentage": 0.0,
- "mean_confidence": 0.0,
- },
- }
- else: # deforestation (default)
- template = {
- "region": {"bbox": bbox or None},
- "ndvi_stats": {"NDVI_min": 0.0, "NDVI_mean": 0.0, "NDVI_max": 0.0},
- "inference": {
- "image_size": [256, 256],
- "forest_pixels": 0,
- "non_forest_pixels": 0,
- "forest_percentage": 0.0,
- "mean_confidence": 0.0,
- },
- }
-
- if bbox is not None:
- template.setdefault("region", {})["bbox"] = bbox
- if start_date and end_date:
- template.setdefault("region", {})["date_range"] = f"{start_date} to {end_date}"
-
- template["analysis_type"] = analysis_type
-
- return template
+ """
+ DEPRECATED: Template results are no longer used in production.
+
+ This function is retained only for backward compatibility with tests.
+ All production failures now raise HTTPException instead of returning stubs.
+ """
+ raise HTTPException(
+ status_code=503,
+ detail=f"Inference failed for {analysis_type}. No synthetic fallback available in production.",
+ )
async def _persist_upload(*, run_id: int, file: UploadFile) -> str:
@@ -364,6 +326,8 @@ def create_app() -> FastAPI:
)
app.add_middleware(AuditLogMiddleware)
+ from climatevision.security.api_security import SecurityMiddleware
+ app.add_middleware(SecurityMiddleware)
app.add_middleware(
CORSMiddleware,
allow_origins=[
@@ -572,8 +536,8 @@ async def predict_json(
with get_connection() as conn:
cur = conn.execute(
"""
- INSERT INTO runs (kind, status, analysis_type, bbox, start_date, end_date, created_at, updated_at)
- VALUES (?, ?, ?, ?, ?, ?, ?, ?)
+ INSERT INTO runs (kind, status, analysis_type, bbox, start_date, end_date, created_at, updated_at, organization_id)
+ VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
""",
(
body.kind,
@@ -584,6 +548,7 @@ async def predict_json(
body.end_date,
created_at,
created_at,
+ org["id"],
),
)
run_id = int(cur.lastrowid)
@@ -600,14 +565,36 @@ async def predict_json(
status = "completed"
except Exception as exc:
logger.exception("Inference failed for run %s", run_id)
- result_payload = _load_template_result(
- bbox=body.bbox,
- start_date=body.start_date,
- end_date=body.end_date,
- analysis_type=body.analysis_type,
- )
- result_payload["error"] = str(exc)
- status = "failed"
+ raise HTTPException(
+ status_code=503,
+ detail=f"Inference failed: {exc}",
+ ) from exc
+
+ # Generate alerts for flooding
+ if body.analysis_type == "flooding":
+ try:
+ analysis = FloodingAnalysis()
+ metrics = analysis.calculate_metrics(
+ prediction=np.array([]), # Will be populated from result
+ image_size=(256, 256),
+ bbox=body.bbox,
+ )
+ # TODO: populate metrics from actual prediction mask
+ # For now, use the percentage from inference result
+ flooded_pct = result_payload.get("inference", {}).get("flooded_percentage", 0)
+ metrics["flooded_percentage"] = flooded_pct
+ alerts = analysis.generate_alerts(metrics)
+ for alert in alerts:
+ create_organization_alert(
+ org_id=org["id"],
+ alert_type=alert.alert_type,
+ severity=alert.severity.value,
+ title=alert.title,
+ message=alert.message,
+ details=alert.details,
+ )
+ except Exception as exc:
+ logger.warning("Alert generation failed for run %s: %s", run_id, exc)
# Persist result
result_created_at = _utc_now_iso()
@@ -629,12 +616,12 @@ async def predict_json(
@app.post("/api/predict/upload")
async def predict_upload(
kind: str = Form(default="upload"),
- org: dict[str, Any] = Depends(require_api_key),
analysis_type: str = Form(default="deforestation"),
bbox: str | None = Form(default=None),
start_date: str | None = Form(default=None),
end_date: str | None = Form(default=None),
file: UploadFile = File(...),
+ org: dict[str, Any] = Depends(require_api_key),
) -> dict[str, Any]:
"""Run prediction on uploaded satellite imagery file."""
if start_date and end_date and start_date > end_date:
@@ -652,8 +639,8 @@ async def predict_upload(
with get_connection() as conn:
cur = conn.execute(
"""
- INSERT INTO runs (kind, status, analysis_type, bbox, start_date, end_date, created_at, updated_at)
- VALUES (?, ?, ?, ?, ?, ?, ?, ?)
+ INSERT INTO runs (kind, status, analysis_type, bbox, start_date, end_date, created_at, updated_at, organization_id)
+ VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
""",
(
kind,
@@ -664,6 +651,7 @@ async def predict_upload(
end_date,
created_at,
created_at,
+ org["id"],
),
)
run_id = int(cur.lastrowid)
@@ -683,15 +671,10 @@ async def predict_upload(
status = "completed"
except Exception as exc:
logger.exception("Inference failed for upload run %s", run_id)
- result_payload = _load_template_result(
- bbox=parsed_bbox,
- start_date=start_date,
- end_date=end_date,
- analysis_type=analysis_type,
- )
- result_payload.setdefault("input", {})["file"] = dest
- result_payload["error"] = str(exc)
- status = "failed"
+ raise HTTPException(
+ status_code=503,
+ detail=f"Inference failed: {exc}",
+ ) from exc
# Persist result
result_created_at = _utc_now_iso()
diff --git a/src/climatevision/data/__init__.py b/src/climatevision/data/__init__.py
index 232f42d..320410f 100644
--- a/src/climatevision/data/__init__.py
+++ b/src/climatevision/data/__init__.py
@@ -1,7 +1,7 @@
-from .dataset import ForestDataset, create_dataloaders
+from .dataset import SatelliteDataset, ForestDataset, FloodDataset, create_dataloaders
from .augmentation import get_train_transforms, get_val_transforms
from .preprocessing import Sentinel2Normalizer, compute_dataset_stats, apply_scl_cloud_mask
-from .synthetic import generate_synthetic_dataset
+from .synthetic import generate_synthetic_dataset, generate_flood_test_patches
from .gee_downloader import download_tile_for_analysis
from .band_mapping import (
get_bands_for_analysis,
@@ -27,7 +27,9 @@
__all__ = [
# Dataset
+ "SatelliteDataset",
"ForestDataset",
+ "FloodDataset",
"create_dataloaders",
# Augmentation
"get_train_transforms",
@@ -38,6 +40,7 @@
"apply_scl_cloud_mask",
# Synthetic
"generate_synthetic_dataset",
+ "generate_flood_test_patches",
# GEE
"download_tile_for_analysis",
# Band mapping
diff --git a/src/climatevision/data/band_mapping.py b/src/climatevision/data/band_mapping.py
index 9f9d73b..2d55db1 100644
--- a/src/climatevision/data/band_mapping.py
+++ b/src/climatevision/data/band_mapping.py
@@ -1,5 +1,5 @@
"""
-Analysis-specific Sentinel-2 band mapping utilities.
+Analysis-specific satellite band mapping utilities.
Provides a single source of truth for which spectral bands each
climate analysis type requires, derived from config.yaml.
@@ -22,6 +22,9 @@
"B8A", "B09", "B10", "B11", "B12",
]
+# Sentinel-1 SAR bands
+SENTINEL1_BAND_ORDER = ["VV", "VH"]
+
# Scene Classification Layer (SCL) is not part of the 13 reflectance bands
# but is essential for cloud masking.
SCL_BAND = "SCL"
@@ -36,7 +39,7 @@ def _load_config() -> dict[str, Any]:
def get_bands_for_analysis(analysis_type: str) -> list[str]:
"""
- Return the Sentinel-2 band names required for *analysis_type*.
+ Return the satellite band names required for *analysis_type*.
The bands are read from ``config.yaml`` and are guaranteed to be
returned in the same order they are declared there.
@@ -69,9 +72,6 @@ def get_band_indices(band_names: list[str]) -> list[int]:
indices = []
for b in band_names:
if b == SCL_BAND:
- # SCL does not belong to the 13 reflectance bands;
- # callers that need an index in a multi-band array should
- # append it separately and compute len(reflectance_bands).
raise ValueError(
f"SCL is not part of the 13-band reflectance stack. "
f"Append it manually after resolving reflectance indices."
@@ -82,6 +82,16 @@ def get_band_indices(band_names: list[str]) -> list[int]:
return indices
+def get_sar_band_indices(band_names: list[str]) -> list[int]:
+ """Map Sentinel-1 band names to zero-based indices."""
+ indices = []
+ for b in band_names:
+ if b not in SENTINEL1_BAND_ORDER:
+ raise ValueError(f"Unknown Sentinel-1 band: {b}")
+ indices.append(SENTINEL1_BAND_ORDER.index(b))
+ return indices
+
+
def is_analysis_enabled(analysis_type: str) -> bool:
"""Return True if the analysis type is enabled in config.yaml."""
cfg = _load_config()
@@ -109,3 +119,9 @@ def get_model_config(analysis_type: str) -> dict[str, Any]:
cfg = _load_config()
analysis_cfg = cfg.get("analysis_types", {}).get(analysis_type, {})
return dict(analysis_cfg.get("model", {}))
+
+
+def get_analysis_config(analysis_type: str) -> dict[str, Any]:
+ """Return the full analysis type configuration dict."""
+ cfg = _load_config()
+ return dict(cfg.get("analysis_types", {}).get(analysis_type, {}))
diff --git a/src/climatevision/data/gee_downloader.py b/src/climatevision/data/gee_downloader.py
index fa65f0b..d6c45ae 100644
--- a/src/climatevision/data/gee_downloader.py
+++ b/src/climatevision/data/gee_downloader.py
@@ -1,9 +1,8 @@
"""
Google Earth Engine tile downloader for ClimateVision.
-Provides analysis-aware Sentinel-2 tile downloads with a synthetic fallback
-when GEE credentials are unavailable. Downloaded tiles are saved as GeoTIFF
-and include a metadata dict that labels synthetic scenes explicitly.
+Provides analysis-aware Sentinel-2 and Sentinel-1 tile downloads.
+GEE failures raise explicit exceptions — NO synthetic fallbacks in production.
"""
from __future__ import annotations
@@ -12,7 +11,7 @@
import tempfile
import urllib.request
from pathlib import Path
-from typing import Any, Optional
+from typing import Any
import numpy as np
@@ -41,6 +40,22 @@
}
+# ---------------------------------------------------------------------------
+# Exceptions
+# ---------------------------------------------------------------------------
+
+class GEEAuthenticationError(RuntimeError):
+ """Raised when GEE credentials are missing or invalid."""
+
+
+class TileNotFoundError(RuntimeError):
+ """Raised when no satellite images are found for the given criteria."""
+
+
+# ---------------------------------------------------------------------------
+# GEE initialisation
+# ---------------------------------------------------------------------------
+
def _initialize_ee() -> Any:
"""Lazy import and initialise Google Earth Engine."""
import ee # noqa
@@ -62,16 +77,9 @@ def _initialize_ee() -> Any:
return ee
-def _get_default_tile_size() -> int:
- """Read the default tile size from config.yaml."""
- import yaml
-
- config_path = _PROJECT_ROOT / "config.yaml"
- with open(config_path, "r") as f:
- cfg = yaml.safe_load(f)
- image_size = cfg.get("data", {}).get("image_size", [256, 256])
- return int(image_size[0])
-
+# ---------------------------------------------------------------------------
+# Sentinel-2 download
+# ---------------------------------------------------------------------------
def download_tile_for_analysis(
bbox: list[float],
@@ -95,8 +103,11 @@ def download_tile_for_analysis(
include_scl: Whether to append the SCL band for cloud masking.
Returns:
- (file_path, metadata_dict). If GEE is unavailable, the synthetic
- fallback is used and ``metadata["is_synthetic"]`` is ``True``.
+ (file_path, metadata_dict).
+
+ Raises:
+ GEEAuthenticationError: If GEE cannot be initialised.
+ TileNotFoundError: If no images are found for the date range.
"""
if output_dir is None:
output_dir = _SATELLITE_DIR
@@ -112,14 +123,10 @@ def download_tile_for_analysis(
ee = _initialize_ee()
rasterio = __import__("rasterio")
except Exception as exc:
- logger.warning("GEE unavailable (%s). Using synthetic fallback.", exc)
- return _generate_synthetic_tile(
- bbox=bbox,
- start_date=start_date,
- end_date=end_date,
- analysis_type=analysis_type,
- out_path=out_path,
- )
+ raise GEEAuthenticationError(
+ f"Google Earth Engine unavailable: {exc}. "
+ f"Set GEE_PROJECT_ID or GEE_SERVICE_ACCOUNT + GEE_SERVICE_ACCOUNT_KEY."
+ ) from exc
bands = get_bands_for_analysis(analysis_type)
gee_bands = [_BAND_NAME_TO_GEE[b] for b in bands]
@@ -137,16 +144,9 @@ def download_tile_for_analysis(
count = collection.size().getInfo()
if count == 0:
- logger.warning(
- "No GEE images found for %s %s to %s. Using synthetic fallback.",
- analysis_type, start_date, end_date,
- )
- return _generate_synthetic_tile(
- bbox=bbox,
- start_date=start_date,
- end_date=end_date,
- analysis_type=analysis_type,
- out_path=out_path,
+ raise TileNotFoundError(
+ f"No Sentinel-2 images found for {analysis_type} "
+ f"from {start_date} to {end_date} in bbox {bbox}."
)
image = collection.median().clip(region)
@@ -166,8 +166,6 @@ def download_tile_for_analysis(
os.unlink(tmp)
- # Re-order bands to match project convention if needed
- # (GEE returns in selection order)
profile.update(
driver="GTiff",
dtype="float32",
@@ -179,6 +177,7 @@ def download_tile_for_analysis(
metadata: dict[str, Any] = {
"source": "gee",
+ "satellite": "sentinel2",
"analysis_type": analysis_type,
"bbox": bbox,
"start_date": start_date,
@@ -186,75 +185,123 @@ def download_tile_for_analysis(
"bands": bands,
"scale_m": scale_m,
"images_available": count,
- "is_synthetic": False,
"shape": list(data.shape),
}
- logger.info("Downloaded real tile to %s (%d images available)", out_path, count)
+ logger.info("Downloaded S2 tile to %s (%d images available)", out_path, count)
return out_path, metadata
-def _generate_synthetic_tile(
+# ---------------------------------------------------------------------------
+# Sentinel-1 SAR download
+# ---------------------------------------------------------------------------
+
+def download_s1_tile(
bbox: list[float],
start_date: str,
end_date: str,
- analysis_type: str,
- out_path: Path,
+ polarization: list[str] | None = None,
+ output_dir: str | Path | None = None,
+ scale_m: int = 100,
+ orbit: str = "DESCENDING",
) -> tuple[Path, dict[str, Any]]:
"""
- Generate a physically plausible synthetic Sentinel-2 tile when GEE fails.
- The output is explicitly tagged ``is_synthetic: True``.
+ Download a median Sentinel-1 SAR composite for the given bbox and date range.
+
+ Args:
+ bbox: [west, south, east, north] in WGS84.
+ start_date: Start date (YYYY-MM-DD).
+ end_date: End date (YYYY-MM-DD).
+ polarization: List of polarizations, e.g., ["VV", "VH"]. Defaults to both.
+ output_dir: Where to save the GeoTIFF.
+ scale_m: GEE export resolution in metres.
+ orbit: "ASCENDING", "DESCENDING", or "BOTH".
+
+ Returns:
+ (file_path, metadata_dict).
+
+ Raises:
+ GEEAuthenticationError: If GEE cannot be initialised.
+ TileNotFoundError: If no SAR images are found.
"""
- rasterio = __import__("rasterio")
+ if output_dir is None:
+ output_dir = _SATELLITE_DIR
+ output_dir = Path(output_dir)
+ output_dir.mkdir(parents=True, exist_ok=True)
- bands = get_bands_for_analysis(analysis_type)
- n_bands = len(bands)
- tile_size = _get_default_tile_size()
- h, w = tile_size, tile_size
-
- # Seed RNG from bbox so the same region is deterministic
- seed = int(abs(sum(v * 1000 * (i + 1) for i, v in enumerate(bbox)))) % (2 ** 31)
- rng = np.random.default_rng(seed)
-
- # Build a synthetic stack: draw reflectance values typical for mixed forest
- data = np.zeros((n_bands, h, w), dtype=np.float32)
- for b in range(n_bands):
- mean = rng.uniform(500.0, 3000.0)
- std = rng.uniform(200.0, 800.0)
- data[b] = rng.normal(mean, std, (h, w)).clip(0.0, 10000.0)
-
- # Append an SCL band (all clear = 4)
- scl = np.full((1, h, w), 4, dtype=np.float32)
- data = np.concatenate([data, scl], axis=0)
-
- transform = rasterio.transform.from_bounds(
- bbox[0], bbox[1], bbox[2], bbox[3], w, h
+ if polarization is None:
+ polarization = ["VV", "VH"]
+
+ safe_start = start_date.replace("-", "")
+ safe_end = end_date.replace("-", "")
+ stem = f"s1_{safe_start}_{safe_end}_{'_'.join(str(round(c, 4)) for c in bbox)}"
+ out_path = output_dir / f"{stem}.tif"
+
+ try:
+ ee = _initialize_ee()
+ rasterio = __import__("rasterio")
+ except Exception as exc:
+ raise GEEAuthenticationError(
+ f"Google Earth Engine unavailable: {exc}."
+ ) from exc
+
+ region = ee.Geometry.Rectangle(bbox)
+ collection = (
+ ee.ImageCollection("COPERNICUS/S1_GRD")
+ .filterBounds(region)
+ .filterDate(start_date, end_date)
+ .filter(ee.Filter.listContains("transmitterReceiverPolarisation", "VV"))
+ .select(polarization)
+ )
+
+ if orbit != "BOTH":
+ collection = collection.filter(ee.Filter.eq("orbitProperties_pass", orbit))
+
+ count = collection.size().getInfo()
+ if count == 0:
+ raise TileNotFoundError(
+ f"No Sentinel-1 images found from {start_date} to {end_date} in bbox {bbox}."
+ )
+
+ # Convert to linear intensity from dB scale
+ image = collection.median().clip(region)
+ image = image.pow(2) # GEE S1 is in dB; convert to linear intensity for ML
+
+ url = image.getDownloadURL({
+ "region": region,
+ "scale": scale_m,
+ "format": "GEO_TIFF",
+ })
+
+ tmp = tempfile.mktemp(suffix=".tif")
+ urllib.request.urlretrieve(url, tmp)
+
+ with rasterio.open(tmp) as src:
+ data = src.read().astype(np.float32)
+ profile = src.profile
+
+ os.unlink(tmp)
+
+ profile.update(
+ driver="GTiff",
+ dtype="float32",
+ count=data.shape[0],
)
- profile = {
- "driver": "GTiff",
- "dtype": "float32",
- "count": data.shape[0],
- "height": h,
- "width": w,
- "crs": "EPSG:4326",
- "transform": transform,
- }
with rasterio.open(out_path, "w", **profile) as dst:
dst.write(data)
metadata: dict[str, Any] = {
- "source": "synthetic_fallback",
- "analysis_type": analysis_type,
+ "source": "gee",
+ "satellite": "sentinel1",
"bbox": bbox,
"start_date": start_date,
"end_date": end_date,
- "bands": bands,
- "scale_m": 100,
- "images_available": 0,
- "is_synthetic": True,
+ "polarization": polarization,
+ "scale_m": scale_m,
+ "images_available": count,
"shape": list(data.shape),
}
- logger.info("Generated synthetic fallback tile to %s", out_path)
+ logger.info("Downloaded S1 tile to %s (%d images available)", out_path, count)
return out_path, metadata
diff --git a/src/climatevision/data/preprocessing.py b/src/climatevision/data/preprocessing.py
index fd62b17..6d95cca 100644
--- a/src/climatevision/data/preprocessing.py
+++ b/src/climatevision/data/preprocessing.py
@@ -2,14 +2,7 @@
Sentinel-2 band normalization and preprocessing.
Sentinel-2 L2A surface reflectance is stored as uint16 in range [0, 10000].
-We normalize each band to float32 using robust per-channel statistics derived
-from a large sample of Amazon/Congo forest and non-forest pixels.
-
-Reference band order expected throughout this project:
- index 0 → B04 Red (~665 nm)
- index 1 → B03 Green (~560 nm)
- index 2 → B02 Blue (~490 nm)
- index 3 → B08 NIR (~842 nm)
+Normalizers are band-agnostic and adapt to any number of input channels.
"""
from __future__ import annotations
@@ -23,97 +16,121 @@
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
-# Sentinel-2 L2A statistics computed from 50 k Amazon/Congo patches
-# Values are surface reflectance ×10000, band order [R, G, B, NIR]
+# Default Sentinel-2 statistics — used as fallback when no dataset stats exist
+# Band order MUST match the actual input bands (not hardcoded to RGB+NIR)
# ---------------------------------------------------------------------------
-_S2_MEAN = np.array([943.0, 1069.0, 981.0, 2734.0], dtype=np.float32)
-_S2_STD = np.array([590.0, 547.0, 498.0, 1246.0], dtype=np.float32)
-# Robust (2nd–98th percentile) clip bounds to suppress sensor artefacts
-_S2_P2 = np.array([ 0.0, 10.0, 0.0, 100.0], dtype=np.float32)
-_S2_P98 = np.array([2500.0, 2500.0, 2200.0, 8000.0], dtype=np.float32)
+def _default_stats(num_bands: int) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
+ """Return sensible default mean/std/p2/p98 for `num_bands`."""
+ mean = np.full(num_bands, 1500.0, dtype=np.float32)
+ std = np.full(num_bands, 800.0, dtype=np.float32)
+ p2 = np.full(num_bands, 0.0, dtype=np.float32)
+ p98 = np.full(num_bands, 4000.0, dtype=np.float32)
+ return mean, std, p2, p98
+
+
+# ---------------------------------------------------------------------------
+# Sentinel-2 Normalizer (band-agnostic)
+# ---------------------------------------------------------------------------
class Sentinel2Normalizer:
"""
- Normalize a 4-band Sentinel-2 image to zero-mean / unit-variance float32.
+ Normalize a multi-band Sentinel-2 image to zero-mean / unit-variance float32.
Two modes:
- - 'standard': use pre-computed global statistics (default, fast).
+ - 'standard': use pre-computed or default global statistics (fast).
- 'dataset': use statistics supplied via `fit()` (accurate per dataset).
"""
- def __init__(self, mode: str = "standard"):
+ def __init__(self, mode: str = "standard", num_bands: int = 4):
assert mode in ("standard", "dataset")
self.mode = mode
- self.mean: np.ndarray = _S2_MEAN.copy()
- self.std: np.ndarray = _S2_STD.copy()
- self.p2: np.ndarray = _S2_P2.copy()
- self.p98: np.ndarray = _S2_P98.copy()
+ self.num_bands = num_bands
+ self.mean, self.std, self.p2, self.p98 = _default_stats(num_bands)
self._fitted = (mode == "standard")
- # ------------------------------------------------------------------
def fit(self, images: list[np.ndarray]) -> "Sentinel2Normalizer":
- """Compute statistics from a list of (4, H, W) arrays."""
+ """Compute statistics from a list of (C, H, W) arrays."""
+ if not images:
+ raise ValueError("Cannot fit on empty image list")
+
all_pixels: list[np.ndarray] = []
for img in images:
c, h, w = img.shape
all_pixels.append(img.reshape(c, -1))
- stacked = np.concatenate(all_pixels, axis=1) # (4, N)
+ stacked = np.concatenate(all_pixels, axis=1) # (C, N)
+ self.num_bands = stacked.shape[0]
self.mean = stacked.mean(axis=1).astype(np.float32)
- self.std = stacked.std(axis=1).astype(np.float32) + 1e-6
- self.p2 = np.percentile(stacked, 2, axis=1).astype(np.float32)
- self.p98 = np.percentile(stacked, 98, axis=1).astype(np.float32)
+ self.std = stacked.std(axis=1).astype(np.float32) + 1e-6
+ self.p2 = np.percentile(stacked, 2, axis=1).astype(np.float32)
+ self.p98 = np.percentile(stacked, 98, axis=1).astype(np.float32)
self._fitted = True
return self
- # ------------------------------------------------------------------
def __call__(self, image: np.ndarray) -> np.ndarray:
"""
- Normalize a (4, H, W) uint16 or float32 array to float32.
+ Normalize a (C, H, W) uint16 or float32 array to float32.
Returns values roughly in [-3, 3].
"""
if not self._fitted:
raise RuntimeError("Call fit() before normalizing in 'dataset' mode.")
img = image.astype(np.float32)
+ c = img.shape[0]
+
+ # Resize stats if band count mismatch (e.g., 3-band vs 5-band)
+ if c != self.num_bands:
+ if c > self.num_bands:
+ # Pad stats with defaults for extra bands
+ extra = c - self.num_bands
+ self.mean = np.concatenate([self.mean, np.full(extra, 1500.0, dtype=np.float32)])
+ self.std = np.concatenate([self.std, np.full(extra, 800.0, dtype=np.float32)])
+ self.p2 = np.concatenate([self.p2, np.full(extra, 0.0, dtype=np.float32)])
+ self.p98 = np.concatenate([self.p98, np.full(extra, 4000.0, dtype=np.float32)])
+ else:
+ self.mean = self.mean[:c]
+ self.std = self.std[:c]
+ self.p2 = self.p2[:c]
+ self.p98 = self.p98[:c]
+ self.num_bands = c
# 1. Clip outliers band-wise
- for b in range(min(4, img.shape[0])):
+ for b in range(c):
img[b] = np.clip(img[b], self.p2[b], self.p98[b])
# 2. Standardize
- for b in range(min(4, img.shape[0])):
+ for b in range(c):
img[b] = (img[b] - self.mean[b]) / self.std[b]
return img
- # ------------------------------------------------------------------
def save(self, path: str | Path) -> None:
data = {
"mean": self.mean.tolist(),
- "std": self.std.tolist(),
- "p2": self.p2.tolist(),
- "p98": self.p98.tolist(),
+ "std": self.std.tolist(),
+ "p2": self.p2.tolist(),
+ "p98": self.p98.tolist(),
"mode": self.mode,
+ "num_bands": self.num_bands,
}
Path(path).write_text(json.dumps(data, indent=2))
@classmethod
def load(cls, path: str | Path) -> "Sentinel2Normalizer":
data = json.loads(Path(path).read_text())
- obj = cls(mode=data["mode"])
+ obj = cls(mode=data["mode"], num_bands=data.get("num_bands", 4))
obj.mean = np.array(data["mean"], dtype=np.float32)
- obj.std = np.array(data["std"], dtype=np.float32)
- obj.p2 = np.array(data["p2"], dtype=np.float32)
- obj.p98 = np.array(data["p98"], dtype=np.float32)
+ obj.std = np.array(data["std"], dtype=np.float32)
+ obj.p2 = np.array(data["p2"], dtype=np.float32)
+ obj.p98 = np.array(data["p98"], dtype=np.float32)
obj._fitted = True
return obj
# ---------------------------------------------------------------------------
-# Dataset statistics helper
+# Cloud masking
# ---------------------------------------------------------------------------
def apply_scl_cloud_mask(
@@ -128,15 +145,15 @@ def apply_scl_cloud_mask(
Args:
image: Array of shape (C, H, W).
scl_band: Array of shape (H, W) containing Scene Classification Layer values.
- clear_labels: SCL codes considered clear. Defaults to vegetation, bare soil,
- water, and snow (``[4, 5, 6, 11]``).
+ clear_labels: SCL codes considered clear. If None, uses a safe default
+ of [4, 5, 6] (vegetation, bare soil, water).
fill_value: Value to replace cloudy pixels with.
Returns:
Cloud-masked image with the same shape as *image*.
"""
if clear_labels is None:
- clear_labels = [4, 5, 6, 11]
+ clear_labels = [4, 5, 6]
if image.ndim != 3:
raise ValueError(f"image must be 3-D (C, H, W), got shape {image.shape}")
@@ -152,6 +169,47 @@ def apply_scl_cloud_mask(
return masked
+# ---------------------------------------------------------------------------
+# Band resampling
+# ---------------------------------------------------------------------------
+
+def resample_20m_to_10m(
+ band_20m: np.ndarray,
+ target_shape: tuple[int, int],
+ order: int = 1,
+) -> np.ndarray:
+ """
+ Resample a 20m band to 10m resolution using bilinear interpolation.
+
+ Args:
+ band_20m: Array of shape (H, W) or (C, H, W).
+ target_shape: Desired (H, W) at 10m resolution (2× the 20m dimensions).
+ order: 0=nearest, 1=bilinear, 3=bicubic. Default 1 (bilinear).
+
+ Returns:
+ Resampled array with the same number of dimensions as input.
+ """
+ try:
+ from scipy.ndimage import zoom
+ except ImportError:
+ raise ImportError("scipy is required for band resampling. Install: pip install scipy")
+
+ if band_20m.ndim == 2:
+ h, w = band_20m.shape
+ zoom_factors = (target_shape[0] / h, target_shape[1] / w)
+ return zoom(band_20m, zoom_factors, order=order, mode="reflect")
+ elif band_20m.ndim == 3:
+ c, h, w = band_20m.shape
+ zoom_factors = (1.0, target_shape[0] / h, target_shape[1] / w)
+ return zoom(band_20m, zoom_factors, order=order, mode="reflect")
+ else:
+ raise ValueError(f"band_20m must be 2-D or 3-D, got shape {band_20m.shape}")
+
+
+# ---------------------------------------------------------------------------
+# Dataset statistics helper
+# ---------------------------------------------------------------------------
+
def compute_dataset_stats(
image_dir: str | Path,
max_samples: int = 500,
@@ -176,7 +234,7 @@ def compute_dataset_stats(
stacked = np.concatenate(all_pixels, axis=1).astype(np.float32) # (C, N)
return {
"mean": stacked.mean(axis=1).tolist(),
- "std": stacked.std(axis=1).tolist(),
- "min": stacked.min(axis=1).tolist(),
- "max": stacked.max(axis=1).tolist(),
+ "std": stacked.std(axis=1).tolist(),
+ "min": stacked.min(axis=1).tolist(),
+ "max": stacked.max(axis=1).tolist(),
}
diff --git a/src/climatevision/inference/pipeline.py b/src/climatevision/inference/pipeline.py
index 7af17ab..06cb3d0 100644
--- a/src/climatevision/inference/pipeline.py
+++ b/src/climatevision/inference/pipeline.py
@@ -2,11 +2,13 @@
Inference pipeline for ClimateVision.
Provides:
-- run_inference(image_array, bbox, start_date, end_date, analysis_type) — core inference on a numpy array
-- run_inference_from_file(path, bbox, start_date, end_date, analysis_type) — load file then infer
-- run_inference_from_gee(bbox, start_date, end_date, analysis_type) — GEE NDVI + real tile inference
-"""
+- run_inference(image_array, bbox, start_date, end_date, analysis_type) — core inference
+- run_inference_from_file(path, ...) — load file then infer
+- run_inference_from_gee(bbox, ...) — GEE query + real tile inference
+- run_bitemporal_inference(pre, post, ...) — change detection for floods
+NO synthetic fallbacks in production. All failures raise explicit exceptions.
+"""
from __future__ import annotations
import json
@@ -17,22 +19,16 @@
import numpy as np
import torch
-from climatevision.data.band_mapping import get_bands_for_analysis, get_model_config
+from climatevision.data.band_mapping import get_bands_for_analysis, get_model_config, get_analysis_config
from climatevision.models.unet import UNet
logger = logging.getLogger(__name__)
-# ---------------------------------------------------------------------------
-# Project paths (mirrors run_training.py conventions, NOT Config.MODELS_DIR)
-# ---------------------------------------------------------------------------
_PROJECT_ROOT = Path(__file__).resolve().parents[3]
_MODELS_DIR = _PROJECT_ROOT / "models"
_OUTPUTS_DIR = _PROJECT_ROOT / "outputs"
-# ---------------------------------------------------------------------------
-# Per-analysis-type model cache
-# ---------------------------------------------------------------------------
-_model_cache: dict[str, tuple[UNet, torch.device]] = {}
+_model_cache: dict[str, tuple[torch.nn.Module, torch.device]] = {}
def _get_device() -> torch.device:
@@ -64,8 +60,9 @@ def _find_best_checkpoint(analysis_type: str) -> Optional[Path]:
return candidates[0] if candidates else None
-def _load_model(analysis_type: str = "deforestation") -> tuple[UNet, torch.device]:
- """Load (or return cached) U-Net model configured for the analysis type."""
+def _load_model(analysis_type: str = "deforestation") -> tuple[torch.nn.Module, torch.device]:
+ """Load (or return cached) model configured for the analysis type."""
+
if analysis_type in _model_cache:
return _model_cache[analysis_type]
@@ -80,13 +77,11 @@ def _load_model(analysis_type: str = "deforestation") -> tuple[UNet, torch.devic
if model_path is not None:
checkpoint = torch.load(model_path, map_location=device)
- # Load full state first (includes BatchNorm running stats)
model_state = checkpoint.get("model_state_dict")
- ema_state = checkpoint.get("ema_state_dict")
+ ema_state = checkpoint.get("ema_state_dict")
if model_state is not None:
model.load_state_dict(model_state, strict=False)
- # Overlay EMA parameters on top (better generalisation)
if ema_state is not None:
with torch.no_grad():
for name, param in model.named_parameters():
@@ -94,7 +89,7 @@ def _load_model(analysis_type: str = "deforestation") -> tuple[UNet, torch.devic
param.data.copy_(ema_state[name])
logger.info(
- "Loaded %s model from %s (epoch %s val_iou %.4f)",
+ "Loaded %s model from %s (epoch %s val_iou %.4f)",
analysis_type,
model_path,
checkpoint.get("epoch", "?"),
@@ -102,7 +97,7 @@ def _load_model(analysis_type: str = "deforestation") -> tuple[UNet, torch.devic
)
else:
logger.warning(
- "No trained model found for %s under %s — using untrained weights (demo).",
+ "No trained model found for %s under %s — using untrained weights.",
analysis_type,
_MODELS_DIR,
)
@@ -114,88 +109,82 @@ def _load_model(analysis_type: str = "deforestation") -> tuple[UNet, torch.devic
return model, device
-# ---------------------------------------------------------------------------
-# Sentinel-2 normalisation statistics (matches preprocessing.py)
-# Band order: [Red, Green, Blue, NIR]
-# ---------------------------------------------------------------------------
-_S2_MEAN = np.array([943.0, 1069.0, 981.0, 2734.0], dtype=np.float64)
-_S2_STD = np.array([590.0, 547.0, 498.0, 1246.0], dtype=np.float64)
+def _validate_model_compatibility(
+ model: torch.nn.Module, expected_channels: int, expected_classes: int
+) -> None:
+ """Ensure loaded model matches the expected architecture."""
+ actual_channels = getattr(model, "n_channels", None)
+ actual_classes = getattr(model, "n_classes", None)
+
+ if actual_channels is not None and actual_channels != expected_channels:
+ raise RuntimeError(
+ f"Model channel mismatch: expected {expected_channels}, got {actual_channels}. "
+ f"The checkpoint may be for a different analysis type."
+ )
+ if actual_classes is not None and actual_classes != expected_classes:
+ raise RuntimeError(
+ f"Model class mismatch: expected {expected_classes}, got {actual_classes}. "
+ f"The checkpoint may be for a different analysis type."
+ )
# ---------------------------------------------------------------------------
-# NDVI helper (works for >=4 bands; returns zeros for RGB-only)
+# Analysis-specific index computation
# ---------------------------------------------------------------------------
def _compute_ndvi_stats(image: np.ndarray) -> dict[str, float]:
- """
- Compute NDVI min/mean/max from image array.
-
- Expects (C, H, W) with C >= 4 where band order is [Red, Green, Blue, NIR].
- Automatically detects and reverses Sentinel-2 z-score normalisation
- (values in roughly [-5, 5]) before computing NDVI.
- Returns zeros if fewer than 4 bands.
- """
- if image.ndim == 2:
+ """Compute NDVI from (C, H, W) with B04=Red, B08=NIR."""
+ if image.ndim != 3 or image.shape[0] < 4:
return {"NDVI_min": 0.0, "NDVI_mean": 0.0, "NDVI_max": 0.0}
- # Normalise to (C, H, W)
- if image.ndim == 3 and image.shape[2] < image.shape[0]:
- image = np.transpose(image, (2, 0, 1))
-
- n_bands = image.shape[0]
- if n_bands < 4:
- return {"NDVI_min": 0.0, "NDVI_mean": 0.0, "NDVI_max": 0.0}
+ image = np.transpose(image, (1, 2, 0)) if image.shape[0] > image.shape[2] else image
+ red = image[..., 0].astype(np.float64)
+ nir = image[..., 3].astype(np.float64)
- # Band order: Red=0, Green=1, Blue=2, NIR=3
- red = image[0].astype(np.float64)
- nir = image[3].astype(np.float64)
-
- # If data looks like z-score normalised input (values in [-10, 10])
- # denormalise back to raw Sentinel-2 DN before computing NDVI.
if red.max() <= 10.0 and nir.max() <= 10.0:
- red = red * _S2_STD[0] + _S2_MEAN[0]
- nir = nir * _S2_STD[3] + _S2_MEAN[3]
+ red = red * 943.0 + 943.0
+ nir = nir * 1246.0 + 2734.0
denom = nir + red + 1e-8
ndvi = (nir - red) / denom
-
return {
- "NDVI_min": round(float(np.nanmin(ndvi)), 4),
+ "NDVI_min": round(float(np.nanmin(ndvi)), 4),
"NDVI_mean": round(float(np.nanmean(ndvi)), 4),
- "NDVI_max": round(float(np.nanmax(ndvi)), 4),
+ "NDVI_max": round(float(np.nanmax(ndvi)), 4),
}
-def _synthetic_ndvi_stats(bbox: Optional[list[float]]) -> dict[str, float]:
+def _compute_mndwi_stats(image: np.ndarray) -> dict[str, float]:
"""
- Compute NDVI from a synthetic but physically realistic Sentinel-2 scene.
-
- Used as a fallback when GEE credentials are unavailable.
- The bbox is used to seed the RNG so the same region always returns
- the same values. Band statistics match typical tropical/temperate forest.
+ Compute MNDWI from (C, H, W) with B03=Green, B11=SWIR.
+ Expects band order [B03, B08, B11] for flood inputs.
"""
- seed = 42
- if bbox:
- seed = int(abs(sum(v * 1000 * (i + 1) for i, v in enumerate(bbox)))) % (2 ** 31)
- rng = np.random.default_rng(seed)
+ if image.ndim != 3 or image.shape[0] < 3:
+ return {"MNDWI_min": 0.0, "MNDWI_mean": 0.0, "MNDWI_max": 0.0}
- # Typical Sentinel-2 L2A forest reflectance (DN, 0-10000 scale)
- # Red ~600-1200, NIR ~2500-5000
- red = rng.normal(900.0, 350.0, (256, 256)).clip(50.0, 5000.0)
- nir = rng.normal(3800.0, 900.0, (256, 256)).clip(100.0, 9000.0)
-
- denom = nir + red + 1e-8
- ndvi = (nir - red) / denom
+ # image is (C, H, W); for flooding: B03=idx0, B11=idx2
+ green = image[0].astype(np.float64)
+ swir = image[2].astype(np.float64)
+ denom = green + swir + 1e-8
+ mndwi = (green - swir) / denom
return {
- "NDVI_min": round(float(np.nanmin(ndvi)), 4),
- "NDVI_mean": round(float(np.nanmean(ndvi)), 4),
- "NDVI_max": round(float(np.nanmax(ndvi)), 4),
+ "MNDWI_min": round(float(np.nanmin(mndwi)), 4),
+ "MNDWI_mean": round(float(np.nanmean(mndwi)), 4),
+ "MNDWI_max": round(float(np.nanmax(mndwi)), 4),
}
+def _compute_analysis_stats(image: np.ndarray, analysis_type: str) -> dict[str, float]:
+ """Return analysis-specific spectral index stats."""
+ if analysis_type == "flooding":
+ return _compute_mndwi_stats(image)
+ else:
+ return _compute_ndvi_stats(image)
+
+
# ---------------------------------------------------------------------------
-# Core inference on a numpy array
+# Core inference
# ---------------------------------------------------------------------------
def run_inference(
@@ -209,41 +198,42 @@ def run_inference(
"""
Run full inference pipeline on a (C, H, W) numpy image.
- Returns dict with keys: region, ndvi_stats, inference.
+ Returns dict with keys: region, index_stats, inference.
"""
- # Normalise to (C, H, W)
if image.ndim == 3 and image.shape[2] < image.shape[0]:
image = np.transpose(image, (2, 0, 1))
- ndvi_stats = _compute_ndvi_stats(image)
+ index_stats = _compute_analysis_stats(image, analysis_type)
model, device = _load_model(analysis_type)
+ model_cfg = get_model_config(analysis_type)
+ expected_channels = model_cfg.get("in_channels", 4)
+ expected_classes = model_cfg.get("num_classes", 2)
+
+ _validate_model_compatibility(model, expected_channels, expected_classes)
+
n_channels = model.n_channels
n_classes = model.n_classes
- # Prepare tensor — model expects (N, n_channels, H, W)
c, h, w = image.shape
if c < n_channels:
- # Pad missing channels with zeros
pad = np.zeros((n_channels - c, h, w), dtype=image.dtype)
image = np.concatenate([image, pad], axis=0)
elif c > n_channels:
image = image[:n_channels]
- # Use torch.FloatTensor via tolist() to avoid numpy<->torch interop issues
- tensor = torch.FloatTensor(image.astype(np.float32).tolist()).unsqueeze(0) # (1, C, H, W)
+ tensor = torch.FloatTensor(image.astype(np.float32).tolist()).unsqueeze(0)
tensor = tensor.to(device)
with torch.no_grad():
output = model(tensor)
- predictions = torch.argmax(output, dim=1) # (1, H, W)
- probabilities = torch.softmax(output, dim=1) # (1, n_classes, H, W)
+ predictions = torch.argmax(output, dim=1)
+ probabilities = torch.softmax(output, dim=1)
total_pixels = int(predictions.numel())
max_probs = probabilities.max(dim=1).values
mean_confidence = float(max_probs.mean().item())
- # Build per-class pixel counts
class_pixels: dict[str, int] = {}
class_percentages: dict[str, float] = {}
for cls in range(n_classes):
@@ -252,7 +242,6 @@ def run_inference(
class_pixels[f"class_{cls}_pixels"] = count
class_percentages[f"class_{cls}_percentage"] = round(pct, 4)
- # Add friendly keys for known 2-class deforestation output (backward compat)
inference: dict[str, Any] = {
"image_size": [h, w],
"num_classes": n_classes,
@@ -260,10 +249,24 @@ def run_inference(
**class_pixels,
**class_percentages,
}
- if n_classes == 2:
+
+ # Add friendly keys for known analysis types
+ if analysis_type == "deforestation" and n_classes == 2:
inference["forest_pixels"] = class_pixels.get("class_1_pixels", 0)
inference["non_forest_pixels"] = class_pixels.get("class_0_pixels", 0)
inference["forest_percentage"] = class_percentages.get("class_1_percentage", 0.0)
+ elif analysis_type == "flooding" and n_classes == 3:
+ inference["dry_pixels"] = class_pixels.get("class_0_pixels", 0)
+ inference["water_pixels"] = class_pixels.get("class_1_pixels", 0)
+ inference["flooded_pixels"] = class_pixels.get("class_2_pixels", 0)
+ inference["dry_percentage"] = class_percentages.get("class_0_percentage", 0.0)
+ inference["water_percentage"] = class_percentages.get("class_1_percentage", 0.0)
+ inference["flooded_percentage"] = class_percentages.get("class_2_percentage", 0.0)
+ elif analysis_type == "ice_melting" and n_classes == 3:
+ inference["water_pixels"] = class_pixels.get("class_0_pixels", 0)
+ inference["ice_pixels"] = class_pixels.get("class_1_pixels", 0)
+ inference["land_pixels"] = class_pixels.get("class_2_pixels", 0)
+ inference["ice_percentage"] = class_percentages.get("class_1_percentage", 0.0)
region: dict[str, Any] = {}
if bbox is not None:
@@ -273,14 +276,14 @@ def run_inference(
return {
"region": region,
- "ndvi_stats": ndvi_stats,
+ "index_stats": index_stats,
"inference": inference,
"is_synthetic": False,
}
# ---------------------------------------------------------------------------
-# File-based inference (upload path)
+# File-based inference
# ---------------------------------------------------------------------------
def run_inference_from_file(
@@ -291,9 +294,7 @@ def run_inference_from_file(
end_date: Optional[str] = None,
analysis_type: str = "deforestation",
) -> dict[str, Any]:
- """
- Load an image file (GeoTIFF or PNG/JPEG) and run inference.
- """
+ """Load an image file (GeoTIFF or PNG/JPEG) and run inference."""
image = _load_image_file(path)
result = run_inference(
image,
@@ -307,40 +308,33 @@ def run_inference_from_file(
def _load_image_file(path: str) -> np.ndarray:
- """
- Load image as (C, H, W) numpy array.
- Tries rasterio first (GeoTIFF), falls back to Pillow.
- """
+ """Load image as (C, H, W) numpy array. Tries rasterio first, falls back to Pillow."""
p = Path(path)
suffix = p.suffix.lower()
- # Try rasterio for geospatial formats
if suffix in {".tif", ".tiff", ".geotiff"}:
try:
import rasterio
-
with rasterio.open(path) as src:
- image = src.read() # (C, H, W)
+ image = src.read()
return image.astype(np.float32)
except Exception:
logger.warning("rasterio failed for %s, trying Pillow", path)
- # Pillow fallback for PNG, JPEG, etc.
from PIL import Image
-
pil_img = Image.open(path)
- arr = np.array(pil_img) # (H, W, C) or (H, W)
+ arr = np.array(pil_img)
if arr.ndim == 2:
- arr = arr[np.newaxis, :, :] # (1, H, W)
+ arr = arr[np.newaxis, :, :]
else:
- arr = np.transpose(arr, (2, 0, 1)) # (C, H, W)
+ arr = np.transpose(arr, (2, 0, 1))
return arr.astype(np.float32)
# ---------------------------------------------------------------------------
-# GEE-based inference (bbox path) — lazy import, safe fallback
+# GEE-based inference
# ---------------------------------------------------------------------------
def run_inference_from_gee(
@@ -351,145 +345,129 @@ def run_inference_from_gee(
analysis_type: str = "deforestation",
) -> dict[str, Any]:
"""
- Query Google Earth Engine for a real Sentinel-2 tile and run inference.
+ Query Google Earth Engine for a real satellite tile and run inference.
- Falls back to synthetic NDVI stats and a synthetic tile if GEE is
- unavailable or returns no images.
+ Raises:
+ GEEAuthenticationError: If GEE credentials are missing.
+ TileNotFoundError: If no images are found for the date range.
+ RuntimeError: If model architecture does not match analysis type.
"""
- ndvi_stats: Optional[dict[str, Any]] = None
- gee_count: int = 0
-
- if bbox and start_date and end_date:
- ndvi_stats, gee_count = _try_gee_ndvi(bbox, start_date, end_date)
+ from climatevision.data import download_tile_for_analysis, apply_scl_cloud_mask
+ from climatevision.data.gee_downloader import GEEAuthenticationError, TileNotFoundError
- # --- Attempt to download a real tile from GEE ---
- try:
- from climatevision.data import download_tile_for_analysis, apply_scl_cloud_mask
-
- tile_path, metadata = download_tile_for_analysis(
- bbox=bbox,
- start_date=start_date,
- end_date=end_date,
- analysis_type=analysis_type,
- )
+ if not (bbox and start_date and end_date):
+ raise ValueError("bbox, start_date, and end_date are required for GEE inference.")
- image = _load_image_file(str(tile_path))
-
- # If SCL band is present (last band), apply cloud mask and drop it
- n_bands_expected = len(get_bands_for_analysis(analysis_type))
- if image.shape[0] == n_bands_expected + 1:
- scl_band = image[-1].astype(np.uint8)
- image = image[:-1]
- image = apply_scl_cloud_mask(image, scl_band)
-
- result = run_inference(
- image,
- bbox=bbox,
- start_date=start_date,
- end_date=end_date,
- analysis_type=analysis_type,
- )
- result["metadata"] = metadata
- result["is_synthetic"] = metadata.get("is_synthetic", False)
-
- # Override NDVI with GEE-derived stats if we got them; else keep computed
- if ndvi_stats is not None:
- result["ndvi_stats"] = ndvi_stats
- elif metadata.get("is_synthetic"):
- result["ndvi_stats"] = _synthetic_ndvi_stats(bbox)
-
- if gee_count:
- result["region"]["images_available"] = gee_count
+ # Download real tile
+ tile_path, metadata = download_tile_for_analysis(
+ bbox=bbox,
+ start_date=start_date,
+ end_date=end_date,
+ analysis_type=analysis_type,
+ )
- return result
+ image = _load_image_file(str(tile_path))
- except Exception as exc:
- logger.warning("Real tile inference failed (%s). Using fallback.", exc)
+ # If SCL band is present (last band), apply cloud mask and drop it
+ n_bands_expected = len(get_bands_for_analysis(analysis_type))
+ if image.shape[0] == n_bands_expected + 1:
+ scl_band = image[-1].astype(np.uint8)
+ image = image[:-1]
+ analysis_cfg = get_analysis_config(analysis_type)
+ clear_labels = analysis_cfg.get("scl_clear_labels", [4, 5, 6])
+ image = apply_scl_cloud_mask(image, scl_band, clear_labels=clear_labels)
- # --- Fallback: template result with synthetic stats ---
result = run_inference(
- np.zeros((4, 256, 256), dtype=np.float32),
+ image,
bbox=bbox,
start_date=start_date,
end_date=end_date,
analysis_type=analysis_type,
)
-
- if ndvi_stats is None:
- ndvi_stats = _synthetic_ndvi_stats(bbox)
- result["ndvi_stats"] = ndvi_stats
-
- region = result.get("region", {})
- if gee_count:
- region["images_available"] = gee_count
- result["region"] = region
- result["is_synthetic"] = True
- result["metadata"] = {"is_synthetic": True, "fallback_reason": "gee_tile_download_failed"}
+ result["metadata"] = metadata
+ result["region"]["images_available"] = metadata.get("images_available", 0)
return result
-def _try_gee_ndvi(
- bbox: list[float], start_date: str, end_date: str
-) -> tuple[Optional[dict[str, Any]], int]:
- """Attempt GEE NDVI query. Returns (ndvi_stats_or_None, image_count)."""
- try:
- import ee # lazy import
- import os
-
- project = os.getenv("GEE_PROJECT_ID")
- svc_account = os.getenv("GEE_SERVICE_ACCOUNT")
- key_file = os.getenv("GEE_SERVICE_ACCOUNT_KEY")
-
- # Resolve relative key path against project root
- if key_file and not os.path.isabs(key_file):
- key_file = str(_PROJECT_ROOT / key_file)
-
- if svc_account and key_file and os.path.exists(key_file):
- credentials = ee.ServiceAccountCredentials(svc_account, key_file)
- ee.Initialize(credentials)
- elif project:
- ee.Initialize(project=project)
- else:
- ee.Initialize()
-
- geometry = ee.Geometry.Rectangle(bbox)
- collection = (
- ee.ImageCollection("COPERNICUS/S2_SR_HARMONIZED")
- .filterBounds(geometry)
- .filterDate(start_date, end_date)
- .filter(ee.Filter.lt("CLOUDY_PIXEL_PERCENTAGE", 20))
- .select(["B4", "B3", "B2", "B8"])
- )
+# ---------------------------------------------------------------------------
+# Bitemporal change detection inference
+# ---------------------------------------------------------------------------
- count = collection.size().getInfo()
+def run_bitemporal_inference(
+ pre_image: np.ndarray,
+ post_image: np.ndarray,
+ *,
+ bbox: Optional[list[float]] = None,
+ analysis_type: str = "flooding",
+) -> dict[str, Any]:
+ """
+ Run inference on pre-event and post-event images, then compute change detection.
- median = collection.median()
- nir = median.select("B8")
- red = median.select("B4")
- ndvi = nir.subtract(red).divide(nir.add(red)).rename("NDVI")
+ Returns:
+ Dict with pre_result, post_result, and change detection metrics.
+ For flooding: newly_flooded_pixels, receded_pixels, permanent_water_pixels.
+ """
+ pre_result = run_inference(
+ pre_image,
+ bbox=bbox,
+ analysis_type=analysis_type,
+ )
+ post_result = run_inference(
+ post_image,
+ bbox=bbox,
+ analysis_type=analysis_type,
+ )
- stats = ndvi.reduceRegion(
- reducer=ee.Reducer.mean().combine(ee.Reducer.minMax(), sharedInputs=True),
- geometry=geometry,
- scale=100,
- maxPixels=int(1e9),
- ).getInfo()
+ pre_mask = np.array(pre_result["inference"].get("prediction_mask", []))
+ post_mask = np.array(post_result["inference"].get("prediction_mask", []))
- return stats, count
+ # If prediction masks are not directly available, reconstruct from pixel counts
+ # (This is a fallback; ideally the model returns the full mask)
+ # For now, we store the mask in the result for change detection
+ # NOTE: run_inference currently doesn't return the mask — we need to add it
+ # We'll add mask extraction below
- except Exception as exc:
- logger.warning("GEE query failed (%s). Using fallback.", exc)
- return None, 0
+ # Actually, let's re-run inference and capture the mask directly
+ model, device = _load_model(analysis_type)
+ def _infer_mask(img: np.ndarray) -> np.ndarray:
+ c, h, w = img.shape
+ n_channels = model.n_channels
+ if c < n_channels:
+ pad = np.zeros((n_channels - c, h, w), dtype=img.dtype)
+ img = np.concatenate([img, pad], axis=0)
+ elif c > n_channels:
+ img = img[:n_channels]
+ tensor = torch.FloatTensor(img.astype(np.float32).tolist()).unsqueeze(0).to(device)
+ with torch.no_grad():
+ output = model(tensor)
+ pred = torch.argmax(output, dim=1).squeeze().cpu().numpy()
+ return pred
+
+ pre_mask = _infer_mask(pre_image)
+ post_mask = _infer_mask(post_image)
+
+ h, w = pre_mask.shape
+ total = h * w
+
+ # Class indices for flooding: 0=dry, 1=permanent_water, 2=flooded
+ newly_flooded = ((pre_mask != 2) & (post_mask == 2)).sum()
+ receded = ((pre_mask == 2) & (post_mask != 2)).sum()
+ permanent_water = ((pre_mask == 1) & (post_mask == 1)).sum()
+
+ change_metrics = {
+ "newly_flooded_pixels": int(newly_flooded),
+ "newly_flooded_percentage": round(float(newly_flooded / total * 100), 4),
+ "receded_pixels": int(receded),
+ "receded_percentage": round(float(receded / total * 100), 4),
+ "permanent_water_pixels": int(permanent_water),
+ "permanent_water_percentage": round(float(permanent_water / total * 100), 4),
+ }
-def _load_cached_ndvi() -> dict[str, float]:
- """Load NDVI from outputs/inference_results.json if it exists, else zeros."""
- cached = _OUTPUTS_DIR / "inference_results.json"
- if cached.exists():
- try:
- data = json.loads(cached.read_text(encoding="utf-8"))
- return data.get("ndvi_stats", {"NDVI_min": 0.0, "NDVI_mean": 0.0, "NDVI_max": 0.0})
- except Exception:
- pass
- return {"NDVI_min": 0.0, "NDVI_mean": 0.0, "NDVI_max": 0.0}
+ return {
+ "region": {"bbox": bbox} if bbox else {},
+ "pre_event": pre_result,
+ "post_event": post_result,
+ "change_detection": change_metrics,
+ }
diff --git a/src/climatevision/inference/postprocess.py b/src/climatevision/inference/postprocess.py
index ad99cab..c7025c6 100644
--- a/src/climatevision/inference/postprocess.py
+++ b/src/climatevision/inference/postprocess.py
@@ -1,14 +1,14 @@
"""
-Post-processing utilities for inference pipeline.
+Inference post-processing utilities.
-Provides confidence thresholding, output filtering, and anomaly detection
-for model predictions before they are returned to users.
+Provides:
+- Confidence thresholding and small-region removal
+- Sliding-window tiling for large-region inference
+- Anomaly detection and quality scoring
"""
-
from __future__ import annotations
import logging
-from dataclasses import dataclass
from typing import Any, Optional
import numpy as np
@@ -16,214 +16,241 @@
logger = logging.getLogger(__name__)
-@dataclass
-class PostProcessConfig:
- """Configuration for post-processing operations."""
- confidence_threshold: float = 0.5
- min_region_pixels: int = 100
- anomaly_std_threshold: float = 3.0
- smooth_kernel_size: int = 3
- apply_morphological_ops: bool = True
-
-
-@dataclass
-class PostProcessResult:
- """Result from post-processing operations."""
- mask: np.ndarray
- confidence_map: np.ndarray
- filtered_pixels: int
- anomaly_detected: bool
- anomaly_regions: list[dict[str, Any]]
- quality_score: float
-
+# ---------------------------------------------------------------------------
+# Confidence thresholding
+# ---------------------------------------------------------------------------
def apply_confidence_threshold(
- predictions: np.ndarray,
- confidence: np.ndarray,
- threshold: float = 0.5
+ probabilities: np.ndarray,
+ threshold: float = 0.5,
) -> np.ndarray:
"""
- Filter predictions below confidence threshold.
+ Mask out predictions where max class probability is below threshold.
Args:
- predictions: Model prediction mask (H, W) or (H, W, C)
- confidence: Confidence scores (H, W)
- threshold: Minimum confidence to keep prediction
+ probabilities: (C, H, W) softmax probabilities.
+ threshold: Minimum confidence to keep a prediction.
Returns:
- Filtered prediction mask with low-confidence pixels zeroed
+ (H, W) integer mask with -1 for low-confidence pixels.
"""
- mask = confidence >= threshold
- filtered = predictions.copy()
-
- if filtered.ndim == 2:
- filtered[~mask] = 0
- else:
- filtered[~mask, :] = 0
+ max_prob = probabilities.max(axis=0)
+ predictions = probabilities.argmax(axis=0)
+ predictions[max_prob < threshold] = -1
+ return predictions
- filtered_count = (~mask).sum()
- logger.debug(f"Filtered {filtered_count} pixels below threshold {threshold}")
-
- return filtered
+# ---------------------------------------------------------------------------
+# Small-region removal
+# ---------------------------------------------------------------------------
def remove_small_regions(
mask: np.ndarray,
- min_pixels: int = 100
+ min_size: int = 50,
+ connectivity: int = 1,
) -> np.ndarray:
"""
- Remove small isolated regions from segmentation mask.
+ Remove connected components smaller than min_size pixels.
Args:
- mask: Binary segmentation mask (H, W)
- min_pixels: Minimum region size to keep
+ mask: (H, W) binary or integer label mask.
+ min_size: Minimum component size to retain.
+ connectivity: 1 for 4-connectivity, 2 for 8-connectivity.
Returns:
- Cleaned mask with small regions removed
+ Cleaned mask with small regions removed (set to 0).
"""
try:
from scipy import ndimage
except ImportError:
- logger.warning("scipy not available, skipping small region removal")
+ logger.warning("scipy not available; skipping small-region removal")
return mask
- labeled, num_features = ndimage.label(mask)
-
- cleaned = np.zeros_like(mask)
- for i in range(1, num_features + 1):
- region = labeled == i
- if region.sum() >= min_pixels:
- cleaned[region] = mask[region]
-
- removed = num_features - len(np.unique(ndimage.label(cleaned)[0])) + 1
- logger.debug(f"Removed {removed} regions smaller than {min_pixels} pixels")
+ cleaned = mask.copy()
+ unique_labels = np.unique(mask)
+ for label in unique_labels:
+ if label == 0:
+ continue
+ binary = (mask == label).astype(np.uint8)
+ labeled, num_features = ndimage.label(binary, structure=ndimage.generate_binary_structure(2, connectivity))
+ component_sizes = ndimage.sum(binary, labeled, index=range(1, num_features + 1))
+ too_small = component_sizes < min_size
+ remove_mask = too_small[labeled - 1]
+ cleaned[(mask == label) & remove_mask] = 0
return cleaned
+# ---------------------------------------------------------------------------
+# Anomaly detection
+# ---------------------------------------------------------------------------
+
def detect_anomalies(
- predictions: np.ndarray,
- confidence: np.ndarray,
- std_threshold: float = 3.0
-) -> tuple[bool, list[dict[str, Any]]]:
+ prediction: np.ndarray,
+ expected_classes: list[int],
+) -> dict[str, Any]:
"""
- Detect anomalous predictions that may indicate model issues.
+ Flag anomalous predictions (unexpected class labels, extreme coverage).
Args:
- predictions: Model predictions
- confidence: Confidence scores
- std_threshold: Number of standard deviations for anomaly detection
+ prediction: (H, W) integer mask.
+ expected_classes: List of valid class indices.
Returns:
- Tuple of (anomaly_detected, list of anomaly regions)
+ Dict with anomaly flags and descriptions.
"""
anomalies = []
+ unique = np.unique(prediction)
+ unexpected = set(unique) - set(expected_classes)
+ if unexpected:
+ anomalies.append(f"Unexpected class labels: {unexpected}")
- mean_conf = confidence.mean()
- std_conf = confidence.std()
+ total = prediction.size
+ for cls in expected_classes:
+ pct = (prediction == cls).sum() / total * 100
+ if pct > 95:
+ anomalies.append(f"Class {cls} dominates ({pct:.1f}% — possible model collapse)")
- if std_conf > 0:
- z_scores = np.abs((confidence - mean_conf) / std_conf)
- anomaly_mask = z_scores > std_threshold
+ return {
+ "is_anomalous": len(anomalies) > 0,
+ "anomalies": anomalies,
+ }
- if anomaly_mask.any():
- anomaly_indices = np.where(anomaly_mask)
- anomalies.append({
- "type": "confidence_outlier",
- "count": int(anomaly_mask.sum()),
- "mean_confidence": float(mean_conf),
- "std_confidence": float(std_conf),
- "threshold": std_threshold
- })
-
- # Check for suspiciously uniform predictions
- unique_values = len(np.unique(predictions))
- if unique_values == 1 and predictions.size > 1000:
- anomalies.append({
- "type": "uniform_prediction",
- "message": "Model returned uniform predictions across entire region"
- })
-
- return len(anomalies) > 0, anomalies
+# ---------------------------------------------------------------------------
+# Quality scoring
+# ---------------------------------------------------------------------------
def compute_quality_score(
+ prediction: np.ndarray,
confidence: np.ndarray,
- filtered_ratio: float,
- anomaly_detected: bool
) -> float:
"""
- Compute overall quality score for the prediction.
-
- Args:
- confidence: Confidence map
- filtered_ratio: Ratio of pixels filtered out
- anomaly_detected: Whether anomalies were detected
+ Compute an overall quality score [0, 1] for a prediction.
- Returns:
- Quality score between 0 and 1
+ Based on mean confidence and class balance.
"""
- mean_confidence = float(confidence.mean())
- confidence_score = min(mean_confidence, 1.0)
+ mean_conf = float(confidence.mean())
+ total = prediction.size
+ class_balance = 1.0
+ unique, counts = np.unique(prediction, return_counts=True)
+ if len(unique) > 1:
+ max_pct = counts.max() / total
+ class_balance = 1.0 - max(0, max_pct - 0.8) / 0.2 # penalize if one class > 80%
- filter_penalty = filtered_ratio * 0.3
- anomaly_penalty = 0.2 if anomaly_detected else 0.0
+ return round(mean_conf * 0.7 + class_balance * 0.3, 4)
- quality = max(0.0, confidence_score - filter_penalty - anomaly_penalty)
- return round(quality, 4)
+# ---------------------------------------------------------------------------
+# Sliding-window tiling for large regions
+# ---------------------------------------------------------------------------
-def postprocess_predictions(
- predictions: np.ndarray,
- confidence: np.ndarray,
- config: Optional[PostProcessConfig] = None
-) -> PostProcessResult:
+class SlidingWindowTiler:
"""
- Apply full post-processing pipeline to model predictions.
-
- Args:
- predictions: Raw model predictions (H, W) or (H, W, C)
- confidence: Confidence scores (H, W)
- config: Post-processing configuration
-
- Returns:
- PostProcessResult with filtered predictions and quality metrics
+ Tile large images into overlapping patches for inference,
+ then stitch predictions back with soft-voting in overlap regions.
"""
- if config is None:
- config = PostProcessConfig()
-
- original_pixels = predictions.size
-
- # Apply confidence threshold
- filtered = apply_confidence_threshold(
- predictions, confidence, config.confidence_threshold
- )
-
- # Remove small regions for binary masks
- if filtered.ndim == 2:
- filtered = remove_small_regions(filtered, config.min_region_pixels)
-
- # Detect anomalies
- anomaly_detected, anomaly_regions = detect_anomalies(
- predictions, confidence, config.anomaly_std_threshold
- )
-
- # Compute metrics
- filtered_pixels = int((filtered == 0).sum())
- filtered_ratio = filtered_pixels / max(original_pixels, 1)
-
- quality_score = compute_quality_score(
- confidence, filtered_ratio, anomaly_detected
- )
-
- if anomaly_detected:
- logger.warning(f"Anomalies detected in prediction: {anomaly_regions}")
-
- return PostProcessResult(
- mask=filtered,
- confidence_map=confidence,
- filtered_pixels=filtered_pixels,
- anomaly_detected=anomaly_detected,
- anomaly_regions=anomaly_regions,
- quality_score=quality_score
- )
+
+ def __init__(
+ self,
+ tile_size: int = 256,
+ overlap: int = 32,
+ ):
+ assert overlap < tile_size, "overlap must be smaller than tile_size"
+ self.tile_size = tile_size
+ self.overlap = overlap
+ self.stride = tile_size - overlap
+
+ def generate_tiles(
+ self,
+ image: np.ndarray,
+ ) -> list[tuple[int, int, np.ndarray]]:
+ """
+ Generate (row, col, tile) tuples from a (C, H, W) image.
+
+ row and col are the top-left coordinates in the original image.
+ """
+ c, h, w = image.shape
+ tiles = []
+ for y in range(0, h - self.tile_size + 1, self.stride):
+ for x in range(0, w - self.tile_size + 1, self.stride):
+ tile = image[:, y:y + self.tile_size, x:x + self.tile_size]
+ tiles.append((y, x, tile))
+
+ # Handle edge cases: last row/col if not perfectly divisible
+ if h % self.stride != 0 and h > self.tile_size:
+ y = h - self.tile_size
+ for x in range(0, w - self.tile_size + 1, self.stride):
+ tile = image[:, y:y + self.tile_size, x:x + self.tile_size]
+ tiles.append((y, x, tile))
+ if w % self.stride != 0 and w > self.tile_size:
+ x = w - self.tile_size
+ for y in range(0, h - self.tile_size + 1, self.stride):
+ tile = image[:, y:y + self.tile_size, x:x + self.tile_size]
+ tiles.append((y, x, tile))
+ if h % self.stride != 0 and w % self.stride != 0 and h > self.tile_size and w > self.tile_size:
+ y = h - self.tile_size
+ x = w - self.tile_size
+ tile = image[:, y:y + self.tile_size, x:x + self.tile_size]
+ tiles.append((y, x, tile))
+
+ return tiles
+
+ def stitch_predictions(
+ self,
+ tiles: list[tuple[int, int, np.ndarray]],
+ image_shape: tuple[int, int, int],
+ num_classes: int,
+ ) -> np.ndarray:
+ """
+ Stitch per-tile predictions into a full-size prediction mask.
+
+ Args:
+ tiles: List of (row, col, pred_mask) where pred_mask is (H, W) int.
+ image_shape: (C, H, W) original image shape.
+ num_classes: Number of classes.
+
+ Returns:
+ (H, W) stitched prediction mask.
+ """
+ c, h, w = image_shape
+ vote_counts = np.zeros((num_classes, h, w), dtype=np.float32)
+
+ for y, x, pred in tiles:
+ th, tw = pred.shape
+ for cls in range(num_classes):
+ vote_counts[cls, y:y + th, x:x + tw] += (pred == cls).astype(np.float32)
+
+ # Class with most votes wins
+ stitched = vote_counts.argmax(axis=0)
+ return stitched
+
+ def stitch_probabilities(
+ self,
+ tiles: list[tuple[int, int, np.ndarray]],
+ image_shape: tuple[int, int, int],
+ ) -> np.ndarray:
+ """
+ Stitch per-tile probability maps with averaging in overlap regions.
+
+ Args:
+ tiles: List of (row, col, probs) where probs is (C, H, W) float.
+ image_shape: (C, H, W) original image shape.
+
+ Returns:
+ (num_classes, H, W) averaged probability map.
+ """
+ c, h, w = image_shape
+ num_classes = tiles[0][2].shape[0]
+ prob_sum = np.zeros((num_classes, h, w), dtype=np.float32)
+ count = np.zeros((h, w), dtype=np.float32)
+
+ for y, x, probs in tiles:
+ _, th, tw = probs.shape
+ prob_sum[:, y:y + th, x:x + tw] += probs
+ count[y:y + th, x:x + tw] += 1.0
+
+ # Avoid divide by zero
+ count = np.clip(count, 1e-8, None)
+ averaged = prob_sum / count[np.newaxis, :, :]
+ return averaged
diff --git a/tests/test_security.py b/tests/test_security.py
new file mode 100644
index 0000000..e9e4fd5
--- /dev/null
+++ b/tests/test_security.py
@@ -0,0 +1,268 @@
+"""
+Security tests for ClimateVision API.
+
+Tests input validation, sanitization, and security controls.
+"""
+
+import pytest
+import numpy as np
+from pathlib import Path
+import sys
+
+sys.path.insert(0, str(Path(__file__).parent.parent / "src"))
+
+from climatevision.security import (
+ validate_payload_size,
+ validate_bbox,
+ validate_file_upload,
+ sanitize_string_input,
+ SecurityConfig,
+ RateLimiter,
+ detect_adversarial_input,
+ validate_model_output,
+ InputAnomalyDetector,
+ PipelineGuard,
+)
+
+
+class TestPayloadValidation:
+ """Test payload size validation."""
+
+ def test_valid_payload(self):
+ data = b"x" * 1000
+ is_valid, error = validate_payload_size(data, max_size=2000)
+ assert is_valid
+ assert error == ""
+
+ def test_oversized_payload(self):
+ data = b"x" * 10000
+ is_valid, error = validate_payload_size(data, max_size=1000)
+ assert not is_valid
+ assert "exceeds maximum" in error
+
+
+class TestBboxValidation:
+ """Test bounding box validation."""
+
+ def test_valid_bbox(self):
+ bbox = [-60.0, -15.0, -45.0, -5.0]
+ is_valid, error = validate_bbox(bbox)
+ assert is_valid
+ assert error == ""
+
+ def test_invalid_longitude(self):
+ bbox = [200.0, 10.0, 30.0, 40.0]
+ is_valid, error = validate_bbox(bbox)
+ assert not is_valid
+ assert "longitude" in error.lower()
+
+ def test_invalid_latitude(self):
+ bbox = [10.0, 100.0, 30.0, 40.0]
+ is_valid, error = validate_bbox(bbox)
+ assert not is_valid
+ assert "latitude" in error.lower()
+
+ def test_wrong_order_longitude(self):
+ bbox = [30.0, 10.0, 20.0, 40.0]
+ is_valid, error = validate_bbox(bbox)
+ assert not is_valid
+ assert "West" in error
+
+ def test_wrong_order_latitude(self):
+ bbox = [10.0, 50.0, 30.0, 40.0]
+ is_valid, error = validate_bbox(bbox)
+ assert not is_valid
+ assert "South" in error
+
+ def test_too_large_area(self):
+ bbox = [-180.0, -90.0, 180.0, 90.0]
+ is_valid, error = validate_bbox(bbox, max_area=100.0)
+ assert not is_valid
+ assert "area" in error.lower()
+
+ def test_wrong_element_count(self):
+ bbox = [10.0, 20.0, 30.0]
+ is_valid, error = validate_bbox(bbox)
+ assert not is_valid
+ assert "exactly 4" in error
+
+
+class TestFileUploadValidation:
+ """Test file upload validation."""
+
+ def test_valid_png(self):
+ content = b"\x89PNG\r\n\x1a\n" + b"x" * 100
+ is_valid, error = validate_file_upload(content, "image.png")
+ assert is_valid
+ assert error == ""
+
+ def test_valid_tiff(self):
+ content = b"II*\x00" + b"x" * 100
+ is_valid, error = validate_file_upload(content, "satellite.tif")
+ assert is_valid
+ assert error == ""
+
+ def test_invalid_extension(self):
+ content = b"\x89PNG\r\n\x1a\n" + b"x" * 100
+ is_valid, error = validate_file_upload(content, "malware.exe")
+ assert not is_valid
+ assert "not allowed" in error
+
+ def test_path_traversal(self):
+ content = b"\x89PNG\r\n\x1a\n" + b"x" * 100
+ is_valid, error = validate_file_upload(content, "../../../etc/passwd")
+ assert not is_valid
+ assert "path traversal" in error.lower()
+
+ def test_extension_mismatch(self):
+ content = b"\x89PNG\r\n\x1a\n" + b"x" * 100
+ is_valid, error = validate_file_upload(content, "image.jpg")
+ assert not is_valid
+ assert "does not match" in error
+
+ def test_filename_too_long(self):
+ content = b"\x89PNG\r\n\x1a\n" + b"x" * 100
+ filename = "a" * 300 + ".png"
+ is_valid, error = validate_file_upload(content, filename)
+ assert not is_valid
+ assert "too long" in error
+
+
+class TestStringSanitization:
+ """Test string input sanitization."""
+
+ def test_normal_string(self):
+ result, warnings = sanitize_string_input("Hello World")
+ assert "Hello" in result
+ assert len(warnings) == 0 or "sanitized" in warnings[0].lower()
+
+ def test_sql_injection(self):
+ result, warnings = sanitize_string_input("'; DROP TABLE users; --")
+ assert "DROP TABLE" not in result
+ assert any("blocked" in w.lower() or "sanitized" in w.lower() for w in warnings)
+
+ def test_xss_script(self):
+ result, warnings = sanitize_string_input("")
+ assert "