Skip to content

Commit 3738f44

Browse files
authored
rocfft: include batch dimension in MPI sample bricks (#1785)
## Motivation Fix the rocFFT MPI example. Currently it fails to create the FFT plan. ## Technical Details An N-dimensional distributed FFT is specified to rocFFT as a set of bricks. However, the bricks require N+1 indexes for strides and brick coordinates. The bricks were incorrectly specified with N indexes. ## Test Plan Ran the example manually with the fix. MPI unit tests cover additional cases in the correct (N+1) way. ## Test Result Manual run was successful. ## Submission Checklist - [x] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
1 parent 6ce74cd commit 3738f44

1 file changed

Lines changed: 13 additions & 8 deletions

File tree

clients/samples/mpi/rocfft_mpi_example.cpp

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,11 @@ int main(int argc, char** argv)
126126
if(fftrc != rocfft_status_success)
127127
throw std::runtime_error("failed to create description");
128128

129+
// This example is unbatched, so the batch stride is not used
130+
// for anything. For batched examples, this would be
131+
// distance in elements between consecutive batches.
132+
const size_t batch_stride = 0;
133+
129134
if(mpi_rank == 0)
130135
{
131136
std::cout << "input data decomposition:\n";
@@ -135,14 +140,14 @@ int main(int argc, char** argv)
135140
rocfft_field infield = nullptr;
136141
rocfft_field_create(&infield);
137142

138-
std::vector<size_t> inbrick_stride = {1, length[1]};
143+
std::vector<size_t> inbrick_stride = {1, length[1], batch_stride};
139144
const size_t inbrick_length1 = length[1] / (size_t)mpi_size
140145
+ ((size_t)mpi_rank < length[1] % (size_t)mpi_size ? 1 : 0);
141146
const size_t inbrick_lower1
142147
= mpi_rank * (length[1] / mpi_size) + std::min((size_t)mpi_rank, length[1] % mpi_size);
143148
const size_t inbrick_upper1 = inbrick_lower1 + inbrick_length1;
144-
std::vector<size_t> inbrick_lower = {0, inbrick_lower1};
145-
std::vector<size_t> inbrick_upper = {length[0], inbrick_upper1};
149+
std::vector<size_t> inbrick_lower = {0, inbrick_lower1, 0};
150+
std::vector<size_t> inbrick_upper = {length[0], inbrick_upper1, 1};
146151

147152
rocfft_brick inbrick = nullptr;
148153
rocfft_brick_create(&inbrick,
@@ -219,15 +224,15 @@ int main(int argc, char** argv)
219224
std::vector<void*> gpu_out = {nullptr};
220225
std::vector<size_t> outbrick_lower;
221226
std::vector<size_t> outbrick_upper;
222-
std::vector<size_t> outbrick_stride = {1, length[1]};
227+
std::vector<size_t> outbrick_stride = {1, length[1], batch_stride};
223228
{
224229
const size_t outbrick_length1 = length[1] / (size_t)mpi_size
225230
+ ((size_t)mpi_rank < length[1] % (size_t)mpi_size ? 1 : 0);
226231
const size_t outbrick_lower1
227232
= mpi_rank * (length[1] / mpi_size) + std::min((size_t)mpi_rank, length[1] % mpi_size);
228233
const size_t outbrick_upper1 = outbrick_lower1 + outbrick_length1;
229-
outbrick_lower = {0, outbrick_lower1};
230-
outbrick_upper = {length[0], outbrick_upper1};
234+
outbrick_lower = {0, outbrick_lower1, 0};
235+
outbrick_upper = {length[0], outbrick_upper1, 1};
231236

232237
const size_t memSize = length[0] * outbrick_length1 * sizeof(std::complex<double>);
233238
for(int irank = 0; irank < mpi_size; ++irank)
@@ -254,8 +259,8 @@ int main(int argc, char** argv)
254259
rocfft_field_create(&outfield);
255260

256261
rocfft_brick outbrick = nullptr;
257-
outbrick_lower = {0, outbrick_lower1};
258-
outbrick_upper = {length[0], outbrick_lower1 + outbrick_length1};
262+
outbrick_lower = {0, outbrick_lower1, 0};
263+
outbrick_upper = {length[0], outbrick_lower1 + outbrick_length1, 1};
259264
rocfft_brick_create(&outbrick,
260265
outbrick_lower.data(),
261266
outbrick_upper.data(),

0 commit comments

Comments
 (0)