88use Doctrine \ODM \MongoDB \Aggregation \Stage ;
99use Doctrine \ODM \MongoDB \Persisters \DocumentPersister ;
1010use Doctrine \ODM \MongoDB \Query \Expr ;
11+ use InvalidArgumentException ;
1112use MongoDB \BSON \Binary ;
1213use MongoDB \BSON \Decimal128 ;
1314use MongoDB \BSON \Int64 ;
1415
16+ use function array_is_list ;
17+ use function is_array ;
18+ use function sprintf ;
19+
1520/**
1621 * @phpstan-type Vector list<int|Int64>|list<float|Decimal128>|list<bool|0|1>|Binary
1722 * @phpstan-type VectorSearchStageExpression array{
2833 */
2934class VectorSearch extends Stage
3035{
31- private ?bool $ exact = null ;
32- private ?Expr $ filter = null ;
33- private ?string $ index = null ;
34- private ?int $ limit = null ;
35- private ?int $ numCandidates = null ;
36- private ?string $ path = null ;
36+ /** @see Binary::TYPE_VECTOR introduced in ext-mongodb 2.2 */
37+ private const BINARY_TYPE_VECTOR = 9 ;
38+
39+ private ?bool $ exact = null ;
40+ private array |Expr |null $ filter = null ;
41+ private ?string $ index = null ;
42+ private ?int $ limit = null ;
43+ private ?int $ numCandidates = null ;
44+ private ?string $ path = null ;
3745 /** @phpstan-var Vector|null */
3846 private array |Binary |null $ queryVector = null ;
3947
@@ -50,8 +58,10 @@ public function getExpression(): array
5058 $ params ['exact ' ] = $ this ->exact ;
5159 }
5260
53- if ($ this ->filter !== null ) {
61+ if ($ this ->filter instanceof Expr ) {
5462 $ params ['filter ' ] = $ this ->filter ->getQuery ();
63+ } elseif (is_array ($ this ->filter )) {
64+ $ params ['filter ' ] = $ this ->filter ;
5565 }
5666
5767 if ($ this ->index !== null ) {
@@ -84,7 +94,8 @@ public function exact(bool $exact): static
8494 return $ this ;
8595 }
8696
87- public function filter (Expr $ filter ): static
97+ /** @phpstan-param array<string, mixed>|Expr $filter */
98+ public function filter (array |Expr $ filter ): static
8899 {
89100 $ this ->filter = $ filter ;
90101
@@ -122,6 +133,18 @@ public function path(string $path): static
122133 /** @phpstan-param Vector $queryVector */
123134 public function queryVector (array |Binary $ queryVector ): static
124135 {
136+ if ($ queryVector === []) {
137+ throw new InvalidArgumentException ('Query vector cannot be an empty array. ' );
138+ }
139+
140+ if (is_array ($ queryVector ) && ! array_is_list ($ queryVector )) {
141+ throw new InvalidArgumentException ('Query vector must be a list of numbers, got an associative array. ' );
142+ }
143+
144+ if ($ queryVector instanceof Binary && $ queryVector ->getType () !== self ::BINARY_TYPE_VECTOR ) {
145+ throw new InvalidArgumentException (sprintf ('Binary query vector must be of type 9 (Vector), got %d. ' , $ queryVector ->getType ()));
146+ }
147+
125148 $ this ->queryVector = $ queryVector ;
126149
127150 return $ this ;
0 commit comments