diff --git a/ompi/mca/coll/ucc/coll_ucc.h b/ompi/mca/coll/ucc/coll_ucc.h index da2d1d2e141..e60b3372433 100644 --- a/ompi/mca/coll/ucc/coll_ucc.h +++ b/ompi/mca/coll/ucc/coll_ucc.h @@ -168,6 +168,7 @@ OBJ_CLASS_DECLARATION(mca_coll_ucc_module_t); int mca_coll_ucc_init_query(bool enable_progress_threads, bool enable_mpi_threads); mca_coll_base_module_t *mca_coll_ucc_comm_query(struct ompi_communicator_t *comm, int *priority); +void mca_coll_ucc_finalize_ctx(void); int mca_coll_ucc_allreduce(const void *sbuf, void *rbuf, size_t count, struct ompi_datatype_t *dtype, struct ompi_op_t *op, diff --git a/ompi/mca/coll/ucc/coll_ucc_component.c b/ompi/mca/coll/ucc/coll_ucc_component.c index 4fde1e0a999..d604537ae23 100644 --- a/ompi/mca/coll/ucc/coll_ucc_component.c +++ b/ompi/mca/coll/ucc/coll_ucc_component.c @@ -256,5 +256,6 @@ static int mca_coll_ucc_open(void) static int mca_coll_ucc_close(void) { + mca_coll_ucc_finalize_ctx(); return OMPI_SUCCESS; } diff --git a/ompi/mca/coll/ucc/coll_ucc_module.c b/ompi/mca/coll/ucc/coll_ucc_module.c index 028382df344..d227685156d 100644 --- a/ompi/mca/coll/ucc/coll_ucc_module.c +++ b/ompi/mca/coll/ucc/coll_ucc_module.c @@ -19,7 +19,7 @@ #include "ompi/mca/coll/base/coll_tags.h" #include "ompi/mca/pml/pml.h" -static int ucc_comm_attr_keyval; + /* * Initial query function that is invoked during MPI_INIT, allowing * this module to indicate what level of thread support it provides. @@ -129,37 +129,37 @@ static int mca_coll_ucc_progress(void) return OPAL_SUCCESS; } -static void mca_coll_ucc_module_destruct(mca_coll_ucc_module_t *ucc_module) +void mca_coll_ucc_finalize_ctx(void) { - if (ucc_module->comm == &ompi_mpi_comm_world.comm){ - if (OMPI_SUCCESS != ompi_attr_free_keyval(COMM_ATTR, &ucc_comm_attr_keyval, 0)) { - UCC_ERROR("ucc ompi_attr_free_keyval failed"); - } + mca_coll_ucc_component_t *cm = &mca_coll_ucc_component; + if (!cm->libucc_initialized) { + return; } - mca_coll_ucc_module_clear(ucc_module); + UCC_VERBOSE(1, "finalizing ucc library"); + opal_progress_unregister(mca_coll_ucc_progress); + ucc_context_destroy(cm->ucc_context); + ucc_finalize(cm->ucc_lib); + OBJ_DESTRUCT(&cm->requests); + cm->libucc_initialized = false; } -/* -** Communicator free callback -*/ -static int ucc_comm_attr_del_fn(MPI_Comm comm, int keyval, void *attr_val, void *extra) +static void mca_coll_ucc_module_destruct(mca_coll_ucc_module_t *ucc_module) { - mca_coll_ucc_module_t *ucc_module = (mca_coll_ucc_module_t*) attr_val; - ucc_status_t status; - while(UCC_INPROGRESS == (status = ucc_team_destroy(ucc_module->ucc_team))) {} - if (ucc_module->comm == &ompi_mpi_comm_world.comm) { - if (mca_coll_ucc_component.libucc_initialized) { - UCC_VERBOSE(1,"finalizing ucc library"); - opal_progress_unregister(mca_coll_ucc_progress); - ucc_context_destroy(mca_coll_ucc_component.ucc_context); - ucc_finalize(mca_coll_ucc_component.ucc_lib); + if (ucc_module->ucc_team != NULL) { + ucc_status_t status; + while (UCC_INPROGRESS == (status = ucc_team_destroy(ucc_module->ucc_team))) {} + if (UCC_OK != status) { + UCC_ERROR("UCC team destroy failed"); } } - if (UCC_OK != status) { - UCC_ERROR("UCC team destroy failed"); - return OMPI_ERROR; + /* ucc_context_destroy needs OOB via MPI_COMM_WORLD; call it while + COMM_WORLD is still alive (module destructor fires before c_local_group + is released in ompi_comm_destruct). mca_coll_ucc_close() will call + mca_coll_ucc_finalize_ctx() as a no-op safety net if already done. */ + if (ucc_module->comm == &ompi_mpi_comm_world.comm) { + mca_coll_ucc_finalize_ctx(); } - return OMPI_SUCCESS; + mca_coll_ucc_module_clear(ucc_module); } typedef struct oob_allgather_req{ @@ -253,8 +253,6 @@ static int mca_coll_ucc_init_ctx(ompi_communicator_t* comm) { mca_coll_ucc_component_t *cm = &mca_coll_ucc_component; char str_buf[256]; - ompi_attribute_fn_ptr_union_t del_fn; - ompi_attribute_fn_ptr_union_t copy_fn; ucc_lib_config_h lib_config; ucc_context_config_h ctx_config; ucc_thread_mode_t tm_requested; @@ -343,14 +341,6 @@ static int mca_coll_ucc_init_ctx(ompi_communicator_t* comm) } ucc_context_config_release(ctx_config); - copy_fn.attr_communicator_copy_fn = MPI_COMM_NULL_COPY_FN; - del_fn.attr_communicator_delete_fn = ucc_comm_attr_del_fn; - if (OMPI_SUCCESS != ompi_attr_create_keyval(COMM_ATTR, copy_fn, del_fn, - &ucc_comm_attr_keyval, NULL ,0, NULL)) { - UCC_ERROR("UCC comm keyval create failed"); - goto cleanup_ctx; - } - OBJ_CONSTRUCT(&cm->requests, opal_free_list_t); opal_free_list_init(&cm->requests, sizeof(mca_coll_ucc_req_t), opal_cache_line_size, OBJ_CLASS(mca_coll_ucc_req_t), @@ -362,9 +352,6 @@ static int mca_coll_ucc_init_ctx(ompi_communicator_t* comm) UCC_VERBOSE(1, "initialized ucc context"); cm->libucc_initialized = true; return OMPI_SUCCESS; -cleanup_ctx: - ucc_context_destroy(cm->ucc_context); - cleanup_lib: ucc_finalize(cm->ucc_lib); cm->ucc_enable = 0; @@ -478,7 +465,6 @@ static int mca_coll_ucc_module_enable(mca_coll_base_module_t *module, mca_coll_ucc_component_t *cm = &mca_coll_ucc_component; mca_coll_ucc_module_t *ucc_module = (mca_coll_ucc_module_t *)module; ucc_status_t status; - int rc; ucc_team_params_t team_params = { .mask = UCC_TEAM_PARAM_FIELD_EP_MAP | UCC_TEAM_PARAM_FIELD_EP | @@ -523,13 +509,6 @@ static int mca_coll_ucc_module_enable(mca_coll_base_module_t *module, goto err; } - rc = ompi_attr_set_c(COMM_ATTR, comm, &comm->c_keyhash, - ucc_comm_attr_keyval, (void *)module, false); - if (OMPI_SUCCESS != rc) { - UCC_ERROR("ucc ompi_attr_set_c failed"); - goto err; - } - return OMPI_SUCCESS; err: