Skip to content

Commit 6493d8e

Browse files
authored
bugfix: support disagg PD for MTP. (#551)
1 parent 21962a3 commit 6493d8e

4 files changed

Lines changed: 118 additions & 189 deletions

File tree

xllm/core/framework/kv_cache/kv_cache_transfer.h

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,6 @@ class KVCacheTransfer {
3636
int64_t dst_v_cache_id;
3737
std::vector<uint64_t> src_blocks;
3838
std::vector<uint64_t> dst_blocks;
39-
std::vector<uint64_t> src_embed_ids;
40-
std::vector<uint64_t> dst_embed_ids;
4139
};
4240

4341
KVCacheTransfer() = default;

xllm/core/framework/kv_cache/spec_kv_cache_transfer.cpp

Lines changed: 110 additions & 136 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,10 @@ namespace {
2424
CHECK(ret == LLM_SUCCESS) \
2525
<< "Call LlmDataDist function failed, ret = " << std::hex << ret
2626

27+
#define CHECK_ACL_RET(ret) \
28+
CHECK(ret == ACL_SUCCESS) \
29+
<< "Call ACL function failed, ret = " << std::hex << ret
30+
2731
const std::map<torch::ScalarType, ge::DataType> kScalarTypeToDtype = {
2832
{torch::kBool, ge::DT_BOOL},
2933
{torch::kByte, ge::DT_UINT8},
@@ -48,29 +52,29 @@ void SpecKVCacheTransfer::allocate_kv_cache(
4852
const int64_t num_layers,
4953
const std::vector<std::vector<int64_t>>& kv_cache_shape,
5054
torch::ScalarType dtype) {
51-
_allocate_kv_cache(kv_caches,
52-
num_layers,
53-
kv_cache_shape,
54-
dtype,
55-
/*is_spec*/ false,
56-
k_cache_,
57-
v_cache_);
55+
allocate_kv_cache_internal(kv_caches,
56+
num_layers,
57+
kv_cache_shape,
58+
dtype,
59+
/*is_spec*/ false,
60+
k_cache_,
61+
v_cache_);
5862
}
5963

6064
void SpecKVCacheTransfer::allocate_kv_cache_spec(
6165
std::vector<xllm::KVCache>& kv_caches,
6266
const int64_t num_layers,
6367
const std::vector<std::vector<int64_t>>& kv_cache_shape,
6468
torch::ScalarType dtype) {
65-
_allocate_kv_cache(kv_caches,
66-
num_layers,
67-
kv_cache_shape,
68-
dtype,
69-
/*is_spec*/ true,
70-
spec_k_cache_,
71-
spec_v_cache_);
69+
allocate_kv_cache_internal(kv_caches,
70+
num_layers,
71+
kv_cache_shape,
72+
dtype,
73+
/*is_spec*/ true,
74+
spec_k_cache_,
75+
spec_v_cache_);
7276
}
73-
void SpecKVCacheTransfer::_allocate_kv_cache(
77+
void SpecKVCacheTransfer::allocate_kv_cache_internal(
7478
std::vector<xllm::KVCache>& kv_caches,
7579
const int64_t num_layers,
7680
const std::vector<std::vector<int64_t>>& kv_cache_shape,
@@ -87,18 +91,42 @@ void SpecKVCacheTransfer::_allocate_kv_cache(
8791
const auto& it = kScalarTypeToDtype.find(dtype);
8892
CHECK(it != kScalarTypeToDtype.cend()) << "Unsupport data type : " << dtype;
8993
auto ge_dtype = it->second;
90-
CacheDesc k_cache_desc;
91-
k_cache_desc.num_tensors = num_layers;
92-
k_cache_desc.data_type = ge_dtype;
93-
k_cache_desc.shape = kv_cache_shape[0];
94-
CHECK_LDD_RET(llm_data_dist_->AllocateCache(k_cache_desc, k_cache));
9594

96-
CacheDesc v_cache_desc;
97-
v_cache_desc.num_tensors = num_layers;
98-
v_cache_desc.data_type = ge_dtype;
99-
v_cache_desc.shape = kv_cache_shape[1];
100-
CHECK_LDD_RET(llm_data_dist_->AllocateCache(v_cache_desc, v_cache));
95+
// calculate the size of kv cache for each layer
96+
auto data_size = torch::elementSize(dtype);
97+
int64_t k_cache_size_per_layer = std::accumulate(kv_cache_shape[0].begin(),
98+
kv_cache_shape[0].end(),
99+
data_size,
100+
std::multiplies<int64_t>());
101+
int64_t v_cache_size_per_layer = std::accumulate(kv_cache_shape[1].begin(),
102+
kv_cache_shape[1].end(),
103+
data_size,
104+
std::multiplies<int64_t>());
105+
106+
// allocate device memory for kv cache
107+
std::vector<uint64_t> k_cache_addrs;
108+
std::vector<uint64_t> v_cache_addrs;
109+
k_cache_addrs.reserve(num_layers);
110+
v_cache_addrs.reserve(num_layers);
111+
k_cache.tensor_addrs.reserve(num_layers);
112+
v_cache.tensor_addrs.reserve(num_layers);
113+
for (int64_t i = 0; i < num_layers; ++i) {
114+
void* k_cache_buffer = nullptr;
115+
void* v_cache_buffer = nullptr;
116+
CHECK_ACL_RET(aclrtMalloc(
117+
&k_cache_buffer, k_cache_size_per_layer, ACL_MEM_MALLOC_HUGE_ONLY));
118+
CHECK_ACL_RET(aclrtMalloc(
119+
&v_cache_buffer, v_cache_size_per_layer, ACL_MEM_MALLOC_HUGE_ONLY));
120+
121+
k_cache_addrs.emplace_back(reinterpret_cast<uint64_t>(k_cache_buffer));
122+
v_cache_addrs.emplace_back(reinterpret_cast<uint64_t>(v_cache_buffer));
123+
k_cache.tensor_addrs.emplace_back(
124+
reinterpret_cast<uintptr_t>(k_cache_buffer));
125+
v_cache.tensor_addrs.emplace_back(
126+
reinterpret_cast<uintptr_t>(v_cache_buffer));
127+
}
101128

129+
// convert memory addrs to torch tensors
102130
auto k_torch_tensors =
103131
convert_to_torch_tensor(kv_cache_shape[0], dtype, k_cache.tensor_addrs);
104132
auto v_torch_tensors =
@@ -109,35 +137,40 @@ void SpecKVCacheTransfer::_allocate_kv_cache(
109137
value_cache = v_torch_tensors[i];
110138
kv_caches.emplace_back(key_cache, value_cache);
111139
}
112-
}
113140

114-
void SpecKVCacheTransfer::allocate_embedding(
115-
std::shared_ptr<EmbeddingAllocator> embedding_allocator,
116-
const std::vector<int64_t>& embedding_shape,
117-
torch::ScalarType dtype,
118-
torch::Device device) {
119-
const auto& it = kScalarTypeToDtype.find(dtype);
120-
CHECK(it != kScalarTypeToDtype.cend()) << "Unsupport data type : " << dtype;
121-
auto ge_dtype = it->second;
122-
CacheDesc embed_cache_desc;
123-
embed_cache_desc.num_tensors = 1;
124-
embed_cache_desc.data_type = ge_dtype;
125-
embed_cache_desc.shape = embedding_shape;
126-
CHECK_LDD_RET(llm_data_dist_->AllocateCache(embed_cache_desc, embed_cache_));
141+
// register key cache
142+
CacheDesc& k_cache_desc = k_cache.cache_desc;
143+
k_cache_desc.num_tensors = num_layers;
144+
k_cache_desc.data_type = ge_dtype;
145+
k_cache_desc.shape = kv_cache_shape[0];
146+
auto ret = llm_data_dist_->RegisterKvCache(
147+
k_cache_desc, k_cache_addrs, {}, k_cache.cache_id);
148+
CHECK(ret == LLM_SUCCESS)
149+
<< "Register key cache failed, ret = " << std::hex << ret;
150+
151+
// register value cache
152+
CacheDesc& v_cache_desc = v_cache.cache_desc;
153+
v_cache_desc.num_tensors = num_layers;
154+
v_cache_desc.data_type = ge_dtype;
155+
v_cache_desc.shape = kv_cache_shape[1];
156+
ret = llm_data_dist_->RegisterKvCache(
157+
v_cache_desc, v_cache_addrs, {}, v_cache.cache_id);
158+
CHECK(ret == LLM_SUCCESS)
159+
<< "Register value cache failed, ret = " << std::hex << ret;
127160

128-
embed_host_cache_.cache_desc = embed_cache_.cache_desc;
129-
embed_host_cache_.cache_desc.placement = CachePlacement::kHost;
130-
CHECK_EQ(embed_host_cache_.cache_desc.num_tensors, 1);
131-
embed_host_cache_.tensor_addrs.emplace_back(reinterpret_cast<uint64_t>(
132-
embedding_allocator->get_embeddings_cache_ptr()));
161+
LOG(INFO) << "Register KV cache success.";
133162
}
134163

135164
void SpecKVCacheTransfer::free_kv_cache() {
136-
llm_data_dist_->DeallocateCache(k_cache_.cache_id);
137-
llm_data_dist_->DeallocateCache(v_cache_.cache_id);
138-
llm_data_dist_->DeallocateCache(spec_k_cache_.cache_id);
139-
llm_data_dist_->DeallocateCache(spec_v_cache_.cache_id);
140-
llm_data_dist_->DeallocateCache(embed_cache_.cache_id);
165+
auto free_cache = [](const std::vector<uintptr_t>& tensor_addrs) {
166+
for (auto tensor_addr : tensor_addrs) {
167+
CHECK_ACL_RET(aclrtFree(reinterpret_cast<void*>(tensor_addr)));
168+
}
169+
};
170+
free_cache(k_cache_.tensor_addrs);
171+
free_cache(v_cache_.tensor_addrs);
172+
free_cache(spec_k_cache_.tensor_addrs);
173+
free_cache(spec_v_cache_.tensor_addrs);
141174
}
142175

143176
bool SpecKVCacheTransfer::pull_kv_blocks(
@@ -161,80 +194,59 @@ bool SpecKVCacheTransfer::pull_kv_blocks(
161194
CHECK_LDD_RET(llm_data_dist_->PullKvBlocks(
162195
spec_v_cache_index, spec_v_cache_, src_blocks, dst_blocks));
163196

164-
CacheIndex embed_cache_index{src_cluster_id, embed_cache_.cache_id};
165-
CHECK_LDD_RET(llm_data_dist_->PullKvBlocks(embed_cache_index,
166-
embed_cache_,
167-
{src_blocks.back()},
168-
{dst_blocks.back()}));
169197
return true;
170198
}
171199

172200
bool SpecKVCacheTransfer::push_kv_blocks(
173201
std::unordered_map<std::string, KVCacheInfo>& merged_kv_infos,
174202
std::shared_ptr<NPULayerSynchronizerImpl>& layer_synchronizer,
175203
bool is_spec_draft) {
176-
if (!layer_synchronizer) {
177-
return push_embed_blocks(merged_kv_infos);
178-
}
179-
180204
if (is_spec_draft) {
181205
return push_kv_blocks_spec(merged_kv_infos, layer_synchronizer);
206+
} else {
207+
return push_kv_blocks_internal(
208+
merged_kv_infos, layer_synchronizer, num_layers_, k_cache_, v_cache_);
182209
}
183-
for (int64_t layer_index = 0; layer_index < num_layers_; ++layer_index) {
184-
// Wait for the KV cache computation of this layer to complete.
185-
layer_synchronizer->synchronize_layer(layer_index);
186-
// Push the KV Cache computed at this layer for all requests to the
187-
// designated worker.
188-
for (const auto& pair : merged_kv_infos) {
189-
const KVCacheInfo& kv_info = pair.second;
190-
CacheIndex k_cache_index{kv_info.dst_cluster_id, k_cache_.cache_id};
191-
CacheIndex v_cache_index{kv_info.dst_cluster_id, v_cache_.cache_id};
192-
KvCacheExtParam ext_param{};
193-
ext_param.src_layer_range =
194-
std::pair<int32_t, int32_t>(layer_index, layer_index);
195-
ext_param.dst_layer_range =
196-
std::pair<int32_t, int32_t>(layer_index, layer_index);
197-
ext_param.tensor_num_per_layer = 1;
198-
CHECK_LDD_RET(llm_data_dist_->PushKvBlocks(k_cache_,
199-
k_cache_index,
200-
kv_info.src_blocks,
201-
kv_info.dst_blocks,
202-
ext_param));
203-
CHECK_LDD_RET(llm_data_dist_->PushKvBlocks(v_cache_,
204-
v_cache_index,
205-
kv_info.src_blocks,
206-
kv_info.dst_blocks,
207-
ext_param));
208-
}
209-
}
210-
return true;
211210
}
212211

213212
bool SpecKVCacheTransfer::push_kv_blocks_spec(
214213
std::unordered_map<std::string, KVCacheInfo>& merged_kv_infos,
215214
std::shared_ptr<NPULayerSynchronizerImpl>& layer_synchronizer) {
216-
for (int64_t layer_index = 0; layer_index < spec_num_layers_; ++layer_index) {
215+
return push_kv_blocks_internal(merged_kv_infos,
216+
layer_synchronizer,
217+
spec_num_layers_,
218+
spec_k_cache_,
219+
spec_v_cache_);
220+
}
221+
222+
bool SpecKVCacheTransfer::push_kv_blocks_internal(
223+
std::unordered_map<std::string, KVCacheInfo>& merged_kv_infos,
224+
std::shared_ptr<NPULayerSynchronizerImpl>& layer_synchronizer,
225+
int64_t num_layers,
226+
const Cache& k_cache,
227+
const Cache& v_cache) {
228+
for (int64_t layer_index = 0; layer_index < num_layers; ++layer_index) {
217229
// Wait for the KV cache computation of this layer to complete.
218230
layer_synchronizer->synchronize_layer(layer_index);
231+
219232
// Push the KV Cache computed at this layer for all requests to the
220233
// designated worker.
221234
for (const auto& pair : merged_kv_infos) {
222235
const KVCacheInfo& kv_info = pair.second;
223-
CacheIndex k_cache_index{kv_info.dst_cluster_id, spec_k_cache_.cache_id};
224-
CacheIndex v_cache_index{kv_info.dst_cluster_id, spec_v_cache_.cache_id};
236+
CacheIndex k_cache_index{kv_info.dst_cluster_id, k_cache.cache_id};
237+
CacheIndex v_cache_index{kv_info.dst_cluster_id, v_cache.cache_id};
238+
225239
KvCacheExtParam ext_param{};
226-
ext_param.src_layer_range =
227-
std::pair<int32_t, int32_t>(layer_index, layer_index);
228-
ext_param.dst_layer_range =
229-
std::pair<int32_t, int32_t>(layer_index, layer_index);
240+
ext_param.src_layer_range = {layer_index, layer_index};
241+
ext_param.dst_layer_range = {layer_index, layer_index};
230242
ext_param.tensor_num_per_layer = 1;
231243

232-
CHECK_LDD_RET(llm_data_dist_->PushKvBlocks(spec_k_cache_,
244+
CHECK_LDD_RET(llm_data_dist_->PushKvBlocks(k_cache,
233245
k_cache_index,
234246
kv_info.src_blocks,
235247
kv_info.dst_blocks,
236248
ext_param));
237-
CHECK_LDD_RET(llm_data_dist_->PushKvBlocks(spec_v_cache_,
249+
CHECK_LDD_RET(llm_data_dist_->PushKvBlocks(v_cache,
238250
v_cache_index,
239251
kv_info.src_blocks,
240252
kv_info.dst_blocks,
@@ -244,24 +256,6 @@ bool SpecKVCacheTransfer::push_kv_blocks_spec(
244256
return true;
245257
}
246258

247-
bool SpecKVCacheTransfer::push_embed_blocks(
248-
std::unordered_map<std::string, KVCacheInfo>& merged_kv_infos) {
249-
for (const auto& pair : merged_kv_infos) {
250-
const KVCacheInfo& kv_info = pair.second;
251-
CacheIndex cache_index{kv_info.dst_cluster_id, embed_cache_.cache_id};
252-
KvCacheExtParam ext_param{};
253-
ext_param.src_layer_range = std::pair<int32_t, int32_t>(0, 0);
254-
ext_param.dst_layer_range = std::pair<int32_t, int32_t>(0, 0);
255-
ext_param.tensor_num_per_layer = 1;
256-
CHECK_LDD_RET(llm_data_dist_->PushKvBlocks(embed_cache_,
257-
cache_index,
258-
kv_info.src_embed_ids,
259-
kv_info.dst_embed_ids,
260-
ext_param));
261-
}
262-
return true;
263-
}
264-
265259
folly::SemiFuture<bool> SpecKVCacheTransfer::push_kv_blocks_async(
266260
const std::vector<TransferKVInfo>& transfer_kv_infos,
267261
const ParallelArgs& parallel_args,
@@ -323,16 +317,18 @@ void SpecKVCacheTransfer::merge_kv_blocks(
323317
i < dst_tp_size * (dst_dp_rank + 1);
324318
i += src_tp_size) {
325319
uint64_t dst_cluster_id = info.remote_instance_info.cluster_ids[i];
320+
auto& dst_addr = info.remote_instance_info.addrs[i];
326321
int64_t k_cache_id = info.remote_instance_info.k_cache_ids[i];
327322
int64_t v_cache_id = info.remote_instance_info.v_cache_ids[i];
328-
std::string key = std::to_string(dst_cluster_id) + "_" +
323+
std::string key = std::to_string(dst_cluster_id) + "_" + dst_addr + "_" +
329324
std::to_string(k_cache_id) + "_" +
330325
std::to_string(v_cache_id);
331326
// Merge all kv blocks with the same destination worker into a single
332327
// vector.
333328
if (merged_kv_infos.find(key) == merged_kv_infos.end()) {
334329
KVCacheInfo kv_info;
335330
kv_info.dst_cluster_id = dst_cluster_id;
331+
kv_info.dst_addr = dst_addr;
336332
kv_info.dst_k_cache_id = k_cache_id;
337333
kv_info.dst_v_cache_id = v_cache_id;
338334
kv_info.src_blocks.insert(kv_info.src_blocks.end(),
@@ -341,8 +337,6 @@ void SpecKVCacheTransfer::merge_kv_blocks(
341337
kv_info.dst_blocks.insert(kv_info.dst_blocks.end(),
342338
info.remote_blocks_ids.begin(),
343339
info.remote_blocks_ids.end());
344-
kv_info.src_embed_ids.push_back(kv_info.src_blocks.back());
345-
kv_info.dst_embed_ids.push_back(kv_info.dst_blocks.back());
346340
merged_kv_infos[key] = std::move(kv_info);
347341
} else {
348342
merged_kv_infos[key].src_blocks.insert(
@@ -353,28 +347,8 @@ void SpecKVCacheTransfer::merge_kv_blocks(
353347
merged_kv_infos[key].dst_blocks.end(),
354348
info.remote_blocks_ids.begin(),
355349
info.remote_blocks_ids.end());
356-
merged_kv_infos[key].src_embed_ids.push_back(
357-
merged_kv_infos[key].src_blocks.back());
358-
merged_kv_infos[key].dst_embed_ids.push_back(
359-
merged_kv_infos[key].dst_blocks.back());
360350
}
361351
}
362352
}
363353
}
364-
365-
void SpecKVCacheTransfer::copy_blocks(const std::vector<int>& blocks,
366-
bool h2d) {
367-
std::vector<uint64_t> _blocks;
368-
_blocks.reserve(blocks.size());
369-
for (const auto& block : blocks) {
370-
_blocks.push_back(static_cast<uint64_t>(block));
371-
}
372-
if (h2d) {
373-
CHECK_LDD_RET(llm_data_dist_->CopyKvBlocks(
374-
embed_host_cache_, embed_cache_, _blocks, {_blocks}));
375-
} else {
376-
CHECK_LDD_RET(llm_data_dist_->CopyKvBlocks(
377-
embed_cache_, embed_host_cache_, _blocks, {_blocks}));
378-
}
379-
}
380354
} // namespace xllm

0 commit comments

Comments
 (0)