Skip to content

Commit 98cea21

Browse files
committed
remove sparse logic in sgemm_tcu
I will add a new test called sgemm_tcu_sparse
1 parent 300ba90 commit 98cea21

File tree

1 file changed

+1
-124
lines changed

1 file changed

+1
-124
lines changed

tests/regression/sgemm_tcu/main.cpp

Lines changed: 1 addition & 124 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@
2626
using namespace vortex;
2727
namespace vt = tensor;
2828

29-
static bool g_enable_sparse = false;
3029
///////////////////////////////////////////////////////////////////////////////
3130

3231
static void convert_row_to_col_major_4bit(uint8_t *dst, uint32_t width, uint32_t height, const uint8_t *src) {
@@ -603,14 +602,6 @@ using cfg = vt::wmma_config_t<NUM_THREADS, vt::ITYPE, vt::OTYPE>;
603602
using itype_t = typename vt::ITYPE::dtype;
604603
using otype_t = typename vt::OTYPE::dtype;
605604

606-
struct SparseMat {
607-
std::vector<itype_t> values; // non-zeros
608-
std::vector<uint8_t> meta; // Array of row-masks: 1 byte marks the columns
609-
// of the 4 elements in the block that are non-zero.
610-
// e.g. 0b0101 means 2nd and 4th elements are non-zero.
611-
612-
uint32_t rows, cols; // original A dims (M × K)
613-
};
614605

615606
static void matmul_cpu(otype_t *C, const itype_t *A, const itype_t *B, uint32_t M, uint32_t N, uint32_t K) {
616607
uint32_t subbytes = 8 / vt::ITYPE::bits;
@@ -628,54 +619,6 @@ static void matmul_cpu(otype_t *C, const itype_t *A, const itype_t *B, uint32_t
628619
}
629620
}
630621

631-
/*
632-
static void matmul_cpu_sparseA(
633-
otype_t* C, // [M × N] output
634-
const SparseMat& A, // sparse-A
635-
const itype_t* B, // [K × N] dense-B
636-
uint32_t N) // number of columns of B/C
637-
{
638-
const uint32_t M = A.rows;
639-
const uint32_t K = A.cols;
640-
641-
const uint32_t subbytes = 8 / vt::ITYPE::bits;
642-
643-
// --- helper lambdas to index sparse arrays by row ---
644-
auto row_values = [&](uint32_t m) {
645-
return A.values.data() + m * (K / 2); // two kept per block
646-
};
647-
auto row_meta = [&](uint32_t m) {
648-
return A.meta .data() + m * (K / 4);
649-
};
650-
651-
for (uint32_t m = 0; m < M; ++m) {
652-
653-
const itype_t* Avals = row_values(m);
654-
const uint8_t* Ameta = row_meta (m);
655-
size_t v_idx = 0; // cursor inside values[]
656-
657-
for (uint32_t n = 0; n < N; ++n) {
658-
otype_t sum(0);
659-
for (uint32_t blk = 0; blk < K; blk += 4) {
660-
uint8_t mask = *(Ameta++);
661-
assert(mask);
662-
for (uint32_t i = 0; i < 4; ++i) {
663-
if (mask & (1u << i)) {
664-
auto a_val = Avals[v_idx++];
665-
uint32_t k = blk + i; // logical K index
666-
uint32_t kk = subbytes ? k * subbytes // packed-layout idx
667-
: k;
668-
auto b_val = data_accessor_t<vt::ITYPE>::read(
669-
B, kk * N + n);
670-
sum = muladd_t<vt::ITYPE, vt::OTYPE>::eval(a_val, b_val, sum);
671-
}
672-
}
673-
}
674-
data_accessor_t<vt::OTYPE>::write(C, m * N + n, sum);
675-
}
676-
}
677-
}*/
678-
679622
///////////////////////////////////////////////////////////////////////////////
680623

681624
const char *kernel_file = "kernel.vxbin";
@@ -696,8 +639,7 @@ std::string last_build_options;
696639

697640
static void show_usage() {
698641
std::cout << "Vortex Sgemm TCU Test." << std::endl;
699-
std::cout << "Usage: [-m: m] [-n N] [-k: K] [-s] [-h: help]" << std::endl;
700-
std::cout << " -s Enable 2:4 structured sparsity " << std::endl;
642+
std::cout << "Usage: [-m: m] [-n N] [-k: K] [-h: help]" << std::endl;
701643
}
702644

703645
static void parse_args(int argc, char **argv) {
@@ -713,10 +655,6 @@ static void parse_args(int argc, char **argv) {
713655
case 'k':
714656
xk = atoi(optarg);
715657
break;
716-
case 's':
717-
g_enable_sparse = true;
718-
std::cout << "Sparse mode enabled (-s)" << std::endl;
719-
break;
720658
case 'h':
721659
show_usage();
722660
exit(0);
@@ -740,73 +678,12 @@ void cleanup() {
740678
}
741679

742680

743-
static SparseMat pruneAndCompressMatrixA(const std::vector<itype_t>& denseA,
744-
uint32_t M, uint32_t K) {
745-
SparseMat out;
746-
out.rows = M;
747-
out.cols = K;
748-
out.values.reserve(M * K / 2); // Select 2 values every 4 values
749-
out.meta.reserve(M * K / 4); // 1 byte for every 4 values
750-
751-
const itype_t* src = denseA.data();
752-
753-
for (uint32_t r = 0; r < M; ++r) {
754-
for (uint32_t c = 0; c < K; c += 4) {
755-
itype_t blk[4] = {src[r * K + c],
756-
src[r * K + c + 1],
757-
src[r * K + c + 2],
758-
src[r * K + c + 3]};
759-
760-
uint32_t idx[4] = {0, 1, 2, 3};
761-
std::sort(idx, idx + 4,
762-
[&](uint32_t a, uint32_t b) {
763-
return std::abs((int)blk[a]) < std::abs((int)blk[b]);
764-
}); //Sort the 4 elements by absolute value, ascending order
765-
766-
uint8_t keep0 = idx[3];
767-
uint8_t keep1 = idx[2]; //idx of largest 2 elements
768-
769-
out.values.push_back(blk[keep0]);
770-
out.values.push_back(blk[keep1]);
771-
772-
uint8_t m = (1u << keep0) | (1u << keep1); // e.g. 0b0101
773-
out.meta.push_back(m);
774-
}
775-
}
776-
return out;
777-
}
778-
779-
void test_pruneA() {
780-
const uint32_t M = 4, K = 8;
781-
std::vector<itype_t> denseA(M * K);
782-
for (auto& v : denseA) v = Comparator<vt::ITYPE>::generate();
783-
784-
auto spA = pruneAndCompressMatrixA(denseA, M, K);
785-
786-
std::vector<itype_t> recovered(M * K, 0);
787-
size_t v_idx = 0, m_idx = 0;
788-
for (uint32_t r = 0; r < M; ++r)
789-
for (uint32_t c = 0; c < K; c += 4) {
790-
uint8_t m = spA.meta[m_idx++];
791-
for (uint32_t i = 0; i < 4; ++i)
792-
if (m & (1u << i))
793-
recovered[r * K + c + i] = spA.values[v_idx++];
794-
}
795-
796-
for (uint32_t i = 0; i < M * K; ++i)
797-
assert(recovered[i] == denseA[i] || recovered[i] == 0); //Either the value is preserved or pruned
798-
std::cout << "pruneAndCompressMatrixA passed\n";
799-
}
800681

801682

802683
int main(int argc, char *argv[]) {
803684
// parse command arguments
804685
parse_args(argc, argv);
805686

806-
if(g_enable_sparse) {
807-
test_pruneA(); // Test the pruning function
808-
}
809-
810687
std::srand(50);
811688

812689
// open device connection

0 commit comments

Comments
 (0)