@@ -88,7 +88,6 @@ class KernelCallExpr;
8888class DeviceFunctionInfo ;
8989class CallFunctionExpr ;
9090class DeviceFunctionDecl ;
91- class DeviceFunctionDeclInModule ;
9291class MemVarInfo ;
9392class VarInfo ;
9493class ExplicitInstantiationDecl ;
@@ -239,6 +238,12 @@ struct RnnBackwardFuncInfo {
239238 std::vector<std::string> FuncArgs;
240239};
241240
241+ struct DeviceFunctionInfoForWrapper {
242+ std::vector<std::pair<std::string, std::string>> ParametersInfo;
243+ std::vector<std::pair<std::string, std::string>> TemplateParametersInfo;
244+ std::shared_ptr<KernelCallExpr> KernelForWrapper;
245+ };
246+
242247// <function name, Info>
243248using HDFuncInfoMap = std::unordered_map<std::string, HostDeviceFuncInfo>;
244249// <file path, <Offset, Info>>
@@ -1000,10 +1005,14 @@ class DpctGlobalInfo {
10001005 return Cur.get <TargetTy>();
10011006 });
10021007 }
1003- template <class TargetTy , class NodeTy >
1008+ template <class TargetTy , class NodeTy , class ... SkipNodeTy >
10041009 static auto findParent (const NodeTy *Node) {
1005- return findAncestor<TargetTy>(
1006- Node, [](const DynTypedNode &Cur) -> bool { return true ; });
1010+ return findAncestor<TargetTy>(Node, [](const DynTypedNode &Cur) -> bool {
1011+ if ((... || Cur.get <SkipNodeTy>())) {
1012+ return false ;
1013+ }
1014+ return true ;
1015+ });
10071016 }
10081017
10091018 template <typename TargetTy, typename NodeTy>
@@ -1136,8 +1145,6 @@ class DpctGlobalInfo {
11361145 std::shared_ptr<DeviceFunctionDecl> insertDeviceFunctionDecl (
11371146 const FunctionDecl *Specialization, const FunctionTypeLoc &FTL,
11381147 const ParsedAttributes &Attrs, const TemplateArgumentListInfo &TAList);
1139- std::shared_ptr<DeviceFunctionDecl>
1140- insertDeviceFunctionDeclInModule (const FunctionDecl *FD);
11411148
11421149 // Build kernel and device function declaration replacements and store
11431150 // them.
@@ -1343,6 +1350,8 @@ class DpctGlobalInfo {
13431350 static bool useNoQueueDevice () {
13441351 return getHelperFuncPreference (HelperFuncPreference::NoQueueDevice);
13451352 }
1353+ static void setCVersionCUDALaunchUsed () { CVersionCUDALaunchUsedFlag = true ; }
1354+ static bool isCVersionCUDALaunchUsed () { return CVersionCUDALaunchUsedFlag; }
13461355 static void setUseSYCLCompat (bool Flag = true ) { UseSYCLCompatFlag = Flag; }
13471356 static bool useSYCLCompat () { return UseSYCLCompatFlag; }
13481357 static bool useEnqueueBarrier () {
@@ -1665,6 +1674,7 @@ class DpctGlobalInfo {
16651674 static unsigned HelperFuncPreferenceFlag;
16661675 static bool AnalysisModeFlag;
16671676 static bool UseSYCLCompatFlag;
1677+ static bool CVersionCUDALaunchUsedFlag;
16681678 static unsigned int ColorOption;
16691679 static std::unordered_map<int , std::shared_ptr<DeviceFunctionInfo>>
16701680 CubPlaceholderIndexMap;
@@ -2568,7 +2578,8 @@ class DeviceFunctionDecl {
25682578 LinkDecl (D, List, Info);
25692579 }
25702580 void setFuncInfo (std::shared_ptr<DeviceFunctionInfo> Info);
2571-
2581+ void insertWrapper ();
2582+ void collectInfoForWrapper (const FunctionDecl *FD);
25722583 virtual ~DeviceFunctionDecl () = default ;
25732584
25742585protected:
@@ -2591,7 +2602,10 @@ class DeviceFunctionDecl {
25912602 bool IsDefFilePathNeeded = false ;
25922603 std::vector<std::shared_ptr<TextureObjectInfo>> TextureObjectList;
25932604 FormatInfo FormatInformation;
2594-
2605+ bool HasBody = false ;
2606+ size_t DeclEnd;
2607+ std::map<int , std::string> TemplateParameterDefaultValueMap;
2608+ std::map<int , std::string> ParameterDefaultValueMap;
25952609 static std::shared_ptr<DeviceFunctionInfo> &getFuncInfo (const FunctionDecl *);
25962610 static std::unordered_map<std::string, std::shared_ptr<DeviceFunctionInfo>>
25972611 FuncInfoMap;
@@ -2622,32 +2636,6 @@ class ExplicitInstantiationDecl : public DeviceFunctionDecl {
26222636 std::string getExtraParameters (LocInfo LI) override ;
26232637};
26242638
2625- class DeviceFunctionDeclInModule : public DeviceFunctionDecl {
2626- void insertWrapper ();
2627- bool HasBody = false ;
2628- size_t DeclEnd;
2629- std::string FuncName;
2630- std::vector<std::pair<std::string, std::string>> ParametersInfo;
2631- std::shared_ptr<KernelCallExpr> Kernel;
2632- void buildParameterInfo (const FunctionDecl *FD);
2633- void buildWrapperInfo (const FunctionDecl *FD);
2634- void buildCallInfo (const FunctionDecl *FD);
2635- std::vector<std::pair<std::string, std::string>> &getParametersInfo () {
2636- return ParametersInfo;
2637- }
2638-
2639- public:
2640- DeviceFunctionDeclInModule (unsigned Offset,
2641- const clang::tooling::UnifiedPath &FilePathIn,
2642- const FunctionTypeLoc &FTL,
2643- const ParsedAttributes &Attrs,
2644- const FunctionDecl *FD);
2645- DeviceFunctionDeclInModule (unsigned Offset,
2646- const clang::tooling::UnifiedPath &FilePathIn,
2647- const FunctionDecl *FD);
2648- void emplaceReplacement () override ;
2649- };
2650-
26512639// device function info includes parameters num, memory variable and call
26522640// expression in the function.
26532641class DeviceFunctionInfo {
@@ -2747,6 +2735,13 @@ class DeviceFunctionInfo {
27472735 bool isParameterReferenced (unsigned int Index);
27482736 void setParameterReferencedStatus (unsigned int Index, bool IsReferenced);
27492737 std::string getFunctionName () { return FunctionName; }
2738+ void collectInfoForWrapper (const FunctionDecl *FD);
2739+ void setModuleUsed () { ModuleUsed = true ; }
2740+ bool isModuleUsed () { return ModuleUsed; }
2741+ std::shared_ptr<DeviceFunctionInfoForWrapper>
2742+ getDeviceFunctionInfoForWrapper () {
2743+ return DFInfoForWrapper;
2744+ }
27502745
27512746private:
27522747 void mergeCalledTexObj (
@@ -2779,12 +2774,15 @@ class DeviceFunctionInfo {
27792774 bool CallGroupFunctionInControlFlow = false ;
27802775 bool HasCheckedCallGroupFunctionInControlFlow = false ;
27812776 OverloadedOperatorKind OO_Kind = OverloadedOperatorKind::OO_None;
2777+ bool ModuleUsed = false ;
2778+ std::shared_ptr<DeviceFunctionInfoForWrapper> DFInfoForWrapper = nullptr ;
27822779};
27832780
27842781class KernelCallExpr : public CallFunctionExpr {
27852782public:
27862783 bool IsInMacroDefine = false ;
27872784 bool NeedLambda = false ;
2785+ bool IsForWrapper = false ;
27882786 bool NeedDefaultRetValue = false ;
27892787
27902788private:
@@ -2857,8 +2855,10 @@ class KernelCallExpr : public CallFunctionExpr {
28572855 const std::pair<clang::tooling::UnifiedPath, unsigned > &LocInfo,
28582856 const CallExpr *, bool IsAssigned = false );
28592857 static std::shared_ptr<KernelCallExpr>
2860- buildForWrapper (clang::tooling::UnifiedPath, const FunctionDecl *,
2861- std::shared_ptr<DeviceFunctionInfo>);
2858+ buildForWrapper (clang::tooling::UnifiedPath, const FunctionDecl *);
2859+ void setTemplateArgsStrForWrapper (std::string Str) {
2860+ TemplateArgsStrForWrapper = std::move (Str);
2861+ }
28622862 unsigned int GridDim = 3 ;
28632863 unsigned int BlockDim = 3 ;
28642864 void setEmitSizeofWarningFlag (bool Flag) { EmitSizeofWarning = Flag; }
@@ -2963,6 +2963,7 @@ class KernelCallExpr : public CallFunctionExpr {
29632963 OuterStmtsList OuterStmts;
29642964 StmtList KernelStmts;
29652965 std::string KernelArgs;
2966+ std::string TemplateArgsStrForWrapper;
29662967 int TotalArgsSize = 0 ;
29672968 bool EmitSizeofWarning = false ;
29682969 unsigned int SizeOfHighestDimension = 0 ;
0 commit comments