Skip to content

Commit 5276f09

Browse files
committed
Cleanup code
1 parent df02d66 commit 5276f09

File tree

2 files changed

+18
-52
lines changed

2 files changed

+18
-52
lines changed

zoidberg/field.py

Lines changed: 17 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1888,6 +1888,8 @@ def Rfunc(self, x, z, phi):
18881888

18891889

18901890
class EMC3(MagneticField):
1891+
"""Field based on a EMC3 grid file"""
1892+
18911893
def __init__(self, ds):
18921894
self.ds = ds
18931895
assert hasattr(ds, "emc3"), "Expected an xemc3 dataset. Is xemc3 imported?"
@@ -1899,7 +1901,6 @@ def Bxfunc(self, x, z, phi):
18991901
raise NotImplementedError("Use maybe EMC3 tracer?")
19001902

19011903
def Byfunc(self, x, z, phi):
1902-
# raise NotImplementedError("Use maybe EMC3 tracer?")
19031904
return self.Bmag(x, z, phi)
19041905

19051906
def Bzfunc(self, x, z, phi):
@@ -1914,45 +1915,20 @@ def Bmag(self, x, z, phi):
19141915
].values
19151916
nans = np.isnan(vals)
19161917
if np.any(nans):
1917-
if 1:
1918-
from scipy.interpolate import CubicSpline as CS
1919-
1920-
for i in range(x.shape[1]):
1921-
ni = nans[:, i]
1922-
if not np.any(ni):
1923-
continue
1924-
xi = x[:, i]
1925-
zi = z[:, i]
1926-
vi = vals[:, i]
1927-
si = np.zeros_like(xi)
1928-
si[1:] = np.cumsum(
1929-
np.sqrt((xi[1:] - xi[:-1]) ** 2 + (zi[1:] - zi[:-1]) ** 2)
1930-
)
1931-
interp = CS(si[~ni], vi[~ni])
1932-
vals[ni, i] = interp(si[ni])
1933-
else:
1934-
from scipy.interpolate import LinearNDInterpolator as LinInter
1935-
1936-
if phi.shape == ():
1937-
pos = np.array((x, z))
1938-
else:
1939-
pos = np.array((x, z, phi))
1940-
inter = LinInter(pos[:, ~nans].T, vals[~nans])
1941-
print(vals[~nans])
1942-
1943-
print(inter(pos[:, nans].T))
1944-
vals[nans] = inter(pos[:, nans].T)
1945-
1946-
import matplotlib.pyplot as plt
1947-
1948-
plt.figure()
1949-
# plt.pcolormesh(X, Y, Z, shading='auto')
1950-
plt.plot(*pos[:, ~nans], "ok", label="input point")
1951-
plt.plot(*pos[:, nans], "or", label="evaled")
1952-
print(*pos[:, nans].T)
1953-
plt.legend()
1954-
# plt.colorbar()
1955-
plt.axis("equal")
1956-
plt.show()
1918+
from scipy.interpolate import CubicSpline as CS
1919+
1920+
for i in range(x.shape[1]):
1921+
ni = nans[:, i]
1922+
if not np.any(ni):
1923+
continue
1924+
xi = x[:, i]
1925+
zi = z[:, i]
1926+
vi = vals[:, i]
1927+
si = np.zeros_like(xi)
1928+
si[1:] = np.cumsum(
1929+
np.sqrt((xi[1:] - xi[:-1]) ** 2 + (zi[1:] - zi[:-1]) ** 2)
1930+
)
1931+
interp = CS(si[~ni], vi[~ni])
1932+
vals[ni, i] = interp(si[ni])
19571933

19581934
return vals

zoidberg/fieldtracer.py

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -730,7 +730,7 @@ def __init__(self, x, y):
730730

731731

732732
class EMC3FieldTracer(FieldTracer):
733-
"""A class for following magnetic field lines
733+
"""A class for following magnetic field lines provided by an EMC3 grid.
734734
735735
Parameters
736736
----------
@@ -760,9 +760,6 @@ def __init__(self, field):
760760
self.firstlast = (first, last)
761761

762762
def follow_field_lines(self, x_values, z_values, y_values, rtol=None):
763-
# assert np.all(y_values >= self.ds.phi.min().values) and np.all(
764-
# y_values <= self.ds.phi.max().values
765-
# ), f"The condition is not fulfilled: {self.ds.phi.min()} <= {y_values} <= {self.ds.phi.max()}"
766763
meshes = [self.makeMeshes(phi) for phi in y_values]
767764
assert x_values.shape == z_values.shape
768765
out = np.empty((len(y_values), *x_values.shape, 2))
@@ -796,10 +793,6 @@ def _rz_to_ab(self, rz, grid, ij):
796793
nz -= 1
797794
i, j = ij // nz, ij % nz
798795
ABCD = grid[i : i + 2, j : j + 2]
799-
# if j + 1 == nz:
800-
# ABCD[:, 1] = grid[i : i + 2, 0]
801-
# print(ABCD)
802-
# print(i, j, nz, ij)
803796
if ABCD.shape != (2, 2, 2):
804797
print(grid.shape)
805798
print(rz)
@@ -836,8 +829,6 @@ def J(albe):
836829
for i in range(100):
837830
assert np.all(np.isfinite(albe))
838831
albe = albe - np.linalg.inv(J(albe).T) @ fun(albe)
839-
# if i > 20:
840-
# print(f"Failing to converge! {albe} {fun(albe)}")
841832
res = np.sum(fun(albe) ** 2)
842833
if res < tol:
843834
return albe
@@ -870,7 +861,6 @@ def rz_to_ab(self, rz, meshes, ij, zid):
870861
for mind in 1e-3, 1e-2, 1:
871862
mind = mind**2
872863
for mesh in itertools.chain(meshes[zid:], meshes[:zid]):
873-
# print(f"checking mesh {mesh.zid}")
874864
l2 = (rz[0] - mesh.r) ** 2 + (rz[1] - mesh.z) ** 2
875865
ij = np.argmin(l2)
876866
if l2.flat[ij] < mind:

0 commit comments

Comments
 (0)