From d9d3ec4dc043da199f920c07790ed153b2c12b6d Mon Sep 17 00:00:00 2001 From: fbourgey Date: Tue, 23 Dec 2025 09:36:51 -0500 Subject: [PATCH 1/2] fix zeta with q=1 --- include/xsf/zeta.h | 17 +++++++++++++++-- tests/xsf_tests/test_zeta.cpp | 24 ++++++++++++++++++++++++ 2 files changed, 39 insertions(+), 2 deletions(-) create mode 100644 tests/xsf_tests/test_zeta.cpp diff --git a/include/xsf/zeta.h b/include/xsf/zeta.h index fb0e77458b..b01233ba1b 100644 --- a/include/xsf/zeta.h +++ b/include/xsf/zeta.h @@ -345,11 +345,24 @@ XSF_HOST_DEVICE inline double riemann_zeta(double x) { return cephes::riemann_ze XSF_HOST_DEVICE inline float riemann_zeta(float x) { return riemann_zeta(static_cast(x)); } -XSF_HOST_DEVICE inline double zeta(double x, double q) { return cephes::zeta(x, q); } +XSF_HOST_DEVICE inline double zeta(double x, double q) { + if (q == 1.0) { + return riemann_zeta(x); + } + return cephes::zeta(x, q); +} -XSF_HOST_DEVICE inline float zeta(float x, float q) { return zeta(static_cast(x), static_cast(q)); } +XSF_HOST_DEVICE inline float zeta(float x, float q) { + if (q == 1.0f) { + return riemann_zeta(x); + } + return zeta(static_cast(x), static_cast(q)); +} XSF_HOST_DEVICE inline std::complex zeta(std::complex z, double q) { + if (q == 1.0) { + return riemann_zeta(z); + } if (z.imag() == 0.0) { return zeta(z.real(), q); } diff --git a/tests/xsf_tests/test_zeta.cpp b/tests/xsf_tests/test_zeta.cpp new file mode 100644 index 0000000000..1db23e7a9b --- /dev/null +++ b/tests/xsf_tests/test_zeta.cpp @@ -0,0 +1,24 @@ +#include "../testing_utils.h" +#include +#include +#include + +TEST_CASE("zeta(x, q=1) matches riemann_zeta for all types", "[zeta][xsf_tests]") { + SECTION("double") { + double x = GENERATE(range(-10.0, 10.0, 0.1)); + REQUIRE(xsf::zeta(x, 1.0) == xsf::riemann_zeta(x)); + } + + SECTION("float") { + float x = GENERATE(range(-10.0f, 10.0f, 0.5f)); + REQUIRE(xsf::zeta(x, 1.0f) == xsf::riemann_zeta(x)); + } + + SECTION("complex") { + using std::complex; + double re = GENERATE(range(0.5, 5.0, 0.5)); + double im = GENERATE(range(-2.0, 2.0, 0.5)); + complex z(re, im); + REQUIRE(xsf::zeta(z, 1.0) == xsf::riemann_zeta(z)); + } +} From 101f767a5ad82b9b639c6a6c978b7014061d3281 Mon Sep 17 00:00:00 2001 From: fbourgey Date: Tue, 23 Dec 2025 09:39:34 -0500 Subject: [PATCH 2/2] lint --- tests/xsf_tests/test_zeta.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/xsf_tests/test_zeta.cpp b/tests/xsf_tests/test_zeta.cpp index 1db23e7a9b..6bd1dd4e35 100644 --- a/tests/xsf_tests/test_zeta.cpp +++ b/tests/xsf_tests/test_zeta.cpp @@ -1,19 +1,19 @@ #include "../testing_utils.h" +#include #include #include -#include TEST_CASE("zeta(x, q=1) matches riemann_zeta for all types", "[zeta][xsf_tests]") { SECTION("double") { double x = GENERATE(range(-10.0, 10.0, 0.1)); REQUIRE(xsf::zeta(x, 1.0) == xsf::riemann_zeta(x)); } - + SECTION("float") { float x = GENERATE(range(-10.0f, 10.0f, 0.5f)); REQUIRE(xsf::zeta(x, 1.0f) == xsf::riemann_zeta(x)); } - + SECTION("complex") { using std::complex; double re = GENERATE(range(0.5, 5.0, 0.5));