@@ -123,10 +123,9 @@ class E2PFromSingleBox(E2PBase):
123123 def default_name (self ):
124124 return "e2p_from_single_box"
125125
126- def get_kernel (self , max_ntargets_in_one_box ):
126+ def get_kernel (self , max_ntargets_in_one_box , max_work_items ):
127127 ncoeffs = len (self .expansion )
128128 loopy_args = self .get_loopy_args ()
129- max_work_items = min (256 , max (ncoeffs , max_ntargets_in_one_box ))
130129
131130 loopy_knl = lp .make_kernel (
132131 [
@@ -208,11 +207,16 @@ def get_kernel(self, max_ntargets_in_one_box):
208207 return loopy_knl
209208
210209 def get_optimized_kernel (self , max_ntargets_in_one_box ):
211- inner_knl , optimizations = self .get_loopy_evaluator_and_optimizations ()
212- knl = self .get_kernel (max_ntargets_in_one_box = max_ntargets_in_one_box )
210+ _ , optimizations = self .get_loopy_evaluator_and_optimizations ()
211+
212+ ncoeffs = len (self .expansion )
213+ max_work_items = min (256 , max (ncoeffs , max_ntargets_in_one_box ))
214+ knl = self .get_kernel (max_ntargets_in_one_box = max_ntargets_in_one_box ,
215+ max_work_items = max_work_items )
216+
213217 knl = lp .tag_inames (knl , {"itgt_box" : "g.0" })
214- knl = lp .split_iname (knl , "itgt_offset" , 256 , inner_tag = "l.0" )
215- knl = lp .split_iname (knl , "icoeff" , 256 , inner_tag = "l.0" )
218+ knl = lp .split_iname (knl , "itgt_offset" , max_work_items , inner_tag = "l.0" )
219+ knl = lp .split_iname (knl , "icoeff" , max_work_items , inner_tag = "l.0" )
216220 knl = lp .add_inames_to_insn (knl , "dummy" ,
217221 "id:fetch_init* or id:fetch_center or id:kernel_scaling" )
218222 knl = lp .add_inames_to_insn (knl , "itgt_box" , "id:kernel_scaling" )
@@ -273,10 +277,9 @@ class E2PFromCSR(E2PBase):
273277 def default_name (self ):
274278 return "e2p_from_csr"
275279
276- def get_kernel (self , max_ntargets_in_one_box ):
280+ def get_kernel (self , max_ntargets_in_one_box , max_work_items ):
277281 ncoeffs = len (self .expansion )
278282 loopy_args = self .get_loopy_args ()
279- max_work_items = min (256 , max (ncoeffs , max_ntargets_in_one_box ))
280283
281284 loopy_knl = lp .make_kernel (
282285 [
@@ -379,12 +382,16 @@ def get_kernel(self, max_ntargets_in_one_box):
379382
380383 def get_optimized_kernel (self , max_ntargets_in_one_box ):
381384 _ , optimizations = self .get_loopy_evaluator_and_optimizations ()
382- knl = self .get_kernel (max_ntargets_in_one_box = max_ntargets_in_one_box )
385+ ncoeffs = len (self .expansion )
386+ max_work_items = min (256 , max (ncoeffs , max_ntargets_in_one_box ))
387+
388+ knl = self .get_kernel (max_ntargets_in_one_box = max_ntargets_in_one_box ,
389+ max_work_items = max_work_items )
383390 knl = lp .tag_inames (knl , {"itgt_box" : "g.0" , "dummy" : "l.0" })
384391 knl = lp .unprivatize_temporaries_with_inames (knl ,
385392 "itgt_offset" , "result_temp" )
386- knl = lp .split_iname (knl , "itgt_offset" , 256 , inner_tag = "l.0" )
387- knl = lp .split_iname (knl , "icoeff" , 256 , inner_tag = "l.0" )
393+ knl = lp .split_iname (knl , "itgt_offset" , max_work_items , inner_tag = "l.0" )
394+ knl = lp .split_iname (knl , "icoeff" , max_work_items , inner_tag = "l.0" )
388395 knl = lp .privatize_temporaries_with_inames (knl ,
389396 "itgt_offset_outer" , "result_temp" )
390397 knl = lp .duplicate_inames (knl , "itgt_offset_outer" , "id:init_result" )
0 commit comments