Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
800 changes: 417 additions & 383 deletions batched/dense/impl/KokkosBatched_InnerGemmFixA_Serial_Impl.hpp

Large diffs are not rendered by default.

849 changes: 434 additions & 415 deletions batched/dense/impl/KokkosBatched_InnerTrsm_Serial_Impl.hpp

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions batched/dense/impl/KokkosBatched_LU_Serial_Internal.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -111,8 +111,8 @@ KOKKOS_INLINE_FUNCTION int SerialLU_Internal<Algo::LU::Blocked>::invoke(
const int m_abr = ib - p - mb, n_abr = jb - p - mb;

// trsm update
trsm_llu.serial_invoke(Ap, pb, n_abr, Ap + mb * as1);
trsm_run.serial_invoke(Ap, pb, m_abr, Ap + mb * as0);
trsm_llu.serial_invoke(KokkosBlas::Impl::OpID(), Ap, pb, n_abr, Ap + mb * as1);
trsm_run.serial_invoke(KokkosBlas::Impl::OpID(), Ap, pb, m_abr, Ap + mb * as0);

// gemm update
Impl::SerialGemmInternal<Algo::Gemm::Blocked>::invoke(
Expand Down
4 changes: 2 additions & 2 deletions batched/dense/impl/KokkosBatched_LU_Team_Internal.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -130,10 +130,10 @@ KOKKOS_INLINE_FUNCTION int TeamLU_Internal<Algo::LU::Blocked>::invoke(
Kokkos::parallel_for(Kokkos::TeamThreadRange(member, 0, mq_abr + nq_abr), [&](const int &ij) {
if (ij < nq_abr) {
const int j = (ij)*nb, qb = (j + nb) > n_abr ? np_abr : nb;
trsm_llu.serial_invoke(Ap, pb, qb, Ap + (j + mb) * as1);
trsm_llu.serial_invoke(KokkosBlas::Impl::OpID(), Ap, pb, qb, Ap + (j + mb) * as1);
} else {
const int i = (ij - nq_abr) * nb, qb = (i + nb) > m_abr ? mp_abr : nb;
trsm_run.serial_invoke(Ap, pb, qb, Ap + (i + mb) * as0);
trsm_run.serial_invoke(KokkosBlas::Impl::OpID(), Ap, pb, qb, Ap + (i + mb) * as0);
}
});
member.team_barrier();
Expand Down
120 changes: 78 additions & 42 deletions batched/dense/impl/KokkosBatched_Trsm_Serial_Impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -117,8 +117,8 @@ struct SerialTrsm<Side::Left, Uplo::Lower, Trans::NoTranspose, ArgDiag, Algo::Tr
if (info) return info;

return KokkosBatched::Impl::SerialTrsmInternalLeftLower<Algo::Trsm::Unblocked>::invoke(
ArgDiag::use_unit_diag, false, B.extent(0), B_extent_1, alpha, A.data(), A.stride(0), A.stride(1), B.data(),
B.stride(0), B_stride_1);
ArgDiag::use_unit_diag, KokkosBlas::Impl::OpID(), B.extent(0), B_extent_1, alpha, A.data(), A.stride(0),
A.stride(1), B.data(), B.stride(0), B_stride_1);
}
};

Expand All @@ -141,8 +141,8 @@ struct SerialTrsm<Side::Left, Uplo::Lower, Trans::NoTranspose, ArgDiag, Algo::Tr

static_assert(AViewType::rank() == 2 && BViewType::rank() == 2);
return KokkosBatched::Impl::SerialTrsmInternalLeftLower<Algo::Trsm::Blocked>::invoke(
ArgDiag::use_unit_diag, false, B.extent(0), B_extent_1, alpha, A.data(), A.stride(0), A.stride(1), B.data(),
B.stride(0), B_stride_1);
ArgDiag::use_unit_diag, KokkosBlas::Impl::OpID(), B.extent(0), B_extent_1, alpha, A.data(), A.stride(0),
A.stride(1), B.data(), B.stride(0), B_stride_1);
}
};

Expand Down Expand Up @@ -214,8 +214,8 @@ struct SerialTrsm<Side::Left, Uplo::Upper, Trans::NoTranspose, ArgDiag, Algo::Tr
if (info) return info;

return KokkosBatched::Impl::SerialTrsmInternalLeftUpper<Algo::Trsm::Unblocked>::invoke(
ArgDiag::use_unit_diag, false, B.extent(0), B_extent_1, alpha, A.data(), A.stride(0), A.stride(1), B.data(),
B.stride(0), B_stride_1);
ArgDiag::use_unit_diag, KokkosBlas::Impl::OpID(), B.extent(0), B_extent_1, alpha, A.data(), A.stride(0),
A.stride(1), B.data(), B.stride(0), B_stride_1);
}
};

Expand All @@ -237,8 +237,8 @@ struct SerialTrsm<Side::Left, Uplo::Upper, Trans::NoTranspose, ArgDiag, Algo::Tr
if (info) return info;

return KokkosBatched::Impl::SerialTrsmInternalLeftUpper<Algo::Trsm::Blocked>::invoke(
ArgDiag::use_unit_diag, false, B.extent(0), B_extent_1, alpha, A.data(), A.stride(0), A.stride(1), B.data(),
B.stride(0), B_stride_1);
ArgDiag::use_unit_diag, KokkosBlas::Impl::OpID(), B.extent(0), B_extent_1, alpha, A.data(), A.stride(0),
A.stride(1), B.data(), B.stride(0), B_stride_1);
}
};

Expand Down Expand Up @@ -310,8 +310,8 @@ struct SerialTrsm<Side::Left, Uplo::Lower, Trans::Transpose, ArgDiag, Algo::Trsm
if (info) return info;

return KokkosBatched::Impl::SerialTrsmInternalLeftUpper<Algo::Trsm::Unblocked>::invoke(
ArgDiag::use_unit_diag, false, B.extent(0), B_extent_1, alpha, A.data(), A.stride(1), A.stride(0), B.data(),
B.stride(0), B_stride_1);
ArgDiag::use_unit_diag, KokkosBlas::Impl::OpID(), B.extent(0), B_extent_1, alpha, A.data(), A.stride(1),
A.stride(0), B.data(), B.stride(0), B_stride_1);
}
};

Expand All @@ -333,8 +333,8 @@ struct SerialTrsm<Side::Left, Uplo::Lower, Trans::Transpose, ArgDiag, Algo::Trsm
if (info) return info;

return KokkosBatched::Impl::SerialTrsmInternalLeftUpper<Algo::Trsm::Blocked>::invoke(
ArgDiag::use_unit_diag, false, B.extent(0), B_extent_1, alpha, A.data(), A.stride(1), A.stride(0), B.data(),
B.stride(0), B_stride_1);
ArgDiag::use_unit_diag, KokkosBlas::Impl::OpID(), B.extent(0), B_extent_1, alpha, A.data(), A.stride(1),
A.stride(0), B.data(), B.stride(0), B_stride_1);
}
};

Expand Down Expand Up @@ -405,8 +405,8 @@ struct SerialTrsm<Side::Left, Uplo::Upper, Trans::Transpose, ArgDiag, Algo::Trsm
if (info) return info;

return KokkosBatched::Impl::SerialTrsmInternalLeftLower<Algo::Trsm::Unblocked>::invoke(
ArgDiag::use_unit_diag, false, B.extent(0), B_extent_1, alpha, A.data(), A.stride(1), A.stride(0), B.data(),
B.stride(0), B_stride_1);
ArgDiag::use_unit_diag, KokkosBlas::Impl::OpID(), B.extent(0), B_extent_1, alpha, A.data(), A.stride(1),
A.stride(0), B.data(), B.stride(0), B_stride_1);
}
};

Expand All @@ -428,8 +428,8 @@ struct SerialTrsm<Side::Left, Uplo::Upper, Trans::Transpose, ArgDiag, Algo::Trsm
if (info) return info;

return KokkosBatched::Impl::SerialTrsmInternalLeftLower<Algo::Trsm::Blocked>::invoke(
ArgDiag::use_unit_diag, false, B.extent(0), B_extent_1, alpha, A.data(), A.stride(1), A.stride(0), B.data(),
B.stride(0), B_stride_1);
ArgDiag::use_unit_diag, KokkosBlas::Impl::OpID(), B.extent(0), B_extent_1, alpha, A.data(), A.stride(1),
A.stride(0), B.data(), B.stride(0), B_stride_1);
}
};

Expand Down Expand Up @@ -501,12 +501,33 @@ struct SerialTrsm<Side::Left, Uplo::Lower, Trans::ConjTranspose, ArgDiag, Algo::
if (info) return info;

return KokkosBatched::Impl::SerialTrsmInternalLeftUpper<Algo::Trsm::Unblocked>::invoke(
ArgDiag::use_unit_diag, true, B.extent(0), B_extent_1, alpha, A.data(), A.stride(1), A.stride(0), B.data(),
B.stride(0), B_stride_1);
ArgDiag::use_unit_diag, KokkosBlas::Impl::OpConj(), B.extent(0), B_extent_1, alpha, A.data(), A.stride(1),
A.stride(0), B.data(), B.stride(0), B_stride_1);
}
};

// [TO DO] ConjTranspose is not supported yet
template <typename ArgDiag>
struct SerialTrsm<Side::Left, Uplo::Lower, Trans::ConjTranspose, ArgDiag, Algo::Trsm::Blocked> {
template <typename ScalarType, typename AViewType, typename BViewType>
KOKKOS_INLINE_FUNCTION static int invoke(const ScalarType alpha, const AViewType &A, const BViewType &B) {
static_assert(AViewType::rank() == 2);
constexpr size_t B_rank = BViewType::rank();
static_assert(B_rank == 1 || B_rank == 2);

// Quick return if possible
if (B.size() == 0) return 0;

size_t B_extent_1 = B_rank == 1 ? 1 : B.extent(1);
size_t B_stride_1 = B_rank == 1 ? 1 : B.stride(1);

auto info = KokkosBatched::Impl::checkTrsmInput<Side::Left>(A, B);
if (info) return info;

return KokkosBatched::Impl::SerialTrsmInternalLeftUpper<Algo::Trsm::Blocked>::invoke(
ArgDiag::use_unit_diag, KokkosBlas::Impl::OpConj(), B.extent(0), B_extent_1, alpha, A.data(), A.stride(1),
A.stride(0), B.data(), B.stride(0), B_stride_1);
}
};

///
/// L/U/C
Expand Down Expand Up @@ -575,8 +596,8 @@ struct SerialTrsm<Side::Left, Uplo::Upper, Trans::ConjTranspose, ArgDiag, Algo::
if (info) return info;

return KokkosBatched::Impl::SerialTrsmInternalLeftLower<Algo::Trsm::Unblocked>::invoke(
ArgDiag::use_unit_diag, true, B.extent(0), B_extent_1, alpha, A.data(), A.stride(1), A.stride(0), B.data(),
B.stride(0), B_stride_1);
ArgDiag::use_unit_diag, KokkosBlas::Impl::OpConj(), B.extent(0), B_extent_1, alpha, A.data(), A.stride(1),
A.stride(0), B.data(), B.stride(0), B_stride_1);
}
};

Expand Down Expand Up @@ -635,8 +656,8 @@ struct SerialTrsm<Side::Right, Uplo::Lower, Trans::NoTranspose, ArgDiag, Algo::T
if (info) return info;

return KokkosBatched::Impl::SerialTrsmInternalLeftUpper<Algo::Trsm::Unblocked>::invoke(
ArgDiag::use_unit_diag, false, B.extent(1), B.extent(0), alpha, A.data(), A.stride(1), A.stride(0), B.data(),
B.stride(1), B.stride(0));
ArgDiag::use_unit_diag, KokkosBlas::Impl::OpID(), B.extent(1), B.extent(0), alpha, A.data(), A.stride(1),
A.stride(0), B.data(), B.stride(1), B.stride(0));
}
};

Expand All @@ -652,8 +673,8 @@ struct SerialTrsm<Side::Right, Uplo::Lower, Trans::NoTranspose, ArgDiag, Algo::T
if (info) return info;

return KokkosBatched::Impl::SerialTrsmInternalLeftUpper<Algo::Trsm::Blocked>::invoke(
ArgDiag::use_unit_diag, false, B.extent(1), B.extent(0), alpha, A.data(), A.stride(1), A.stride(0), B.data(),
B.stride(1), B.stride(0));
ArgDiag::use_unit_diag, KokkosBlas::Impl::OpID(), B.extent(1), B.extent(0), alpha, A.data(), A.stride(1),
A.stride(0), B.data(), B.stride(1), B.stride(0));
}
};

Expand Down Expand Up @@ -709,8 +730,8 @@ struct SerialTrsm<Side::Right, Uplo::Upper, Trans::NoTranspose, ArgDiag, Algo::T
if (info) return info;

return KokkosBatched::Impl::SerialTrsmInternalLeftLower<Algo::Trsm::Unblocked>::invoke(
ArgDiag::use_unit_diag, false, B.extent(1), B.extent(0), alpha, A.data(), A.stride(1), A.stride(0), B.data(),
B.stride(1), B.stride(0));
ArgDiag::use_unit_diag, KokkosBlas::Impl::OpID(), B.extent(1), B.extent(0), alpha, A.data(), A.stride(1),
A.stride(0), B.data(), B.stride(1), B.stride(0));
}
};

Expand All @@ -726,8 +747,8 @@ struct SerialTrsm<Side::Right, Uplo::Upper, Trans::NoTranspose, ArgDiag, Algo::T
if (info) return info;

return KokkosBatched::Impl::SerialTrsmInternalLeftLower<Algo::Trsm::Blocked>::invoke(
ArgDiag::use_unit_diag, false, B.extent(1), B.extent(0), alpha, A.data(), A.stride(1), A.stride(0), B.data(),
B.stride(1), B.stride(0));
ArgDiag::use_unit_diag, KokkosBlas::Impl::OpID(), B.extent(1), B.extent(0), alpha, A.data(), A.stride(1),
A.stride(0), B.data(), B.stride(1), B.stride(0));
}
};

Expand Down Expand Up @@ -784,8 +805,8 @@ struct SerialTrsm<Side::Right, Uplo::Lower, Trans::Transpose, ArgDiag, Algo::Trs
if (info) return info;

return KokkosBatched::Impl::SerialTrsmInternalLeftLower<Algo::Trsm::Unblocked>::invoke(
ArgDiag::use_unit_diag, false, B.extent(1), B.extent(0), alpha, A.data(), A.stride(0), A.stride(1), B.data(),
B.stride(1), B.stride(0));
ArgDiag::use_unit_diag, KokkosBlas::Impl::OpID(), B.extent(1), B.extent(0), alpha, A.data(), A.stride(0),
A.stride(1), B.data(), B.stride(1), B.stride(0));
}
};

Expand All @@ -801,8 +822,8 @@ struct SerialTrsm<Side::Right, Uplo::Lower, Trans::Transpose, ArgDiag, Algo::Trs
if (info) return info;

return KokkosBatched::Impl::SerialTrsmInternalLeftLower<Algo::Trsm::Blocked>::invoke(
ArgDiag::use_unit_diag, false, B.extent(1), B.extent(0), alpha, A.data(), A.stride(0), A.stride(1), B.data(),
B.stride(1), B.stride(0));
ArgDiag::use_unit_diag, KokkosBlas::Impl::OpID(), B.extent(1), B.extent(0), alpha, A.data(), A.stride(0),
A.stride(1), B.data(), B.stride(1), B.stride(0));
}
};

Expand Down Expand Up @@ -858,8 +879,8 @@ struct SerialTrsm<Side::Right, Uplo::Upper, Trans::Transpose, ArgDiag, Algo::Trs
if (info) return info;

return KokkosBatched::Impl::SerialTrsmInternalLeftUpper<Algo::Trsm::Unblocked>::invoke(
ArgDiag::use_unit_diag, false, B.extent(1), B.extent(0), alpha, A.data(), A.stride(0), A.stride(1), B.data(),
B.stride(1), B.stride(0));
ArgDiag::use_unit_diag, KokkosBlas::Impl::OpID(), B.extent(1), B.extent(0), alpha, A.data(), A.stride(0),
A.stride(1), B.data(), B.stride(1), B.stride(0));
}
};

Expand All @@ -875,8 +896,8 @@ struct SerialTrsm<Side::Right, Uplo::Upper, Trans::Transpose, ArgDiag, Algo::Trs
if (info) return info;

return KokkosBatched::Impl::SerialTrsmInternalLeftUpper<Algo::Trsm::Blocked>::invoke(
ArgDiag::use_unit_diag, false, B.extent(1), B.extent(0), alpha, A.data(), A.stride(0), A.stride(1), B.data(),
B.stride(1), B.stride(0));
ArgDiag::use_unit_diag, KokkosBlas::Impl::OpID(), B.extent(1), B.extent(0), alpha, A.data(), A.stride(0),
A.stride(1), B.data(), B.stride(1), B.stride(0));
}
};

Expand Down Expand Up @@ -933,8 +954,8 @@ struct SerialTrsm<Side::Right, Uplo::Lower, Trans::ConjTranspose, ArgDiag, Algo:
if (info) return info;

return KokkosBatched::Impl::SerialTrsmInternalLeftLower<Algo::Trsm::Unblocked>::invoke(
ArgDiag::use_unit_diag, true, B.extent(1), B.extent(0), alpha, A.data(), A.stride(0), A.stride(1), B.data(),
B.stride(1), B.stride(0));
ArgDiag::use_unit_diag, KokkosBlas::Impl::OpConj(), B.extent(1), B.extent(0), alpha, A.data(), A.stride(0),
A.stride(1), B.data(), B.stride(1), B.stride(0));
}
};

Expand Down Expand Up @@ -992,12 +1013,27 @@ struct SerialTrsm<Side::Right, Uplo::Upper, Trans::ConjTranspose, ArgDiag, Algo:
if (info) return info;

return KokkosBatched::Impl::SerialTrsmInternalLeftUpper<Algo::Trsm::Unblocked>::invoke(
ArgDiag::use_unit_diag, true, B.extent(1), B.extent(0), alpha, A.data(), A.stride(0), A.stride(1), B.data(),
B.stride(1), B.stride(0));
ArgDiag::use_unit_diag, KokkosBlas::Impl::OpConj(), B.extent(1), B.extent(0), alpha, A.data(), A.stride(0),
A.stride(1), B.data(), B.stride(1), B.stride(0));
}
};

// [TO DO] ConjTranspose is not supported yet
template <typename ArgDiag>
struct SerialTrsm<Side::Right, Uplo::Upper, Trans::ConjTranspose, ArgDiag, Algo::Trsm::Blocked> {
template <typename ScalarType, typename AViewType, typename BViewType>
KOKKOS_INLINE_FUNCTION static int invoke(const ScalarType alpha, const AViewType &A, const BViewType &B) {
static_assert(AViewType::rank() == 2 && BViewType::rank() == 2);
// Quick return if possible
if (B.extent(0) == 0 || B.extent(1) == 0) return 0;

auto info = KokkosBatched::Impl::checkTrsmInput<Side::Right>(A, B);
if (info) return info;

return KokkosBatched::Impl::SerialTrsmInternalLeftUpper<Algo::Trsm::Blocked>::invoke(
ArgDiag::use_unit_diag, KokkosBlas::Impl::OpConj(), B.extent(1), B.extent(0), alpha, A.data(), A.stride(0),
A.stride(1), B.data(), B.stride(1), B.stride(0));
}
};

} // namespace KokkosBatched

Expand Down
Loading
Loading