Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions include/llama.h
Original file line number Diff line number Diff line change
Expand Up @@ -736,6 +736,20 @@ extern "C" {
// Check if the memory supports shifting
LLAMA_API bool llama_memory_can_shift(llama_memory_t mem);

//
// Memory factory extension
//
// Allows custom memory (KV cache) implementations to be used.
// Factory returns non-null to use custom implementation, null to use default.
// Call llama_set_memory_factory() before llama_init_from_model().

typedef llama_memory_t (*llama_memory_factory_fn)(
const struct llama_model * model,
const struct llama_context_params * params,
void * user_data);

LLAMA_API void llama_set_memory_factory(llama_memory_factory_fn factory, void * user_data);

//
// State / sessions
//
Expand Down
22 changes: 21 additions & 1 deletion src/llama-context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,18 @@
#include <limits>
#include <stdexcept>

//
// Memory factory extension
//

static llama_memory_factory_fn g_memory_factory = nullptr;
static void * g_memory_factory_user_data = nullptr;

void llama_set_memory_factory(llama_memory_factory_fn factory, void * user_data) {
g_memory_factory = factory;
g_memory_factory_user_data = user_data;
}

//
// llama_context
//
Expand Down Expand Up @@ -249,7 +261,15 @@ llama_context::llama_context(
/*.swa_full =*/ params.swa_full,
};

memory.reset(model.create_memory(params_mem, cparams));
// Try custom memory factory first
if (g_memory_factory) {
memory.reset(g_memory_factory(&model, &params, g_memory_factory_user_data));
}

// Fall back to default if factory returned null or not set
if (!memory) {
memory.reset(model.create_memory(params_mem, cparams));
}
}

// init backends
Expand Down
Loading