Skip to content

Commit 987c805

Browse files
committed
Sparse: applying clang-format to crsmatrix traversal
1 parent 5ec9b19 commit 987c805

3 files changed

Lines changed: 67 additions & 42 deletions

File tree

sparse/impl/KokkosSparse_CrsMatrix_traversal_impl.hpp

Lines changed: 39 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -26,18 +26,20 @@ struct crsmatrix_traversal_functor {
2626
using team_policy_type = Kokkos::TeamPolicy<execution_space>;
2727
using team_member_type = typename team_policy_type::member_type;
2828

29-
matrix_type A;
29+
matrix_type A;
3030
functor_type func;
3131
ordinal_type rows_per_team;
3232

33-
crsmatrix_traversal_functor(const matrix_type& A_, const functor_type& func_, const ordinal_type rows_per_team_)
34-
: A(A_), func(func_), rows_per_team(rows_per_team_) {}
33+
crsmatrix_traversal_functor(const matrix_type& A_, const functor_type& func_,
34+
const ordinal_type rows_per_team_)
35+
: A(A_), func(func_), rows_per_team(rows_per_team_) {}
3536

3637
// RangePolicy overload
3738
KOKKOS_INLINE_FUNCTION void operator()(const ordinal_type rowIdx) const {
38-
for(size_type entryIdx = A.graph.row_map(rowIdx); entryIdx < A.graph.row_map(rowIdx + 1); ++entryIdx) {
39+
for (size_type entryIdx = A.graph.row_map(rowIdx);
40+
entryIdx < A.graph.row_map(rowIdx + 1); ++entryIdx) {
3941
const ordinal_type colIdx = A.graph.entries(entryIdx);
40-
const value_type value = A.values(entryIdx);
42+
const value_type value = A.values(entryIdx);
4143

4244
func(rowIdx, entryIdx, colIdx, value);
4345
}
@@ -55,24 +57,27 @@ struct crsmatrix_traversal_functor {
5557
return;
5658
}
5759

58-
const ordinal_type row_length = A.graph.row_map(rowIdx + 1) - A.graph.row_map(rowIdx);
60+
const ordinal_type row_length =
61+
A.graph.row_map(rowIdx + 1) - A.graph.row_map(rowIdx);
5962
Kokkos::parallel_for(
6063
Kokkos::ThreadVectorRange(dev, row_length),
6164
[&](ordinal_type rowEntryIdx) {
62-
const size_type entryIdx = A.graph.row_map(rowIdx) + static_cast<size_type>(rowEntryIdx);
63-
const ordinal_type colIdx = A.graph.entries(entryIdx);
64-
const value_type value = A.values(entryIdx);
65+
const size_type entryIdx = A.graph.row_map(rowIdx) +
66+
static_cast<size_type>(rowEntryIdx);
67+
const ordinal_type colIdx = A.graph.entries(entryIdx);
68+
const value_type value = A.values(entryIdx);
6569

66-
func(rowIdx, entryIdx, colIdx, value);
70+
func(rowIdx, entryIdx, colIdx, value);
6771
});
6872
});
6973
}
7074
};
7175

7276
template <class execution_space>
7377
int64_t crsmatrix_traversal_launch_parameters(int64_t numRows, int64_t nnz,
74-
int64_t rows_per_thread, int& team_size,
75-
int& vector_length) {
78+
int64_t rows_per_thread,
79+
int& team_size,
80+
int& vector_length) {
7681
int64_t rows_per_team;
7782
int64_t nnz_per_row = nnz / numRows;
7883

@@ -129,35 +134,40 @@ int64_t crsmatrix_traversal_launch_parameters(int64_t numRows, int64_t nnz,
129134

130135
template <class execution_space, class crsmatrix_type, class functor_type>
131136
void crsmatrix_traversal_on_host(const execution_space& space,
132-
const crsmatrix_type& A,
133-
const functor_type& func) {
134-
137+
const crsmatrix_type& A,
138+
const functor_type& func) {
135139
// Wrap user functor with crsmatrix_traversal_functor
136-
crsmatrix_traversal_functor<execution_space, crsmatrix_type, functor_type> traversal_func(A, func, -1);
140+
crsmatrix_traversal_functor<execution_space, crsmatrix_type, functor_type>
141+
traversal_func(A, func, -1);
137142

138143
// Launch traversal kernel
139-
Kokkos::parallel_for("KokkosSparse::crsmatrix_traversal",
140-
Kokkos::RangePolicy<execution_space>(space, 0, A.numRows()),
141-
traversal_func);
144+
Kokkos::parallel_for(
145+
"KokkosSparse::crsmatrix_traversal",
146+
Kokkos::RangePolicy<execution_space>(space, 0, A.numRows()),
147+
traversal_func);
142148
}
143149

144150
template <class execution_space, class crsmatrix_type, class functor_type>
145151
void crsmatrix_traversal_on_gpu(const execution_space& space,
146-
const crsmatrix_type& A,
147-
const functor_type& func) {
148-
152+
const crsmatrix_type& A,
153+
const functor_type& func) {
149154
// Wrap user functor with crsmatrix_traversal_functor
150155
int64_t rows_per_thread = 0;
151156
int team_size = 0, vector_length = 0;
152-
const int64_t rows_per_team = crsmatrix_traversal_launch_parameters<execution_space>(A.numRows(), A.nnz(), rows_per_thread, team_size, vector_length);
153-
const int nteams = (static_cast<int>(A.numRows()) + rows_per_team - 1) / rows_per_team;
154-
crsmatrix_traversal_functor<execution_space, crsmatrix_type, functor_type> traversal_func(A, func, rows_per_team);
157+
const int64_t rows_per_team =
158+
crsmatrix_traversal_launch_parameters<execution_space>(
159+
A.numRows(), A.nnz(), rows_per_thread, team_size, vector_length);
160+
const int nteams =
161+
(static_cast<int>(A.numRows()) + rows_per_team - 1) / rows_per_team;
162+
crsmatrix_traversal_functor<execution_space, crsmatrix_type, functor_type>
163+
traversal_func(A, func, rows_per_team);
155164

156165
// Launch traversal kernel
157166
Kokkos::parallel_for("KokkosSparse::crsmatrix_traversal",
158-
Kokkos::TeamPolicy<execution_space>(space, nteams, team_size, vector_length),
159-
traversal_func);
167+
Kokkos::TeamPolicy<execution_space>(
168+
space, nteams, team_size, vector_length),
169+
traversal_func);
160170
}
161171

162-
} // Impl
163-
} // KokkosSparse
172+
} // namespace Impl
173+
} // namespace KokkosSparse

sparse/src/KokkosSparse_CrsMatrix_traversal.hpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,15 +33,14 @@ namespace KokkosSparse {
3333
namespace Experimental {
3434

3535
template <class execution_space, class crsmatrix_type, class functor_type>
36-
void crsmatrix_traversal(const execution_space& space, const crsmatrix_type& matrix, functor_type& functor) {
37-
36+
void crsmatrix_traversal(const execution_space& space,
37+
const crsmatrix_type& matrix, functor_type& functor) {
3838
// Choose between device and host implementation
3939
if constexpr (KokkosKernels::Impl::kk_is_gpu_exec_space<execution_space>()) {
4040
KokkosSparse::Impl::crsmatrix_traversal_on_gpu(space, matrix, functor);
4141
} else {
4242
KokkosSparse::Impl::crsmatrix_traversal_on_host(space, matrix, functor);
4343
}
44-
4544
}
4645

4746
template <class crsmatrix_type, class functor_type>

sparse/unit_test/Test_Sparse_crsmatrix_traversal.hpp

Lines changed: 26 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -37,12 +37,15 @@ struct diag_extraction {
3737

3838
diag_view diag;
3939

40-
diag_extraction(CrsMatrix A) {
40+
diag_extraction(CrsMatrix A) {
4141
diag = diag_view("diag values", A.numRows());
4242
};
4343

44-
KOKKOS_INLINE_FUNCTION void operator()(const ordinal_type rowIdx, const size_type /*entryIdx*/, const ordinal_type colIdx, const value_type value) const {
45-
if(rowIdx == colIdx) {
44+
KOKKOS_INLINE_FUNCTION void operator()(const ordinal_type rowIdx,
45+
const size_type /*entryIdx*/,
46+
const ordinal_type colIdx,
47+
const value_type value) const {
48+
if (rowIdx == colIdx) {
4649
diag(rowIdx) = value;
4750
}
4851
}
@@ -61,7 +64,8 @@ void testCrsMatrixTraversal(int testCase) {
6164
constexpr int nx = 4, ny = 4;
6265
constexpr bool leftBC = true, rightBC = false, topBC = false, botBC = false;
6366

64-
Kokkos::View<int*[3], Kokkos::HostSpace> mat_structure("Matrix Structure", 2);
67+
Kokkos::View<int * [3], Kokkos::HostSpace> mat_structure("Matrix Structure",
68+
2);
6569
mat_structure(0, 0) = nx;
6670
mat_structure(0, 1) = (leftBC ? 1 : 0);
6771
mat_structure(0, 2) = (rightBC ? 1 : 0);
@@ -74,10 +78,22 @@ void testCrsMatrixTraversal(int testCase) {
7478

7579
Vector diag_ref("diag ref", A.numRows());
7680
auto diag_ref_h = Kokkos::create_mirror_view(diag_ref);
77-
diag_ref_h( 0) = 1; diag_ref_h( 1) = 3; diag_ref_h( 2) = 3; diag_ref_h( 3) = 2;
78-
diag_ref_h( 4) = 1; diag_ref_h( 5) = 4; diag_ref_h( 6) = 4; diag_ref_h( 7) = 3;
79-
diag_ref_h( 8) = 1; diag_ref_h( 9) = 4; diag_ref_h(10) = 4; diag_ref_h(11) = 3;
80-
diag_ref_h(12) = 1; diag_ref_h(13) = 3; diag_ref_h(14) = 3; diag_ref_h(15) = 2;
81+
diag_ref_h(0) = 1;
82+
diag_ref_h(1) = 3;
83+
diag_ref_h(2) = 3;
84+
diag_ref_h(3) = 2;
85+
diag_ref_h(4) = 1;
86+
diag_ref_h(5) = 4;
87+
diag_ref_h(6) = 4;
88+
diag_ref_h(7) = 3;
89+
diag_ref_h(8) = 1;
90+
diag_ref_h(9) = 4;
91+
diag_ref_h(10) = 4;
92+
diag_ref_h(11) = 3;
93+
diag_ref_h(12) = 1;
94+
diag_ref_h(13) = 3;
95+
diag_ref_h(14) = 3;
96+
diag_ref_h(15) = 2;
8197

8298
// Run the diagonal extraction functor
8399
// using traversal function.
@@ -91,8 +107,8 @@ void testCrsMatrixTraversal(int testCase) {
91107

92108
// Check for correctness
93109
bool matches = true;
94-
for(int rowIdx = 0; rowIdx < A.numRows(); ++rowIdx) {
95-
if(diag_ref_h(rowIdx) != diag_h(rowIdx)) matches = false;
110+
for (int rowIdx = 0; rowIdx < A.numRows(); ++rowIdx) {
111+
if (diag_ref_h(rowIdx) != diag_h(rowIdx)) matches = false;
96112
}
97113

98114
EXPECT_TRUE(matches)

0 commit comments

Comments
 (0)