Skip to content

Commit c0852a0

Browse files
authored
[SYCLomatic] Enable the migration of kernel function pointer with new introduced helper class: kernel_launcher and wrapper_register (#2561)
Key changes during migration: 1. For each CUDA kernel, generate a kernel_wrapper() (which actually is SYCL code to submit a kernel) if the CUDA kernel has been called indirectly. (eg. by function pointer) 2. The kernel_wrapper() is launched by launch() member function of new helper class kernel_launcher. 3. If the CUDA kernel is called by a raw pointer, then the new helper class wrapper_register is used to register the map relationship between raw pointer and real kernel_wrappper(). Note: No change for migration of direct kernel call. Signed-off-by: intwanghao <hao3.wang@intel.com>
1 parent d236fce commit c0852a0

File tree

18 files changed

+988
-217
lines changed

18 files changed

+988
-217
lines changed

clang/lib/DPCT/ASTTraversal.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,7 @@ REGISTER_RULE(EventAPICallRule, PassKind::PK_Migration)
109109
REGISTER_RULE(ProfilingEnableOnDemandRule, PassKind::PK_Analysis)
110110
REGISTER_RULE(StreamAPICallRule, PassKind::PK_Migration)
111111
REGISTER_RULE(KernelCallRule, PassKind::PK_Analysis)
112+
REGISTER_RULE(KernelCallRefRule, PassKind::PK_Migration)
112113
REGISTER_RULE(DeviceFunctionDeclRule, PassKind::PK_Analysis)
113114
REGISTER_RULE(MemVarRefMigrationRule, PassKind::PK_Migration)
114115
REGISTER_RULE(ConstantMemVarMigrationRule, PassKind::PK_Migration)

clang/lib/DPCT/AnalysisInfo.cpp

Lines changed: 231 additions & 93 deletions
Large diffs are not rendered by default.

clang/lib/DPCT/AnalysisInfo.h

Lines changed: 37 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,6 @@ class KernelCallExpr;
8888
class DeviceFunctionInfo;
8989
class CallFunctionExpr;
9090
class DeviceFunctionDecl;
91-
class DeviceFunctionDeclInModule;
9291
class MemVarInfo;
9392
class VarInfo;
9493
class ExplicitInstantiationDecl;
@@ -239,6 +238,12 @@ struct RnnBackwardFuncInfo {
239238
std::vector<std::string> FuncArgs;
240239
};
241240

241+
struct DeviceFunctionInfoForWrapper {
242+
std::vector<std::pair<std::string, std::string>> ParametersInfo;
243+
std::vector<std::pair<std::string, std::string>> TemplateParametersInfo;
244+
std::shared_ptr<KernelCallExpr> KernelForWrapper;
245+
};
246+
242247
// <function name, Info>
243248
using HDFuncInfoMap = std::unordered_map<std::string, HostDeviceFuncInfo>;
244249
// <file path, <Offset, Info>>
@@ -1000,10 +1005,14 @@ class DpctGlobalInfo {
10001005
return Cur.get<TargetTy>();
10011006
});
10021007
}
1003-
template <class TargetTy, class NodeTy>
1008+
template <class TargetTy, class NodeTy, class... SkipNodeTy>
10041009
static auto findParent(const NodeTy *Node) {
1005-
return findAncestor<TargetTy>(
1006-
Node, [](const DynTypedNode &Cur) -> bool { return true; });
1010+
return findAncestor<TargetTy>(Node, [](const DynTypedNode &Cur) -> bool {
1011+
if ((... || Cur.get<SkipNodeTy>())) {
1012+
return false;
1013+
}
1014+
return true;
1015+
});
10071016
}
10081017

10091018
template <typename TargetTy, typename NodeTy>
@@ -1136,8 +1145,6 @@ class DpctGlobalInfo {
11361145
std::shared_ptr<DeviceFunctionDecl> insertDeviceFunctionDecl(
11371146
const FunctionDecl *Specialization, const FunctionTypeLoc &FTL,
11381147
const ParsedAttributes &Attrs, const TemplateArgumentListInfo &TAList);
1139-
std::shared_ptr<DeviceFunctionDecl>
1140-
insertDeviceFunctionDeclInModule(const FunctionDecl *FD);
11411148

11421149
// Build kernel and device function declaration replacements and store
11431150
// them.
@@ -1343,6 +1350,8 @@ class DpctGlobalInfo {
13431350
static bool useNoQueueDevice() {
13441351
return getHelperFuncPreference(HelperFuncPreference::NoQueueDevice);
13451352
}
1353+
static void setCVersionCUDALaunchUsed() { CVersionCUDALaunchUsedFlag = true; }
1354+
static bool isCVersionCUDALaunchUsed() { return CVersionCUDALaunchUsedFlag; }
13461355
static void setUseSYCLCompat(bool Flag = true) { UseSYCLCompatFlag = Flag; }
13471356
static bool useSYCLCompat() { return UseSYCLCompatFlag; }
13481357
static bool useEnqueueBarrier() {
@@ -1665,6 +1674,7 @@ class DpctGlobalInfo {
16651674
static unsigned HelperFuncPreferenceFlag;
16661675
static bool AnalysisModeFlag;
16671676
static bool UseSYCLCompatFlag;
1677+
static bool CVersionCUDALaunchUsedFlag;
16681678
static unsigned int ColorOption;
16691679
static std::unordered_map<int, std::shared_ptr<DeviceFunctionInfo>>
16701680
CubPlaceholderIndexMap;
@@ -2568,7 +2578,8 @@ class DeviceFunctionDecl {
25682578
LinkDecl(D, List, Info);
25692579
}
25702580
void setFuncInfo(std::shared_ptr<DeviceFunctionInfo> Info);
2571-
2581+
void insertWrapper();
2582+
void collectInfoForWrapper(const FunctionDecl *FD);
25722583
virtual ~DeviceFunctionDecl() = default;
25732584

25742585
protected:
@@ -2591,7 +2602,10 @@ class DeviceFunctionDecl {
25912602
bool IsDefFilePathNeeded = false;
25922603
std::vector<std::shared_ptr<TextureObjectInfo>> TextureObjectList;
25932604
FormatInfo FormatInformation;
2594-
2605+
bool HasBody = false;
2606+
size_t DeclEnd;
2607+
std::map<int, std::string> TemplateParameterDefaultValueMap;
2608+
std::map<int, std::string> ParameterDefaultValueMap;
25952609
static std::shared_ptr<DeviceFunctionInfo> &getFuncInfo(const FunctionDecl *);
25962610
static std::unordered_map<std::string, std::shared_ptr<DeviceFunctionInfo>>
25972611
FuncInfoMap;
@@ -2622,32 +2636,6 @@ class ExplicitInstantiationDecl : public DeviceFunctionDecl {
26222636
std::string getExtraParameters(LocInfo LI) override;
26232637
};
26242638

2625-
class DeviceFunctionDeclInModule : public DeviceFunctionDecl {
2626-
void insertWrapper();
2627-
bool HasBody = false;
2628-
size_t DeclEnd;
2629-
std::string FuncName;
2630-
std::vector<std::pair<std::string, std::string>> ParametersInfo;
2631-
std::shared_ptr<KernelCallExpr> Kernel;
2632-
void buildParameterInfo(const FunctionDecl *FD);
2633-
void buildWrapperInfo(const FunctionDecl *FD);
2634-
void buildCallInfo(const FunctionDecl *FD);
2635-
std::vector<std::pair<std::string, std::string>> &getParametersInfo() {
2636-
return ParametersInfo;
2637-
}
2638-
2639-
public:
2640-
DeviceFunctionDeclInModule(unsigned Offset,
2641-
const clang::tooling::UnifiedPath &FilePathIn,
2642-
const FunctionTypeLoc &FTL,
2643-
const ParsedAttributes &Attrs,
2644-
const FunctionDecl *FD);
2645-
DeviceFunctionDeclInModule(unsigned Offset,
2646-
const clang::tooling::UnifiedPath &FilePathIn,
2647-
const FunctionDecl *FD);
2648-
void emplaceReplacement() override;
2649-
};
2650-
26512639
// device function info includes parameters num, memory variable and call
26522640
// expression in the function.
26532641
class DeviceFunctionInfo {
@@ -2747,6 +2735,13 @@ class DeviceFunctionInfo {
27472735
bool isParameterReferenced(unsigned int Index);
27482736
void setParameterReferencedStatus(unsigned int Index, bool IsReferenced);
27492737
std::string getFunctionName() { return FunctionName; }
2738+
void collectInfoForWrapper(const FunctionDecl *FD);
2739+
void setModuleUsed() { ModuleUsed = true; }
2740+
bool isModuleUsed() { return ModuleUsed; }
2741+
std::shared_ptr<DeviceFunctionInfoForWrapper>
2742+
getDeviceFunctionInfoForWrapper() {
2743+
return DFInfoForWrapper;
2744+
}
27502745

27512746
private:
27522747
void mergeCalledTexObj(
@@ -2779,12 +2774,15 @@ class DeviceFunctionInfo {
27792774
bool CallGroupFunctionInControlFlow = false;
27802775
bool HasCheckedCallGroupFunctionInControlFlow = false;
27812776
OverloadedOperatorKind OO_Kind = OverloadedOperatorKind::OO_None;
2777+
bool ModuleUsed = false;
2778+
std::shared_ptr<DeviceFunctionInfoForWrapper> DFInfoForWrapper = nullptr;
27822779
};
27832780

27842781
class KernelCallExpr : public CallFunctionExpr {
27852782
public:
27862783
bool IsInMacroDefine = false;
27872784
bool NeedLambda = false;
2785+
bool IsForWrapper = false;
27882786
bool NeedDefaultRetValue = false;
27892787

27902788
private:
@@ -2857,8 +2855,10 @@ class KernelCallExpr : public CallFunctionExpr {
28572855
const std::pair<clang::tooling::UnifiedPath, unsigned> &LocInfo,
28582856
const CallExpr *, bool IsAssigned = false);
28592857
static std::shared_ptr<KernelCallExpr>
2860-
buildForWrapper(clang::tooling::UnifiedPath, const FunctionDecl *,
2861-
std::shared_ptr<DeviceFunctionInfo>);
2858+
buildForWrapper(clang::tooling::UnifiedPath, const FunctionDecl *);
2859+
void setTemplateArgsStrForWrapper(std::string Str) {
2860+
TemplateArgsStrForWrapper = std::move(Str);
2861+
}
28622862
unsigned int GridDim = 3;
28632863
unsigned int BlockDim = 3;
28642864
void setEmitSizeofWarningFlag(bool Flag) { EmitSizeofWarning = Flag; }
@@ -2963,6 +2963,7 @@ class KernelCallExpr : public CallFunctionExpr {
29632963
OuterStmtsList OuterStmts;
29642964
StmtList KernelStmts;
29652965
std::string KernelArgs;
2966+
std::string TemplateArgsStrForWrapper;
29662967
int TotalArgsSize = 0;
29672968
bool EmitSizeofWarning = false;
29682969
unsigned int SizeOfHighestDimension = 0;

0 commit comments

Comments
 (0)