Skip to content

Commit 285cb26

Browse files
committed
fwd: TA::UMTensorType -> TA::UMTensor
UMTensor will be TA::Tensor type with a UM allocator To avoid confusion, renames some existing uses of UMTensor
1 parent 595c613 commit 285cb26

File tree

3 files changed

+32
-37
lines changed

3 files changed

+32
-37
lines changed

src/TiledArray/device/btas_um_tensor.h

Lines changed: 26 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -565,10 +565,9 @@ typename btasUMTensorVarray<T, Range>::value_type abs_min(
565565
}
566566

567567
/// to host for UM Array
568-
template <typename UMTensor, typename Policy>
569-
void to_host(
570-
TiledArray::DistArray<TiledArray::Tile<UMTensor>, Policy> &um_array) {
571-
auto to_host = [](TiledArray::Tile<UMTensor> &tile) {
568+
template <typename UMT, typename Policy>
569+
void to_host(TiledArray::DistArray<TiledArray::Tile<UMT>, Policy> &um_array) {
570+
auto to_host = [](TiledArray::Tile<UMT> &tile) {
572571
auto stream = device::stream_for(tile.range());
573572

574573
TiledArray::to_execution_space<TiledArray::ExecutionSpace::Host>(
@@ -591,10 +590,9 @@ void to_host(
591590
};
592591

593592
/// to device for UM Array
594-
template <typename UMTensor, typename Policy>
595-
void to_device(
596-
TiledArray::DistArray<TiledArray::Tile<UMTensor>, Policy> &um_array) {
597-
auto to_device = [](TiledArray::Tile<UMTensor> &tile) {
593+
template <typename UMT, typename Policy>
594+
void to_device(TiledArray::DistArray<TiledArray::Tile<UMT>, Policy> &um_array) {
595+
auto to_device = [](TiledArray::Tile<UMT> &tile) {
598596
auto stream = device::stream_for(tile.range());
599597

600598
TiledArray::to_execution_space<TiledArray::ExecutionSpace::Device>(
@@ -617,12 +615,11 @@ void to_device(
617615
};
618616

619617
/// convert array from UMTensor to TiledArray::Tensor
620-
template <typename UMTensor, typename TATensor, typename Policy>
621-
typename std::enable_if<!std::is_same<UMTensor, TATensor>::value,
618+
template <typename UMT, typename TATensor, typename Policy>
619+
typename std::enable_if<!std::is_same<UMT, TATensor>::value,
622620
TiledArray::DistArray<TATensor, Policy>>::type
623-
um_tensor_to_ta_tensor(
624-
const TiledArray::DistArray<UMTensor, Policy> &um_array) {
625-
const auto convert_tile_memcpy = [](const UMTensor &tile) {
621+
um_tensor_to_ta_tensor(const TiledArray::DistArray<UMT, Policy> &um_array) {
622+
const auto convert_tile_memcpy = [](const UMT &tile) {
626623
TATensor result(tile.tensor().range());
627624

628625
auto stream = device::stream_for(result.range());
@@ -635,7 +632,7 @@ um_tensor_to_ta_tensor(
635632
return result;
636633
};
637634

638-
const auto convert_tile_um = [](const UMTensor &tile) {
635+
const auto convert_tile_um = [](const UMT &tile) {
639636
TATensor result(tile.tensor().range());
640637
using std::begin;
641638
const auto n = tile.tensor().size();
@@ -661,29 +658,28 @@ um_tensor_to_ta_tensor(
661658
}
662659

663660
/// no-op if UMTensor is the same type as TATensor type
664-
template <typename UMTensor, typename TATensor, typename Policy>
665-
typename std::enable_if<std::is_same<UMTensor, TATensor>::value,
666-
TiledArray::DistArray<UMTensor, Policy>>::type
667-
um_tensor_to_ta_tensor(
668-
const TiledArray::DistArray<UMTensor, Policy> &um_array) {
661+
template <typename UMT, typename TATensor, typename Policy>
662+
typename std::enable_if<std::is_same<UMT, TATensor>::value,
663+
TiledArray::DistArray<UMT, Policy>>::type
664+
um_tensor_to_ta_tensor(const TiledArray::DistArray<UMT, Policy> &um_array) {
669665
return um_array;
670666
}
671667

672668
/// convert array from TiledArray::Tensor to UMTensor
673-
template <typename UMTensor, typename TATensor, typename Policy>
674-
typename std::enable_if<!std::is_same<UMTensor, TATensor>::value,
675-
TiledArray::DistArray<UMTensor, Policy>>::type
669+
template <typename UMT, typename TATensor, typename Policy>
670+
typename std::enable_if<!std::is_same<UMT, TATensor>::value,
671+
TiledArray::DistArray<UMT, Policy>>::type
676672
ta_tensor_to_um_tensor(const TiledArray::DistArray<TATensor, Policy> &array) {
677673
using inT = typename TATensor::value_type;
678-
using outT = typename UMTensor::value_type;
674+
using outT = typename UMT::value_type;
679675
// check if element conversion is necessary
680676
constexpr bool T_conversion = !std::is_same_v<inT, outT>;
681677

682678
// this is safe even when need to convert element types, but less efficient
683679
auto convert_tile_um = [](const TATensor &tile) {
684680
/// UMTensor must be wrapped into TA::Tile
685681

686-
using Tensor = typename UMTensor::tensor_type;
682+
using Tensor = typename UMT::tensor_type;
687683
typename Tensor::storage_type storage(tile.range().area());
688684

689685
Tensor result(tile.range(), std::move(storage));
@@ -703,7 +699,7 @@ ta_tensor_to_um_tensor(const TiledArray::DistArray<TATensor, Policy> &array) {
703699
return TiledArray::Tile<Tensor>(std::move(result));
704700
};
705701

706-
TiledArray::DistArray<UMTensor, Policy> um_array;
702+
TiledArray::DistArray<UMT, Policy> um_array;
707703
if constexpr (T_conversion) {
708704
um_array = to_new_tile_type(array, convert_tile_um);
709705
} else {
@@ -715,7 +711,7 @@ ta_tensor_to_um_tensor(const TiledArray::DistArray<TATensor, Policy> &array) {
715711
auto convert_tile_memcpy = [](const TATensor &tile) {
716712
/// UMTensor must be wrapped into TA::Tile
717713

718-
using Tensor = typename UMTensor::tensor_type;
714+
using Tensor = typename UMT::tensor_type;
719715

720716
auto stream = device::stream_for(tile.range());
721717
typename Tensor::storage_type storage;
@@ -745,10 +741,10 @@ ta_tensor_to_um_tensor(const TiledArray::DistArray<TATensor, Policy> &array) {
745741
}
746742

747743
/// no-op if array is the same as return type
748-
template <typename UMTensor, typename TATensor, typename Policy>
749-
typename std::enable_if<std::is_same<UMTensor, TATensor>::value,
750-
TiledArray::DistArray<UMTensor, Policy>>::type
751-
ta_tensor_to_um_tensor(const TiledArray::DistArray<UMTensor, Policy> &array) {
744+
template <typename UMT, typename TATensor, typename Policy>
745+
typename std::enable_if<std::is_same<UMT, TATensor>::value,
746+
TiledArray::DistArray<UMT, Policy>>::type
747+
ta_tensor_to_um_tensor(const TiledArray::DistArray<UMT, Policy> &array) {
752748
return array;
753749
}
754750

src/TiledArray/fwd.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,7 @@ using btasUMTensorVarray =
144144

145145
/// TA::Tensor with UM storage
146146
template <typename T>
147-
using UMTensorType = TiledArray::Tensor<T, TiledArray::device_um_allocator<T>>;
147+
using UMTensor = TiledArray::Tensor<T, TiledArray::device_um_allocator<T>>;
148148

149149
#endif // TILEDARRAY_HAS_DEVICE
150150

tests/expressions_device_um.cpp

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,8 @@
3535
using namespace TiledArray;
3636

3737
struct UMExpressionsFixture : public TiledRangeFixture {
38-
using UMTensor = TA::Tile<btasUMTensorVarray<double>>;
39-
using TArrayUMD = TiledArray::DistArray<UMTensor, TA::DensePolicy>;
38+
using UMT = TA::Tile<btasUMTensorVarray<double>>;
39+
using TArrayUMD = TiledArray::DistArray<UMT, TA::DensePolicy>;
4040

4141
UMExpressionsFixture()
4242
: a(*GlobalFixture::world, tr),
@@ -69,13 +69,12 @@ struct UMExpressionsFixture : public TiledRangeFixture {
6969
t = GlobalFixture::world->drand();
7070
}
7171

72-
static UMTensor permute_task(const UMTensor& tensor,
73-
const Permutation& perm) {
72+
static UMT permute_task(const UMT& tensor, const Permutation& perm) {
7473
return perm * tensor;
7574
}
7675

77-
static UMTensor permute_fn(const madness::Future<UMTensor>& tensor_f,
78-
const Permutation& perm) {
76+
static UMT permute_fn(const madness::Future<UMT>& tensor_f,
77+
const Permutation& perm) {
7978
return madness::add_device_task(*GlobalFixture::world, permute_task,
8079
tensor_f, perm)
8180
.get();

0 commit comments

Comments
 (0)