Skip to content

Commit f308d44

Browse files
authored
Merge pull request #22 from simonsobs/hwfe
Hwfe
2 parents 10042e0 + 4ec878d commit f308d44

24 files changed

+4498
-43
lines changed

.pre-commit-config.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
exclude: "scratch/*"
12
repos:
23
- repo: https://github.com/asottile/pyupgrade
34
rev: v3.19.0

lat_alignment/data/reference.yaml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,10 @@ primary:
55
- [[2397.31, 2656.61, 2142.25], [['CODE17', [2394.80317257, 2707.63422625, 2189.3771685 ]], ['CODE18', [2423.17510973, 2716.87329402, 2255.89994897]], ['CODE29', [2399.70887496, 2661.08669889, 2208.7234212 ]]]] # Upper right
66
- [[-2399.23, 2658.4, 2141.8], [['CODE20', [-2400.36156242, 2710.36348701, 2189.10965623]], ['CODE21', [-2424.37063619, 2711.61367883, 2252.58142288]], ['CODE30', [-2397.17221076, 2662.79003412, 2206.98632987]]]] # Upper left
77
secondary:
8-
- [[-1998.53, -3762.8, -2550.87], [['CODE41', [-1973.72503342, -3823.9597507 , -2569.77344251]], ['CODE42', [-2000.85463917, -3831.63844276, -2547.94045392]], ['CODE43', [-2023.38884744, -3822.97163457, -2572.96223343]]]]
9-
- [[1993.61, -3763.22, -2551.27], [['CODE32', [ 2018.50872848, -3822.7553573 , -2569.82040645]], ['CODE33', [ 1992.06371373, -3830.39574952, -2547.93891399]], ['CODE34', [ 1968.84490998, -3820.65426004, -2573.54021621]]]]
10-
- [[1998.7, -5497.53, 2652.38], [['CODE35', [ 1973.82124844, -5555.19077343, 2630.47057749]], ['CODE36', [ 2001.45892775, -5546.6894774 , 2609.85226912]], ['CODE37', [ 2023.52324911, -5556.41495723, 2634.32566628]]]]
11-
- [[-1995.38, -5496.5, 2651.13], [['CODE38', [-2020.37541076, -5555.87810728, 2627.84602498]], ['CODE39', [-1993.34495544, -5548.0853943 , 2607.79228317]], ['CODE40', [-1970.45168232, -5555.23316558, 2633.12777916]]]]
8+
- [[1998.7, -5497.53, 2652.38], [['CODE35', [ 1973.82124844, -5555.19077343, 2630.47057749]], ['CODE36', [ 2001.45892775, -5546.6894774 , 2609.85226912]], ['CODE37', [ 2023.52324911, -5556.41495723, 2634.32566628]]]] # Lower right
9+
- [[-1995.38, -5496.5, 2651.13], [['CODE38', [-2020.37541076, -5555.87810728, 2627.84602498]], ['CODE39', [-1993.34495544, -5548.0853943 , 2607.79228317]], ['CODE40', [-1970.45168232, -5555.23316558, 2633.12777916]]]] # Lower left
10+
- [[1993.61, -3763.22, -2551.27], [['CODE32', [ 2018.50872848, -3822.7553573 , -2569.82040645]], ['CODE33', [ 1992.06371373, -3830.39574952, -2547.93891399]], ['CODE34', [ 1968.84490998, -3820.65426004, -2573.54021621]]]] # Upper right
11+
- [[-1998.53, -3762.8, -2550.87], [['CODE41', [-1973.72503342, -3823.9597507 , -2569.77344251]], ['CODE42', [-2000.85463917, -3831.63844276, -2547.94045392]], ['CODE43', [-2023.38884744, -3822.97163457, -2572.96223343]]]] # Upper left
1212
bearing:
1313
# - [[1952.6868314671935, 4353.4578245209605, -340.67023968224555], ['CODE91']]
1414
# - [[1593.041170139037, 4358.034056279059, 1174.5278709573736], ['CODE92']]

lat_alignment/hwfe.py

Lines changed: 167 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,167 @@
1+
"""
2+
Script for calculating HWFE.
3+
Also tells you how to move whatever elements are included
4+
"""
5+
6+
import argparse
7+
import logging
8+
import os
9+
import sys
10+
from copy import deepcopy
11+
from functools import partial
12+
13+
import matplotlib.pyplot as plt
14+
import numpy as np
15+
from megham.transform import (
16+
apply_transform,
17+
decompose_affine,
18+
decompose_rotation,
19+
get_affine,
20+
get_rigid,
21+
)
22+
from tqdm import tqdm
23+
24+
from .io import load_tracker
25+
from .transforms import coord_transform
26+
27+
elements = ["primary", "secondary", "receiver"]
28+
hwfe_factors = {
29+
"secondary": [0.00034, 0.00076, 0.0029, 0.044, 0.042, 0.014],
30+
"receiver": [0.0, 0.0, 0.0026, 0.0, 0.0, 0.0],
31+
}
32+
mm_to_um = 1000
33+
rad_to_arcsec = 3600 * 180 / np.pi
34+
35+
36+
def get_hwfe(data, get_transform, add_err=False) -> float:
37+
# Put everything in M1 coordinates
38+
data_m1 = deepcopy(data)
39+
for element in elements:
40+
dat = data_m1[element]
41+
if add_err:
42+
dat += np.nan_to_num(data_m1[f"{element}_err"])
43+
data_m1[element] = coord_transform(dat, "opt_global", "opt_primary")
44+
data_m1[f"{element}_ref"] = coord_transform(
45+
data_m1[f"{element}_ref"], "opt_global", "opt_primary"
46+
)
47+
48+
# Transform for M1 perfect
49+
aff_m1, sft_m1 = get_transform(
50+
data_m1["primary"][data_m1["primary_msk"]],
51+
data_m1["primary_ref"][data_m1["primary_msk"]],
52+
method="mean",
53+
)
54+
55+
hwfe = 0
56+
for element in hwfe_factors.keys():
57+
src = data_m1[element][data_m1[f"{element}_msk"]]
58+
dst = data_m1[f"{element}_ref"][data_m1[f"{element}_msk"]]
59+
60+
# Apply the transform to align M1
61+
src = apply_transform(src, aff_m1, sft_m1)
62+
63+
# Get the new transform
64+
aff, sft = get_transform(src, dst, method="mean")
65+
_, _, rot = decompose_affine(aff)
66+
rot = decompose_rotation(rot)
67+
68+
# compute HWFE
69+
vals = np.hstack([sft * mm_to_um, rot * rad_to_arcsec]).ravel()
70+
hwfe += float(np.sum((np.array(hwfe_factors[element]) * vals) ** 2))
71+
return np.sqrt(hwfe)
72+
73+
74+
def main():
75+
parser = argparse.ArgumentParser()
76+
parser.add_argument("path", help="path to data file")
77+
parser.add_argument(
78+
"--affine",
79+
"-a",
80+
action="store_true",
81+
help="Pass to compute affine instead of rigid rotation",
82+
)
83+
parser.add_argument(
84+
"--n_draws",
85+
"-n",
86+
default=10000,
87+
type=int,
88+
help="Number of draws from the error distribution to do when estimating HWFE error",
89+
)
90+
parser.add_argument(
91+
"--log_level", "-l", default="INFO", help="the log level to use"
92+
)
93+
args = parser.parse_args()
94+
logging.basicConfig()
95+
logger = logging.getLogger("lat_alignment")
96+
logger.setLevel(args.log_level.upper())
97+
98+
# Pick the fitter
99+
get_transform = get_rigid
100+
transform_str = "rigid"
101+
if args.affine:
102+
get_transform = partial(get_affine, force_svd=True)
103+
transform_str = "affine"
104+
# Load data
105+
logger.info("Loading data from %s", args.path)
106+
ext = os.path.splitext(args.path)[1]
107+
if ext != ".yaml":
108+
raise ValueError("Data for HWFE script must be a yaml file")
109+
data = load_tracker(args.path)
110+
111+
# Get the transform for each element assuming no error
112+
have_err = False
113+
for element in elements:
114+
logger.info("Getting transform for %s", element)
115+
src = np.array(data[element])
116+
dst = np.array(data[f"{element}_ref"])
117+
if np.all(np.isnan(src)):
118+
logger.info("\tElement is all nan!, assuming it is perfect")
119+
src = dst.copy()
120+
data[element] = src
121+
have = np.all(np.isfinite(src), axis=1)
122+
have_err += np.any(np.isfinite(data[f"{element}_err"]))
123+
if np.sum(have) < 3:
124+
raise ValueError(f"Only {np.sum(have)} points found!")
125+
data[f"{element}_msk"] = have
126+
aff, sft = get_transform(src[have], dst[have], method="mean")
127+
scale, shear, rot = decompose_affine(aff)
128+
rot = decompose_rotation(rot)
129+
logger.info("\tShift is %s mm", str(sft))
130+
logger.info("\tRotation is %s deg", str(np.rad2deg(rot)))
131+
logger.info("\tRotation is %s deg", str(rot))
132+
logger.info("\tScale is %s", scale)
133+
logger.info("\tShear is %s", shear)
134+
135+
# Get HWFE
136+
hwfe = get_hwfe(data, get_transform)
137+
logger.info("HWFE is %f", hwfe)
138+
139+
# Error Propagation
140+
if not have_err:
141+
logger.info("No errors found")
142+
sys.exit()
143+
144+
logger.info("Propagating errors")
145+
hwfe_werr = np.zeros(args.n_draws)
146+
rng = np.random.default_rng(12345)
147+
148+
for i in tqdm(range(args.n_draws)):
149+
_data = deepcopy(data)
150+
_data["primary_err"] *= rng.normal(size=(4, 3))
151+
_data["secondary_err"] *= rng.normal(size=(4, 3))
152+
_data["receiver_err"] *= rng.normal(size=(4, 3))
153+
hwfe_werr[i] = get_hwfe(_data, get_transform, True)
154+
logger.info("\tStandard deviation of error dist is %f", np.std(hwfe_werr))
155+
plt.hist(
156+
hwfe_werr,
157+
density=True,
158+
bins="auto",
159+
label=f"Mean: {np.mean(hwfe_werr):.2f}\nSTD: {np.std(hwfe_werr):.2f}",
160+
alpha=0.7,
161+
)
162+
plt.axvline(hwfe, label=f"Without Error: {hwfe:.2f}", color="black")
163+
plt.legend()
164+
plt.xlabel("HWFE (um-rms)")
165+
plt.ylabel("Density")
166+
plt.title(f"HWFE With Uncorrellated Error ({transform_str})")
167+
plt.savefig(os.path.splitext(args.path)[0] + f"_error_{transform_str}.png")

lat_alignment/io.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
import logging
2+
import os
23
from collections import defaultdict
4+
from importlib.resources import files
35

46
import matplotlib.pyplot as plt
57
import numpy as np
@@ -13,6 +15,79 @@
1315
logger = logging.getLogger("lat_alignment")
1416

1517

18+
def _load_tracker_yaml(path: str):
19+
with open(path) as file:
20+
dat = yaml.safe_load(file)
21+
if "reference" in dat:
22+
ref_path = dat["reference"]
23+
else:
24+
ref_path = str(files("lat_alignment.data").joinpath("reference.yaml"))
25+
with open(ref_path) as file:
26+
reference = yaml.safe_load(file)
27+
28+
null = np.zeros((4, 3)) + np.nan
29+
data = {}
30+
31+
# Add optical eliments
32+
data["primary"] = dat.get("primary", null)
33+
data["secondary"] = dat.get("secondary", null)
34+
data["receiver"] = dat.get("receiver", null)
35+
36+
# Add errors
37+
data["primary_err"] = dat.get("primary_err", null)
38+
data["secondary_err"] = dat.get("secondary_err", null)
39+
data["receiver_err"] = dat.get("receiver_err", null)
40+
41+
# Add reference
42+
data["primary_ref"] = np.array([p for p, _ in reference["primary"]])
43+
data["secondary_ref"] = np.array([p for p, _ in reference["secondary"]])
44+
data["receiver_ref"] = np.array([p for p, _ in reference["receiver"]])
45+
46+
return data
47+
48+
49+
def _load_tracker_txt(path: str):
50+
_ = path
51+
raise NotImplementedError(
52+
"Loading tracker data from a txt file not yet implemented"
53+
)
54+
55+
56+
def _load_tracker_csv(path: str):
57+
_ = path
58+
raise NotImplementedError(
59+
"Loading tracker data from a csv file not yet implemented"
60+
)
61+
62+
63+
def load_tracker(path: str):
64+
"""
65+
Load laser tracker data.
66+
TODO: This interface needs to be unified with `load_photo` so all code can use either datatype interchangibly
67+
68+
Parameters
69+
----------
70+
path : str
71+
The path to the laser tracker data.
72+
The type of data will be infered from the extension.
73+
74+
Returns
75+
-------
76+
data
77+
The tracker data.
78+
The return type will depend on the extension.
79+
TODO: Make Dataset better for this.
80+
"""
81+
ext = os.path.splitext(path)[1]
82+
if ext == ".yaml":
83+
return _load_tracker_yaml(path)
84+
elif ext == ".txt":
85+
return _load_tracker_txt(path)
86+
elif ext == ".csv":
87+
return _load_tracker_csv(path)
88+
raise ValueError(f"Invalid tracker data with extension {ext}")
89+
90+
1691
def load_photo(
1792
path: str, err_thresh: float = 2, doubles_dist: float = 10, plot: bool = True
1893
) -> Dataset:

lat_alignment/ixb.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,16 +2,16 @@
22
Functions for integrating with the Atlas Copco IxB tool
33
"""
44

5-
from importlib.resources import files
65
import argparse
76
import json
7+
import os
88
import socket
9+
import time
910
import warnings
11+
from copy import deepcopy
1012
from functools import partial
13+
from importlib.resources import files
1114
from typing import Any, Callable, Optional
12-
import os
13-
from copy import deepcopy
14-
import time
1515

1616
import numpy as np
1717
import tqdm
@@ -415,12 +415,10 @@ def main():
415415
# Load templates
416416
with open(
417417
str(files("lat_alignment.data").joinpath("Tightening_Tighten_Template.json")),
418-
"r",
419418
) as f:
420419
tighten_template = json.load(f)
421420
with open(
422421
str(files("lat_alignment.data").joinpath("Tightening_Loosen_Template.json")),
423-
"r",
424422
) as f:
425423
loosen_template = json.load(f)
426424

lat_alignment/photogrammetry.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010

1111
import matplotlib.pyplot as plt
1212
import numpy as np
13-
from traitlets import ValidateHandler
1413
from megham.transform import apply_transform, decompose_rotation, get_affine, get_rigid
1514
from megham.utils import make_edm
1615
from numpy.typing import NDArray
@@ -276,7 +275,9 @@ def align_photo(
276275
pts += found_coded
277276
ref += ref_coded
278277
if len(ref) < 3:
279-
raise ValueError(f"Only {len(ref)} reference points found including codes! Can't align!")
278+
raise ValueError(
279+
f"Only {len(ref)} reference points found including codes! Can't align!"
280+
)
280281
logger.debug(
281282
"\t\tFound %d reference points in measurements with labels:\n\t\t\t%s",
282283
len(pts),
@@ -298,8 +299,10 @@ def align_photo(
298299
rot, sft = get_rigid(pts[msk], ref[msk], method="mean")
299300
if scale:
300301
triu_idx = np.triu_indices(len(pts[msk]), 1)
301-
scale_fac = np.nanmedian(make_edm(ref[msk])[triu_idx] / make_edm(pts[msk])[triu_idx])
302-
pts_scaled = pts*scale_fac
302+
scale_fac = np.nanmedian(
303+
make_edm(ref[msk])[triu_idx] / make_edm(pts[msk])[triu_idx]
304+
)
305+
pts_scaled = pts * scale_fac
303306
logger.debug("\t\tScale factor of %f applied", scale_fac)
304307

305308
new_rot, new_sft = get_rigid(pts_scaled[msk], ref[msk], method="mean")

measurements/20241212/bearing.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,14 @@
1414
nest_y = 3990.4 * np.ones(len(nest_labels)) + 100
1515
nest_z = -2040.0 * np.cos(np.deg2rad(10 * nest_n))
1616
nest_model = np.column_stack([nest_x, nest_y, nest_z])
17-
nest_model = np.array([
18-
[1995.47, 3840.23, -375.91],
19-
[1562.74, 3843.45, 1293.41],
20-
[-1561.77, 3848.33, 1318.03],
21-
[-2021.09, 3849.25, -343.08]
22-
])
17+
nest_model = np.array(
18+
[
19+
[1995.47, 3840.23, -375.91],
20+
[1562.74, 3843.45, 1293.41],
21+
[-1561.77, 3848.33, 1318.03],
22+
[-2021.09, 3849.25, -343.08],
23+
]
24+
)
2325

2426
print(nest_model)
2527
nest_meas = [coords[labels == l] for l in nest_labels]

measurements/20250609/final.yaml

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
primary:
2+
- [2400.4679, -2700.7464, 4818.4219]
3+
- [-2397.7851, -2700.0503, 4821.8107]
4+
- [2397.6569, 2655.6100, 2142.3663]
5+
- [-2398.9775, 2657.7954, 2142.5339]
6+
receiver:
7+
- [-1030.2578, 6996.401, -179.2929]
8+
- [-359.6616, 6996.5164, 982.6493]
9+
- [359.1369, 6996.4168, 982.6657]
10+
- [1030.1505, 6996.2558, -179.2974]

0 commit comments

Comments
 (0)