diff --git a/docs/metrics.md b/docs/metrics.md index fbdf491..9b920bb 100644 --- a/docs/metrics.md +++ b/docs/metrics.md @@ -26,3 +26,5 @@ Currently supported are the following metrics: | vllm:lora_requests_info | Running stats on LoRA requests | | vllm:kv_cache_usage_perc | The fraction of KV-cache blocks currently in use (from 0 to 1) | | vllm:cache_config_info | Information of the LLMEngine CacheConfig | +| vllm:prefix_cache_hits | Prefix cache hits, in terms of number of cached tokens | +| vllm:prefix_cache_queries | Prefix cache queries, in terms of number of queried tokens | diff --git a/pkg/common/config.go b/pkg/common/config.go index 80ca8a0..82b8025 100644 --- a/pkg/common/config.go +++ b/pkg/common/config.go @@ -316,6 +316,11 @@ type Metrics struct { ReqPrefillTimeBucketValues []int `yaml:"prefill-time-buckets-values" json:"prefill-time-buckets-values"` // ReqDecodeTimeBucketValues is an array of values for request decode time buckets. ReqDecodeTimeBucketValues []int `yaml:"decode-time-buckets-values" json:"decode-time-buckets-values"` + + // PrefixCacheHits is the initial value for the prefix cache hits counter (in tokens) + PrefixCacheHits *int64 `yaml:"prefix-cache-hits" json:"prefix-cache-hits,omitempty"` + // PrefixCacheQueries is the initial value for the prefix cache queries counter (in tokens) + PrefixCacheQueries *int64 `yaml:"prefix-cache-queries" json:"prefix-cache-queries,omitempty"` } type LorasMetrics struct { @@ -689,6 +694,19 @@ func (c *Configuration) validate() error { return errors.New("fake metrics decode-time-buckets-values cannot contain negative values") } } + if c.FakeMetrics.PrefixCacheHits != nil && *c.FakeMetrics.PrefixCacheHits < 0 { + return errors.New("fake metrics prefix-cache-hits cannot be negative") + } + if c.FakeMetrics.PrefixCacheQueries != nil && *c.FakeMetrics.PrefixCacheQueries < 0 { + return errors.New("fake metrics prefix-cache-queries cannot be negative") + } + if (c.FakeMetrics.PrefixCacheHits == nil) != (c.FakeMetrics.PrefixCacheQueries == nil) { + return errors.New("fake metrics prefix-cache-hits and prefix-cache-queries must be specified together") + } + if c.FakeMetrics.PrefixCacheHits != nil && c.FakeMetrics.PrefixCacheQueries != nil && + *c.FakeMetrics.PrefixCacheHits > *c.FakeMetrics.PrefixCacheQueries { + return errors.New("fake metrics prefix-cache-hits cannot exceed prefix-cache-queries") + } } if c.DPSize < 1 || c.DPSize > 8 { diff --git a/pkg/common/config_test.go b/pkg/common/config_test.go index 6062684..651e5cd 100644 --- a/pkg/common/config_test.go +++ b/pkg/common/config_test.go @@ -555,6 +555,36 @@ var _ = Describe("Simulator configuration", func() { "--config", "../../manifests/config.yaml"}, expectedError: "fake metrics request-max-generation-tokens cannot contain negative values", }, + { + name: "invalid fake metrics: negative prefix-cache-hits", + args: []string{"cmd", "--fake-metrics", "{\"prefix-cache-hits\":-5,\"prefix-cache-queries\":10}", + "--config", "../../manifests/config.yaml"}, + expectedError: "fake metrics prefix-cache-hits cannot be negative", + }, + { + name: "invalid fake metrics: negative prefix-cache-queries", + args: []string{"cmd", "--fake-metrics", "{\"prefix-cache-hits\":0,\"prefix-cache-queries\":-1}", + "--config", "../../manifests/config.yaml"}, + expectedError: "fake metrics prefix-cache-queries cannot be negative", + }, + { + name: "invalid fake metrics: prefix-cache-hits without prefix-cache-queries", + args: []string{"cmd", "--fake-metrics", "{\"prefix-cache-hits\":100}", + "--config", "../../manifests/config.yaml"}, + expectedError: "fake metrics prefix-cache-hits and prefix-cache-queries must be specified together", + }, + { + name: "invalid fake metrics: prefix-cache-queries without prefix-cache-hits", + args: []string{"cmd", "--fake-metrics", "{\"prefix-cache-queries\":100}", + "--config", "../../manifests/config.yaml"}, + expectedError: "fake metrics prefix-cache-hits and prefix-cache-queries must be specified together", + }, + { + name: "invalid fake metrics: prefix-cache-hits exceeds prefix-cache-queries", + args: []string{"cmd", "--fake-metrics", "{\"prefix-cache-hits\":100,\"prefix-cache-queries\":50}", + "--config", "../../manifests/config.yaml"}, + expectedError: "fake metrics prefix-cache-hits cannot exceed prefix-cache-queries", + }, { name: "invalid echo mode with dataset", args: []string{"cmd", "--model", "test", "--dataset-path", "my/path", diff --git a/pkg/kv-cache/kv_cache.go b/pkg/kv-cache/kv_cache.go index ac73129..4b8bbcd 100644 --- a/pkg/kv-cache/kv_cache.go +++ b/pkg/kv-cache/kv_cache.go @@ -28,16 +28,26 @@ import ( "github.com/llm-d/llm-d-kv-cache-manager/pkg/kvcache/kvblock" ) +// PrefixCacheStats holds token-level prefix cache statistics for a single request, +// matching vLLM's PrefixCacheStats semantics where both fields count tokens. +type PrefixCacheStats struct { + // QueriedTokens is the total number of prompt tokens checked against the cache + QueriedTokens int + // CachedTokens is the number of prompt tokens that were already cached + CachedTokens int +} + type KVCacheHelper struct { - tokenizer tokenizer.Tokenizer - tokensProcessor kvblock.TokenProcessor // turns tokens to kv block keys - logger logr.Logger - blockCache *blockCache - blockSize int + tokenizer tokenizer.Tokenizer + tokensProcessor kvblock.TokenProcessor // turns tokens to kv block keys + logger logr.Logger + blockCache *blockCache + blockSize int + prefixCacheStatsChan chan PrefixCacheStats } func NewKVCacheHelper(config *common.Configuration, logger logr.Logger, usageChan chan float64, - tokenizer tokenizer.Tokenizer) (*KVCacheHelper, error) { + prefixCacheStatsChan chan PrefixCacheStats, tokenizer tokenizer.Tokenizer) (*KVCacheHelper, error) { tokenProcConfig := kvblock.DefaultTokenProcessorConfig() tokenProcConfig.BlockSize = config.TokenBlockSize if config.HashSeed != "" { @@ -50,11 +60,12 @@ func NewKVCacheHelper(config *common.Configuration, logger logr.Logger, usageCha return nil, fmt.Errorf("failed to create block cache: %w", err) } return &KVCacheHelper{ - tokenizer: tokenizer, - tokensProcessor: tokensProcessor, - blockCache: blockCache, - logger: logger, - blockSize: config.TokenBlockSize, + tokenizer: tokenizer, + tokensProcessor: tokensProcessor, + blockCache: blockCache, + logger: logger, + blockSize: config.TokenBlockSize, + prefixCacheStatsChan: prefixCacheStatsChan, }, nil } @@ -92,7 +103,8 @@ func (h *KVCacheHelper) OnRequestStart(vllmReq openaiserverapi.Request) (float64 return 0, err } - vllmReq.SetNumberOfCachedPromptTokens(nBlocksAlreadyInCache * h.blockSize) + cachedTokens := nBlocksAlreadyInCache * h.blockSize + vllmReq.SetNumberOfCachedPromptTokens(cachedTokens) totalBlocks := len(blockHashes) cachedBlocks := h.blockCache.countCachedBlockPrefix(blockHashes) @@ -102,6 +114,13 @@ func (h *KVCacheHelper) OnRequestStart(vllmReq openaiserverapi.Request) (float64 hitRate = float64(cachedBlocks) / float64(totalBlocks) } + if h.prefixCacheStatsChan != nil { + common.WriteToChannel(h.prefixCacheStatsChan, PrefixCacheStats{ + QueriedTokens: len(tokens), + CachedTokens: cachedTokens, + }, h.logger, "prefixCacheStatsChan") + } + return hitRate, nil } diff --git a/pkg/llm-d-inference-sim/context.go b/pkg/llm-d-inference-sim/context.go index 426849a..7da1d90 100644 --- a/pkg/llm-d-inference-sim/context.go +++ b/pkg/llm-d-inference-sim/context.go @@ -124,6 +124,12 @@ type metricsData struct { requestParamsMaxTokens *prometheus.HistogramVec // requestSuccessTotal is prometheus counter for total number of successful requests requestSuccessTotal *prometheus.CounterVec + // prefixCacheHits is prometheus counter for total cached tokens (prefix cache hits) + prefixCacheHits *prometheus.CounterVec + // prefixCacheQueries is prometheus counter for total queried tokens (prefix cache queries) + prefixCacheQueries *prometheus.CounterVec + // prefixCacheStatsChan is a channel to update prefix cache hit/query counters + prefixCacheStatsChan chan kvcache.PrefixCacheStats } // LoRAs usage info for requests execution @@ -190,7 +196,7 @@ func (s *simContext) initialize(ctx context.Context) error { if s.config.EnableKVCache { s.kvcacheHelper, err = kvcache.NewKVCacheHelper(s.config, s.logger, - s.metrics.kvCacheUsageChan, s.tokenizer) + s.metrics.kvCacheUsageChan, s.metrics.prefixCacheStatsChan, s.tokenizer) if err != nil { return err } diff --git a/pkg/llm-d-inference-sim/metrics.go b/pkg/llm-d-inference-sim/metrics.go index 8bb570f..98b3865 100644 --- a/pkg/llm-d-inference-sim/metrics.go +++ b/pkg/llm-d-inference-sim/metrics.go @@ -29,6 +29,7 @@ import ( "github.com/prometheus/client_golang/prometheus" "github.com/llm-d/llm-d-inference-sim/pkg/common" + kvcache "github.com/llm-d/llm-d-inference-sim/pkg/kv-cache" vllmapi "github.com/llm-d/llm-d-inference-sim/pkg/vllm-api" ) @@ -53,6 +54,8 @@ const ( reqWaitingMetricName = "vllm:num_requests_waiting" kvCacheUsageMetricName = "vllm:kv_cache_usage_perc" cacheConfigName = "vllm:cache_config_info" + prefixCacheHitsMetricName = "vllm:prefix_cache_hits" + prefixCacheQueriesMetricName = "vllm:prefix_cache_queries" ) // createAndRegisterPrometheus creates and registers prometheus metrics used by vLLM simulator @@ -271,6 +274,35 @@ func (s *simContext) createAndRegisterPrometheus(ctx context.Context) error { s.metrics.kvCacheUsageChan = make(chan float64, maxNumberOfRequests) go s.kvCacheUsageUpdater(ctx) + s.metrics.prefixCacheHits = prometheus.NewCounterVec( + prometheus.CounterOpts{ + Subsystem: "", + Name: prefixCacheHitsMetricName, + Help: "Prefix cache hits, in terms of number of cached tokens.", + }, + []string{vllmapi.PromLabelModelName}, + ) + if err := s.metrics.registry.Register(s.metrics.prefixCacheHits); err != nil { + s.logger.Error(err, "prometheus prefix_cache_hits counter register failed") + return err + } + + s.metrics.prefixCacheQueries = prometheus.NewCounterVec( + prometheus.CounterOpts{ + Subsystem: "", + Name: prefixCacheQueriesMetricName, + Help: "Prefix cache queries, in terms of number of queried tokens.", + }, + []string{vllmapi.PromLabelModelName}, + ) + if err := s.metrics.registry.Register(s.metrics.prefixCacheQueries); err != nil { + s.logger.Error(err, "prometheus prefix_cache_queries counter register failed") + return err + } + + s.metrics.prefixCacheStatsChan = make(chan kvcache.PrefixCacheStats, maxNumberOfRequests) + go s.prefixCacheStatsUpdater(ctx) + s.metrics.requestPromptTokens = prometheus.NewHistogramVec( prometheus.HistogramOpts{ Subsystem: "", @@ -457,6 +489,12 @@ func (s *simContext) setInitialPrometheusMetrics(cacheConfig *prometheus.GaugeVe if s.config.FakeMetrics.ReqDecodeTimeBucketValues != nil { s.initFakeHistogram(s.metrics.reqDecodeTime, common.RequestLatencyBucketsBoundaries, s.config.FakeMetrics.ReqDecodeTimeBucketValues) } + if s.config.FakeMetrics.PrefixCacheQueries != nil { + s.metrics.prefixCacheQueries.WithLabelValues(modelName).Add(float64(*s.config.FakeMetrics.PrefixCacheQueries)) + } + if s.config.FakeMetrics.PrefixCacheHits != nil { + s.metrics.prefixCacheHits.WithLabelValues(modelName).Add(float64(*s.config.FakeMetrics.PrefixCacheHits)) + } } s.metrics.runningRequests.WithLabelValues(modelName).Set(nRunningReqs) @@ -621,6 +659,32 @@ func (s *simContext) kvCacheUsageUpdater(ctx context.Context) { } } +// prefixCacheStatsUpdater increments prefix cache hit/query counters by listening on the relevant channel +func (s *simContext) prefixCacheStatsUpdater(ctx context.Context) { + for { + select { + case <-ctx.Done(): + return + case stats := <-s.metrics.prefixCacheStatsChan: + s.reportPrefixCacheStats(stats) + } + } +} + +// reportPrefixCacheStats increments the prefix cache counters +func (s *simContext) reportPrefixCacheStats(stats kvcache.PrefixCacheStats) { + if s.config.FakeMetrics != nil { + return + } + modelName := s.getDisplayedModelName(s.config.Model) + if s.metrics.prefixCacheQueries != nil { + s.metrics.prefixCacheQueries.WithLabelValues(modelName).Add(float64(stats.QueriedTokens)) + } + if s.metrics.prefixCacheHits != nil { + s.metrics.prefixCacheHits.WithLabelValues(modelName).Add(float64(stats.CachedTokens)) + } +} + // ttftUpdater updates the time to first token metric by listening on the relevant channel func (s *simContext) ttftUpdater(ctx context.Context) { for { diff --git a/pkg/llm-d-inference-sim/metrics_test.go b/pkg/llm-d-inference-sim/metrics_test.go index 4a3a820..0e98171 100644 --- a/pkg/llm-d-inference-sim/metrics_test.go +++ b/pkg/llm-d-inference-sim/metrics_test.go @@ -601,6 +601,57 @@ var _ = Describe("Simulator metrics", Ordered, func() { wg.Wait() }) + It("Should increment prefix cache counters for requests with shared prefixes", func() { + ctx := context.TODO() + args := []string{"cmd", "--model", qwenModelName, "--mode", common.ModeRandom, + "--enable-kvcache", "true", "--kv-cache-size", "64", "--block-size", "8", + "--time-to-first-token", "100"} + + client, err := startServerWithArgsAndEnv(ctx, common.ModeRandom, args, map[string]string{"POD_IP": "localhost"}) + Expect(err).NotTo(HaveOccurred()) + + openaiclient := openai.NewClient( + option.WithBaseURL(baseURL), + option.WithHTTPClient(client)) + + // Send requests sequentially so the cache is populated between requests + prompts := []string{ + "What is the weather like in Haifa today?", + "What is the weather like in Haifa today? Is it cold?", + } + for _, prompt := range prompts { + _, err = openaiclient.Completions.New(ctx, openai.CompletionNewParams{ + Prompt: openai.CompletionNewParamsPromptUnion{ + OfString: openai.String(prompt), + }, + Model: openai.CompletionNewParamsModel(qwenModelName), + }) + Expect(err).NotTo(HaveOccurred()) + } + + time.Sleep(500 * time.Millisecond) + metricsResp, err := client.Get(metricsUrl) + Expect(err).NotTo(HaveOccurred()) + Expect(metricsResp.StatusCode).To(Equal(http.StatusOK)) + + data, err := io.ReadAll(metricsResp.Body) + Expect(err).NotTo(HaveOccurred()) + metricsLines := strings.Split(string(data), "\n") + + // prefix_cache_queries should reflect total prompt tokens across both requests + queries := findIntMetric(metricsLines, getCountMetricPrefix(qwenModelName, prefixCacheQueriesMetricName)) + Expect(queries).NotTo(BeNil()) + Expect(*queries).To(BeNumerically(">", 0)) + + // The second request shares a prefix with the first, so hits should be non-zero + hits := findIntMetric(metricsLines, getCountMetricPrefix(qwenModelName, prefixCacheHitsMetricName)) + Expect(hits).NotTo(BeNil()) + Expect(*hits).To(BeNumerically(">", 0)) + + // Hits cannot exceed queries + Expect(*hits).To(BeNumerically("<=", *queries)) + }) + It("Should send correct kv cache config metrics", func() { ctx := context.TODO() args := []string{"cmd", "--model", qwenModelName, "--mode", common.ModeRandom, @@ -641,6 +692,8 @@ var _ = Describe("Simulator metrics", Ordered, func() { `"request-params-max-tokens":[10,20,30],` + `"ttft-buckets-values":[1,2,3],` + `"tpot-buckets-values":[0,0,1,2,3],` + + `"prefix-cache-hits":750,` + + `"prefix-cache-queries":2000,` + `"loras":[` + `{` + `"running":"lora4,lora2",` + @@ -713,6 +766,9 @@ var _ = Describe("Simulator metrics", Ordered, func() { Expect(metrics).To(ContainSubstring(`vllm:request_success_total{finish_reason="remote_decode",model_name="testmodel"} 0`)) Expect(metrics).To(ContainSubstring(`vllm:request_success_total{finish_reason="stop",model_name="testmodel"} 20`)) Expect(metrics).To(ContainSubstring(`vllm:request_success_total{finish_reason="tool_calls",model_name="testmodel"} 0`)) + + Expect(metrics).To(ContainSubstring(getCountMetricLine(testModel, prefixCacheHitsMetricName, 750))) + Expect(metrics).To(ContainSubstring(getCountMetricLine(testModel, prefixCacheQueriesMetricName, 2000))) }) It("Should use TotalPromptTokens and TotalGenerationTokens if provided", func() { ctx := context.TODO() @@ -748,6 +804,67 @@ var _ = Describe("Simulator metrics", Ordered, func() { }) }) + Context("fake prefix cache metrics", func() { + It("Should respond with fake prefix cache metrics to /metrics", func() { + ctx := context.TODO() + args := []string{"cmd", "--model", testModel, "--mode", common.ModeRandom, + "--fake-metrics", + `{"prefix-cache-hits":500,"prefix-cache-queries":1000}`, + } + + client, err := startServerWithArgs(ctx, args) + Expect(err).NotTo(HaveOccurred()) + + resp, err := client.Get(metricsUrl) + Expect(err).NotTo(HaveOccurred()) + Expect(resp.StatusCode).To(Equal(http.StatusOK)) + + data, err := io.ReadAll(resp.Body) + Expect(err).NotTo(HaveOccurred()) + metrics := string(data) + Expect(metrics).To(ContainSubstring(getCountMetricLine(testModel, prefixCacheQueriesMetricName, 1000))) + Expect(metrics).To(ContainSubstring(getCountMetricLine(testModel, prefixCacheHitsMetricName, 500))) + }) + + It("Should not update prefix cache counters from real requests when fake metrics are set", func() { + ctx := context.TODO() + args := []string{"cmd", "--model", qwenModelName, "--mode", common.ModeRandom, + "--enable-kvcache", "true", "--kv-cache-size", "16", "--block-size", "8", + "--fake-metrics", + `{"prefix-cache-hits":100,"prefix-cache-queries":200}`, + } + + client, err := startServerWithArgsAndEnv(ctx, common.ModeRandom, args, map[string]string{"POD_IP": "localhost"}) + Expect(err).NotTo(HaveOccurred()) + + openaiclient := openai.NewClient( + option.WithBaseURL(baseURL), + option.WithHTTPClient(client)) + + // Send a request — counters should not change from the fake values + _, err = openaiclient.Completions.New(ctx, openai.CompletionNewParams{ + Prompt: openai.CompletionNewParamsPromptUnion{ + OfString: openai.String("What is the weather like in Haifa today?"), + }, + Model: openai.CompletionNewParamsModel(qwenModelName), + }) + Expect(err).NotTo(HaveOccurred()) + + time.Sleep(500 * time.Millisecond) + + resp, err := client.Get(metricsUrl) + Expect(err).NotTo(HaveOccurred()) + Expect(resp.StatusCode).To(Equal(http.StatusOK)) + + data, err := io.ReadAll(resp.Body) + Expect(err).NotTo(HaveOccurred()) + metrics := string(data) + // Fake values should be unchanged — reportPrefixCacheStats returns early when FakeMetrics is set + Expect(metrics).To(ContainSubstring(getCountMetricLine(qwenModelName, prefixCacheQueriesMetricName, 200))) + Expect(metrics).To(ContainSubstring(getCountMetricLine(qwenModelName, prefixCacheHitsMetricName, 100))) + }) + }) + Context("fake ttft metrics", func() { It("Should respond with fake ttft metrics to /metrics", func() { ctx := context.TODO()