Skip to content

Commit ab17b0e

Browse files
committed
UCC/CTX: passing cuda check from tl ucp to others
1 parent 777df69 commit ab17b0e

File tree

4 files changed

+15
-0
lines changed

4 files changed

+15
-0
lines changed

src/components/tl/mlx5/mcast/tl_mlx5_mcast.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,7 @@ typedef struct ucc_tl_mlx5_mcast_coll_context {
152152
ucc_rcache_t *rcache;
153153
ucc_tl_mlx5_mcast_ctx_params_t params;
154154
ucc_base_lib_t *lib;
155+
enum ucc_tl_capabilities tl_caps;
155156
} ucc_tl_mlx5_mcast_coll_context_t;
156157

157158
typedef struct ucc_tl_mlx5_mcast_join_info_t {

src/components/tl/mlx5/mcast/tl_mlx5_mcast_team.c

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,8 @@ ucc_status_t ucc_tl_mlx5_mcast_team_init(ucc_base_context_t *base_context,
6868
conf_params->rx_sge = 2;
6969
conf_params->scq_moderation = 64;
7070

71+
mcast_context->tl_caps = base_context->ucc_context->tl_caps;
72+
7173
comm = (ucc_tl_mlx5_mcast_coll_comm_t*)
7274
ucc_calloc(1, sizeof(ucc_tl_mlx5_mcast_coll_comm_t) +
7375
sizeof(struct pp_packet*)*(conf_params->wsize-1),

src/components/tl/ucp/tl_ucp_context.c

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,11 @@ UCC_CLASS_INIT_FUNC(ucc_tl_ucp_context_t,
194194
self);
195195

196196
self->ucp_memory_types = context_attr.memory_types;
197+
if (self->ucp_memory_types & UCC_BIT(ucc_memtype_to_ucs[UCC_MEMORY_TYPE_CUDA])) {
198+
/* TL MLX5 needs this information */
199+
self->super.super.ucc_context->tl_caps |= UCC_TL_UCP_CUDA_ENABLED;
200+
}
201+
197202
worker_params.field_mask = UCP_WORKER_PARAM_FIELD_THREAD_MODE;
198203
switch (params->thread_mode) {
199204
case UCC_THREAD_SINGLE:

src/core/ucc_context.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,12 @@ typedef struct ucc_context_id {
3535
#define UCC_CTX_ID_EQUAL(_id1, _id2) (UCC_PROC_INFO_EQUAL((_id1).pi, (_id2).pi) \
3636
&& (_id1).seq_num == (_id2).seq_num)
3737

38+
enum ucc_tl_capabilities {
39+
/* capabalities that every TL needs to be aware of
40+
* about other TLs */
41+
UCC_TL_UCP_CUDA_ENABLED = UCC_BIT(0)
42+
};
43+
3844
enum {
3945
/* all ranks have identical set of TLs*/
4046
UCC_ADDR_STORAGE_FLAG_TLS_SYMMETRIC = UCC_BIT(0),
@@ -78,6 +84,7 @@ typedef struct ucc_context {
7884
uint64_t cl_flags;
7985
ucc_tl_team_t *service_team;
8086
int32_t throttle_progress;
87+
enum ucc_tl_capabilities tl_caps;
8188
} ucc_context_t;
8289

8390
typedef struct ucc_context_config {

0 commit comments

Comments
 (0)