|
2 | 2 |
|
3 | 3 | #include <oneapi/mkl.hpp> |
4 | 4 |
|
| 5 | +#include "mkl_allocator.hpp" |
5 | 6 | #include <spblas/detail/log.hpp> |
6 | 7 | #include <spblas/detail/operation_info_t.hpp> |
7 | 8 | #include <spblas/detail/ranges.hpp> |
|
24 | 25 |
|
25 | 26 | namespace spblas { |
26 | 27 |
|
| 28 | +class spmv_state_t { |
| 29 | +public: |
| 30 | + spmv_state_t() : spmv_state_t(mkl::mkl_allocator<char>{}) {} |
| 31 | + |
| 32 | + spmv_state_t(sycl::queue* q) : spmv_state_t(mkl::mkl_allocator<char>{q}) {} |
| 33 | + |
| 34 | + spmv_state_t(mkl::mkl_allocator<char> alloc) : alloc_(alloc) {} |
| 35 | + |
| 36 | + sycl::queue* queue() { |
| 37 | + return alloc_.queue(); |
| 38 | + } |
| 39 | + |
| 40 | +private: |
| 41 | + mkl::mkl_allocator<char> alloc_; |
| 42 | +}; |
| 43 | + |
27 | 44 | template <matrix A, vector X, vector Y> |
28 | 45 | requires((__detail::has_csr_base<A> || __detail::has_csc_base<A>) && |
29 | 46 | __detail::has_contiguous_range_base<X> && |
30 | 47 | __ranges::contiguous_range<Y>) |
31 | | -void multiply(A&& a, X&& x, Y&& y) { |
| 48 | +void multiply(spmv_state_t& state, A&& a, X&& x, Y&& y) { |
32 | 49 | log_trace(""); |
33 | 50 | auto a_base = __detail::get_ultimate_base(a); |
34 | 51 | auto x_base = __detail::get_ultimate_base(x); |
35 | 52 |
|
36 | 53 | auto alpha_optional = __detail::get_scaling_factor(a, x); |
37 | 54 | tensor_scalar_t<A> alpha = alpha_optional.value_or(1); |
38 | 55 |
|
39 | | - sycl::queue q(sycl::cpu_selector_v); |
| 56 | + auto q_ptr = state.queue(); |
40 | 57 |
|
41 | | - auto a_handle = __mkl::create_matrix_handle(q, a_base); |
| 58 | + auto a_handle = __mkl::create_matrix_handle(*q_ptr, a_base); |
42 | 59 | auto a_transpose = __mkl::get_transpose(a); |
43 | 60 |
|
44 | | - oneapi::mkl::sparse::gemv(q, a_transpose, alpha, a_handle, |
| 61 | + oneapi::mkl::sparse::gemv(*q_ptr, a_transpose, alpha, a_handle, |
45 | 62 | __ranges::data(x_base), 0.0, __ranges::data(y)) |
46 | 63 | .wait(); |
47 | 64 |
|
48 | | - oneapi::mkl::sparse::release_matrix_handle(q, &a_handle).wait(); |
| 65 | + oneapi::mkl::sparse::release_matrix_handle(*q_ptr, &a_handle).wait(); |
| 66 | +} |
| 67 | + |
| 68 | +template <matrix A, vector X, vector Y> |
| 69 | + requires((__detail::has_csr_base<A> || __detail::has_csc_base<A>) && |
| 70 | + __detail::has_contiguous_range_base<X> && |
| 71 | + __ranges::contiguous_range<Y>) |
| 72 | +void multiply(A&& a, X&& x, Y&& y) { |
| 73 | + spmv_state_t state; |
| 74 | + multiply(state, a, x, y); |
49 | 75 | } |
50 | 76 |
|
51 | 77 | } // namespace spblas |
0 commit comments