@@ -24,6 +24,10 @@ namespace {
2424 CHECK (ret == LLM_SUCCESS) \
2525 << " Call LlmDataDist function failed, ret = " << std::hex << ret
2626
27+ #define CHECK_ACL_RET (ret ) \
28+ CHECK (ret == ACL_SUCCESS) \
29+ << " Call ACL function failed, ret = " << std::hex << ret
30+
2731const std::map<torch::ScalarType, ge::DataType> kScalarTypeToDtype = {
2832 {torch::kBool , ge::DT_BOOL},
2933 {torch::kByte , ge::DT_UINT8},
@@ -48,29 +52,29 @@ void SpecKVCacheTransfer::allocate_kv_cache(
4852 const int64_t num_layers,
4953 const std::vector<std::vector<int64_t >>& kv_cache_shape,
5054 torch::ScalarType dtype) {
51- _allocate_kv_cache (kv_caches,
52- num_layers,
53- kv_cache_shape,
54- dtype,
55- /* is_spec*/ false ,
56- k_cache_,
57- v_cache_);
55+ allocate_kv_cache_internal (kv_caches,
56+ num_layers,
57+ kv_cache_shape,
58+ dtype,
59+ /* is_spec*/ false ,
60+ k_cache_,
61+ v_cache_);
5862}
5963
6064void SpecKVCacheTransfer::allocate_kv_cache_spec (
6165 std::vector<xllm::KVCache>& kv_caches,
6266 const int64_t num_layers,
6367 const std::vector<std::vector<int64_t >>& kv_cache_shape,
6468 torch::ScalarType dtype) {
65- _allocate_kv_cache (kv_caches,
66- num_layers,
67- kv_cache_shape,
68- dtype,
69- /* is_spec*/ true ,
70- spec_k_cache_,
71- spec_v_cache_);
69+ allocate_kv_cache_internal (kv_caches,
70+ num_layers,
71+ kv_cache_shape,
72+ dtype,
73+ /* is_spec*/ true ,
74+ spec_k_cache_,
75+ spec_v_cache_);
7276}
73- void SpecKVCacheTransfer::_allocate_kv_cache (
77+ void SpecKVCacheTransfer::allocate_kv_cache_internal (
7478 std::vector<xllm::KVCache>& kv_caches,
7579 const int64_t num_layers,
7680 const std::vector<std::vector<int64_t >>& kv_cache_shape,
@@ -87,18 +91,42 @@ void SpecKVCacheTransfer::_allocate_kv_cache(
8791 const auto & it = kScalarTypeToDtype .find (dtype);
8892 CHECK (it != kScalarTypeToDtype .cend ()) << " Unsupport data type : " << dtype;
8993 auto ge_dtype = it->second ;
90- CacheDesc k_cache_desc;
91- k_cache_desc.num_tensors = num_layers;
92- k_cache_desc.data_type = ge_dtype;
93- k_cache_desc.shape = kv_cache_shape[0 ];
94- CHECK_LDD_RET (llm_data_dist_->AllocateCache (k_cache_desc, k_cache));
9594
96- CacheDesc v_cache_desc;
97- v_cache_desc.num_tensors = num_layers;
98- v_cache_desc.data_type = ge_dtype;
99- v_cache_desc.shape = kv_cache_shape[1 ];
100- CHECK_LDD_RET (llm_data_dist_->AllocateCache (v_cache_desc, v_cache));
95+ // calculate the size of kv cache for each layer
96+ auto data_size = torch::elementSize (dtype);
97+ int64_t k_cache_size_per_layer = std::accumulate (kv_cache_shape[0 ].begin (),
98+ kv_cache_shape[0 ].end (),
99+ data_size,
100+ std::multiplies<int64_t >());
101+ int64_t v_cache_size_per_layer = std::accumulate (kv_cache_shape[1 ].begin (),
102+ kv_cache_shape[1 ].end (),
103+ data_size,
104+ std::multiplies<int64_t >());
105+
106+ // allocate device memory for kv cache
107+ std::vector<uint64_t > k_cache_addrs;
108+ std::vector<uint64_t > v_cache_addrs;
109+ k_cache_addrs.reserve (num_layers);
110+ v_cache_addrs.reserve (num_layers);
111+ k_cache.tensor_addrs .reserve (num_layers);
112+ v_cache.tensor_addrs .reserve (num_layers);
113+ for (int64_t i = 0 ; i < num_layers; ++i) {
114+ void * k_cache_buffer = nullptr ;
115+ void * v_cache_buffer = nullptr ;
116+ CHECK_ACL_RET (aclrtMalloc (
117+ &k_cache_buffer, k_cache_size_per_layer, ACL_MEM_MALLOC_HUGE_ONLY));
118+ CHECK_ACL_RET (aclrtMalloc (
119+ &v_cache_buffer, v_cache_size_per_layer, ACL_MEM_MALLOC_HUGE_ONLY));
120+
121+ k_cache_addrs.emplace_back (reinterpret_cast <uint64_t >(k_cache_buffer));
122+ v_cache_addrs.emplace_back (reinterpret_cast <uint64_t >(v_cache_buffer));
123+ k_cache.tensor_addrs .emplace_back (
124+ reinterpret_cast <uintptr_t >(k_cache_buffer));
125+ v_cache.tensor_addrs .emplace_back (
126+ reinterpret_cast <uintptr_t >(v_cache_buffer));
127+ }
101128
129+ // convert memory addrs to torch tensors
102130 auto k_torch_tensors =
103131 convert_to_torch_tensor (kv_cache_shape[0 ], dtype, k_cache.tensor_addrs );
104132 auto v_torch_tensors =
@@ -109,35 +137,40 @@ void SpecKVCacheTransfer::_allocate_kv_cache(
109137 value_cache = v_torch_tensors[i];
110138 kv_caches.emplace_back (key_cache, value_cache);
111139 }
112- }
113140
114- void SpecKVCacheTransfer::allocate_embedding (
115- std::shared_ptr<EmbeddingAllocator> embedding_allocator,
116- const std::vector<int64_t >& embedding_shape,
117- torch::ScalarType dtype,
118- torch::Device device) {
119- const auto & it = kScalarTypeToDtype .find (dtype);
120- CHECK (it != kScalarTypeToDtype .cend ()) << " Unsupport data type : " << dtype;
121- auto ge_dtype = it->second ;
122- CacheDesc embed_cache_desc;
123- embed_cache_desc.num_tensors = 1 ;
124- embed_cache_desc.data_type = ge_dtype;
125- embed_cache_desc.shape = embedding_shape;
126- CHECK_LDD_RET (llm_data_dist_->AllocateCache (embed_cache_desc, embed_cache_));
141+ // register key cache
142+ CacheDesc& k_cache_desc = k_cache.cache_desc ;
143+ k_cache_desc.num_tensors = num_layers;
144+ k_cache_desc.data_type = ge_dtype;
145+ k_cache_desc.shape = kv_cache_shape[0 ];
146+ auto ret = llm_data_dist_->RegisterKvCache (
147+ k_cache_desc, k_cache_addrs, {}, k_cache.cache_id );
148+ CHECK (ret == LLM_SUCCESS)
149+ << " Register key cache failed, ret = " << std::hex << ret;
150+
151+ // register value cache
152+ CacheDesc& v_cache_desc = v_cache.cache_desc ;
153+ v_cache_desc.num_tensors = num_layers;
154+ v_cache_desc.data_type = ge_dtype;
155+ v_cache_desc.shape = kv_cache_shape[1 ];
156+ ret = llm_data_dist_->RegisterKvCache (
157+ v_cache_desc, v_cache_addrs, {}, v_cache.cache_id );
158+ CHECK (ret == LLM_SUCCESS)
159+ << " Register value cache failed, ret = " << std::hex << ret;
127160
128- embed_host_cache_.cache_desc = embed_cache_.cache_desc ;
129- embed_host_cache_.cache_desc .placement = CachePlacement::kHost ;
130- CHECK_EQ (embed_host_cache_.cache_desc .num_tensors , 1 );
131- embed_host_cache_.tensor_addrs .emplace_back (reinterpret_cast <uint64_t >(
132- embedding_allocator->get_embeddings_cache_ptr ()));
161+ LOG (INFO) << " Register KV cache success." ;
133162}
134163
135164void SpecKVCacheTransfer::free_kv_cache () {
136- llm_data_dist_->DeallocateCache (k_cache_.cache_id );
137- llm_data_dist_->DeallocateCache (v_cache_.cache_id );
138- llm_data_dist_->DeallocateCache (spec_k_cache_.cache_id );
139- llm_data_dist_->DeallocateCache (spec_v_cache_.cache_id );
140- llm_data_dist_->DeallocateCache (embed_cache_.cache_id );
165+ auto free_cache = [](const std::vector<uintptr_t >& tensor_addrs) {
166+ for (auto tensor_addr : tensor_addrs) {
167+ CHECK_ACL_RET (aclrtFree (reinterpret_cast <void *>(tensor_addr)));
168+ }
169+ };
170+ free_cache (k_cache_.tensor_addrs );
171+ free_cache (v_cache_.tensor_addrs );
172+ free_cache (spec_k_cache_.tensor_addrs );
173+ free_cache (spec_v_cache_.tensor_addrs );
141174}
142175
143176bool SpecKVCacheTransfer::pull_kv_blocks (
@@ -161,80 +194,59 @@ bool SpecKVCacheTransfer::pull_kv_blocks(
161194 CHECK_LDD_RET (llm_data_dist_->PullKvBlocks (
162195 spec_v_cache_index, spec_v_cache_, src_blocks, dst_blocks));
163196
164- CacheIndex embed_cache_index{src_cluster_id, embed_cache_.cache_id };
165- CHECK_LDD_RET (llm_data_dist_->PullKvBlocks (embed_cache_index,
166- embed_cache_,
167- {src_blocks.back ()},
168- {dst_blocks.back ()}));
169197 return true ;
170198}
171199
172200bool SpecKVCacheTransfer::push_kv_blocks (
173201 std::unordered_map<std::string, KVCacheInfo>& merged_kv_infos,
174202 std::shared_ptr<NPULayerSynchronizerImpl>& layer_synchronizer,
175203 bool is_spec_draft) {
176- if (!layer_synchronizer) {
177- return push_embed_blocks (merged_kv_infos);
178- }
179-
180204 if (is_spec_draft) {
181205 return push_kv_blocks_spec (merged_kv_infos, layer_synchronizer);
206+ } else {
207+ return push_kv_blocks_internal (
208+ merged_kv_infos, layer_synchronizer, num_layers_, k_cache_, v_cache_);
182209 }
183- for (int64_t layer_index = 0 ; layer_index < num_layers_; ++layer_index) {
184- // Wait for the KV cache computation of this layer to complete.
185- layer_synchronizer->synchronize_layer (layer_index);
186- // Push the KV Cache computed at this layer for all requests to the
187- // designated worker.
188- for (const auto & pair : merged_kv_infos) {
189- const KVCacheInfo& kv_info = pair.second ;
190- CacheIndex k_cache_index{kv_info.dst_cluster_id , k_cache_.cache_id };
191- CacheIndex v_cache_index{kv_info.dst_cluster_id , v_cache_.cache_id };
192- KvCacheExtParam ext_param{};
193- ext_param.src_layer_range =
194- std::pair<int32_t , int32_t >(layer_index, layer_index);
195- ext_param.dst_layer_range =
196- std::pair<int32_t , int32_t >(layer_index, layer_index);
197- ext_param.tensor_num_per_layer = 1 ;
198- CHECK_LDD_RET (llm_data_dist_->PushKvBlocks (k_cache_,
199- k_cache_index,
200- kv_info.src_blocks ,
201- kv_info.dst_blocks ,
202- ext_param));
203- CHECK_LDD_RET (llm_data_dist_->PushKvBlocks (v_cache_,
204- v_cache_index,
205- kv_info.src_blocks ,
206- kv_info.dst_blocks ,
207- ext_param));
208- }
209- }
210- return true ;
211210}
212211
213212bool SpecKVCacheTransfer::push_kv_blocks_spec (
214213 std::unordered_map<std::string, KVCacheInfo>& merged_kv_infos,
215214 std::shared_ptr<NPULayerSynchronizerImpl>& layer_synchronizer) {
216- for (int64_t layer_index = 0 ; layer_index < spec_num_layers_; ++layer_index) {
215+ return push_kv_blocks_internal (merged_kv_infos,
216+ layer_synchronizer,
217+ spec_num_layers_,
218+ spec_k_cache_,
219+ spec_v_cache_);
220+ }
221+
222+ bool SpecKVCacheTransfer::push_kv_blocks_internal (
223+ std::unordered_map<std::string, KVCacheInfo>& merged_kv_infos,
224+ std::shared_ptr<NPULayerSynchronizerImpl>& layer_synchronizer,
225+ int64_t num_layers,
226+ const Cache& k_cache,
227+ const Cache& v_cache) {
228+ for (int64_t layer_index = 0 ; layer_index < num_layers; ++layer_index) {
217229 // Wait for the KV cache computation of this layer to complete.
218230 layer_synchronizer->synchronize_layer (layer_index);
231+
219232 // Push the KV Cache computed at this layer for all requests to the
220233 // designated worker.
221234 for (const auto & pair : merged_kv_infos) {
222235 const KVCacheInfo& kv_info = pair.second ;
223- CacheIndex k_cache_index{kv_info.dst_cluster_id , spec_k_cache_.cache_id };
224- CacheIndex v_cache_index{kv_info.dst_cluster_id , spec_v_cache_.cache_id };
236+ CacheIndex k_cache_index{kv_info.dst_cluster_id , k_cache.cache_id };
237+ CacheIndex v_cache_index{kv_info.dst_cluster_id , v_cache.cache_id };
238+
225239 KvCacheExtParam ext_param{};
226- ext_param.src_layer_range =
227- std::pair<int32_t , int32_t >(layer_index, layer_index);
228- ext_param.dst_layer_range =
229- std::pair<int32_t , int32_t >(layer_index, layer_index);
240+ ext_param.src_layer_range = {layer_index, layer_index};
241+ ext_param.dst_layer_range = {layer_index, layer_index};
230242 ext_param.tensor_num_per_layer = 1 ;
231243
232- CHECK_LDD_RET (llm_data_dist_->PushKvBlocks (spec_k_cache_ ,
244+ CHECK_LDD_RET (llm_data_dist_->PushKvBlocks (k_cache ,
233245 k_cache_index,
234246 kv_info.src_blocks ,
235247 kv_info.dst_blocks ,
236248 ext_param));
237- CHECK_LDD_RET (llm_data_dist_->PushKvBlocks (spec_v_cache_ ,
249+ CHECK_LDD_RET (llm_data_dist_->PushKvBlocks (v_cache ,
238250 v_cache_index,
239251 kv_info.src_blocks ,
240252 kv_info.dst_blocks ,
@@ -244,24 +256,6 @@ bool SpecKVCacheTransfer::push_kv_blocks_spec(
244256 return true ;
245257}
246258
247- bool SpecKVCacheTransfer::push_embed_blocks (
248- std::unordered_map<std::string, KVCacheInfo>& merged_kv_infos) {
249- for (const auto & pair : merged_kv_infos) {
250- const KVCacheInfo& kv_info = pair.second ;
251- CacheIndex cache_index{kv_info.dst_cluster_id , embed_cache_.cache_id };
252- KvCacheExtParam ext_param{};
253- ext_param.src_layer_range = std::pair<int32_t , int32_t >(0 , 0 );
254- ext_param.dst_layer_range = std::pair<int32_t , int32_t >(0 , 0 );
255- ext_param.tensor_num_per_layer = 1 ;
256- CHECK_LDD_RET (llm_data_dist_->PushKvBlocks (embed_cache_,
257- cache_index,
258- kv_info.src_embed_ids ,
259- kv_info.dst_embed_ids ,
260- ext_param));
261- }
262- return true ;
263- }
264-
265259folly::SemiFuture<bool > SpecKVCacheTransfer::push_kv_blocks_async (
266260 const std::vector<TransferKVInfo>& transfer_kv_infos,
267261 const ParallelArgs& parallel_args,
@@ -323,16 +317,18 @@ void SpecKVCacheTransfer::merge_kv_blocks(
323317 i < dst_tp_size * (dst_dp_rank + 1 );
324318 i += src_tp_size) {
325319 uint64_t dst_cluster_id = info.remote_instance_info .cluster_ids [i];
320+ auto & dst_addr = info.remote_instance_info .addrs [i];
326321 int64_t k_cache_id = info.remote_instance_info .k_cache_ids [i];
327322 int64_t v_cache_id = info.remote_instance_info .v_cache_ids [i];
328- std::string key = std::to_string (dst_cluster_id) + " _" +
323+ std::string key = std::to_string (dst_cluster_id) + " _" + dst_addr + " _ " +
329324 std::to_string (k_cache_id) + " _" +
330325 std::to_string (v_cache_id);
331326 // Merge all kv blocks with the same destination worker into a single
332327 // vector.
333328 if (merged_kv_infos.find (key) == merged_kv_infos.end ()) {
334329 KVCacheInfo kv_info;
335330 kv_info.dst_cluster_id = dst_cluster_id;
331+ kv_info.dst_addr = dst_addr;
336332 kv_info.dst_k_cache_id = k_cache_id;
337333 kv_info.dst_v_cache_id = v_cache_id;
338334 kv_info.src_blocks .insert (kv_info.src_blocks .end (),
@@ -341,8 +337,6 @@ void SpecKVCacheTransfer::merge_kv_blocks(
341337 kv_info.dst_blocks .insert (kv_info.dst_blocks .end (),
342338 info.remote_blocks_ids .begin (),
343339 info.remote_blocks_ids .end ());
344- kv_info.src_embed_ids .push_back (kv_info.src_blocks .back ());
345- kv_info.dst_embed_ids .push_back (kv_info.dst_blocks .back ());
346340 merged_kv_infos[key] = std::move (kv_info);
347341 } else {
348342 merged_kv_infos[key].src_blocks .insert (
@@ -353,28 +347,8 @@ void SpecKVCacheTransfer::merge_kv_blocks(
353347 merged_kv_infos[key].dst_blocks .end (),
354348 info.remote_blocks_ids .begin (),
355349 info.remote_blocks_ids .end ());
356- merged_kv_infos[key].src_embed_ids .push_back (
357- merged_kv_infos[key].src_blocks .back ());
358- merged_kv_infos[key].dst_embed_ids .push_back (
359- merged_kv_infos[key].dst_blocks .back ());
360350 }
361351 }
362352 }
363353}
364-
365- void SpecKVCacheTransfer::copy_blocks (const std::vector<int >& blocks,
366- bool h2d) {
367- std::vector<uint64_t > _blocks;
368- _blocks.reserve (blocks.size ());
369- for (const auto & block : blocks) {
370- _blocks.push_back (static_cast <uint64_t >(block));
371- }
372- if (h2d) {
373- CHECK_LDD_RET (llm_data_dist_->CopyKvBlocks (
374- embed_host_cache_, embed_cache_, _blocks, {_blocks}));
375- } else {
376- CHECK_LDD_RET (llm_data_dist_->CopyKvBlocks (
377- embed_cache_, embed_host_cache_, _blocks, {_blocks}));
378- }
379- }
380354} // namespace xllm
0 commit comments