diff --git a/include/xsf/zeta.h b/include/xsf/zeta.h index fb0e77458..b01233ba1 100644 --- a/include/xsf/zeta.h +++ b/include/xsf/zeta.h @@ -345,11 +345,24 @@ XSF_HOST_DEVICE inline double riemann_zeta(double x) { return cephes::riemann_ze XSF_HOST_DEVICE inline float riemann_zeta(float x) { return riemann_zeta(static_cast(x)); } -XSF_HOST_DEVICE inline double zeta(double x, double q) { return cephes::zeta(x, q); } +XSF_HOST_DEVICE inline double zeta(double x, double q) { + if (q == 1.0) { + return riemann_zeta(x); + } + return cephes::zeta(x, q); +} -XSF_HOST_DEVICE inline float zeta(float x, float q) { return zeta(static_cast(x), static_cast(q)); } +XSF_HOST_DEVICE inline float zeta(float x, float q) { + if (q == 1.0f) { + return riemann_zeta(x); + } + return zeta(static_cast(x), static_cast(q)); +} XSF_HOST_DEVICE inline std::complex zeta(std::complex z, double q) { + if (q == 1.0) { + return riemann_zeta(z); + } if (z.imag() == 0.0) { return zeta(z.real(), q); } diff --git a/tests/xsf_tests/test_zeta.cpp b/tests/xsf_tests/test_zeta.cpp new file mode 100644 index 000000000..6bd1dd4e3 --- /dev/null +++ b/tests/xsf_tests/test_zeta.cpp @@ -0,0 +1,24 @@ +#include "../testing_utils.h" +#include +#include +#include + +TEST_CASE("zeta(x, q=1) matches riemann_zeta for all types", "[zeta][xsf_tests]") { + SECTION("double") { + double x = GENERATE(range(-10.0, 10.0, 0.1)); + REQUIRE(xsf::zeta(x, 1.0) == xsf::riemann_zeta(x)); + } + + SECTION("float") { + float x = GENERATE(range(-10.0f, 10.0f, 0.5f)); + REQUIRE(xsf::zeta(x, 1.0f) == xsf::riemann_zeta(x)); + } + + SECTION("complex") { + using std::complex; + double re = GENERATE(range(0.5, 5.0, 0.5)); + double im = GENERATE(range(-2.0, 2.0, 0.5)); + complex z(re, im); + REQUIRE(xsf::zeta(z, 1.0) == xsf::riemann_zeta(z)); + } +}