Skip to content

Commit 7db899e

Browse files
committed
Add session variables support
1 parent aff4cc5 commit 7db899e

File tree

7 files changed

+479
-56
lines changed

7 files changed

+479
-56
lines changed

src/main/java/io/asyncer/r2dbc/mysql/MySqlConnection.java

Lines changed: 79 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,6 @@
3737
import io.r2dbc.spi.TransactionDefinition;
3838
import io.r2dbc.spi.ValidationDepth;
3939
import org.jetbrains.annotations.Nullable;
40-
import org.reactivestreams.Publisher;
4140
import reactor.core.publisher.Flux;
4241
import reactor.core.publisher.Mono;
4342
import reactor.core.publisher.SynchronousSink;
@@ -46,6 +45,7 @@
4645
import java.time.Duration;
4746
import java.time.ZoneId;
4847
import java.time.ZoneOffset;
48+
import java.util.List;
4949
import java.util.function.BiConsumer;
5050
import java.util.function.Function;
5151
import java.util.function.Predicate;
@@ -447,12 +447,12 @@ public Mono<Void> setStatementTimeout(Duration timeout) {
447447
}
448448

449449
return Mono.error(
450-
new R2dbcNonTransientResourceException(
451-
"Statement timeout is not supported by server version " + serverVersion,
452-
"HY000",
453-
-1,
454-
sql
455-
)
450+
new R2dbcNonTransientResourceException(
451+
"Statement timeout is not supported by server version " + serverVersion,
452+
"HY000",
453+
-1,
454+
sql
455+
)
456456
);
457457
}
458458

@@ -467,38 +467,23 @@ boolean isSessionAutoCommit() {
467467
/**
468468
* Initialize a {@link MySqlConnection} after login.
469469
*
470-
* @param client must be logged-in.
471-
* @param codecs the {@link Codecs}.
472-
* @param context must be initialized.
473-
* @param database the database that should be lazy init.
474-
* @param queryCache the cache of {@link Query}.
475-
* @param prepareCache the cache of server-preparing result.
476-
* @param prepare judging for prefer use prepare statement to execute simple query.
470+
* @param client must be logged-in.
471+
* @param codecs the {@link Codecs}.
472+
* @param context must be initialized.
473+
* @param database the database that should be lazy init.
474+
* @param queryCache the cache of {@link Query}.
475+
* @param prepareCache the cache of server-preparing result.
476+
* @param sessionVariables the session variables to set.
477+
* @param prepare judging for prefer use prepare statement to execute simple query.
477478
* @return a {@link Mono} will emit an initialized {@link MySqlConnection}.
478479
*/
479480
static Mono<MySqlConnection> init(
480481
Client client, Codecs codecs, ConnectionContext context, String database,
481482
QueryCache queryCache, PrepareCache prepareCache,
482-
@Nullable Predicate<String> prepare
483+
List<String> sessionVariables, @Nullable Predicate<String> prepare
483484
) {
484-
StringBuilder query = new StringBuilder(128)
485-
.append("SELECT ")
486-
.append(transactionIsolationColumn(context))
487-
.append(",@@innodb_lock_wait_timeout AS l,@@version_comment AS v");
488-
489-
Function<MySqlResult, Publisher<InitData>> handler;
490-
491-
if (context.shouldSetServerZoneId()) {
492-
query.append(",@@system_time_zone AS s,@@time_zone AS t");
493-
handler = MySqlConnection::fullInit;
494-
} else {
495-
handler = MySqlConnection::init;
496-
}
497-
498-
Mono<MySqlConnection> connection = new TextSimpleStatement(client, codecs, context, query.toString())
499-
.execute()
500-
.flatMap(handler)
501-
.last()
485+
Mono<MySqlConnection> connection = initSessionVariables(client, sessionVariables)
486+
.then(loadSessionVariables(client, codecs, context))
502487
.map(data -> {
503488
ZoneId serverZoneId = data.serverZoneId;
504489
if (serverZoneId != null) {
@@ -514,29 +499,83 @@ static Mono<MySqlConnection> init(
514499
return connection;
515500
}
516501

517-
requireNonEmpty(database, "database must not be empty");
502+
return connection.flatMap(c -> initDatabase(client, database).thenReturn(c));
503+
}
504+
505+
private static Mono<Void> initSessionVariables(Client client, List<String> sessionVariables) {
506+
if (sessionVariables.isEmpty()) {
507+
return Mono.empty();
508+
}
509+
510+
StringBuilder query = new StringBuilder(sessionVariables.size() * 32 + 16).append("SET ");
511+
boolean comma = false;
512+
513+
for (String variable : sessionVariables) {
514+
if (variable.isEmpty()) {
515+
continue;
516+
}
517+
518+
if (comma) {
519+
query.append(',');
520+
} else {
521+
comma = true;
522+
}
523+
524+
if (variable.startsWith("@")) {
525+
query.append(variable);
526+
} else {
527+
query.append("SESSION ").append(variable);
528+
}
529+
}
530+
531+
return QueryFlow.executeVoid(client, query.toString());
532+
}
533+
534+
private static Mono<InitData> loadSessionVariables(
535+
Client client, Codecs codecs, ConnectionContext context
536+
) {
537+
StringBuilder query = new StringBuilder(160)
538+
.append("SELECT ")
539+
.append(transactionIsolationColumn(context))
540+
.append(",@@innodb_lock_wait_timeout AS l,@@version_comment AS v");
541+
542+
Function<MySqlResult, Flux<InitData>> handler;
543+
544+
if (context.shouldSetServerZoneId()) {
545+
query.append(",@@system_time_zone AS s,@@time_zone AS t");
546+
handler = MySqlConnection::fullInit;
547+
} else {
548+
handler = MySqlConnection::init;
549+
}
550+
551+
return new TextSimpleStatement(client, codecs, context, query.toString())
552+
.execute()
553+
.flatMap(handler)
554+
.last();
555+
}
518556

519-
return connection.flatMap(conn -> client.exchange(new InitDbMessage(database), INIT_DB)
557+
private static Mono<Void> initDatabase(Client client, String database) {
558+
return client.exchange(new InitDbMessage(database), INIT_DB)
520559
.last()
521560
.flatMap(success -> {
522561
if (success) {
523-
return Mono.just(conn);
562+
return Mono.empty();
524563
}
525564

526565
String sql = "CREATE DATABASE IF NOT EXISTS " + StringUtils.quoteIdentifier(database);
527566

528567
return QueryFlow.executeVoid(client, sql)
529-
.then(client.exchange(new InitDbMessage(database), INIT_DB_AFTER).then(Mono.just(conn)));
530-
}));
568+
.then(client.exchange(new InitDbMessage(database), INIT_DB_AFTER).then());
569+
});
531570
}
532571

533-
private static Publisher<InitData> init(MySqlResult r) {
572+
private static Flux<InitData> init(MySqlResult r) {
534573
return r.map((row, meta) -> new InitData(convertIsolationLevel(row.get(0, String.class)),
535574
convertLockWaitTimeout(row.get(1, Long.class)),
536575
row.get(2, String.class), null));
537576
}
538577

539-
private static Publisher<InitData> fullInit(MySqlResult r) {
578+
private static Flux<InitData> fullInit(MySqlResult r) {
540579
return r.map((row, meta) -> {
541580
IsolationLevel level = convertIsolationLevel(row.get(0, String.class));
542581
long lockWaitTimeout = convertLockWaitTimeout(row.get(1, Long.class));

src/main/java/io/asyncer/r2dbc/mysql/MySqlConnectionConfiguration.java

Lines changed: 40 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import io.asyncer.r2dbc.mysql.constant.SslMode;
2121
import io.asyncer.r2dbc.mysql.constant.ZeroDateOption;
2222
import io.asyncer.r2dbc.mysql.extension.Extension;
23+
import io.asyncer.r2dbc.mysql.internal.util.InternalArrays;
2324
import io.netty.handler.ssl.SslContextBuilder;
2425
import org.jetbrains.annotations.Nullable;
2526
import org.reactivestreams.Publisher;
@@ -94,6 +95,8 @@ public final class MySqlConnectionConfiguration {
9495
@Nullable
9596
private final Predicate<String> preferPrepareStatement;
9697

98+
private final List<String> sessionVariables;
99+
97100
@Nullable
98101
private final Path loadLocalInfilePath;
99102

@@ -120,6 +123,7 @@ private MySqlConnectionConfiguration(
120123
ZeroDateOption zeroDateOption, @Nullable ZoneId serverZoneId,
121124
String user, @Nullable CharSequence password, @Nullable String database,
122125
boolean createDatabaseIfNotExist, @Nullable Predicate<String> preferPrepareStatement,
126+
List<String> sessionVariables,
123127
@Nullable Path loadLocalInfilePath, int localInfileBufferSize,
124128
int queryCacheSize, int prepareCacheSize,
125129
Set<CompressionAlgorithm> compressionAlgorithms, int zstdCompressionLevel,
@@ -140,13 +144,14 @@ private MySqlConnectionConfiguration(
140144
this.database = database == null || database.isEmpty() ? "" : database;
141145
this.createDatabaseIfNotExist = createDatabaseIfNotExist;
142146
this.preferPrepareStatement = preferPrepareStatement;
147+
this.sessionVariables = sessionVariables;
143148
this.loadLocalInfilePath = loadLocalInfilePath;
144149
this.localInfileBufferSize = localInfileBufferSize;
145150
this.queryCacheSize = queryCacheSize;
146151
this.prepareCacheSize = prepareCacheSize;
147152
this.compressionAlgorithms = compressionAlgorithms;
148153
this.zstdCompressionLevel = zstdCompressionLevel;
149-
this.loopResources = loopResources == null? TcpResources.get() : loopResources;
154+
this.loopResources = loopResources == null ? TcpResources.get() : loopResources;
150155
this.extensions = extensions;
151156
this.passwordPublisher = passwordPublisher;
152157
}
@@ -220,6 +225,10 @@ Predicate<String> getPreferPrepareStatement() {
220225
return preferPrepareStatement;
221226
}
222227

228+
List<String> getSessionVariables() {
229+
return sessionVariables;
230+
}
231+
223232
@Nullable
224233
Path getLoadLocalInfilePath() {
225234
return loadLocalInfilePath;
@@ -281,6 +290,7 @@ public boolean equals(Object o) {
281290
database.equals(that.database) &&
282291
createDatabaseIfNotExist == that.createDatabaseIfNotExist &&
283292
Objects.equals(preferPrepareStatement, that.preferPrepareStatement) &&
293+
sessionVariables.equals(that.sessionVariables) &&
284294
Objects.equals(loadLocalInfilePath, that.loadLocalInfilePath) &&
285295
localInfileBufferSize == that.localInfileBufferSize &&
286296
queryCacheSize == that.queryCacheSize &&
@@ -296,9 +306,9 @@ public boolean equals(Object o) {
296306
public int hashCode() {
297307
return Objects.hash(isHost, domain, port, ssl, tcpKeepAlive, tcpNoDelay, connectTimeout,
298308
serverZoneId, zeroDateOption, user, password, database, createDatabaseIfNotExist,
299-
preferPrepareStatement, loadLocalInfilePath, localInfileBufferSize, queryCacheSize,
300-
prepareCacheSize, compressionAlgorithms, zstdCompressionLevel, loopResources,
301-
extensions, passwordPublisher);
309+
preferPrepareStatement, sessionVariables, loadLocalInfilePath,
310+
localInfileBufferSize, queryCacheSize, prepareCacheSize, compressionAlgorithms,
311+
zstdCompressionLevel, loopResources, extensions, passwordPublisher);
302312
}
303313

304314
@Override
@@ -310,6 +320,7 @@ public String toString() {
310320
", zeroDateOption=" + zeroDateOption + ", user='" + user + "', password=" + password +
311321
", database='" + database + "', createDatabaseIfNotExist=" + createDatabaseIfNotExist +
312322
", preferPrepareStatement=" + preferPrepareStatement +
323+
", sessionVariables=" + sessionVariables +
313324
", loadLocalInfilePath=" + loadLocalInfilePath +
314325
", localInfileBufferSize=" + localInfileBufferSize +
315326
", queryCacheSize=" + queryCacheSize + ", prepareCacheSize=" + prepareCacheSize +
@@ -324,6 +335,7 @@ public String toString() {
324335
", zeroDateOption=" + zeroDateOption + ", user='" + user + "', password=" + password +
325336
", database='" + database + "', createDatabaseIfNotExist=" + createDatabaseIfNotExist +
326337
", preferPrepareStatement=" + preferPrepareStatement +
338+
", sessionVariables=" + sessionVariables +
327339
", loadLocalInfilePath=" + loadLocalInfilePath +
328340
", localInfileBufferSize=" + localInfileBufferSize +
329341
", queryCacheSize=" + queryCacheSize +
@@ -393,6 +405,8 @@ public static final class Builder {
393405
@Nullable
394406
private Predicate<String> preferPrepareStatement;
395407

408+
private List<String> sessionVariables = Collections.emptyList();
409+
396410
@Nullable
397411
private Path loadLocalInfilePath;
398412

@@ -440,7 +454,7 @@ public MySqlConnectionConfiguration build() {
440454
sslCa, sslKey, sslKeyPassword, sslCert, sslContextBuilderCustomizer);
441455
return new MySqlConnectionConfiguration(isHost, domain, port, ssl, tcpKeepAlive, tcpNoDelay,
442456
connectTimeout, zeroDateOption, serverZoneId, user, password, database,
443-
createDatabaseIfNotExist, preferPrepareStatement, loadLocalInfilePath,
457+
createDatabaseIfNotExist, preferPrepareStatement, sessionVariables, loadLocalInfilePath,
444458
localInfileBufferSize, queryCacheSize, prepareCacheSize,
445459
compressionAlgorithms, zstdCompressionLevel, loopResources,
446460
Extensions.from(extensions, autodetectExtensions), passwordPublisher);
@@ -801,6 +815,23 @@ public Builder useServerPrepareStatement(Predicate<String> preferPrepareStatemen
801815
return this;
802816
}
803817

818+
/**
819+
* Configure the session variables, used to set session variables immediately after login. Default no
820+
* session variables to set. It should be a list of key-value pairs. e.g.
821+
* {@code ["sql_mode='ANSI_QUOTES,STRICT_TRANS_TABLES'", "time_zone=00:00"]}.
822+
*
823+
* @param sessionVariables the session variables to set.
824+
* @return {@link Builder this}
825+
* @throws IllegalArgumentException if {@code sessionVariables} is {@code null}.
826+
* @since 1.1.2
827+
*/
828+
public Builder sessionVariables(String... sessionVariables) {
829+
requireNonNull(sessionVariables, "sessionVariables must not be null");
830+
831+
this.sessionVariables = InternalArrays.toImmutableList(sessionVariables);
832+
return this;
833+
}
834+
804835
/**
805836
* Configures to allow the {@code LOAD DATA LOCAL INFILE} statement in the given {@code path} or
806837
* disallow the statement. Default to {@code null} which means not allow the statement.
@@ -917,9 +948,9 @@ public Builder compressionAlgorithms(CompressionAlgorithm... compressionAlgorith
917948
* @param level the compression level.
918949
* @return {@link Builder this}.
919950
* @throws IllegalArgumentException if {@code level} is not between 1 and 22.
920-
* @since 1.1.2
921951
* @see <a href="https://dev.mysql.com/doc/refman/8.0/en/connection-options.html">
922952
* MySQL Connection Options --zstd-compression-level</a>
953+
* @since 1.1.2
923954
*/
924955
public Builder zstdCompressionLevel(int level) {
925956
require(level >= 1 && level <= 22, "level must be between 1 and 22");
@@ -929,8 +960,9 @@ public Builder zstdCompressionLevel(int level) {
929960
}
930961

931962
/**
932-
* Configures the {@link LoopResources} for the driver.
933-
* Default to {@link TcpResources#get() global tcp resources}.
963+
* Configures the {@link LoopResources} for the driver. Default to
964+
* {@link TcpResources#get() global tcp resources}.
965+
*
934966
* @param loopResources the {@link LoopResources}.
935967
* @return this {@link Builder}.
936968
* @throws IllegalArgumentException if {@code loopResources} is {@code null}.

src/main/java/io/asyncer/r2dbc/mysql/MySqlConnectionFactory.java

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535

3636
import java.net.InetSocketAddress;
3737
import java.net.SocketAddress;
38+
import java.util.List;
3839
import java.util.Objects;
3940
import java.util.Set;
4041
import java.util.concurrent.locks.ReentrantLock;
@@ -100,6 +101,7 @@ public static MySqlConnectionFactory from(MySqlConnectionConfiguration configura
100101
configuration.getServerZoneId()
101102
);
102103
Set<CompressionAlgorithm> compressionAlgorithms = configuration.getCompressionAlgorithms();
104+
List<String> sessionVariables = configuration.getSessionVariables();
103105
Extensions extensions = configuration.getExtensions();
104106
Predicate<String> prepare = configuration.getPreferPrepareStatement();
105107
int prepareCacheSize = configuration.getPrepareCacheSize();
@@ -112,7 +114,7 @@ public static MySqlConnectionFactory from(MySqlConnectionConfiguration configura
112114
database, createDbIfNotExist,
113115
user, sslMode,
114116
compressionAlgorithms, zstdCompressionLevel,
115-
context, extensions, prepare,
117+
context, extensions, sessionVariables, prepare,
116118
prepareCacheSize, token
117119
));
118120
}
@@ -123,7 +125,7 @@ public static MySqlConnectionFactory from(MySqlConnectionConfiguration configura
123125
database, createDbIfNotExist,
124126
user, sslMode,
125127
compressionAlgorithms, zstdCompressionLevel,
126-
context, extensions, prepare,
128+
context, extensions, sessionVariables, prepare,
127129
prepareCacheSize, password
128130
);
129131
}));
@@ -142,6 +144,7 @@ private static Mono<MySqlConnection> getMySqlConnection(
142144
final int zstdCompressionLevel,
143145
final ConnectionContext context,
144146
final Extensions extensions,
147+
final List<String> sessionVariables,
145148
@Nullable final Predicate<String> prepare,
146149
final int prepareCacheSize,
147150
@Nullable final CharSequence password) {
@@ -163,7 +166,7 @@ private static Mono<MySqlConnection> getMySqlConnection(
163166
registrar.register(allocator, builder));
164167

165168
return MySqlConnection.init(client, builder.build(), context, db, queryCache.get(),
166-
prepareCache, prepare);
169+
prepareCache, sessionVariables, prepare);
167170
});
168171
}
169172

0 commit comments

Comments
 (0)