@@ -121,6 +121,44 @@ namespace detail {
121121 return *reinterpret_cast <const D*>(&result_u);
122122 }
123123 };
124+
125+ template <typename T>
126+ inline double abs_to_double (T v) {
127+ if constexpr (std::is_floating_point_v<T>) {
128+ return (v < static_cast <T>(0 )) ? -static_cast <double >(v) : static_cast <double >(v);
129+ } else if constexpr (std::is_signed_v<T>) {
130+ auto wide = static_cast <long long >(v);
131+ return (wide < 0 ) ? static_cast <double >(-wide) : static_cast <double >(wide);
132+ } else {
133+ return static_cast <double >(v);
134+ }
135+ }
136+
137+ template <typename T>
138+ inline void select_top2 (const T (&vals)[4], uint32_t &keep0, uint32_t &keep1) {
139+ uint32_t k0 = 0 ;
140+ uint32_t k1 = 1 ;
141+ double m0 = abs_to_double (vals[0 ]);
142+ double m1 = abs_to_double (vals[1 ]);
143+ if (m1 > m0) {
144+ std::swap (m0, m1);
145+ std::swap (k0, k1);
146+ }
147+ for (uint32_t i = 2 ; i < 4 ; ++i) {
148+ double mi = abs_to_double (vals[i]);
149+ if (mi > m0) {
150+ m1 = m0;
151+ k1 = k0;
152+ m0 = mi;
153+ k0 = i;
154+ } else if (mi > m1) {
155+ m1 = mi;
156+ k1 = i;
157+ }
158+ }
159+ keep0 = k0;
160+ keep1 = k1;
161+ }
124162}
125163
126164template <uint32_t NT, // number of threads per warp
@@ -403,7 +441,70 @@ struct wmma_context {
403441 fragD.data = {fd0, fd1, fd2, fd3, fd4, fd5, fd6, fd7};
404442 }
405443 }
444+ // add a wmma_sparse_sync function here?
406445};
407446
447+ template <typename T>
448+ inline bool enforce_2to4_sparsity (const T* dense, uint32_t rows, uint32_t cols, uint32_t ld,
449+ T* pruned, uint32_t ld_pruned, mem_layout layout = row_major) {
450+ constexpr uint32_t kBlock = 4 ;
451+ if (layout != row_major)
452+ return false ;
453+ if ((cols % kBlock ) != 0 || ld < cols || ld_pruned < cols)
454+ return false ;
455+
456+ // Keep the top-2 magnitudes per 4-wide block, zero the rest.
457+ for (uint32_t r = 0 ; r < rows; ++r) {
458+ const T* row_in = dense + r * ld;
459+ T* row_out = pruned + r * ld_pruned;
460+ for (uint32_t c = 0 ; c < cols; c += kBlock ) {
461+ T vals[kBlock ] = {row_in[c + 0 ], row_in[c + 1 ], row_in[c + 2 ], row_in[c + 3 ]};
462+ uint32_t keep0, keep1;
463+ detail::select_top2 (vals, keep0, keep1);
464+ for (uint32_t i = 0 ; i < kBlock ; ++i) {
465+ row_out[c + i] = ((i == keep0) || (i == keep1)) ? vals[i] : static_cast <T>(0 );
466+ }
467+ }
468+ }
469+ return true ;
470+ }
471+
472+ template <typename T>
473+ inline bool compress_2to4_matrix (const T* dense, uint32_t rows, uint32_t cols, uint32_t ld,
474+ T* compressed, uint32_t ld_compressed,
475+ uint8_t * metadata, uint32_t ld_metadata,
476+ mem_layout layout = row_major) {
477+ constexpr uint32_t kBlock = 4 ;
478+ constexpr uint32_t kKeep = 2 ;
479+ if (layout != row_major)
480+ return false ;
481+ if ((cols % kBlock ) != 0 || ld < cols)
482+ return false ;
483+
484+ const uint32_t out_cols = cols / kKeep ;
485+ const uint32_t meta_cols = cols / kBlock ;
486+ if (ld_compressed < out_cols || ld_metadata < meta_cols)
487+ return false ;
488+
489+ for (uint32_t r = 0 ; r < rows; ++r) {
490+ const T* row_in = dense + r * ld;
491+ T* row_out = compressed + r * ld_compressed;
492+ uint8_t * row_meta = metadata + r * ld_metadata;
493+ for (uint32_t c = 0 ; c < cols; c += kBlock ) {
494+ T vals[kBlock ] = {row_in[c + 0 ], row_in[c + 1 ], row_in[c + 2 ], row_in[c + 3 ]};
495+ uint32_t keep0, keep1;
496+ detail::select_top2 (vals, keep0, keep1);
497+ uint32_t idx0 = (keep0 < keep1) ? keep0 : keep1;
498+ uint32_t idx1 = (keep0 < keep1) ? keep1 : keep0;
499+ uint32_t out_base = (c / kBlock ) * kKeep ;
500+ row_out[out_base + 0 ] = vals[idx0];
501+ row_out[out_base + 1 ] = vals[idx1];
502+ // Metadata byte encodes which 2 of 4 entries were kept (low 4 bits).
503+ row_meta[c / kBlock ] = static_cast <uint8_t >((1u << idx0) | (1u << idx1));
504+ }
505+ }
506+ return true ;
507+ }
508+
408509} // namespace tensor
409510} // namespace vortex
0 commit comments