Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion xllm/core/framework/block/block_manager_pool.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -192,13 +192,14 @@ bool BlockManagerPool::allocate(Sequence* sequence, size_t num_tokens) {
AUTO_COUNTER(allocate_blocks_latency_seconds);
DCHECK(sequence != nullptr);
int32_t dp_rank = get_dp_rank(sequence);
const bool started_empty = sequence->kv_state().num_kv_blocks() == 0;
const bool needs_embedding_id = !sequence->has_embedding_id();
if (needs_embedding_id && !allocate_embedding_id(sequence, dp_rank)) {
return false;
}

// first try to allocate shared blocks
if (sequence->kv_state().num_kv_blocks() == 0) {
if (started_empty) {
BlockManagerPool::allocate_shared(sequence);
}

Expand All @@ -215,6 +216,13 @@ bool BlockManagerPool::allocate(Sequence* sequence, size_t num_tokens) {

const auto blocks = block_managers_[dp_rank]->allocate(num_additional_blocks);
if (blocks.size() != num_additional_blocks) {
if (started_empty) {
block_managers_[dp_rank]->deallocate(sequence->kv_state().kv_blocks());
if (needs_embedding_id) {
deallocate_embedding_id(sequence, dp_rank);
}
sequence->reset();
}
// LOG(ERROR) << " Fail to allocate " << num_additional_blocks << "
// blocks.";
return false;
Expand Down
79 changes: 77 additions & 2 deletions xllm/core/scheduler/continuous_scheduler_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,13 @@ class FakeTokenizer : public Tokenizer {

class FakeEngine : public Engine {
public:
FakeEngine(int32_t num_blocks, int32_t block_size) {
FakeEngine(int32_t num_blocks,
int32_t block_size,
bool enable_prefix_cache = false) {
BlockManagerPool::Options opt;
opt.num_blocks_ = num_blocks;
opt.block_size_ = block_size;
opt.enable_prefix_cache_ = false; // we dont consider prefix cache here
opt.enable_prefix_cache_ = enable_prefix_cache;
fake_tokenizer_ = std::make_unique<FakeTokenizer>();
fake_block_manager_ = std::make_unique<BlockManagerPool>(opt, 1);
}
Expand Down Expand Up @@ -182,6 +184,37 @@ std::vector<std::shared_ptr<Request>> generate_request(
return requests;
}

std::shared_ptr<Request> generate_request_with_prompt_tokens(
const std::vector<int32_t>& prompt_token_ids,
int32_t max_tokens,
int32_t max_context_len) {
RequestSamplingParam sampling_param;
SchedulerParam scheduler_param;

StoppingChecker stopping_checker;
stopping_checker.set_max_generated_tokens(max_tokens);
stopping_checker.set_max_context_len(max_context_len);
stopping_checker.set_ignore_eos(true);

RequestState req_state("x",
prompt_token_ids,
sampling_param,
scheduler_param,
stopping_checker,
prompt_token_ids.size() + 30000,
1,
1,
false,
false,
false,
false,
false,
nullptr,
nullptr);

return std::make_shared<Request>("1", "1", "1", std::move(req_state), "1");
}

// dont not consider speculative decoding.
void update_requests(std::vector<std::shared_ptr<Request>> requests) {
for (auto req : requests) {
Expand Down Expand Up @@ -651,4 +684,46 @@ TEST(ContinuousSchedulerTest, LatencySchedule) {
// EXPECT_TRUE(scheduler->get_running_requests().size() == 2);
}

TEST(BlockManagerPoolTest, AllocateFailureRollsBackSharedPrefixBlocks) {
auto engine = std::make_unique<FakeEngine>(3, 4, true);
BlockManagerPool* block_manager_pool = engine->block_manager_pool();

auto cached_request =
generate_request_with_prompt_tokens({1, 2, 3, 4, 5, 6, 7, 8}, 1, 30000);
auto failed_request = generate_request_with_prompt_tokens(
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, 1, 30000);
auto later_request =
generate_request_with_prompt_tokens({20, 21, 22, 23}, 1, 30000);

auto* cached_sequence = cached_request->sequences()[0].get();
ASSERT_TRUE(block_manager_pool->allocate(cached_sequence,
cached_sequence->num_tokens()));
cached_sequence->kv_state().set_kv_cache_tokens_num(
cached_sequence->num_tokens());
block_manager_pool->deallocate(cached_sequence);

const size_t free_blocks_before_failure =
util::max(block_manager_pool->num_free_blocks());
const size_t used_blocks_before_failure =
util::min(block_manager_pool->num_used_blocks());
EXPECT_EQ(free_blocks_before_failure, 0);

auto* failed_sequence = failed_request->sequences()[0].get();
EXPECT_FALSE(block_manager_pool->allocate(failed_sequence,
failed_sequence->num_tokens()));
EXPECT_EQ(failed_sequence->kv_state().num_kv_blocks(), 0);
EXPECT_EQ(failed_sequence->kv_state().shared_kv_blocks_num(), 0);
EXPECT_EQ(util::max(block_manager_pool->num_free_blocks()),
free_blocks_before_failure);
EXPECT_EQ(util::min(block_manager_pool->num_used_blocks()),
used_blocks_before_failure);

auto* later_sequence = later_request->sequences()[0].get();
EXPECT_TRUE(block_manager_pool->allocate(later_sequence,
later_sequence->num_tokens()));
EXPECT_EQ(later_sequence->kv_state().num_kv_blocks(), 1);

(void)engine.release();
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Using engine.release() here causes a memory leak of the FakeEngine object. More importantly, it bypasses the BlockManagerImpl destructor which contains a critical safety check (CHECK_EQ(num_free_blocks_, free_blocks_.size() - 1)) to ensure all blocks have been correctly freed. This check is essential for verifying that the rollback logic being tested actually works as expected and doesn't leak blocks. Use engine.reset() or simply let it go out of scope instead.

Suggested change
(void)engine.release();
engine.reset();

}

} // namespace xllm
Loading