forked from zendev-sh/goai
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathembed.go
More file actions
149 lines (124 loc) · 3.62 KB
/
embed.go
File metadata and controls
149 lines (124 loc) · 3.62 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
package goai
import (
"context"
"fmt"
"slices"
"sync"
"github.com/zendev-sh/goai/provider"
)
// EmbedResult is the result of a single embedding generation.
type EmbedResult struct {
// Embedding is the generated vector.
Embedding []float64
// Usage tracks token consumption.
Usage provider.Usage
}
// EmbedManyResult is the result of multiple embedding generations.
type EmbedManyResult struct {
// Embeddings contains the generated vectors (one per input value).
Embeddings [][]float64
// Usage is the aggregated token consumption.
Usage provider.Usage
}
// Embed generates an embedding vector for a single value.
func Embed(ctx context.Context, model provider.EmbeddingModel, value string, opts ...Option) (*EmbedResult, error) {
o := applyOptions(opts...)
if o.Timeout > 0 {
var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(ctx, o.Timeout)
defer cancel()
}
embedParams := provider.EmbedParams{
ProviderOptions: o.EmbeddingProviderOptions,
}
result, err := withRetry(ctx, o.MaxRetries, func() (*provider.EmbedResult, error) {
return model.DoEmbed(ctx, []string{value}, embedParams)
})
if err != nil {
return nil, err
}
if len(result.Embeddings) == 0 {
return nil, fmt.Errorf("goai: no embedding returned")
}
return &EmbedResult{
Embedding: result.Embeddings[0],
Usage: result.Usage,
}, nil
}
// EmbedMany generates embedding vectors for multiple values.
// Auto-chunks when values exceed the model's MaxValuesPerCall limit
// and processes chunks in parallel (controlled by WithMaxParallelCalls).
func EmbedMany(ctx context.Context, model provider.EmbeddingModel, values []string, opts ...Option) (*EmbedManyResult, error) {
o := applyOptions(opts...)
if o.Timeout > 0 {
var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(ctx, o.Timeout)
defer cancel()
}
embedParams := provider.EmbedParams{
ProviderOptions: o.EmbeddingProviderOptions,
}
maxPerCall := model.MaxValuesPerCall()
// Single call when no chunking needed.
if maxPerCall <= 0 || len(values) <= maxPerCall {
result, err := withRetry(ctx, o.MaxRetries, func() (*provider.EmbedResult, error) {
return model.DoEmbed(ctx, values, embedParams)
})
if err != nil {
return nil, err
}
return &EmbedManyResult{
Embeddings: result.Embeddings,
Usage: result.Usage,
}, nil
}
// Split into chunks.
var chunks [][]string
for chunk := range slices.Chunk(values, maxPerCall) {
chunks = append(chunks, chunk)
}
// Process chunks with bounded parallelism.
maxParallel := o.MaxParallelCalls
if maxParallel <= 0 {
maxParallel = 4
}
type chunkResult struct {
result *provider.EmbedResult
err error
}
results := make([]chunkResult, len(chunks))
sem := make(chan struct{}, maxParallel)
var wg sync.WaitGroup
for i, chunk := range chunks {
wg.Go(func() {
// Use select to avoid blocking forever if ctx is cancelled
// while waiting for the semaphore.
select {
case sem <- struct{}{}:
defer func() { <-sem }()
case <-ctx.Done():
results[i] = chunkResult{err: ctx.Err()}
return
}
r, err := withRetry(ctx, o.MaxRetries, func() (*provider.EmbedResult, error) {
return model.DoEmbed(ctx, chunk, embedParams)
})
results[i] = chunkResult{result: r, err: err}
})
}
wg.Wait()
// Combine results in order.
var allEmbeddings [][]float64
var totalUsage provider.Usage
for _, cr := range results {
if cr.err != nil {
return nil, cr.err
}
allEmbeddings = append(allEmbeddings, cr.result.Embeddings...)
totalUsage = addUsage(totalUsage, cr.result.Usage)
}
return &EmbedManyResult{
Embeddings: allEmbeddings,
Usage: totalUsage,
}, nil
}