Skip to content

Commit 4a19be1

Browse files
committed
Improved distance functions tests for Ent
1 parent 7c39a9b commit 4a19be1

File tree

1 file changed

+70
-0
lines changed

1 file changed

+70
-0
lines changed

ent_test.go

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,4 +87,74 @@ func TestEnt(t *testing.T) {
8787
if !reflect.DeepEqual(items[1].SparseEmbedding.Slice(), []float32{1, 1, 2}) {
8888
t.Error()
8989
}
90+
91+
items, err = client.Item.
92+
Query().
93+
Order(func(s *sql.Selector) {
94+
s.OrderExpr(entvec.MaxInnerProduct("embedding", embedding))
95+
}).
96+
Limit(5).
97+
All(ctx)
98+
if err != nil {
99+
panic(err)
100+
}
101+
if items[0].ID != 2 || items[1].ID != 3 || items[2].ID != 1 {
102+
t.Error()
103+
}
104+
105+
items, err = client.Item.
106+
Query().
107+
Order(func(s *sql.Selector) {
108+
s.OrderExpr(entvec.CosineDistance("embedding", embedding))
109+
}).
110+
Limit(5).
111+
All(ctx)
112+
if err != nil {
113+
panic(err)
114+
}
115+
if items[0].ID != 1 || items[1].ID != 2 || items[2].ID != 3 {
116+
t.Error()
117+
}
118+
119+
items, err = client.Item.
120+
Query().
121+
Order(func(s *sql.Selector) {
122+
s.OrderExpr(entvec.L1Distance("embedding", embedding))
123+
}).
124+
Limit(5).
125+
All(ctx)
126+
if err != nil {
127+
panic(err)
128+
}
129+
if items[0].ID != 1 || items[1].ID != 3 || items[2].ID != 2 {
130+
t.Error()
131+
}
132+
133+
items, err = client.Item.
134+
Query().
135+
Order(func(s *sql.Selector) {
136+
s.OrderExpr(entvec.HammingDistance("binary_embedding", "101"))
137+
}).
138+
Limit(5).
139+
All(ctx)
140+
if err != nil {
141+
panic(err)
142+
}
143+
if items[0].ID != 2 || items[1].ID != 3 || items[2].ID != 1 {
144+
t.Error()
145+
}
146+
147+
items, err = client.Item.
148+
Query().
149+
Order(func(s *sql.Selector) {
150+
s.OrderExpr(entvec.JaccardDistance("binary_embedding", "101"))
151+
}).
152+
Limit(5).
153+
All(ctx)
154+
if err != nil {
155+
panic(err)
156+
}
157+
if items[0].ID != 2 || items[1].ID != 3 || items[2].ID != 1 {
158+
t.Error()
159+
}
90160
}

0 commit comments

Comments
 (0)