Skip to content

Commit 3ebb8bc

Browse files
committed
[SYCLomatic][PTX] Refine migration of asm PTX instrution "lop3.b32"
Signed-off-by: chenwei.sun <chenwei.sun@intel.com>
1 parent bf01d5d commit 3ebb8bc

File tree

3 files changed

+51
-76
lines changed

3 files changed

+51
-76
lines changed

clang/lib/DPCT/RulesAsm/AsmMigration.cpp

Lines changed: 4 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -942,82 +942,14 @@ class SYCLGen : public SYCLGenBase {
942942
return SYCLGenError();
943943
OS() << " = ";
944944

945-
std::string Op[3];
946-
for (auto Idx : llvm::seq(0, 3)) {
945+
std::string Op[4];
946+
for (auto Idx : llvm::seq(0, 4)) {
947947
if (tryEmitStmt(Op[Idx], I->getInputOperand(Idx)))
948948
return SYCLGenError();
949949
}
950950

951-
if (!isa<InlineAsmIntegerLiteral>(I->getInputOperand(3)))
952-
return SYCLGenError();
953-
unsigned Imm = dyn_cast<InlineAsmIntegerLiteral>(I->getInputOperand(3))
954-
->getValue()
955-
.getZExtValue();
956-
957-
#define EMPTY nullptr
958-
#define EMPTY4 EMPTY, EMPTY, EMPTY, EMPTY
959-
#define EMPTY16 EMPTY4, EMPTY4, EMPTY4, EMPTY4
960-
constexpr const char *FastMap[256] = {
961-
/*0x00*/ "0",
962-
// clang-format off
963-
EMPTY16, EMPTY4, EMPTY4, EMPTY,
964-
/*0x1a*/ "({0} & {1} | {2}) ^ {0}",
965-
EMPTY, EMPTY, EMPTY,
966-
/*0x1e*/ "{0} ^ ({1} | {2})",
967-
EMPTY4, EMPTY4, EMPTY4, EMPTY, EMPTY,
968-
/*0x2d*/ "~{0} ^ (~{1} & {2})",
969-
EMPTY16, EMPTY, EMPTY,
970-
/*0x40*/ "{0} & {1} & ~{2}",
971-
EMPTY16, EMPTY16, EMPTY16, EMPTY4, EMPTY, EMPTY, EMPTY,
972-
/*0x78*/ "{0} ^ ({1} & {2})",
973-
EMPTY4, EMPTY, EMPTY, EMPTY,
974-
/*0x80*/ "{0} & {1} & {2}",
975-
EMPTY16, EMPTY4, EMPTY,
976-
/*0x96*/ "{0} ^ {1} ^ {2}",
977-
EMPTY16, EMPTY4, EMPTY4, EMPTY4, EMPTY,
978-
/*0xb4*/ "{0} ^ ({1} & ~{2})",
979-
EMPTY, EMPTY, EMPTY,
980-
/*0xb8*/ "({0} ^ ({1} & ({2} ^ {0})))",
981-
EMPTY16, EMPTY4, EMPTY4, EMPTY,
982-
/*0xd2*/ "{0} ^ (~{1} & {2})",
983-
EMPTY16, EMPTY4, EMPTY,
984-
/*0xe8*/ "(({0} & ({1} | {2})) | ({1} & {2}))",
985-
EMPTY,
986-
/*0xea*/ "({0} & {1}) | {2}",
987-
EMPTY16, EMPTY, EMPTY, EMPTY,
988-
/*0xfe*/ "{0} | {1} | {2}",
989-
/*0xff*/ "uint32_t(-1)"};
990-
// clang-format on
991-
992-
#undef EMPTY16
993-
#undef EMPTY4
994-
#undef EMPTY
995-
// clang-format off
996-
constexpr const char *SlowMap[8] = {
997-
/* 0x01*/ "(~{0} & ~{1} & ~{2})",
998-
/* 0x02*/ "(~{0} & ~{1} & {2})",
999-
/* 0x04*/ "(~{0} & {1} & ~{2})",
1000-
/* 0x08*/ "(~{0} & {1} & {2})",
1001-
/* 0x10*/ "({0} & ~{1} & ~{2})",
1002-
/* 0x20*/ "({0} & ~{1} & {2})",
1003-
/* 0x40*/ "({0} & {1} & ~{2})",
1004-
/* 0x80*/ "({0} & {1} & {2})"
1005-
};
1006-
// clang-format on
1007-
1008-
if (FastMap[Imm]) {
1009-
OS() << llvm::formatv(FastMap[Imm], Op[0], Op[1], Op[2]);
1010-
} else {
1011-
SmallVector<std::string, 8> Templates;
1012-
for (auto Bit : llvm::seq(0, 8)) {
1013-
if (Imm & (1U << Bit)) {
1014-
Templates.push_back(
1015-
llvm::formatv(SlowMap[Bit], Op[0], Op[1], Op[2]).str());
1016-
}
1017-
}
1018-
1019-
OS() << llvm::join(Templates, " | ");
1020-
}
951+
OS() << MapNames::getDpctNamespace() << "lop3(" << Op[0] << ", " << Op[1]
952+
<< ", " << Op[2] << ", " << Op[3] << ")";
1021953

1022954
endstmt();
1023955
return SYCLGenSuccess();

clang/runtime/dpct-rt/include/dpct/util.hpp

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1205,6 +1205,36 @@ template <typename Func, std::size_t N> struct nth_argument_type {
12051205
using type = decltype(helper(std::declval<Func>()));
12061206
};
12071207

1208+
/// The functions performs bitwise logical operations on three input values of \p a, \p b and
1209+
/// \p c based on the specified 8-bit truth table \p lut and return the result
1210+
/// \param [in] a Input value
1211+
/// \param [in] b Input value
1212+
/// \param [in] c Input value
1213+
/// \param [in] lut truth table for looking up
1214+
/// \returns The result
1215+
uint32_t lop3(uint32_t a, uint32_t b, uint32_t c, uint8_t lut) {
1216+
uint32_t result = 0;
1217+
1218+
// Iterate through all 32 bits
1219+
for (int i = 0; i < 32; i++) {
1220+
// Extract the i-th bit from each input
1221+
uint8_t a_bit_val = (a >> i) & 1;
1222+
uint8_t b_bit_val = (b >> i) & 1;
1223+
uint8_t c_bit_val = (c >> i) & 1;
1224+
1225+
// Compute the index for the truth table using the three bits
1226+
uint8_t index = a_bit_val | (b_bit_val << 1) | (c_bit_val << 2);
1227+
1228+
// Extract the corresponding bit from the mask
1229+
uint8_t output_bit = (lut >> index) & 1;
1230+
1231+
// Set the output bit in the result
1232+
result |= (output_bit << i);
1233+
}
1234+
1235+
return result;
1236+
}
1237+
12081238
#ifdef _WIN32
12091239
#define DPCT_EXPORT __declspec(dllexport)
12101240
#else

clang/test/dpct/asm/lop3.cu

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,32 +7,45 @@
77
// a^b^c
88
static __device__ __forceinline__ uint32_t LOP3LUT_XOR(uint32_t a, uint32_t b, uint32_t c) {
99
uint32_t d1;
10-
// CHECK: d1 = a ^ b ^ c;
10+
// CHECK: d1 = dpct::lop3(a, b, c, 0x96);
1111
asm("lop3.b32 %0, %1, %2, %3, 0x96;" : "=r"(d1) : "r"(a), "r"(b), "r"(c));
1212
return d1;
1313
}
1414

1515
// (a ^ (c & (b ^ a)))
1616
static __device__ __forceinline__ uint32_t LOP3LUT_XORAND(uint32_t a, uint32_t b, uint32_t c) {
1717
uint32_t d2;
18-
// CHECK: d2 = (a ^ (c & (b ^ a)));
18+
// CHECK: d2 = dpct::lop3(a, c, b, 0xb8);
1919
asm("lop3.b32 %0, %1, %3, %2, 0xb8;" : "=r"(d2) : "r"(a), "r"(b), "r"(c));
2020
return d2;
2121
}
2222

2323
// ((a & (b | b)) | (b & b))
2424
static __device__ __forceinline__ uint32_t LOP3LUT_ANDOR(uint32_t a, uint32_t b) {
2525
uint32_t d3;
26-
// CHECK: d3 = ((a & (b | b)) | (b & b));
26+
// CHECK: d3 = dpct::lop3(a, b, b, 0xe8);
2727
asm("lop3.b32 %0, %1, %2, %2, 0xe8;" : "=r"(d3) : "r"(a), "r"(b));
2828
return d3;
2929
}
3030

3131
#define B 3
3232
__device__ int hard(int a) {
3333
int d4;
34-
// CHECK: d4 = (~(a + B) & B & ~3) | (~(a + B) & B & 3) | ((a + B) & ~B & ~3);
34+
// CHECK: d4 = dpct::lop3((a + B), B, 3, 0x1C);
3535
asm("lop3.b32 %0, %1, %2, 3, 0x1C;" : "=r"(d4) : "r"(a + B), "r"(B));
3636
return d4;
3737
}
38+
39+
// CHECK: template <int lut, typename T> inline T lop3(T a, T b, T c) {
40+
// CHECK-NEXT: T res;
41+
// CHECK-NEXT: res = dpct::lop3(a, b, c, lut);
42+
// CHECK-NEXT: return res;
43+
// CHECK-NEXT:}
44+
template <int lut, typename T> __device__ inline T lop3(T a, T b, T c) {
45+
T res;
46+
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
47+
: "=r"(res)
48+
: "r"(a), "r"(b), "r"(c), "n"(lut));
49+
return res;
50+
}
3851
// clang-format on

0 commit comments

Comments
 (0)