@@ -48,29 +48,29 @@ void SpecKVCacheTransfer::allocate_kv_cache(
4848 const int64_t num_layers,
4949 const std::vector<std::vector<int64_t >>& kv_cache_shape,
5050 torch::ScalarType dtype) {
51- _allocate_kv_cache (kv_caches,
52- num_layers,
53- kv_cache_shape,
54- dtype,
55- /* is_spec*/ false ,
56- k_cache_,
57- v_cache_);
51+ allocate_kv_cache_internal (kv_caches,
52+ num_layers,
53+ kv_cache_shape,
54+ dtype,
55+ /* is_spec*/ false ,
56+ k_cache_,
57+ v_cache_);
5858}
5959
6060void SpecKVCacheTransfer::allocate_kv_cache_spec (
6161 std::vector<xllm::KVCache>& kv_caches,
6262 const int64_t num_layers,
6363 const std::vector<std::vector<int64_t >>& kv_cache_shape,
6464 torch::ScalarType dtype) {
65- _allocate_kv_cache (kv_caches,
66- num_layers,
67- kv_cache_shape,
68- dtype,
69- /* is_spec*/ true ,
70- spec_k_cache_,
71- spec_v_cache_);
65+ allocate_kv_cache_internal (kv_caches,
66+ num_layers,
67+ kv_cache_shape,
68+ dtype,
69+ /* is_spec*/ true ,
70+ spec_k_cache_,
71+ spec_v_cache_);
7272}
73- void SpecKVCacheTransfer::_allocate_kv_cache (
73+ void SpecKVCacheTransfer::allocate_kv_cache_internal (
7474 std::vector<xllm::KVCache>& kv_caches,
7575 const int64_t num_layers,
7676 const std::vector<std::vector<int64_t >>& kv_cache_shape,
@@ -87,18 +87,42 @@ void SpecKVCacheTransfer::_allocate_kv_cache(
8787 const auto & it = kScalarTypeToDtype .find (dtype);
8888 CHECK (it != kScalarTypeToDtype .cend ()) << " Unsupport data type : " << dtype;
8989 auto ge_dtype = it->second ;
90- CacheDesc k_cache_desc;
91- k_cache_desc.num_tensors = num_layers;
92- k_cache_desc.data_type = ge_dtype;
93- k_cache_desc.shape = kv_cache_shape[0 ];
94- CHECK_LDD_RET (llm_data_dist_->AllocateCache (k_cache_desc, k_cache));
9590
96- CacheDesc v_cache_desc;
97- v_cache_desc.num_tensors = num_layers;
98- v_cache_desc.data_type = ge_dtype;
99- v_cache_desc.shape = kv_cache_shape[1 ];
100- CHECK_LDD_RET (llm_data_dist_->AllocateCache (v_cache_desc, v_cache));
91+ // calculate the size of kv cache for each layer
92+ auto data_size = torch::elementSize (dtype);
93+ int64_t k_cache_size_per_layer = std::accumulate (kv_cache_shape[0 ].begin (),
94+ kv_cache_shape[0 ].end (),
95+ data_size,
96+ std::multiplies<int64_t >());
97+ int64_t v_cache_size_per_layer = std::accumulate (kv_cache_shape[1 ].begin (),
98+ kv_cache_shape[1 ].end (),
99+ data_size,
100+ std::multiplies<int64_t >());
101+
102+ // allocate device memory for kv cache
103+ std::vector<uint64_t > k_cache_addrs;
104+ std::vector<uint64_t > v_cache_addrs;
105+ k_cache_addrs.reserve (num_layers);
106+ v_cache_addrs.reserve (num_layers);
107+ k_cache.tensor_addrs .reserve (num_layers);
108+ v_cache.tensor_addrs .reserve (num_layers);
109+ for (int64_t i = 0 ; i < num_layers; ++i) {
110+ void * k_cache_buffer = nullptr ;
111+ void * v_cache_buffer = nullptr ;
112+ CHECK_LDD_RET (aclrtMalloc (
113+ &k_cache_buffer, k_cache_size_per_layer, ACL_MEM_MALLOC_HUGE_ONLY));
114+ CHECK_LDD_RET (aclrtMalloc (
115+ &v_cache_buffer, v_cache_size_per_layer, ACL_MEM_MALLOC_HUGE_ONLY));
101116
117+ k_cache_addrs.emplace_back (reinterpret_cast <uint64_t >(k_cache_buffer));
118+ v_cache_addrs.emplace_back (reinterpret_cast <uint64_t >(v_cache_buffer));
119+ k_cache.tensor_addrs .emplace_back (
120+ reinterpret_cast <uintptr_t >(k_cache_buffer));
121+ v_cache.tensor_addrs .emplace_back (
122+ reinterpret_cast <uintptr_t >(v_cache_buffer));
123+ }
124+
125+ // convert memory addrs to torch tensors
102126 auto k_torch_tensors =
103127 convert_to_torch_tensor (kv_cache_shape[0 ], dtype, k_cache.tensor_addrs );
104128 auto v_torch_tensors =
@@ -109,35 +133,40 @@ void SpecKVCacheTransfer::_allocate_kv_cache(
109133 value_cache = v_torch_tensors[i];
110134 kv_caches.emplace_back (key_cache, value_cache);
111135 }
112- }
113136
114- void SpecKVCacheTransfer::allocate_embedding (
115- std::shared_ptr<EmbeddingAllocator> embedding_allocator,
116- const std::vector<int64_t >& embedding_shape,
117- torch::ScalarType dtype,
118- torch::Device device) {
119- const auto & it = kScalarTypeToDtype .find (dtype);
120- CHECK (it != kScalarTypeToDtype .cend ()) << " Unsupport data type : " << dtype;
121- auto ge_dtype = it->second ;
122- CacheDesc embed_cache_desc;
123- embed_cache_desc.num_tensors = 1 ;
124- embed_cache_desc.data_type = ge_dtype;
125- embed_cache_desc.shape = embedding_shape;
126- CHECK_LDD_RET (llm_data_dist_->AllocateCache (embed_cache_desc, embed_cache_));
137+ // register key cache
138+ CacheDesc& k_cache_desc = k_cache.cache_desc ;
139+ k_cache_desc.num_tensors = num_layers;
140+ k_cache_desc.data_type = ge_dtype;
141+ k_cache_desc.shape = kv_cache_shape[0 ];
142+ auto ret = llm_data_dist_->RegisterKvCache (
143+ k_cache_desc, k_cache_addrs, {}, k_cache.cache_id );
144+ CHECK (ret == LLM_SUCCESS)
145+ << " Register key cache failed, ret = " << std::hex << ret;
146+
147+ // register value cache
148+ CacheDesc& v_cache_desc = v_cache.cache_desc ;
149+ v_cache_desc.num_tensors = num_layers;
150+ v_cache_desc.data_type = ge_dtype;
151+ v_cache_desc.shape = kv_cache_shape[1 ];
152+ ret = llm_data_dist_->RegisterKvCache (
153+ v_cache_desc, v_cache_addrs, {}, v_cache.cache_id );
154+ CHECK (ret == LLM_SUCCESS)
155+ << " Register value cache failed, ret = " << std::hex << ret;
127156
128- embed_host_cache_.cache_desc = embed_cache_.cache_desc ;
129- embed_host_cache_.cache_desc .placement = CachePlacement::kHost ;
130- CHECK_EQ (embed_host_cache_.cache_desc .num_tensors , 1 );
131- embed_host_cache_.tensor_addrs .emplace_back (reinterpret_cast <uint64_t >(
132- embedding_allocator->get_embeddings_cache_ptr ()));
157+ LOG (INFO) << " Register KV cache success." ;
133158}
134159
135160void SpecKVCacheTransfer::free_kv_cache () {
136- llm_data_dist_->DeallocateCache (k_cache_.cache_id );
137- llm_data_dist_->DeallocateCache (v_cache_.cache_id );
138- llm_data_dist_->DeallocateCache (spec_k_cache_.cache_id );
139- llm_data_dist_->DeallocateCache (spec_v_cache_.cache_id );
140- llm_data_dist_->DeallocateCache (embed_cache_.cache_id );
161+ auto free_cache = [](const std::vector<uintptr_t >& tensor_addrs) {
162+ for (auto tensor_addr : tensor_addrs) {
163+ aclrtFree (reinterpret_cast <void *>(tensor_addr));
164+ }
165+ };
166+ free_cache (k_cache_.tensor_addrs );
167+ free_cache (v_cache_.tensor_addrs );
168+ free_cache (spec_k_cache_.tensor_addrs );
169+ free_cache (spec_v_cache_.tensor_addrs );
141170}
142171
143172bool SpecKVCacheTransfer::pull_kv_blocks (
@@ -161,80 +190,59 @@ bool SpecKVCacheTransfer::pull_kv_blocks(
161190 CHECK_LDD_RET (llm_data_dist_->PullKvBlocks (
162191 spec_v_cache_index, spec_v_cache_, src_blocks, dst_blocks));
163192
164- CacheIndex embed_cache_index{src_cluster_id, embed_cache_.cache_id };
165- CHECK_LDD_RET (llm_data_dist_->PullKvBlocks (embed_cache_index,
166- embed_cache_,
167- {src_blocks.back ()},
168- {dst_blocks.back ()}));
169193 return true ;
170194}
171195
172196bool SpecKVCacheTransfer::push_kv_blocks (
173197 std::unordered_map<std::string, KVCacheInfo>& merged_kv_infos,
174198 std::shared_ptr<NPULayerSynchronizerImpl>& layer_synchronizer,
175199 bool is_spec_draft) {
176- if (!layer_synchronizer) {
177- return push_embed_blocks (merged_kv_infos);
178- }
179-
180200 if (is_spec_draft) {
181201 return push_kv_blocks_spec (merged_kv_infos, layer_synchronizer);
202+ } else {
203+ return push_kv_blocks_internal (
204+ merged_kv_infos, layer_synchronizer, num_layers_, k_cache_, v_cache_);
182205 }
183- for (int64_t layer_index = 0 ; layer_index < num_layers_; ++layer_index) {
184- // Wait for the KV cache computation of this layer to complete.
185- layer_synchronizer->synchronize_layer (layer_index);
186- // Push the KV Cache computed at this layer for all requests to the
187- // designated worker.
188- for (const auto & pair : merged_kv_infos) {
189- const KVCacheInfo& kv_info = pair.second ;
190- CacheIndex k_cache_index{kv_info.dst_cluster_id , k_cache_.cache_id };
191- CacheIndex v_cache_index{kv_info.dst_cluster_id , v_cache_.cache_id };
192- KvCacheExtParam ext_param{};
193- ext_param.src_layer_range =
194- std::pair<int32_t , int32_t >(layer_index, layer_index);
195- ext_param.dst_layer_range =
196- std::pair<int32_t , int32_t >(layer_index, layer_index);
197- ext_param.tensor_num_per_layer = 1 ;
198- CHECK_LDD_RET (llm_data_dist_->PushKvBlocks (k_cache_,
199- k_cache_index,
200- kv_info.src_blocks ,
201- kv_info.dst_blocks ,
202- ext_param));
203- CHECK_LDD_RET (llm_data_dist_->PushKvBlocks (v_cache_,
204- v_cache_index,
205- kv_info.src_blocks ,
206- kv_info.dst_blocks ,
207- ext_param));
208- }
209- }
210- return true ;
211206}
212207
213208bool SpecKVCacheTransfer::push_kv_blocks_spec (
214209 std::unordered_map<std::string, KVCacheInfo>& merged_kv_infos,
215210 std::shared_ptr<NPULayerSynchronizerImpl>& layer_synchronizer) {
216- for (int64_t layer_index = 0 ; layer_index < spec_num_layers_; ++layer_index) {
211+ return push_kv_blocks_internal (merged_kv_infos,
212+ layer_synchronizer,
213+ spec_num_layers_,
214+ spec_k_cache_,
215+ spec_v_cache_);
216+ }
217+
218+ bool SpecKVCacheTransfer::push_kv_blocks_internal (
219+ std::unordered_map<std::string, KVCacheInfo>& merged_kv_infos,
220+ std::shared_ptr<NPULayerSynchronizerImpl>& layer_synchronizer,
221+ int64_t num_layers,
222+ const Cache& k_cache,
223+ const Cache& v_cache) {
224+ for (int64_t layer_index = 0 ; layer_index < num_layers; ++layer_index) {
217225 // Wait for the KV cache computation of this layer to complete.
218226 layer_synchronizer->synchronize_layer (layer_index);
227+
219228 // Push the KV Cache computed at this layer for all requests to the
220229 // designated worker.
221230 for (const auto & pair : merged_kv_infos) {
222231 const KVCacheInfo& kv_info = pair.second ;
223- CacheIndex k_cache_index{kv_info.dst_cluster_id , spec_k_cache_.cache_id };
224- CacheIndex v_cache_index{kv_info.dst_cluster_id , spec_v_cache_.cache_id };
232+ CacheIndex k_cache_index{kv_info.dst_cluster_id , k_cache.cache_id };
233+ CacheIndex v_cache_index{kv_info.dst_cluster_id , v_cache.cache_id };
234+
225235 KvCacheExtParam ext_param{};
226- ext_param.src_layer_range =
227- std::pair<int32_t , int32_t >(layer_index, layer_index);
228- ext_param.dst_layer_range =
229- std::pair<int32_t , int32_t >(layer_index, layer_index);
236+ ext_param.src_layer_range = {layer_index, layer_index};
237+ ext_param.dst_layer_range = {layer_index, layer_index};
230238 ext_param.tensor_num_per_layer = 1 ;
231239
232- CHECK_LDD_RET (llm_data_dist_->PushKvBlocks (spec_k_cache_ ,
240+ CHECK_LDD_RET (llm_data_dist_->PushKvBlocks (k_cache ,
233241 k_cache_index,
234242 kv_info.src_blocks ,
235243 kv_info.dst_blocks ,
236244 ext_param));
237- CHECK_LDD_RET (llm_data_dist_->PushKvBlocks (spec_v_cache_ ,
245+ CHECK_LDD_RET (llm_data_dist_->PushKvBlocks (v_cache ,
238246 v_cache_index,
239247 kv_info.src_blocks ,
240248 kv_info.dst_blocks ,
@@ -244,24 +252,6 @@ bool SpecKVCacheTransfer::push_kv_blocks_spec(
244252 return true ;
245253}
246254
247- bool SpecKVCacheTransfer::push_embed_blocks (
248- std::unordered_map<std::string, KVCacheInfo>& merged_kv_infos) {
249- for (const auto & pair : merged_kv_infos) {
250- const KVCacheInfo& kv_info = pair.second ;
251- CacheIndex cache_index{kv_info.dst_cluster_id , embed_cache_.cache_id };
252- KvCacheExtParam ext_param{};
253- ext_param.src_layer_range = std::pair<int32_t , int32_t >(0 , 0 );
254- ext_param.dst_layer_range = std::pair<int32_t , int32_t >(0 , 0 );
255- ext_param.tensor_num_per_layer = 1 ;
256- CHECK_LDD_RET (llm_data_dist_->PushKvBlocks (embed_cache_,
257- cache_index,
258- kv_info.src_embed_ids ,
259- kv_info.dst_embed_ids ,
260- ext_param));
261- }
262- return true ;
263- }
264-
265255folly::SemiFuture<bool > SpecKVCacheTransfer::push_kv_blocks_async (
266256 const std::vector<TransferKVInfo>& transfer_kv_infos,
267257 const ParallelArgs& parallel_args,
@@ -323,16 +313,18 @@ void SpecKVCacheTransfer::merge_kv_blocks(
323313 i < dst_tp_size * (dst_dp_rank + 1 );
324314 i += src_tp_size) {
325315 uint64_t dst_cluster_id = info.remote_instance_info .cluster_ids [i];
316+ auto & dst_addr = info.remote_instance_info .addrs [i];
326317 int64_t k_cache_id = info.remote_instance_info .k_cache_ids [i];
327318 int64_t v_cache_id = info.remote_instance_info .v_cache_ids [i];
328- std::string key = std::to_string (dst_cluster_id) + " _" +
319+ std::string key = std::to_string (dst_cluster_id) + " _" + dst_addr + " _ " +
329320 std::to_string (k_cache_id) + " _" +
330321 std::to_string (v_cache_id);
331322 // Merge all kv blocks with the same destination worker into a single
332323 // vector.
333324 if (merged_kv_infos.find (key) == merged_kv_infos.end ()) {
334325 KVCacheInfo kv_info;
335326 kv_info.dst_cluster_id = dst_cluster_id;
327+ kv_info.dst_addr = dst_addr;
336328 kv_info.dst_k_cache_id = k_cache_id;
337329 kv_info.dst_v_cache_id = v_cache_id;
338330 kv_info.src_blocks .insert (kv_info.src_blocks .end (),
@@ -341,8 +333,6 @@ void SpecKVCacheTransfer::merge_kv_blocks(
341333 kv_info.dst_blocks .insert (kv_info.dst_blocks .end (),
342334 info.remote_blocks_ids .begin (),
343335 info.remote_blocks_ids .end ());
344- kv_info.src_embed_ids .push_back (kv_info.src_blocks .back ());
345- kv_info.dst_embed_ids .push_back (kv_info.dst_blocks .back ());
346336 merged_kv_infos[key] = std::move (kv_info);
347337 } else {
348338 merged_kv_infos[key].src_blocks .insert (
@@ -353,28 +343,8 @@ void SpecKVCacheTransfer::merge_kv_blocks(
353343 merged_kv_infos[key].dst_blocks .end (),
354344 info.remote_blocks_ids .begin (),
355345 info.remote_blocks_ids .end ());
356- merged_kv_infos[key].src_embed_ids .push_back (
357- merged_kv_infos[key].src_blocks .back ());
358- merged_kv_infos[key].dst_embed_ids .push_back (
359- merged_kv_infos[key].dst_blocks .back ());
360346 }
361347 }
362348 }
363349}
364-
365- void SpecKVCacheTransfer::copy_blocks (const std::vector<int >& blocks,
366- bool h2d) {
367- std::vector<uint64_t > _blocks;
368- _blocks.reserve (blocks.size ());
369- for (const auto & block : blocks) {
370- _blocks.push_back (static_cast <uint64_t >(block));
371- }
372- if (h2d) {
373- CHECK_LDD_RET (llm_data_dist_->CopyKvBlocks (
374- embed_host_cache_, embed_cache_, _blocks, {_blocks}));
375- } else {
376- CHECK_LDD_RET (llm_data_dist_->CopyKvBlocks (
377- embed_cache_, embed_host_cache_, _blocks, {_blocks}));
378- }
379- }
380350} // namespace xllm
0 commit comments