Skip to content

Commit c31eb82

Browse files
committed
added prune & compress for sparsity
Added row-major 2:4 structured sparsity in vx_tensor.h header. Prune selects the Top-2 magnitude values, set bottom-2 to 0. Compress takes a pruned matrix and turn it into compressed format + metadata
1 parent 92a6bb0 commit c31eb82

File tree

1 file changed

+101
-0
lines changed

1 file changed

+101
-0
lines changed

kernel/include/vx_tensor.h

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,44 @@ namespace detail {
121121
return *reinterpret_cast<const D*>(&result_u);
122122
}
123123
};
124+
125+
template <typename T>
126+
inline double abs_to_double(T v) {
127+
if constexpr (std::is_floating_point_v<T>) {
128+
return (v < static_cast<T>(0)) ? -static_cast<double>(v) : static_cast<double>(v);
129+
} else if constexpr (std::is_signed_v<T>) {
130+
auto wide = static_cast<long long>(v);
131+
return (wide < 0) ? static_cast<double>(-wide) : static_cast<double>(wide);
132+
} else {
133+
return static_cast<double>(v);
134+
}
135+
}
136+
137+
template <typename T>
138+
inline void select_top2(const T (&vals)[4], uint32_t &keep0, uint32_t &keep1) {
139+
uint32_t k0 = 0;
140+
uint32_t k1 = 1;
141+
double m0 = abs_to_double(vals[0]);
142+
double m1 = abs_to_double(vals[1]);
143+
if (m1 > m0) {
144+
std::swap(m0, m1);
145+
std::swap(k0, k1);
146+
}
147+
for (uint32_t i = 2; i < 4; ++i) {
148+
double mi = abs_to_double(vals[i]);
149+
if (mi > m0) {
150+
m1 = m0;
151+
k1 = k0;
152+
m0 = mi;
153+
k0 = i;
154+
} else if (mi > m1) {
155+
m1 = mi;
156+
k1 = i;
157+
}
158+
}
159+
keep0 = k0;
160+
keep1 = k1;
161+
}
124162
}
125163

126164
template <uint32_t NT, // number of threads per warp
@@ -403,7 +441,70 @@ struct wmma_context {
403441
fragD.data = {fd0, fd1, fd2, fd3, fd4, fd5, fd6, fd7};
404442
}
405443
}
444+
// add a wmma_sparse_sync function here?
406445
};
407446

447+
template <typename T>
448+
inline bool enforce_2to4_sparsity(const T* dense, uint32_t rows, uint32_t cols, uint32_t ld,
449+
T* pruned, uint32_t ld_pruned, mem_layout layout = row_major) {
450+
constexpr uint32_t kBlock = 4;
451+
if (layout != row_major)
452+
return false;
453+
if ((cols % kBlock) != 0 || ld < cols || ld_pruned < cols)
454+
return false;
455+
456+
// Keep the top-2 magnitudes per 4-wide block, zero the rest.
457+
for (uint32_t r = 0; r < rows; ++r) {
458+
const T* row_in = dense + r * ld;
459+
T* row_out = pruned + r * ld_pruned;
460+
for (uint32_t c = 0; c < cols; c += kBlock) {
461+
T vals[kBlock] = {row_in[c + 0], row_in[c + 1], row_in[c + 2], row_in[c + 3]};
462+
uint32_t keep0, keep1;
463+
detail::select_top2(vals, keep0, keep1);
464+
for (uint32_t i = 0; i < kBlock; ++i) {
465+
row_out[c + i] = ((i == keep0) || (i == keep1)) ? vals[i] : static_cast<T>(0);
466+
}
467+
}
468+
}
469+
return true;
470+
}
471+
472+
template <typename T>
473+
inline bool compress_2to4_matrix(const T* dense, uint32_t rows, uint32_t cols, uint32_t ld,
474+
T* compressed, uint32_t ld_compressed,
475+
uint8_t* metadata, uint32_t ld_metadata,
476+
mem_layout layout = row_major) {
477+
constexpr uint32_t kBlock = 4;
478+
constexpr uint32_t kKeep = 2;
479+
if (layout != row_major)
480+
return false;
481+
if ((cols % kBlock) != 0 || ld < cols)
482+
return false;
483+
484+
const uint32_t out_cols = cols / kKeep;
485+
const uint32_t meta_cols = cols / kBlock;
486+
if (ld_compressed < out_cols || ld_metadata < meta_cols)
487+
return false;
488+
489+
for (uint32_t r = 0; r < rows; ++r) {
490+
const T* row_in = dense + r * ld;
491+
T* row_out = compressed + r * ld_compressed;
492+
uint8_t* row_meta = metadata + r * ld_metadata;
493+
for (uint32_t c = 0; c < cols; c += kBlock) {
494+
T vals[kBlock] = {row_in[c + 0], row_in[c + 1], row_in[c + 2], row_in[c + 3]};
495+
uint32_t keep0, keep1;
496+
detail::select_top2(vals, keep0, keep1);
497+
uint32_t idx0 = (keep0 < keep1) ? keep0 : keep1;
498+
uint32_t idx1 = (keep0 < keep1) ? keep1 : keep0;
499+
uint32_t out_base = (c / kBlock) * kKeep;
500+
row_out[out_base + 0] = vals[idx0];
501+
row_out[out_base + 1] = vals[idx1];
502+
// Metadata byte encodes which 2 of 4 entries were kept (low 4 bits).
503+
row_meta[c / kBlock] = static_cast<uint8_t>((1u << idx0) | (1u << idx1));
504+
}
505+
}
506+
return true;
507+
}
508+
408509
} // namespace tensor
409510
} // namespace vortex

0 commit comments

Comments
 (0)