Skip to content

Commit ee4c908

Browse files
authored
Merge pull request #525 from ValeevGroup/gaudel/feature/truncate-near-zero-float-printing
feat: Implement printing of near-zero floats as zero
2 parents c35044e + ef25f0f commit ee4c908

File tree

2 files changed

+36
-3
lines changed

2 files changed

+36
-3
lines changed

src/TiledArray/tensor/print.h

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,12 +40,45 @@ namespace detail {
4040
class NDArrayPrinter {
4141
public:
4242
NDArrayPrinter(int width = 10, int precision = 6)
43-
: width(width), precision(precision) {}
43+
: width(width),
44+
precision(precision),
45+
truncate_(0.5 * std::pow(10., -precision)) {}
4446

4547
private:
4648
int width = 10;
4749
int precision = 10;
4850

51+
/// truncates (=sets to zero) small floating-point numbers
52+
class FloatTruncate {
53+
public:
54+
/// truncates numbers smaller than @p threshold
55+
FloatTruncate(double threshold) noexcept : threshold_{threshold} {}
56+
57+
[[nodiscard]] auto operator()(std::floating_point auto val) const noexcept {
58+
return std::abs(val) < threshold_ ? decltype(val){0} : val;
59+
}
60+
61+
template <typename T>
62+
requires detail::is_complex_v<T> &&
63+
std::floating_point<typename T::value_type>
64+
[[nodiscard]] auto operator()(T const& val) const noexcept {
65+
using std::imag;
66+
using std::real;
67+
return T{(*this)(real(val)), (*this)(imag(val))};
68+
}
69+
70+
template <typename T>
71+
requires(!(std::floating_point<T> || detail::is_complex_v<T>))
72+
[[nodiscard]] auto operator()(T const& val) const noexcept {
73+
return val;
74+
}
75+
76+
private:
77+
double threshold_;
78+
};
79+
80+
FloatTruncate truncate_;
81+
4982
// Helper function to recursively print the array
5083
template <typename T, typename Index = Range1::index1_type,
5184
typename Char = char, typename CharTraits = std::char_traits<Char>>

src/TiledArray/tensor/print.ipp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,9 +50,9 @@ void NDArrayPrinter::printArray(const T* data, const std::size_t order,
5050

5151
for (size_t i = 0; i < extents[level]; ++i) {
5252
if (level == order - 1) {
53+
auto value = truncate_(data[offset + i * strides[level]]);
5354
// At the deepest level, print the actual values
54-
os << std::fixed << std::setprecision(precision) << std::setw(width) << std::setfill(Char(' '))
55-
<< data[offset + i * strides[level]];
55+
os << std::fixed << std::setprecision(precision) << std::setw(width) << std::setfill(Char(' ')) << value;
5656
if (i < extents[level] - 1) {
5757
os << ", ";
5858
}

0 commit comments

Comments
 (0)