Skip to content

Commit fb4069c

Browse files
committed
bugfix: support disagg PD for MTP.
1 parent fd60d08 commit fb4069c

File tree

5 files changed

+115
-190
lines changed

5 files changed

+115
-190
lines changed

third_party/xllm_ops

Submodule xllm_ops updated from 103e150 to 57937f2

xllm/core/framework/kv_cache/kv_cache_transfer.h

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,6 @@ class KVCacheTransfer {
3636
int64_t dst_v_cache_id;
3737
std::vector<uint64_t> src_blocks;
3838
std::vector<uint64_t> dst_blocks;
39-
std::vector<uint64_t> src_embed_ids;
40-
std::vector<uint64_t> dst_embed_ids;
4139
};
4240

4341
KVCacheTransfer() = default;

xllm/core/framework/kv_cache/spec_kv_cache_transfer.cpp

Lines changed: 106 additions & 136 deletions
Original file line numberDiff line numberDiff line change
@@ -48,29 +48,29 @@ void SpecKVCacheTransfer::allocate_kv_cache(
4848
const int64_t num_layers,
4949
const std::vector<std::vector<int64_t>>& kv_cache_shape,
5050
torch::ScalarType dtype) {
51-
_allocate_kv_cache(kv_caches,
52-
num_layers,
53-
kv_cache_shape,
54-
dtype,
55-
/*is_spec*/ false,
56-
k_cache_,
57-
v_cache_);
51+
allocate_kv_cache_internal(kv_caches,
52+
num_layers,
53+
kv_cache_shape,
54+
dtype,
55+
/*is_spec*/ false,
56+
k_cache_,
57+
v_cache_);
5858
}
5959

6060
void SpecKVCacheTransfer::allocate_kv_cache_spec(
6161
std::vector<xllm::KVCache>& kv_caches,
6262
const int64_t num_layers,
6363
const std::vector<std::vector<int64_t>>& kv_cache_shape,
6464
torch::ScalarType dtype) {
65-
_allocate_kv_cache(kv_caches,
66-
num_layers,
67-
kv_cache_shape,
68-
dtype,
69-
/*is_spec*/ true,
70-
spec_k_cache_,
71-
spec_v_cache_);
65+
allocate_kv_cache_internal(kv_caches,
66+
num_layers,
67+
kv_cache_shape,
68+
dtype,
69+
/*is_spec*/ true,
70+
spec_k_cache_,
71+
spec_v_cache_);
7272
}
73-
void SpecKVCacheTransfer::_allocate_kv_cache(
73+
void SpecKVCacheTransfer::allocate_kv_cache_internal(
7474
std::vector<xllm::KVCache>& kv_caches,
7575
const int64_t num_layers,
7676
const std::vector<std::vector<int64_t>>& kv_cache_shape,
@@ -87,18 +87,42 @@ void SpecKVCacheTransfer::_allocate_kv_cache(
8787
const auto& it = kScalarTypeToDtype.find(dtype);
8888
CHECK(it != kScalarTypeToDtype.cend()) << "Unsupport data type : " << dtype;
8989
auto ge_dtype = it->second;
90-
CacheDesc k_cache_desc;
91-
k_cache_desc.num_tensors = num_layers;
92-
k_cache_desc.data_type = ge_dtype;
93-
k_cache_desc.shape = kv_cache_shape[0];
94-
CHECK_LDD_RET(llm_data_dist_->AllocateCache(k_cache_desc, k_cache));
9590

96-
CacheDesc v_cache_desc;
97-
v_cache_desc.num_tensors = num_layers;
98-
v_cache_desc.data_type = ge_dtype;
99-
v_cache_desc.shape = kv_cache_shape[1];
100-
CHECK_LDD_RET(llm_data_dist_->AllocateCache(v_cache_desc, v_cache));
91+
// calculate the size of kv cache for each layer
92+
auto data_size = torch::elementSize(dtype);
93+
int64_t k_cache_size_per_layer = std::accumulate(kv_cache_shape[0].begin(),
94+
kv_cache_shape[0].end(),
95+
data_size,
96+
std::multiplies<int64_t>());
97+
int64_t v_cache_size_per_layer = std::accumulate(kv_cache_shape[1].begin(),
98+
kv_cache_shape[1].end(),
99+
data_size,
100+
std::multiplies<int64_t>());
101+
102+
// allocate device memory for kv cache
103+
std::vector<uint64_t> k_cache_addrs;
104+
std::vector<uint64_t> v_cache_addrs;
105+
k_cache_addrs.reserve(num_layers);
106+
v_cache_addrs.reserve(num_layers);
107+
k_cache.tensor_addrs.reserve(num_layers);
108+
v_cache.tensor_addrs.reserve(num_layers);
109+
for (int64_t i = 0; i < num_layers; ++i) {
110+
void* k_cache_buffer = nullptr;
111+
void* v_cache_buffer = nullptr;
112+
CHECK_LDD_RET(aclrtMalloc(
113+
&k_cache_buffer, k_cache_size_per_layer, ACL_MEM_MALLOC_HUGE_ONLY));
114+
CHECK_LDD_RET(aclrtMalloc(
115+
&v_cache_buffer, v_cache_size_per_layer, ACL_MEM_MALLOC_HUGE_ONLY));
101116

117+
k_cache_addrs.emplace_back(reinterpret_cast<uint64_t>(k_cache_buffer));
118+
v_cache_addrs.emplace_back(reinterpret_cast<uint64_t>(v_cache_buffer));
119+
k_cache.tensor_addrs.emplace_back(
120+
reinterpret_cast<uintptr_t>(k_cache_buffer));
121+
v_cache.tensor_addrs.emplace_back(
122+
reinterpret_cast<uintptr_t>(v_cache_buffer));
123+
}
124+
125+
// convert memory addrs to torch tensors
102126
auto k_torch_tensors =
103127
convert_to_torch_tensor(kv_cache_shape[0], dtype, k_cache.tensor_addrs);
104128
auto v_torch_tensors =
@@ -109,35 +133,40 @@ void SpecKVCacheTransfer::_allocate_kv_cache(
109133
value_cache = v_torch_tensors[i];
110134
kv_caches.emplace_back(key_cache, value_cache);
111135
}
112-
}
113136

114-
void SpecKVCacheTransfer::allocate_embedding(
115-
std::shared_ptr<EmbeddingAllocator> embedding_allocator,
116-
const std::vector<int64_t>& embedding_shape,
117-
torch::ScalarType dtype,
118-
torch::Device device) {
119-
const auto& it = kScalarTypeToDtype.find(dtype);
120-
CHECK(it != kScalarTypeToDtype.cend()) << "Unsupport data type : " << dtype;
121-
auto ge_dtype = it->second;
122-
CacheDesc embed_cache_desc;
123-
embed_cache_desc.num_tensors = 1;
124-
embed_cache_desc.data_type = ge_dtype;
125-
embed_cache_desc.shape = embedding_shape;
126-
CHECK_LDD_RET(llm_data_dist_->AllocateCache(embed_cache_desc, embed_cache_));
137+
// register key cache
138+
CacheDesc& k_cache_desc = k_cache.cache_desc;
139+
k_cache_desc.num_tensors = num_layers;
140+
k_cache_desc.data_type = ge_dtype;
141+
k_cache_desc.shape = kv_cache_shape[0];
142+
auto ret = llm_data_dist_->RegisterKvCache(
143+
k_cache_desc, k_cache_addrs, {}, k_cache.cache_id);
144+
CHECK(ret == LLM_SUCCESS)
145+
<< "Register key cache failed, ret = " << std::hex << ret;
146+
147+
// register value cache
148+
CacheDesc& v_cache_desc = v_cache.cache_desc;
149+
v_cache_desc.num_tensors = num_layers;
150+
v_cache_desc.data_type = ge_dtype;
151+
v_cache_desc.shape = kv_cache_shape[1];
152+
ret = llm_data_dist_->RegisterKvCache(
153+
v_cache_desc, v_cache_addrs, {}, v_cache.cache_id);
154+
CHECK(ret == LLM_SUCCESS)
155+
<< "Register value cache failed, ret = " << std::hex << ret;
127156

128-
embed_host_cache_.cache_desc = embed_cache_.cache_desc;
129-
embed_host_cache_.cache_desc.placement = CachePlacement::kHost;
130-
CHECK_EQ(embed_host_cache_.cache_desc.num_tensors, 1);
131-
embed_host_cache_.tensor_addrs.emplace_back(reinterpret_cast<uint64_t>(
132-
embedding_allocator->get_embeddings_cache_ptr()));
157+
LOG(INFO) << "Register KV cache success.";
133158
}
134159

135160
void SpecKVCacheTransfer::free_kv_cache() {
136-
llm_data_dist_->DeallocateCache(k_cache_.cache_id);
137-
llm_data_dist_->DeallocateCache(v_cache_.cache_id);
138-
llm_data_dist_->DeallocateCache(spec_k_cache_.cache_id);
139-
llm_data_dist_->DeallocateCache(spec_v_cache_.cache_id);
140-
llm_data_dist_->DeallocateCache(embed_cache_.cache_id);
161+
auto free_cache = [](const std::vector<uintptr_t>& tensor_addrs) {
162+
for (auto tensor_addr : tensor_addrs) {
163+
aclrtFree(reinterpret_cast<void*>(tensor_addr));
164+
}
165+
};
166+
free_cache(k_cache_.tensor_addrs);
167+
free_cache(v_cache_.tensor_addrs);
168+
free_cache(spec_k_cache_.tensor_addrs);
169+
free_cache(spec_v_cache_.tensor_addrs);
141170
}
142171

143172
bool SpecKVCacheTransfer::pull_kv_blocks(
@@ -161,80 +190,59 @@ bool SpecKVCacheTransfer::pull_kv_blocks(
161190
CHECK_LDD_RET(llm_data_dist_->PullKvBlocks(
162191
spec_v_cache_index, spec_v_cache_, src_blocks, dst_blocks));
163192

164-
CacheIndex embed_cache_index{src_cluster_id, embed_cache_.cache_id};
165-
CHECK_LDD_RET(llm_data_dist_->PullKvBlocks(embed_cache_index,
166-
embed_cache_,
167-
{src_blocks.back()},
168-
{dst_blocks.back()}));
169193
return true;
170194
}
171195

172196
bool SpecKVCacheTransfer::push_kv_blocks(
173197
std::unordered_map<std::string, KVCacheInfo>& merged_kv_infos,
174198
std::shared_ptr<NPULayerSynchronizerImpl>& layer_synchronizer,
175199
bool is_spec_draft) {
176-
if (!layer_synchronizer) {
177-
return push_embed_blocks(merged_kv_infos);
178-
}
179-
180200
if (is_spec_draft) {
181201
return push_kv_blocks_spec(merged_kv_infos, layer_synchronizer);
202+
} else {
203+
return push_kv_blocks_internal(
204+
merged_kv_infos, layer_synchronizer, num_layers_, k_cache_, v_cache_);
182205
}
183-
for (int64_t layer_index = 0; layer_index < num_layers_; ++layer_index) {
184-
// Wait for the KV cache computation of this layer to complete.
185-
layer_synchronizer->synchronize_layer(layer_index);
186-
// Push the KV Cache computed at this layer for all requests to the
187-
// designated worker.
188-
for (const auto& pair : merged_kv_infos) {
189-
const KVCacheInfo& kv_info = pair.second;
190-
CacheIndex k_cache_index{kv_info.dst_cluster_id, k_cache_.cache_id};
191-
CacheIndex v_cache_index{kv_info.dst_cluster_id, v_cache_.cache_id};
192-
KvCacheExtParam ext_param{};
193-
ext_param.src_layer_range =
194-
std::pair<int32_t, int32_t>(layer_index, layer_index);
195-
ext_param.dst_layer_range =
196-
std::pair<int32_t, int32_t>(layer_index, layer_index);
197-
ext_param.tensor_num_per_layer = 1;
198-
CHECK_LDD_RET(llm_data_dist_->PushKvBlocks(k_cache_,
199-
k_cache_index,
200-
kv_info.src_blocks,
201-
kv_info.dst_blocks,
202-
ext_param));
203-
CHECK_LDD_RET(llm_data_dist_->PushKvBlocks(v_cache_,
204-
v_cache_index,
205-
kv_info.src_blocks,
206-
kv_info.dst_blocks,
207-
ext_param));
208-
}
209-
}
210-
return true;
211206
}
212207

213208
bool SpecKVCacheTransfer::push_kv_blocks_spec(
214209
std::unordered_map<std::string, KVCacheInfo>& merged_kv_infos,
215210
std::shared_ptr<NPULayerSynchronizerImpl>& layer_synchronizer) {
216-
for (int64_t layer_index = 0; layer_index < spec_num_layers_; ++layer_index) {
211+
return push_kv_blocks_internal(merged_kv_infos,
212+
layer_synchronizer,
213+
spec_num_layers_,
214+
spec_k_cache_,
215+
spec_v_cache_);
216+
}
217+
218+
bool SpecKVCacheTransfer::push_kv_blocks_internal(
219+
std::unordered_map<std::string, KVCacheInfo>& merged_kv_infos,
220+
std::shared_ptr<NPULayerSynchronizerImpl>& layer_synchronizer,
221+
int64_t num_layers,
222+
const Cache& k_cache,
223+
const Cache& v_cache) {
224+
for (int64_t layer_index = 0; layer_index < num_layers; ++layer_index) {
217225
// Wait for the KV cache computation of this layer to complete.
218226
layer_synchronizer->synchronize_layer(layer_index);
227+
219228
// Push the KV Cache computed at this layer for all requests to the
220229
// designated worker.
221230
for (const auto& pair : merged_kv_infos) {
222231
const KVCacheInfo& kv_info = pair.second;
223-
CacheIndex k_cache_index{kv_info.dst_cluster_id, spec_k_cache_.cache_id};
224-
CacheIndex v_cache_index{kv_info.dst_cluster_id, spec_v_cache_.cache_id};
232+
CacheIndex k_cache_index{kv_info.dst_cluster_id, k_cache.cache_id};
233+
CacheIndex v_cache_index{kv_info.dst_cluster_id, v_cache.cache_id};
234+
225235
KvCacheExtParam ext_param{};
226-
ext_param.src_layer_range =
227-
std::pair<int32_t, int32_t>(layer_index, layer_index);
228-
ext_param.dst_layer_range =
229-
std::pair<int32_t, int32_t>(layer_index, layer_index);
236+
ext_param.src_layer_range = {layer_index, layer_index};
237+
ext_param.dst_layer_range = {layer_index, layer_index};
230238
ext_param.tensor_num_per_layer = 1;
231239

232-
CHECK_LDD_RET(llm_data_dist_->PushKvBlocks(spec_k_cache_,
240+
CHECK_LDD_RET(llm_data_dist_->PushKvBlocks(k_cache,
233241
k_cache_index,
234242
kv_info.src_blocks,
235243
kv_info.dst_blocks,
236244
ext_param));
237-
CHECK_LDD_RET(llm_data_dist_->PushKvBlocks(spec_v_cache_,
245+
CHECK_LDD_RET(llm_data_dist_->PushKvBlocks(v_cache,
238246
v_cache_index,
239247
kv_info.src_blocks,
240248
kv_info.dst_blocks,
@@ -244,24 +252,6 @@ bool SpecKVCacheTransfer::push_kv_blocks_spec(
244252
return true;
245253
}
246254

247-
bool SpecKVCacheTransfer::push_embed_blocks(
248-
std::unordered_map<std::string, KVCacheInfo>& merged_kv_infos) {
249-
for (const auto& pair : merged_kv_infos) {
250-
const KVCacheInfo& kv_info = pair.second;
251-
CacheIndex cache_index{kv_info.dst_cluster_id, embed_cache_.cache_id};
252-
KvCacheExtParam ext_param{};
253-
ext_param.src_layer_range = std::pair<int32_t, int32_t>(0, 0);
254-
ext_param.dst_layer_range = std::pair<int32_t, int32_t>(0, 0);
255-
ext_param.tensor_num_per_layer = 1;
256-
CHECK_LDD_RET(llm_data_dist_->PushKvBlocks(embed_cache_,
257-
cache_index,
258-
kv_info.src_embed_ids,
259-
kv_info.dst_embed_ids,
260-
ext_param));
261-
}
262-
return true;
263-
}
264-
265255
folly::SemiFuture<bool> SpecKVCacheTransfer::push_kv_blocks_async(
266256
const std::vector<TransferKVInfo>& transfer_kv_infos,
267257
const ParallelArgs& parallel_args,
@@ -323,16 +313,18 @@ void SpecKVCacheTransfer::merge_kv_blocks(
323313
i < dst_tp_size * (dst_dp_rank + 1);
324314
i += src_tp_size) {
325315
uint64_t dst_cluster_id = info.remote_instance_info.cluster_ids[i];
316+
auto& dst_addr = info.remote_instance_info.addrs[i];
326317
int64_t k_cache_id = info.remote_instance_info.k_cache_ids[i];
327318
int64_t v_cache_id = info.remote_instance_info.v_cache_ids[i];
328-
std::string key = std::to_string(dst_cluster_id) + "_" +
319+
std::string key = std::to_string(dst_cluster_id) + "_" + dst_addr + "_" +
329320
std::to_string(k_cache_id) + "_" +
330321
std::to_string(v_cache_id);
331322
// Merge all kv blocks with the same destination worker into a single
332323
// vector.
333324
if (merged_kv_infos.find(key) == merged_kv_infos.end()) {
334325
KVCacheInfo kv_info;
335326
kv_info.dst_cluster_id = dst_cluster_id;
327+
kv_info.dst_addr = dst_addr;
336328
kv_info.dst_k_cache_id = k_cache_id;
337329
kv_info.dst_v_cache_id = v_cache_id;
338330
kv_info.src_blocks.insert(kv_info.src_blocks.end(),
@@ -341,8 +333,6 @@ void SpecKVCacheTransfer::merge_kv_blocks(
341333
kv_info.dst_blocks.insert(kv_info.dst_blocks.end(),
342334
info.remote_blocks_ids.begin(),
343335
info.remote_blocks_ids.end());
344-
kv_info.src_embed_ids.push_back(kv_info.src_blocks.back());
345-
kv_info.dst_embed_ids.push_back(kv_info.dst_blocks.back());
346336
merged_kv_infos[key] = std::move(kv_info);
347337
} else {
348338
merged_kv_infos[key].src_blocks.insert(
@@ -353,28 +343,8 @@ void SpecKVCacheTransfer::merge_kv_blocks(
353343
merged_kv_infos[key].dst_blocks.end(),
354344
info.remote_blocks_ids.begin(),
355345
info.remote_blocks_ids.end());
356-
merged_kv_infos[key].src_embed_ids.push_back(
357-
merged_kv_infos[key].src_blocks.back());
358-
merged_kv_infos[key].dst_embed_ids.push_back(
359-
merged_kv_infos[key].dst_blocks.back());
360346
}
361347
}
362348
}
363349
}
364-
365-
void SpecKVCacheTransfer::copy_blocks(const std::vector<int>& blocks,
366-
bool h2d) {
367-
std::vector<uint64_t> _blocks;
368-
_blocks.reserve(blocks.size());
369-
for (const auto& block : blocks) {
370-
_blocks.push_back(static_cast<uint64_t>(block));
371-
}
372-
if (h2d) {
373-
CHECK_LDD_RET(llm_data_dist_->CopyKvBlocks(
374-
embed_host_cache_, embed_cache_, _blocks, {_blocks}));
375-
} else {
376-
CHECK_LDD_RET(llm_data_dist_->CopyKvBlocks(
377-
embed_cache_, embed_host_cache_, _blocks, {_blocks}));
378-
}
379-
}
380350
} // namespace xllm

0 commit comments

Comments
 (0)