|
| 1 | +#pragma once |
| 2 | +#include <iterator> |
| 3 | +#include <memory> |
| 4 | +#include <spblas/vendor/onemkl_sycl/mkl_allocator.hpp> |
| 5 | +#include <sycl.hpp> |
| 6 | +#include <vector> |
| 7 | + |
| 8 | +namespace thrust { |
| 9 | + |
| 10 | +template <typename InputIt, typename OutputIt> |
| 11 | + requires(std::contiguous_iterator<InputIt> && |
| 12 | + std::contiguous_iterator<OutputIt>) |
| 13 | +OutputIt copy(InputIt first, InputIt last, OutputIt d_first) { |
| 14 | + sycl::queue queue(sycl::default_selector_v); |
| 15 | + using input_value_type = typename std::iterator_traits<InputIt>::value_type; |
| 16 | + using output_value_type = typename std::iterator_traits<OutputIt>::value_type; |
| 17 | + input_value_type* first_ptr = std::to_address(first); |
| 18 | + output_value_type* d_first_ptr = std::to_address(d_first); |
| 19 | + auto num = std::distance(first, last); |
| 20 | + queue.memcpy(d_first_ptr, first_ptr, num * sizeof(input_value_type)) |
| 21 | + .wait_and_throw(); |
| 22 | + return d_first + num; |
| 23 | +} |
| 24 | + |
| 25 | +// incompleted impl for thrust vector in oneMKL just for test usage |
| 26 | +template <typename ValueType> |
| 27 | +class device_vector { |
| 28 | +public: |
| 29 | + device_vector(std::vector<ValueType> host_vector) |
| 30 | + : alloc_{}, size_(host_vector.size()), ptr_(nullptr) { |
| 31 | + ptr_ = alloc_.allocate(size_); |
| 32 | + thrust::copy(host_vector.begin(), host_vector.end(), ptr_); |
| 33 | + } |
| 34 | + |
| 35 | + ~device_vector() { |
| 36 | + alloc_.deallocate(ptr_, size_); |
| 37 | + ptr_ = nullptr; |
| 38 | + } |
| 39 | + |
| 40 | + ValueType* begin() { |
| 41 | + return ptr_; |
| 42 | + } |
| 43 | + |
| 44 | + ValueType* end() { |
| 45 | + return ptr_ + size_; |
| 46 | + } |
| 47 | + |
| 48 | + // just to give data().get() |
| 49 | + std::shared_ptr<ValueType> data() { |
| 50 | + return std::shared_ptr<ValueType>(ptr_, [](ValueType* ptr) {}); |
| 51 | + } |
| 52 | + |
| 53 | +private: |
| 54 | + spblas::mkl::mkl_allocator<ValueType> alloc_; |
| 55 | + std::size_t size_; |
| 56 | + ValueType* ptr_; |
| 57 | +}; |
| 58 | + |
| 59 | +} // namespace thrust |
0 commit comments