Skip to content

Commit d820746

Browse files
authored
Validate Query Vector in $vectorSearch stage builder (#2857)
1 parent 16cfe5a commit d820746

File tree

3 files changed

+59
-10
lines changed

3 files changed

+59
-10
lines changed

lib/Doctrine/ODM/MongoDB/Aggregation/Stage/VectorSearch.php

Lines changed: 31 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,15 @@
88
use Doctrine\ODM\MongoDB\Aggregation\Stage;
99
use Doctrine\ODM\MongoDB\Persisters\DocumentPersister;
1010
use Doctrine\ODM\MongoDB\Query\Expr;
11+
use InvalidArgumentException;
1112
use MongoDB\BSON\Binary;
1213
use MongoDB\BSON\Decimal128;
1314
use MongoDB\BSON\Int64;
1415

16+
use function array_is_list;
17+
use function is_array;
18+
use function sprintf;
19+
1520
/**
1621
* @phpstan-type Vector list<int|Int64>|list<float|Decimal128>|list<bool|0|1>|Binary
1722
* @phpstan-type VectorSearchStageExpression array{
@@ -28,12 +33,15 @@
2833
*/
2934
class VectorSearch extends Stage
3035
{
31-
private ?bool $exact = null;
32-
private ?Expr $filter = null;
33-
private ?string $index = null;
34-
private ?int $limit = null;
35-
private ?int $numCandidates = null;
36-
private ?string $path = null;
36+
/** @see Binary::TYPE_VECTOR introduced in ext-mongodb 2.2 */
37+
private const BINARY_TYPE_VECTOR = 9;
38+
39+
private ?bool $exact = null;
40+
private array|Expr|null $filter = null;
41+
private ?string $index = null;
42+
private ?int $limit = null;
43+
private ?int $numCandidates = null;
44+
private ?string $path = null;
3745
/** @phpstan-var Vector|null */
3846
private array|Binary|null $queryVector = null;
3947

@@ -50,8 +58,10 @@ public function getExpression(): array
5058
$params['exact'] = $this->exact;
5159
}
5260

53-
if ($this->filter !== null) {
61+
if ($this->filter instanceof Expr) {
5462
$params['filter'] = $this->filter->getQuery();
63+
} elseif (is_array($this->filter)) {
64+
$params['filter'] = $this->filter;
5565
}
5666

5767
if ($this->index !== null) {
@@ -84,7 +94,8 @@ public function exact(bool $exact): static
8494
return $this;
8595
}
8696

87-
public function filter(Expr $filter): static
97+
/** @phpstan-param array<string, mixed>|Expr $filter */
98+
public function filter(array|Expr $filter): static
8899
{
89100
$this->filter = $filter;
90101

@@ -122,6 +133,18 @@ public function path(string $path): static
122133
/** @phpstan-param Vector $queryVector */
123134
public function queryVector(array|Binary $queryVector): static
124135
{
136+
if ($queryVector === []) {
137+
throw new InvalidArgumentException('Query vector cannot be an empty array.');
138+
}
139+
140+
if (is_array($queryVector) && ! array_is_list($queryVector)) {
141+
throw new InvalidArgumentException('Query vector must be a list of numbers, got an associative array.');
142+
}
143+
144+
if ($queryVector instanceof Binary && $queryVector->getType() !== self::BINARY_TYPE_VECTOR) {
145+
throw new InvalidArgumentException(sprintf('Binary query vector must be of type 9 (Vector), got %d.', $queryVector->getType()));
146+
}
147+
125148
$this->queryVector = $queryVector;
126149

127150
return $this;

phpstan-baseline.neon

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -330,6 +330,12 @@ parameters:
330330
count: 1
331331
path: lib/Doctrine/ODM/MongoDB/Aggregation/Stage/UnionWith.php
332332

333+
-
334+
message: '#^Property Doctrine\\ODM\\MongoDB\\Aggregation\\Stage\\VectorSearch\:\:\$filter type has no value type specified in iterable type array\.$#'
335+
identifier: missingType.iterableValue
336+
count: 1
337+
path: lib/Doctrine/ODM/MongoDB/Aggregation/Stage/VectorSearch.php
338+
333339
-
334340
message: '#^Return type \(Doctrine\\ODM\\MongoDB\\Mapping\\ClassMetadataFactoryInterface\) of method Doctrine\\ODM\\MongoDB\\DocumentManager\:\:getMetadataFactory\(\) should be compatible with return type \(Doctrine\\Persistence\\Mapping\\ClassMetadataFactory\<Doctrine\\Persistence\\Mapping\\ClassMetadata\<object\>\>\) of method Doctrine\\Persistence\\ObjectManager\:\:getMetadataFactory\(\)$#'
335341
identifier: method.childReturnType

tests/Doctrine/ODM/MongoDB/Tests/Aggregation/Stage/VectorSearchTest.php

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,10 @@
1010
use Doctrine\ODM\MongoDB\Tests\BaseTestCase;
1111
use Documents\User;
1212
use Documents\VectorEmbedding;
13+
use InvalidArgumentException;
1314
use MongoDB\BSON\Binary;
1415
use MongoDB\BSON\VectorType;
16+
use PHPUnit\Framework\Attributes\TestWith;
1517

1618
use function enum_exists;
1719

@@ -27,12 +29,19 @@ public function testEmptyStage(): void
2729

2830
public function testExact(): void
2931
{
30-
[$stage, $builder] = $this->createVectorSearchStage();
32+
[$stage] = $this->createVectorSearchStage();
3133
$stage->exact(true);
3234
self::assertSame(['$vectorSearch' => ['exact' => true]], $stage->getExpression());
3335
}
3436

35-
public function testFilter(): void
37+
public function testFilterArray(): void
38+
{
39+
[$stage] = $this->createVectorSearchStage();
40+
$stage->filter(['status' => ['$ne' => 'inactive']]);
41+
self::assertSame(['$vectorSearch' => ['filter' => ['status' => ['$ne' => 'inactive']]]], $stage->getExpression());
42+
}
43+
44+
public function testFilterExpr(): void
3645
{
3746
[$stage, $builder] = $this->createVectorSearchStage();
3847
$stage->filter($builder->matchExpr()->field('status')->notEqual('inactive'));
@@ -97,6 +106,17 @@ public function testQueryVectorAcceptsBinary(): void
97106
self::assertSame(['$vectorSearch' => ['queryVector' => $binaryVector]], $stage->getExpression());
98107
}
99108

109+
#[TestWith([new Binary("\x03\x00\x01\x02\x03", Binary::TYPE_GENERIC), 'Binary query vector must be of type 9 (Vector), got 0.'])]
110+
#[TestWith([[1 => 1, 2 => 3], 'Query vector must be a list of numbers, got an associative array.'])]
111+
#[TestWith([[], 'Query vector cannot be an empty array.'])]
112+
public function testQueryVectorInvalidType(mixed $queryVector, string $message): void
113+
{
114+
[$stage] = $this->createVectorSearchStage();
115+
$this->expectException(InvalidArgumentException::class);
116+
$this->expectExceptionMessage($message);
117+
$stage->queryVector($queryVector);
118+
}
119+
100120
public function testChainingAllOptions(): void
101121
{
102122
[$stage, $builder] = $this->createVectorSearchStage();

0 commit comments

Comments
 (0)