Skip to content

Commit 5a49fa4

Browse files
committed
use allocator in oneMKL spmv. It still keeps no state input
1 parent eb057fd commit 5a49fa4

File tree

2 files changed

+32
-6
lines changed

2 files changed

+32
-6
lines changed

.github/workflows/ci.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ jobs:
6969
- name: Test
7070
run: |
7171
source /opt/intel/oneapi/setvars.sh
72-
./build/test/gtest/spblas-tests
72+
ONEMKL_DEVICE_SELECTOR=*:cpu ./build/test/gtest/spblas-tests
7373
7474
macos:
7575
runs-on: 'macos-latest'

include/spblas/vendor/onemkl_sycl/spmv_impl.hpp

Lines changed: 31 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
#include <oneapi/mkl.hpp>
44

5+
#include "mkl_allocator.hpp"
56
#include <spblas/detail/log.hpp>
67
#include <spblas/detail/operation_info_t.hpp>
78
#include <spblas/detail/ranges.hpp>
@@ -24,28 +25,53 @@
2425

2526
namespace spblas {
2627

28+
class spmv_state_t {
29+
public:
30+
spmv_state_t() : spmv_state_t(mkl::mkl_allocator<char>{}) {}
31+
32+
spmv_state_t(sycl::queue* q) : spmv_state_t(mkl::mkl_allocator<char>{q}) {}
33+
34+
spmv_state_t(mkl::mkl_allocator<char> alloc) : alloc_(alloc) {}
35+
36+
sycl::queue* queue() {
37+
return alloc_.queue();
38+
}
39+
40+
private:
41+
mkl::mkl_allocator<char> alloc_;
42+
};
43+
2744
template <matrix A, vector X, vector Y>
2845
requires((__detail::has_csr_base<A> || __detail::has_csc_base<A>) &&
2946
__detail::has_contiguous_range_base<X> &&
3047
__ranges::contiguous_range<Y>)
31-
void multiply(A&& a, X&& x, Y&& y) {
48+
void multiply(spmv_state_t& state, A&& a, X&& x, Y&& y) {
3249
log_trace("");
3350
auto a_base = __detail::get_ultimate_base(a);
3451
auto x_base = __detail::get_ultimate_base(x);
3552

3653
auto alpha_optional = __detail::get_scaling_factor(a, x);
3754
tensor_scalar_t<A> alpha = alpha_optional.value_or(1);
3855

39-
sycl::queue q(sycl::cpu_selector_v);
56+
auto q_ptr = state.queue();
4057

41-
auto a_handle = __mkl::create_matrix_handle(q, a_base);
58+
auto a_handle = __mkl::create_matrix_handle(*q_ptr, a_base);
4259
auto a_transpose = __mkl::get_transpose(a);
4360

44-
oneapi::mkl::sparse::gemv(q, a_transpose, alpha, a_handle,
61+
oneapi::mkl::sparse::gemv(*q_ptr, a_transpose, alpha, a_handle,
4562
__ranges::data(x_base), 0.0, __ranges::data(y))
4663
.wait();
4764

48-
oneapi::mkl::sparse::release_matrix_handle(q, &a_handle).wait();
65+
oneapi::mkl::sparse::release_matrix_handle(*q_ptr, &a_handle).wait();
66+
}
67+
68+
template <matrix A, vector X, vector Y>
69+
requires((__detail::has_csr_base<A> || __detail::has_csc_base<A>) &&
70+
__detail::has_contiguous_range_base<X> &&
71+
__ranges::contiguous_range<Y>)
72+
void multiply(A&& a, X&& x, Y&& y) {
73+
spmv_state_t state;
74+
multiply(state, a, x, y);
4975
}
5076

5177
} // namespace spblas

0 commit comments

Comments
 (0)