2626using namespace vortex ;
2727namespace vt = tensor;
2828
29- static bool g_enable_sparse = false ;
3029// /////////////////////////////////////////////////////////////////////////////
3130
3231static void convert_row_to_col_major_4bit (uint8_t *dst, uint32_t width, uint32_t height, const uint8_t *src) {
@@ -603,14 +602,6 @@ using cfg = vt::wmma_config_t<NUM_THREADS, vt::ITYPE, vt::OTYPE>;
603602using itype_t = typename vt::ITYPE::dtype;
604603using otype_t = typename vt::OTYPE::dtype;
605604
606- struct SparseMat {
607- std::vector<itype_t > values; // non-zeros
608- std::vector<uint8_t > meta; // Array of row-masks: 1 byte marks the columns
609- // of the 4 elements in the block that are non-zero.
610- // e.g. 0b0101 means 2nd and 4th elements are non-zero.
611-
612- uint32_t rows, cols; // original A dims (M × K)
613- };
614605
615606static void matmul_cpu (otype_t *C, const itype_t *A, const itype_t *B, uint32_t M, uint32_t N, uint32_t K) {
616607 uint32_t subbytes = 8 / vt::ITYPE::bits;
@@ -628,54 +619,6 @@ static void matmul_cpu(otype_t *C, const itype_t *A, const itype_t *B, uint32_t
628619 }
629620}
630621
631- /*
632- static void matmul_cpu_sparseA(
633- otype_t* C, // [M × N] output
634- const SparseMat& A, // sparse-A
635- const itype_t* B, // [K × N] dense-B
636- uint32_t N) // number of columns of B/C
637- {
638- const uint32_t M = A.rows;
639- const uint32_t K = A.cols;
640-
641- const uint32_t subbytes = 8 / vt::ITYPE::bits;
642-
643- // --- helper lambdas to index sparse arrays by row ---
644- auto row_values = [&](uint32_t m) {
645- return A.values.data() + m * (K / 2); // two kept per block
646- };
647- auto row_meta = [&](uint32_t m) {
648- return A.meta .data() + m * (K / 4);
649- };
650-
651- for (uint32_t m = 0; m < M; ++m) {
652-
653- const itype_t* Avals = row_values(m);
654- const uint8_t* Ameta = row_meta (m);
655- size_t v_idx = 0; // cursor inside values[]
656-
657- for (uint32_t n = 0; n < N; ++n) {
658- otype_t sum(0);
659- for (uint32_t blk = 0; blk < K; blk += 4) {
660- uint8_t mask = *(Ameta++);
661- assert(mask);
662- for (uint32_t i = 0; i < 4; ++i) {
663- if (mask & (1u << i)) {
664- auto a_val = Avals[v_idx++];
665- uint32_t k = blk + i; // logical K index
666- uint32_t kk = subbytes ? k * subbytes // packed-layout idx
667- : k;
668- auto b_val = data_accessor_t<vt::ITYPE>::read(
669- B, kk * N + n);
670- sum = muladd_t<vt::ITYPE, vt::OTYPE>::eval(a_val, b_val, sum);
671- }
672- }
673- }
674- data_accessor_t<vt::OTYPE>::write(C, m * N + n, sum);
675- }
676- }
677- }*/
678-
679622// /////////////////////////////////////////////////////////////////////////////
680623
681624const char *kernel_file = " kernel.vxbin" ;
@@ -696,8 +639,7 @@ std::string last_build_options;
696639
697640static void show_usage () {
698641 std::cout << " Vortex Sgemm TCU Test." << std::endl;
699- std::cout << " Usage: [-m: m] [-n N] [-k: K] [-s] [-h: help]" << std::endl;
700- std::cout << " -s Enable 2:4 structured sparsity " << std::endl;
642+ std::cout << " Usage: [-m: m] [-n N] [-k: K] [-h: help]" << std::endl;
701643}
702644
703645static void parse_args (int argc, char **argv) {
@@ -713,10 +655,6 @@ static void parse_args(int argc, char **argv) {
713655 case ' k' :
714656 xk = atoi (optarg);
715657 break ;
716- case ' s' :
717- g_enable_sparse = true ;
718- std::cout << " Sparse mode enabled (-s)" << std::endl;
719- break ;
720658 case ' h' :
721659 show_usage ();
722660 exit (0 );
@@ -740,73 +678,12 @@ void cleanup() {
740678}
741679
742680
743- static SparseMat pruneAndCompressMatrixA (const std::vector<itype_t >& denseA,
744- uint32_t M, uint32_t K) {
745- SparseMat out;
746- out.rows = M;
747- out.cols = K;
748- out.values .reserve (M * K / 2 ); // Select 2 values every 4 values
749- out.meta .reserve (M * K / 4 ); // 1 byte for every 4 values
750-
751- const itype_t * src = denseA.data ();
752-
753- for (uint32_t r = 0 ; r < M; ++r) {
754- for (uint32_t c = 0 ; c < K; c += 4 ) {
755- itype_t blk[4 ] = {src[r * K + c],
756- src[r * K + c + 1 ],
757- src[r * K + c + 2 ],
758- src[r * K + c + 3 ]};
759-
760- uint32_t idx[4 ] = {0 , 1 , 2 , 3 };
761- std::sort (idx, idx + 4 ,
762- [&](uint32_t a, uint32_t b) {
763- return std::abs ((int )blk[a]) < std::abs ((int )blk[b]);
764- }); // Sort the 4 elements by absolute value, ascending order
765-
766- uint8_t keep0 = idx[3 ];
767- uint8_t keep1 = idx[2 ]; // idx of largest 2 elements
768-
769- out.values .push_back (blk[keep0]);
770- out.values .push_back (blk[keep1]);
771-
772- uint8_t m = (1u << keep0) | (1u << keep1); // e.g. 0b0101
773- out.meta .push_back (m);
774- }
775- }
776- return out;
777- }
778-
779- void test_pruneA () {
780- const uint32_t M = 4 , K = 8 ;
781- std::vector<itype_t > denseA (M * K);
782- for (auto & v : denseA) v = Comparator<vt::ITYPE>::generate ();
783-
784- auto spA = pruneAndCompressMatrixA (denseA, M, K);
785-
786- std::vector<itype_t > recovered (M * K, 0 );
787- size_t v_idx = 0 , m_idx = 0 ;
788- for (uint32_t r = 0 ; r < M; ++r)
789- for (uint32_t c = 0 ; c < K; c += 4 ) {
790- uint8_t m = spA.meta [m_idx++];
791- for (uint32_t i = 0 ; i < 4 ; ++i)
792- if (m & (1u << i))
793- recovered[r * K + c + i] = spA.values [v_idx++];
794- }
795-
796- for (uint32_t i = 0 ; i < M * K; ++i)
797- assert (recovered[i] == denseA[i] || recovered[i] == 0 ); // Either the value is preserved or pruned
798- std::cout << " pruneAndCompressMatrixA passed\n " ;
799- }
800681
801682
802683int main (int argc, char *argv[]) {
803684 // parse command arguments
804685 parse_args (argc, argv);
805686
806- if (g_enable_sparse) {
807- test_pruneA (); // Test the pruning function
808- }
809-
810687 std::srand (50 );
811688
812689 // open device connection
0 commit comments