Skip to content

Commit 47d8a98

Browse files
fix: Make sure the vector search feature is null-safe, too. (#3044)
This change introduces the necessary checks and annotations to make a build utilising Nullaway succeed: ``` ./mvnw clean verify -Pnullaway ``` Also, `VectorSearchFragment` has been made `public`, as it is exposed via constructor of the public class `QueryFragmentsAndParameters`. Signed-off-by: Michael Simons <michael@simons.ac>
1 parent f7c355c commit 47d8a98

File tree

7 files changed

+25
-32
lines changed

7 files changed

+25
-32
lines changed

src/main/java/org/springframework/data/neo4j/core/Neo4jTemplate.java

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -449,7 +449,7 @@ public <T> T save(T instance) {
449449
});
450450
}
451451

452-
private <T> T saveImpl(T instance, Collection<PropertyFilter.ProjectedPath> includedProperties,
452+
private <T> T saveImpl(T instance, @Nullable Collection<PropertyFilter.ProjectedPath> includedProperties,
453453
@Nullable NestedRelationshipProcessingStateMachine stateMachine) {
454454

455455
if (stateMachine != null && stateMachine.hasProcessedValue(instance)) {
@@ -1489,12 +1489,7 @@ private Optional<Neo4jClient.RecordFetchSpec<T>> createFetchSpec() {
14891489
statement = nodesAndRelationshipsById.toStatement(entityMetaData);
14901490
}
14911491
else {
1492-
if (queryFragmentsAndParameters.hasVectorSearchFragment()) {
1493-
statement = queryFragments.toStatement(queryFragmentsAndParameters.getVectorSearchFragment());
1494-
}
1495-
else {
1496-
statement = queryFragments.toStatement();
1497-
}
1492+
statement = queryFragmentsAndParameters.toStatement();
14981493
}
14991494
cypherQuery = Neo4jTemplate.this.renderer.render(statement);
15001495
finalParameters = TemplateSupport.mergeParameters(statement, finalParameters);

src/main/java/org/springframework/data/neo4j/core/ReactiveNeo4jTemplate.java

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -469,13 +469,13 @@ <T, R> Flux<R> doSave(Iterable<R> instances, Class<T> domainType) {
469469
});
470470
}
471471

472-
private <T> Mono<T> saveImpl(T instance, Collection<PropertyFilter.ProjectedPath> includedProperties,
472+
private <T> Mono<T> saveImpl(T instance, @Nullable Collection<PropertyFilter.ProjectedPath> includedProperties,
473473
@Nullable NestedRelationshipProcessingStateMachine stateMachine) {
474474
return saveImpl(instance, includedProperties, stateMachine, new HashSet<>());
475475
}
476476

477477
@SuppressWarnings("deprecation")
478-
private <T> Mono<T> saveImpl(T instance, Collection<PropertyFilter.ProjectedPath> includedProperties,
478+
private <T> Mono<T> saveImpl(T instance, @Nullable Collection<PropertyFilter.ProjectedPath> includedProperties,
479479
@Nullable NestedRelationshipProcessingStateMachine stateMachine, Collection<Object> knownRelationshipsIds) {
480480

481481
if (stateMachine != null && stateMachine.hasProcessedValue(instance)) {
@@ -1408,13 +1408,7 @@ public <T> Mono<ExecutableQuery<T>> toExecutableQuery(PreparedQuery<T> preparedQ
14081408
return new DefaultReactiveExecutableQuery<>(preparedQuery, fetchSpec);
14091409
});
14101410
}
1411-
Statement statement = null;
1412-
if (queryFragmentsAndParameters.hasVectorSearchFragment()) {
1413-
statement = queryFragments.toStatement(queryFragmentsAndParameters.getVectorSearchFragment());
1414-
}
1415-
else {
1416-
statement = queryFragments.toStatement();
1417-
}
1411+
Statement statement = queryFragmentsAndParameters.toStatement();
14181412
cypherQuery = this.renderer.render(statement);
14191413
finalParameters = TemplateSupport.mergeParameters(statement, finalParameters);
14201414
}

src/main/java/org/springframework/data/neo4j/core/TemplateSupport.java

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -137,10 +137,10 @@ else if (candidate != type) {
137137
return candidate;
138138
}
139139

140-
static PropertyFilter computeIncludePropertyPredicate(Collection<PropertyFilter.ProjectedPath> includedProperties,
141-
NodeDescription<?> nodeDescription) {
140+
static PropertyFilter computeIncludePropertyPredicate(
141+
@Nullable Collection<PropertyFilter.ProjectedPath> includedProperties, NodeDescription<?> nodeDescription) {
142142

143-
return PropertyFilter.from(includedProperties, nodeDescription);
143+
return PropertyFilter.from(Objects.requireNonNullElseGet(includedProperties, List::of), nodeDescription);
144144
}
145145

146146
static void updateVersionPropertyIfPossible(Neo4jPersistentEntity<?> entityMetaData,
@@ -207,8 +207,8 @@ static Map<String, Object> mergeParameters(Statement statement, Map<String, Obje
207207
* @return a new binder function that only works on the included properties.
208208
*/
209209
static <T> FilteredBinderFunction<T> createAndApplyPropertyFilter(
210-
Collection<PropertyFilter.ProjectedPath> includedProperties, Neo4jPersistentEntity<?> entityMetaData,
211-
Function<T, Map<String, Object>> binderFunction) {
210+
@Nullable Collection<PropertyFilter.ProjectedPath> includedProperties,
211+
Neo4jPersistentEntity<?> entityMetaData, Function<T, Map<String, Object>> binderFunction) {
212212

213213
PropertyFilter includeProperty = TemplateSupport.computeIncludePropertyPredicate(includedProperties,
214214
entityMetaData);

src/main/java/org/springframework/data/neo4j/repository/query/CypherQueryCreator.java

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,7 @@ protected QueryFragmentsAndParameters complete(@Nullable Condition condition, So
208208
if (this.keysetRequiresSort && theSort.isUnsorted()) {
209209
throw new UnsupportedOperationException("Unsorted keyset based scrolling is not supported.");
210210
}
211-
if (this.queryMethod.hasVectorSearchAnnotation()) {
211+
if (this.queryMethod.hasVectorSearchAnnotation() && this.vectorSearchParameter != null) {
212212
var vectorSearchAnnotation = this.queryMethod.getVectorSearchAnnotation().orElseThrow();
213213
var indexName = vectorSearchAnnotation.indexName();
214214
var numberOfNodes = vectorSearchAnnotation.numberOfNodes();
@@ -219,9 +219,8 @@ protected QueryFragmentsAndParameters complete(@Nullable Condition condition, So
219219
}
220220
var vectorSearchFragment = new VectorSearchFragment(indexName, numberOfNodes,
221221
(this.scoreParameter != null) ? this.scoreParameter.getValue() : null);
222-
var queryFragmentsAndParameters = new QueryFragmentsAndParameters(this.nodeDescription, queryFragments,
223-
vectorSearchFragment, convertedParameters, theSort);
224-
return queryFragmentsAndParameters;
222+
return new QueryFragmentsAndParameters(this.nodeDescription, queryFragments, vectorSearchFragment,
223+
convertedParameters, theSort);
225224
}
226225
return new QueryFragmentsAndParameters(this.nodeDescription, queryFragments, convertedParameters, theSort);
227226
}

src/main/java/org/springframework/data/neo4j/repository/query/QueryFragments.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,7 @@ public Statement toStatement(VectorSearchFragment vectorSearchFragment) {
208208
.with(Cypher.name("node").as(((Node) this.matchOn.get(0)).getRequiredSymbolicName().getValue()),
209209
Cypher.name("score").as(Constants.NAME_OF_SCORE));
210210

211-
StatementBuilder.OngoingReadingWithoutWhere match = null;
211+
StatementBuilder.OngoingReadingWithoutWhere match;
212212
if (vectorSearchFragment.hasScore()) {
213213
match = vectorSearch
214214
.where(Cypher.raw(Constants.NAME_OF_SCORE)

src/main/java/org/springframework/data/neo4j/repository/query/QueryFragmentsAndParameters.java

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
import org.neo4j.cypherdsl.core.PatternElement;
3434
import org.neo4j.cypherdsl.core.RelationshipPattern;
3535
import org.neo4j.cypherdsl.core.SortItem;
36+
import org.neo4j.cypherdsl.core.Statement;
3637

3738
import org.springframework.data.domain.Example;
3839
import org.springframework.data.domain.KeysetScrollPosition;
@@ -63,6 +64,7 @@ public final class QueryFragmentsAndParameters {
6364

6465
private final QueryFragments queryFragments;
6566

67+
@Nullable
6668
private final VectorSearchFragment vectorSearchFragment;
6769

6870
@Nullable
@@ -76,7 +78,7 @@ public final class QueryFragmentsAndParameters {
7678
private NodeDescription<?> nodeDescription;
7779

7880
public QueryFragmentsAndParameters(@Nullable NodeDescription<?> nodeDescription, QueryFragments queryFragments,
79-
VectorSearchFragment vectorSearchFragment, Map<String, Object> parameters, @Nullable Sort sort) {
81+
@Nullable VectorSearchFragment vectorSearchFragment, Map<String, Object> parameters, @Nullable Sort sort) {
8082
this.nodeDescription = nodeDescription;
8183
this.queryFragments = queryFragments;
8284
this.vectorSearchFragment = vectorSearchFragment;
@@ -402,10 +404,6 @@ public boolean hasVectorSearchFragment() {
402404
return this.vectorSearchFragment != null;
403405
}
404406

405-
public VectorSearchFragment getVectorSearchFragment() {
406-
return this.vectorSearchFragment;
407-
}
408-
409407
@Nullable public String getCypherQuery() {
410408
return this.cypherQuery;
411409
}
@@ -418,4 +416,11 @@ public Sort getSort() {
418416
return this.sort;
419417
}
420418

419+
public Statement toStatement() {
420+
if (this.hasVectorSearchFragment()) {
421+
return this.queryFragments.toStatement(Objects.requireNonNull(this.vectorSearchFragment));
422+
}
423+
return this.queryFragments.toStatement();
424+
}
425+
421426
}

src/main/java/org/springframework/data/neo4j/repository/query/VectorSearchFragment.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
* @param numberOfNodes number of nodes to fetch from the index search
2626
* @param score score filter
2727
*/
28-
record VectorSearchFragment(String indexName, int numberOfNodes, @Nullable Double score) {
28+
public record VectorSearchFragment(String indexName, int numberOfNodes, @Nullable Double score) {
2929

3030
boolean hasScore() {
3131
return this.score != null;

0 commit comments

Comments
 (0)