Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 45 additions & 1 deletion Plugins/Gnn/include/ActsPlugins/Gnn/Tensor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
#include <cmath>
#include <cstdint>
#include <cstring>
#include <format>
#include <fstream>
#include <functional>
#include <memory>
#include <optional>
Expand Down Expand Up @@ -87,6 +89,10 @@ TensorPtr createTensorMemory(std::size_t nbytes, const ExecutionContext &ctx);
TensorPtr cloneTensorMemory(const TensorPtr &ptrFrom, std::size_t nbytes,
Device devFrom, const ExecutionContext &ctxTo);

void dumpNpy(const std::string &filename, const std::string &type,
std::span<const std::byte> data,
const std::array<std::size_t, 2> &shape);

} // namespace detail

/// This is a very small, limited class that models a 2D tensor of arbitrary
Expand Down Expand Up @@ -140,8 +146,12 @@ class Tensor {
std::size_t nbytes() const { return size() * sizeof(T); }

/// Get the device of the tensor
Device device() const { return m_device; }
/// @return Device where tensor data is stored
Device device() const { return m_device; }

/// Save the tensor in the numpy NPY format version 1.0 to disk
void dumpNpy(const std::string &filename,
std::optional<cudaStream_t> stream = {}) const;

private:
Tensor(Shape shape, detail::TensorPtr ptr, const ExecutionContext &ctx)
Expand Down Expand Up @@ -180,4 +190,38 @@ std::pair<Tensor<std::int64_t>, std::optional<Tensor<float>>> applyEdgeLimit(
const std::optional<Tensor<float>> &edgeFeatures, std::size_t maxEdges,
std::optional<cudaStream_t> stream);

template <Acts::Concepts::arithmetic T>
void Tensor<T>::dumpNpy(const std::string &filename,
std::optional<cudaStream_t> stream) const {
std::optional<Tensor<T>> maybeCpuTensor;
if (!device().isCpu()) {
maybeCpuTensor = this->clone({Device::Cpu(), stream});
}
const Tensor<T> &toDump =
maybeCpuTensor.has_value() ? maybeCpuTensor.value() : *this;

// Simple NPY header for 2D float32 array
std::ofstream ofs(filename, std::ios::binary);
if (!ofs.is_open()) {
throw std::runtime_error("Could not open file for writing: " + filename);
}

std::string typeStr;
if constexpr (std::is_same_v<T, float>) {
typeStr = "<f4";
} else if constexpr (std::is_same_v<T, double>) {
typeStr = "<f8";
} else if constexpr (std::is_same_v<T, std::int64_t>) {
typeStr = "<i8";
} else if constexpr (std::is_same_v<T, std::int32_t>) {
typeStr = "<i4";
} else {
throw std::runtime_error("Unsupported type for NPY dump");
}

detail::dumpNpy(
filename, typeStr,
std::as_bytes(std::span<const T>(toDump.data(), toDump.size())), m_shape);
}

} // namespace ActsPlugins
40 changes: 40 additions & 0 deletions Plugins/Gnn/src/Tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
#endif

#include <cstring>
#include <format>
#include <fstream>
#include <numeric>
#include <span>

Expand Down Expand Up @@ -222,4 +224,42 @@ std::pair<Tensor<std::int64_t>, std::optional<Tensor<float>>> applyEdgeLimit(
std::move(newEdgeFeatureTensor)};
}

void detail::dumpNpy(const std::string &filename, const std::string &type,
std::span<const std::byte> data,
const std::array<std::size_t, 2> &shape) {
// Simple NPY header for 2D array
std::ofstream ofs(filename, std::ios::binary);
if (!ofs.is_open()) {
throw std::runtime_error("Could not open file for writing: " + filename);
}

// NPY header for version 1.0
const char vMajor = 1;
const char vMinor = 0;
const std::array<char, 8> magicString = {'\x93', 'N', 'U', 'M',
'P', 'Y', vMajor, vMinor};
ofs.write(magicString.data(), magicString.size());

// Construct the dictionary
std::string dict = std::format(
"{{'descr': '{}', 'fortran_order': False, 'shape': ({}, {}), }}", type,
shape[0], shape[1]);

// Pad the dictionary to be 16-byte aligned
std::size_t padding = 16 - (10 + dict.size()) % 16;
dict.append(padding, ' ');
dict.push_back('\n');

// Write the length of the dictionary
static_assert(std::endian::native == std::endian::little);
std::uint16_t dictLen = static_cast<std::uint16_t>(dict.size());
ofs.write(reinterpret_cast<const char *>(&dictLen), sizeof(dictLen));

// Write the dictionary
ofs.write(dict.data(), dict.size());

// Write the data
ofs.write(reinterpret_cast<const char *>(data.data()), data.size());
}

} // namespace ActsPlugins
Loading