diff --git a/hibernate-reactive-core/src/main/java/org/hibernate/reactive/query/sqm/mutation/internal/cte/ReactiveCteInsertHandler.java b/hibernate-reactive-core/src/main/java/org/hibernate/reactive/query/sqm/mutation/internal/cte/ReactiveCteInsertHandler.java index 9361df67f..ecda80dbc 100644 --- a/hibernate-reactive-core/src/main/java/org/hibernate/reactive/query/sqm/mutation/internal/cte/ReactiveCteInsertHandler.java +++ b/hibernate-reactive-core/src/main/java/org/hibernate/reactive/query/sqm/mutation/internal/cte/ReactiveCteInsertHandler.java @@ -5,8 +5,12 @@ */ package org.hibernate.reactive.query.sqm.mutation.internal.cte; +import org.hibernate.dialect.Dialect; +import org.hibernate.dialect.PostgreSQLDialect; +import org.hibernate.dialect.function.array.DdlTypeHelper; import org.hibernate.internal.util.MutableObject; import org.hibernate.query.spi.DomainQueryExecutionContext; +import org.hibernate.query.spi.QueryOptions; import org.hibernate.query.sqm.internal.DomainParameterXref; import org.hibernate.query.sqm.internal.SqmJdbcExecutionContextAdapter; import org.hibernate.query.sqm.mutation.internal.cte.CteInsertHandler; @@ -18,15 +22,30 @@ import org.hibernate.reactive.sql.exec.internal.StandardReactiveSelectExecutor; import org.hibernate.reactive.sql.results.spi.ReactiveListResultsConsumer; import org.hibernate.sql.ast.tree.cte.CteTable; +import org.hibernate.sql.ast.tree.expression.JdbcParameter; +import org.hibernate.sql.exec.spi.ExecutionContext; +import org.hibernate.sql.exec.spi.JdbcLockStrategy; +import org.hibernate.sql.exec.spi.JdbcParameterBinder; +import org.hibernate.sql.exec.spi.JdbcParameterBinding; import org.hibernate.sql.exec.spi.JdbcParameterBindings; +import org.hibernate.sql.exec.spi.JdbcSelect; +import org.hibernate.sql.exec.spi.LoadedValuesCollector; +import org.hibernate.sql.exec.spi.StatementAccess; +import org.hibernate.sql.results.jdbc.spi.JdbcValuesMappingProducer; +import org.hibernate.type.spi.TypeConfiguration; import java.lang.invoke.MethodHandles; +import java.sql.Connection; +import java.util.List; +import java.util.Set; import java.util.concurrent.CompletionStage; public class ReactiveCteInsertHandler extends CteInsertHandler implements ReactiveHandler { private static final Log LOG = LoggerFactory.make( Log.class, MethodHandles.lookup() ); + private final Dialect dialect; + public ReactiveCteInsertHandler( CteTable cteTable, SqmInsertStatement sqmStatement, @@ -34,6 +53,7 @@ public ReactiveCteInsertHandler( DomainQueryExecutionContext context, MutableObject firstJdbcParameterBindingsConsumer) { super( cteTable, sqmStatement, domainParameterXref, context, firstJdbcParameterBindingsConsumer ); + this.dialect = context.getSession().getDialect(); } @Override @@ -45,11 +65,21 @@ public int execute(DomainQueryExecutionContext executionContext) { public CompletionStage reactiveExecute( JdbcParameterBindings jdbcParameterBindings, DomainQueryExecutionContext context) { + JdbcSelect jdbcSelect; + + if ( dialect instanceof PostgreSQLDialect ) { + // need to replace parameters with explicit casts see https://github.com/eclipse-vertx/vertx-sql-client/issues/1540 + jdbcSelect = new PostgreSQLCteMutationSelect( getSelect(), jdbcParameterBindings, context ); + } + else { + jdbcSelect = getSelect(); + } + return ( (ReactiveSharedSessionContractImplementor) context.getSession() ) - .reactiveAutoFlushIfRequired( getSelect().getAffectedTableNames() ) + .reactiveAutoFlushIfRequired( jdbcSelect.getAffectedTableNames() ) .thenCompose( v -> StandardReactiveSelectExecutor.INSTANCE .list( - getSelect(), + jdbcSelect, jdbcParameterBindings, SqmJdbcExecutionContextAdapter.omittingLockingAndPaging( context ), row -> row[0], @@ -60,4 +90,145 @@ public CompletionStage reactiveExecute( .thenApply( list -> ( (Number) list.get( 0 ) ).intValue() ) ); } + + /* + * A JdbcSelect wrapper that adds explicit type casts to parameters in the original SQL Select string. + * This is needed for PostgreSQL when using CTEs for mutation statements, + * See https://github.com/eclipse-vertx/vertx-sql-client/issues/1540 . + */ + public static class PostgreSQLCteMutationSelect implements JdbcSelect { + private final JdbcSelect delegate; + private final String sqlString; + + public PostgreSQLCteMutationSelect( + JdbcSelect delegate, + JdbcParameterBindings jdbcParameterBindings, + DomainQueryExecutionContext context) { + this.delegate = delegate; + this.sqlString = getSqlStringWithExplicitParameterCasting( delegate, jdbcParameterBindings, context ); + } + + private static String getSqlStringWithExplicitParameterCasting( + JdbcSelect original, + JdbcParameterBindings jdbcParameterBindings, + DomainQueryExecutionContext context) { + final StringBuilder newSelect = new StringBuilder( original.getSqlString() ); + addExplicitCastToParameters( + jdbcParameterBindings, + newSelect, + context.getSession().getSessionFactory().getMappingMetamodel().getTypeConfiguration() + ); + return newSelect.toString(); + } + + private static void addExplicitCastToParameters( + JdbcParameterBindings jdbcParameterBindings, + StringBuilder newSelect, + TypeConfiguration typeConfiguration) { + jdbcParameterBindings.visitBindings( + (jdbcParameter, jdbcParameterBinding) -> + addExplicitCastToParameter( + newSelect, + typeConfiguration, + jdbcParameter, + jdbcParameterBinding + ) + ); + } + + private static void addExplicitCastToParameter( + StringBuilder newSelect, + TypeConfiguration typeConfiguration, + JdbcParameter jdbcParameter, + JdbcParameterBinding jdbcParameterBinding) { + final int index = jdbcParameter.getParameterId() + 1; + final String parameterToReplace = "$" + index; + final int start = newSelect.indexOf( parameterToReplace ); + newSelect.replace( + start, + start + parameterToReplace.length(), + parameterToReplace + "::" + DdlTypeHelper.getCastTypeName( + jdbcParameterBinding.getBindType(), + typeConfiguration + ) + ); + } + + @Override + public JdbcValuesMappingProducer getJdbcValuesMappingProducer() { + return delegate.getJdbcValuesMappingProducer(); + } + + @Override + public JdbcLockStrategy getLockStrategy() { + return delegate.getLockStrategy(); + } + + @Override + public boolean usesLimitParameters() { + return delegate.usesLimitParameters(); + } + + @Override + public JdbcParameter getLimitParameter() { + return delegate.getLimitParameter(); + } + + @Override + public int getRowsToSkip() { + return delegate.getRowsToSkip(); + } + + @Override + public int getMaxRows() { + return delegate.getMaxRows(); + } + + @Override + public LoadedValuesCollector getLoadedValuesCollector() { + return delegate.getLoadedValuesCollector(); + } + + @Override + public void performPreActions( + StatementAccess jdbcStatementAccess, + Connection jdbcConnection, + ExecutionContext executionContext) { + delegate.performPreActions( jdbcStatementAccess, jdbcConnection, executionContext ); + } + + @Override + public void performPostAction( + boolean succeeded, + StatementAccess jdbcStatementAccess, + Connection jdbcConnection, + ExecutionContext executionContext) { + delegate.performPostAction( succeeded, jdbcStatementAccess, jdbcConnection, executionContext ); + } + + @Override + public boolean dependsOnParameterBindings() { + return delegate.dependsOnParameterBindings(); + } + + @Override + public boolean isCompatibleWith(JdbcParameterBindings jdbcParameterBindings, QueryOptions queryOptions) { + return delegate.isCompatibleWith( jdbcParameterBindings, queryOptions ); + } + + @Override + public Set getAffectedTableNames() { + return delegate.getAffectedTableNames(); + } + + @Override + public String getSqlString() { + return sqlString; + } + + @Override + public List getParameterBinders() { + return delegate.getParameterBinders(); + } + } } diff --git a/hibernate-reactive-core/src/test/java/org/hibernate/reactive/JoinedSubclassInheritanceTest.java b/hibernate-reactive-core/src/test/java/org/hibernate/reactive/JoinedSubclassInheritanceTest.java index 432535b51..e0f50c5e1 100644 --- a/hibernate-reactive-core/src/test/java/org/hibernate/reactive/JoinedSubclassInheritanceTest.java +++ b/hibernate-reactive-core/src/test/java/org/hibernate/reactive/JoinedSubclassInheritanceTest.java @@ -5,8 +5,6 @@ */ package org.hibernate.reactive; -import org.hibernate.reactive.annotations.DisabledFor; - import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.Test; @@ -30,7 +28,6 @@ import static java.util.concurrent.TimeUnit.MINUTES; import static org.assertj.core.api.Assertions.assertThat; -import static org.hibernate.reactive.containers.DatabaseConfiguration.DBType.POSTGRESQL; @Timeout(value = 10, timeUnit = MINUTES) @@ -189,7 +186,6 @@ public void testQueryUpdateWithParameters(VertxTestContext context) { } @Test - @DisabledFor(value = POSTGRESQL, reason = "https://github.com/hibernate/hibernate-reactive/issues/2412") public void testHqlInsertWithTransaction(VertxTestContext context) { final Integer id = 1; final String title = "Spell Book: A Comprehensive Guide to Magic Spells and Incantations"; @@ -212,6 +208,66 @@ public void testHqlInsertWithTransaction(VertxTestContext context) { ); } + @Test + public void testHqlUpdate(VertxTestContext context) { + final Integer id = 1; + final String title = "Spell Book: A Comprehensive Guide to Magic Spells and Incantations"; + test( context, getMutinySessionFactory().withTransaction( session -> session + .createMutationQuery( "insert into SpellBook (id, title, forbidden) values (:id, :title, :forbidden)" ) + .setParameter( "id", id ) + .setParameter( "title", title ) + .setParameter( "forbidden", true ) + .executeUpdate() ) + .call( () -> getMutinySessionFactory().withTransaction( session -> session + .createMutationQuery( + "update SpellBook set id = :id, title = :newTitle, forbidden = :newForbidden where forbidden = :forbidden and title = :title" ) + .setParameter( "id", id ) + .setParameter( "title", title ) + .setParameter( "forbidden", true ) + .setParameter( "newTitle", "new title" ) + .setParameter( "newForbidden", false ) + .executeUpdate() ) + ) + .call( () -> getMutinySessionFactory().withTransaction( session -> session + .createSelectionQuery( "from SpellBook g where g.id = :id ", SpellBook.class ) + .setParameter( "id", id ) + .getSingleResult() + .invoke( spellBook -> { + assertThat( spellBook.getTitle() ).isEqualTo( "new title" ); + assertThat( spellBook.forbidden ).isFalse(); + } + ) + ) ) + ); + } + + @Test + public void testHqlDelete(VertxTestContext context) { + final Integer id = 1; + final String title = "Spell Book: A Comprehensive Guide to Magic Spells and Incantations"; + test( context, getMutinySessionFactory().withTransaction( session -> session + .createMutationQuery( "insert into SpellBook (id, title, forbidden) values (:id, :title, :forbidden)" ) + .setParameter( "id", id ) + .setParameter( "title", title ) + .setParameter( "forbidden", true ) + .executeUpdate() ) + .call( () -> getMutinySessionFactory().withTransaction( session -> session + .createMutationQuery( + "delete from SpellBook where id = :id and forbidden = :forbidden and title = :title" ) + .setParameter( "id", id ) + .setParameter( "title", title ) + .setParameter( "forbidden", true ) + .executeUpdate() ) + ) + .call( () -> getMutinySessionFactory().withTransaction( session -> session + .createSelectionQuery( "from SpellBook g where g.id = :id ", SpellBook.class ) + .setParameter( "id", id ) + .getSingleResultOrNull() + .invoke( Assertions::assertNull ) ) + ) + ); + } + @Entity(name="SpellBook") @Table(name = "SpellBookJS") @DiscriminatorValue("S")