@@ -26,18 +26,20 @@ struct crsmatrix_traversal_functor {
2626 using team_policy_type = Kokkos::TeamPolicy<execution_space>;
2727 using team_member_type = typename team_policy_type::member_type;
2828
29- matrix_type A;
29+ matrix_type A;
3030 functor_type func;
3131 ordinal_type rows_per_team;
3232
33- crsmatrix_traversal_functor (const matrix_type& A_, const functor_type& func_, const ordinal_type rows_per_team_)
34- : A(A_), func(func_), rows_per_team(rows_per_team_) {}
33+ crsmatrix_traversal_functor (const matrix_type& A_, const functor_type& func_,
34+ const ordinal_type rows_per_team_)
35+ : A(A_), func(func_), rows_per_team(rows_per_team_) {}
3536
3637 // RangePolicy overload
3738 KOKKOS_INLINE_FUNCTION void operator ()(const ordinal_type rowIdx) const {
38- for (size_type entryIdx = A.graph .row_map (rowIdx); entryIdx < A.graph .row_map (rowIdx + 1 ); ++entryIdx) {
39+ for (size_type entryIdx = A.graph .row_map (rowIdx);
40+ entryIdx < A.graph .row_map (rowIdx + 1 ); ++entryIdx) {
3941 const ordinal_type colIdx = A.graph .entries (entryIdx);
40- const value_type value = A.values (entryIdx);
42+ const value_type value = A.values (entryIdx);
4143
4244 func (rowIdx, entryIdx, colIdx, value);
4345 }
@@ -55,24 +57,27 @@ struct crsmatrix_traversal_functor {
5557 return ;
5658 }
5759
58- const ordinal_type row_length = A.graph .row_map (rowIdx + 1 ) - A.graph .row_map (rowIdx);
60+ const ordinal_type row_length =
61+ A.graph .row_map (rowIdx + 1 ) - A.graph .row_map (rowIdx);
5962 Kokkos::parallel_for (
6063 Kokkos::ThreadVectorRange (dev, row_length),
6164 [&](ordinal_type rowEntryIdx) {
62- const size_type entryIdx = A.graph .row_map (rowIdx) + static_cast <size_type>(rowEntryIdx);
63- const ordinal_type colIdx = A.graph .entries (entryIdx);
64- const value_type value = A.values (entryIdx);
65+ const size_type entryIdx = A.graph .row_map (rowIdx) +
66+ static_cast <size_type>(rowEntryIdx);
67+ const ordinal_type colIdx = A.graph .entries (entryIdx);
68+ const value_type value = A.values (entryIdx);
6569
66- func (rowIdx, entryIdx, colIdx, value);
70+ func (rowIdx, entryIdx, colIdx, value);
6771 });
6872 });
6973 }
7074};
7175
7276template <class execution_space >
7377int64_t crsmatrix_traversal_launch_parameters (int64_t numRows, int64_t nnz,
74- int64_t rows_per_thread, int & team_size,
75- int & vector_length) {
78+ int64_t rows_per_thread,
79+ int & team_size,
80+ int & vector_length) {
7681 int64_t rows_per_team;
7782 int64_t nnz_per_row = nnz / numRows;
7883
@@ -129,35 +134,40 @@ int64_t crsmatrix_traversal_launch_parameters(int64_t numRows, int64_t nnz,
129134
130135template <class execution_space , class crsmatrix_type , class functor_type >
131136void crsmatrix_traversal_on_host (const execution_space& space,
132- const crsmatrix_type& A,
133- const functor_type& func) {
134-
137+ const crsmatrix_type& A,
138+ const functor_type& func) {
135139 // Wrap user functor with crsmatrix_traversal_functor
136- crsmatrix_traversal_functor<execution_space, crsmatrix_type, functor_type> traversal_func (A, func, -1 );
140+ crsmatrix_traversal_functor<execution_space, crsmatrix_type, functor_type>
141+ traversal_func (A, func, -1 );
137142
138143 // Launch traversal kernel
139- Kokkos::parallel_for (" KokkosSparse::crsmatrix_traversal" ,
140- Kokkos::RangePolicy<execution_space>(space, 0 , A.numRows ()),
141- traversal_func);
144+ Kokkos::parallel_for (
145+ " KokkosSparse::crsmatrix_traversal" ,
146+ Kokkos::RangePolicy<execution_space>(space, 0 , A.numRows ()),
147+ traversal_func);
142148}
143149
144150template <class execution_space , class crsmatrix_type , class functor_type >
145151void crsmatrix_traversal_on_gpu (const execution_space& space,
146- const crsmatrix_type& A,
147- const functor_type& func) {
148-
152+ const crsmatrix_type& A,
153+ const functor_type& func) {
149154 // Wrap user functor with crsmatrix_traversal_functor
150155 int64_t rows_per_thread = 0 ;
151156 int team_size = 0 , vector_length = 0 ;
152- const int64_t rows_per_team = crsmatrix_traversal_launch_parameters<execution_space>(A.numRows (), A.nnz (), rows_per_thread, team_size, vector_length);
153- const int nteams = (static_cast <int >(A.numRows ()) + rows_per_team - 1 ) / rows_per_team;
154- crsmatrix_traversal_functor<execution_space, crsmatrix_type, functor_type> traversal_func (A, func, rows_per_team);
157+ const int64_t rows_per_team =
158+ crsmatrix_traversal_launch_parameters<execution_space>(
159+ A.numRows (), A.nnz (), rows_per_thread, team_size, vector_length);
160+ const int nteams =
161+ (static_cast <int >(A.numRows ()) + rows_per_team - 1 ) / rows_per_team;
162+ crsmatrix_traversal_functor<execution_space, crsmatrix_type, functor_type>
163+ traversal_func (A, func, rows_per_team);
155164
156165 // Launch traversal kernel
157166 Kokkos::parallel_for (" KokkosSparse::crsmatrix_traversal" ,
158- Kokkos::TeamPolicy<execution_space>(space, nteams, team_size, vector_length),
159- traversal_func);
167+ Kokkos::TeamPolicy<execution_space>(
168+ space, nteams, team_size, vector_length),
169+ traversal_func);
160170}
161171
162- } // Impl
163- } // KokkosSparse
172+ } // namespace Impl
173+ } // namespace KokkosSparse
0 commit comments