Skip to content

Commit ecd1e07

Browse files
committed
Implement vectorize for xtensor Dot
1 parent cc6bed1 commit ecd1e07

File tree

2 files changed

+13
-0
lines changed

2 files changed

+13
-0
lines changed

pytensor/xtensor/math.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -566,6 +566,9 @@ def make_node(self, x, y):
566566
out = xtensor(dtype=out_dtype, shape=out_shape, dims=out_dims)
567567
return Apply(self, [x, y], [out])
568568

569+
def vectorize_node(self, node, *new_inputs, new_dim):
570+
return self(*new_inputs, return_list=True)
571+
569572

570573
def dot(x, y, dim: str | Sequence[str] | EllipsisType | None = None):
571574
"""Generalized dot product for XTensorVariables.

tests/xtensor/test_math.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -353,3 +353,13 @@ def test_xelemwise_vectorize():
353353

354354
check_vectorization([ab], [exp(ab)])
355355
check_vectorization([ab, bc], [ab + bc])
356+
357+
358+
def test_dot_vectorize():
359+
x = xtensor("x", dims=("a", "b"), shape=(2, 3))
360+
y = xtensor("y", dims=("b", "c"), shape=(3, 4))
361+
362+
check_vectorization([x, y], [x.dot(y)])
363+
check_vectorization([x, y], [x.dot(y, dim=("a", "b"))])
364+
check_vectorization([x, y], [x.dot(y, dim="c")])
365+
check_vectorization([x, y], [x.dot(y, dim=...)])

0 commit comments

Comments
 (0)