|
| 1 | +""" |
| 2 | +Script for calculating HWFE. |
| 3 | +Also tells you how to move whatever elements are included |
| 4 | +""" |
| 5 | + |
| 6 | +import argparse |
| 7 | +import logging |
| 8 | +import os |
| 9 | +import sys |
| 10 | +from copy import deepcopy |
| 11 | +from functools import partial |
| 12 | + |
| 13 | +import matplotlib.pyplot as plt |
| 14 | +import numpy as np |
| 15 | +from megham.transform import ( |
| 16 | + apply_transform, |
| 17 | + decompose_affine, |
| 18 | + decompose_rotation, |
| 19 | + get_affine, |
| 20 | + get_rigid, |
| 21 | +) |
| 22 | +from tqdm import tqdm |
| 23 | + |
| 24 | +from .io import load_tracker |
| 25 | +from .transforms import coord_transform |
| 26 | + |
| 27 | +elements = ["primary", "secondary", "receiver"] |
| 28 | +hwfe_factors = { |
| 29 | + "secondary": [0.00034, 0.00076, 0.0029, 0.044, 0.042, 0.014], |
| 30 | + "receiver": [0.0, 0.0, 0.0026, 0.0, 0.0, 0.0], |
| 31 | +} |
| 32 | +mm_to_um = 1000 |
| 33 | +rad_to_arcsec = 3600 * 180 / np.pi |
| 34 | + |
| 35 | + |
| 36 | +def get_hwfe(data, get_transform, add_err=False) -> float: |
| 37 | + # Put everything in M1 coordinates |
| 38 | + data_m1 = deepcopy(data) |
| 39 | + for element in elements: |
| 40 | + dat = data_m1[element] |
| 41 | + if add_err: |
| 42 | + dat += np.nan_to_num(data_m1[f"{element}_err"]) |
| 43 | + data_m1[element] = coord_transform(dat, "opt_global", "opt_primary") |
| 44 | + data_m1[f"{element}_ref"] = coord_transform( |
| 45 | + data_m1[f"{element}_ref"], "opt_global", "opt_primary" |
| 46 | + ) |
| 47 | + |
| 48 | + # Transform for M1 perfect |
| 49 | + aff_m1, sft_m1 = get_transform( |
| 50 | + data_m1["primary"][data_m1["primary_msk"]], |
| 51 | + data_m1["primary_ref"][data_m1["primary_msk"]], |
| 52 | + method="mean", |
| 53 | + ) |
| 54 | + |
| 55 | + hwfe = 0 |
| 56 | + for element in hwfe_factors.keys(): |
| 57 | + src = data_m1[element][data_m1[f"{element}_msk"]] |
| 58 | + dst = data_m1[f"{element}_ref"][data_m1[f"{element}_msk"]] |
| 59 | + |
| 60 | + # Apply the transform to align M1 |
| 61 | + src = apply_transform(src, aff_m1, sft_m1) |
| 62 | + |
| 63 | + # Get the new transform |
| 64 | + aff, sft = get_transform(src, dst, method="mean") |
| 65 | + _, _, rot = decompose_affine(aff) |
| 66 | + rot = decompose_rotation(rot) |
| 67 | + |
| 68 | + # compute HWFE |
| 69 | + vals = np.hstack([sft * mm_to_um, rot * rad_to_arcsec]).ravel() |
| 70 | + hwfe += float(np.sum((np.array(hwfe_factors[element]) * vals) ** 2)) |
| 71 | + return np.sqrt(hwfe) |
| 72 | + |
| 73 | + |
| 74 | +def main(): |
| 75 | + parser = argparse.ArgumentParser() |
| 76 | + parser.add_argument("path", help="path to data file") |
| 77 | + parser.add_argument( |
| 78 | + "--affine", |
| 79 | + "-a", |
| 80 | + action="store_true", |
| 81 | + help="Pass to compute affine instead of rigid rotation", |
| 82 | + ) |
| 83 | + parser.add_argument( |
| 84 | + "--n_draws", |
| 85 | + "-n", |
| 86 | + default=10000, |
| 87 | + type=int, |
| 88 | + help="Number of draws from the error distribution to do when estimating HWFE error", |
| 89 | + ) |
| 90 | + parser.add_argument( |
| 91 | + "--log_level", "-l", default="INFO", help="the log level to use" |
| 92 | + ) |
| 93 | + args = parser.parse_args() |
| 94 | + logging.basicConfig() |
| 95 | + logger = logging.getLogger("lat_alignment") |
| 96 | + logger.setLevel(args.log_level.upper()) |
| 97 | + |
| 98 | + # Pick the fitter |
| 99 | + get_transform = get_rigid |
| 100 | + transform_str = "rigid" |
| 101 | + if args.affine: |
| 102 | + get_transform = partial(get_affine, force_svd=True) |
| 103 | + transform_str = "affine" |
| 104 | + # Load data |
| 105 | + logger.info("Loading data from %s", args.path) |
| 106 | + ext = os.path.splitext(args.path)[1] |
| 107 | + if ext != ".yaml": |
| 108 | + raise ValueError("Data for HWFE script must be a yaml file") |
| 109 | + data = load_tracker(args.path) |
| 110 | + |
| 111 | + # Get the transform for each element assuming no error |
| 112 | + have_err = False |
| 113 | + for element in elements: |
| 114 | + logger.info("Getting transform for %s", element) |
| 115 | + src = np.array(data[element]) |
| 116 | + dst = np.array(data[f"{element}_ref"]) |
| 117 | + if np.all(np.isnan(src)): |
| 118 | + logger.info("\tElement is all nan!, assuming it is perfect") |
| 119 | + src = dst.copy() |
| 120 | + data[element] = src |
| 121 | + have = np.all(np.isfinite(src), axis=1) |
| 122 | + have_err += np.any(np.isfinite(data[f"{element}_err"])) |
| 123 | + if np.sum(have) < 3: |
| 124 | + raise ValueError(f"Only {np.sum(have)} points found!") |
| 125 | + data[f"{element}_msk"] = have |
| 126 | + aff, sft = get_transform(src[have], dst[have], method="mean") |
| 127 | + scale, shear, rot = decompose_affine(aff) |
| 128 | + rot = decompose_rotation(rot) |
| 129 | + logger.info("\tShift is %s mm", str(sft)) |
| 130 | + logger.info("\tRotation is %s deg", str(np.rad2deg(rot))) |
| 131 | + logger.info("\tRotation is %s deg", str(rot)) |
| 132 | + logger.info("\tScale is %s", scale) |
| 133 | + logger.info("\tShear is %s", shear) |
| 134 | + |
| 135 | + # Get HWFE |
| 136 | + hwfe = get_hwfe(data, get_transform) |
| 137 | + logger.info("HWFE is %f", hwfe) |
| 138 | + |
| 139 | + # Error Propagation |
| 140 | + if not have_err: |
| 141 | + logger.info("No errors found") |
| 142 | + sys.exit() |
| 143 | + |
| 144 | + logger.info("Propagating errors") |
| 145 | + hwfe_werr = np.zeros(args.n_draws) |
| 146 | + rng = np.random.default_rng(12345) |
| 147 | + |
| 148 | + for i in tqdm(range(args.n_draws)): |
| 149 | + _data = deepcopy(data) |
| 150 | + _data["primary_err"] *= rng.normal(size=(4, 3)) |
| 151 | + _data["secondary_err"] *= rng.normal(size=(4, 3)) |
| 152 | + _data["receiver_err"] *= rng.normal(size=(4, 3)) |
| 153 | + hwfe_werr[i] = get_hwfe(_data, get_transform, True) |
| 154 | + logger.info("\tStandard deviation of error dist is %f", np.std(hwfe_werr)) |
| 155 | + plt.hist( |
| 156 | + hwfe_werr, |
| 157 | + density=True, |
| 158 | + bins="auto", |
| 159 | + label=f"Mean: {np.mean(hwfe_werr):.2f}\nSTD: {np.std(hwfe_werr):.2f}", |
| 160 | + alpha=0.7, |
| 161 | + ) |
| 162 | + plt.axvline(hwfe, label=f"Without Error: {hwfe:.2f}", color="black") |
| 163 | + plt.legend() |
| 164 | + plt.xlabel("HWFE (um-rms)") |
| 165 | + plt.ylabel("Density") |
| 166 | + plt.title(f"HWFE With Uncorrellated Error ({transform_str})") |
| 167 | + plt.savefig(os.path.splitext(args.path)[0] + f"_error_{transform_str}.png") |
0 commit comments